Compare commits

..

No commits in common. "4012f84886aa023d78facdab3ffe2f8fc43be1ff" and "755cf6467d9b3b2ec36d9aff00df39fe040005ef" have entirely different histories.

4 changed files with 77 additions and 209 deletions

View file

@ -47,30 +47,13 @@ bunbun-worker = { git = "https://git.4o1x5.dev/4o1x5/bunbun-worker", branch = "m
## Usage ## Usage
### Message versioning Here is a basic implementation of an RPC job in bunbun-worker
In this crate message versioning is done by including `v1.0.0` or such on the end of the queue name, instead of including it in the headers of a message. This reduces the amount of redelivered messages.
The following example will send a job to a queue named `emailjob-v1.0.0`.
```rust
let result = client
.rpc_call::<EmailJob>(
EmailJob {
send_to: "someone".into(),
contents: "something".into(),
},
BasicCallOptions::default("emailjob")
.timeout(Duration::from_secs(3))
.message_version("v1.0.0")
)
.await
.unwrap();
```
# Limitations # Limitations
1. Currently some `unwrap()`'s are called inside the code and may results in panics (not in the job-runner). 1. Currently some `unwrap()`'s are called inside the code and may results in panics (not in the job-runner).
2. limited API 2. No settings, and very limited API
3. The rabbitmq RPC logic is very basic with no message-versioning (aside using different queue names (eg service-class-v1.0.0) )
# Bugs department # Bugs department

View file

@ -1,7 +1,7 @@
use std::{ use std::{
fmt::{Debug, Display}, fmt::{Debug, Display},
str::from_utf8, str::from_utf8,
time::{Duration, Instant}, time::Instant,
}; };
use futures::StreamExt; use futures::StreamExt;
@ -10,7 +10,6 @@ use lapin::{
ConnectionProperties, ConnectionProperties,
}; };
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use tokio::time::timeout;
use uuid::Uuid; use uuid::Uuid;
use crate::ResultHeader; use crate::ResultHeader;
@ -48,8 +47,8 @@ impl Client {
pub async fn rpc_call<T: RPCClientTask + Send + Debug>( pub async fn rpc_call<T: RPCClientTask + Send + Debug>(
&self, &self,
data: T, data: T,
options: BasicCallOptions, queue_name: &str,
) -> Result<Result<T::Result, T::ErroredResult>, ClientError> ) -> Result<Result<T::Result, T::ErroredResult>, RpcClientError>
where where
T: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned,
{ {
@ -95,7 +94,7 @@ impl Client {
match channel match channel
.basic_publish( .basic_publish(
"", "",
format!("{}-{}", &options.queue_name, options.message_version).as_str(), queue_name,
BasicPublishOptions::default(), BasicPublishOptions::default(),
serde_json::to_string(&data).unwrap().as_bytes(), serde_json::to_string(&data).unwrap().as_bytes(),
BasicProperties::default() BasicProperties::default()
@ -113,22 +112,22 @@ impl Client {
} }
Ok(confirmation) => { Ok(confirmation) => {
tracing::info!( tracing::info!(
"Sent RPC job of type {} to channel {} Ack: {} Ver: {}", "Sent RPC job of type {} to channel {} Ack: {}",
std::any::type_name::<T>(), std::any::type_name::<T>(),
options.queue_name, queue_name,
confirmation.is_ack(), confirmation.is_ack()
options.message_version
); );
} }
}, },
} }
// TODO implement timeout
tracing::debug!("Awaiting response from callback queue"); tracing::debug!("Awaiting response from callback queue");
let listen = async move { let del = loop {
match consumer.next().await { match consumer.next().await {
None => { None => {
tracing::error!("Received empty data after {:?}", now.elapsed()); tracing::error!("Received empty data after {:?}", now.elapsed());
return Err(ClientError::InvalidResponse); return Err(RpcClientError::NoReply);
} }
Some(del) => match del { Some(del) => match del {
Err(error) => { Err(error) => {
@ -138,30 +137,16 @@ impl Client {
now.elapsed() now.elapsed()
); );
// Idk if i should nack it? // Idk if i should nack it?
return Err(ClientError::FailedDecode); return Err(RpcClientError::FailedDecode);
} }
Ok(del) => { Ok(del) => {
tracing::debug!("Received response after {:?}", now.elapsed()); tracing::debug!("Received response after {:?}", now.elapsed());
return Ok(del); break del;
} }
}, },
}; };
}; };
let del = match options.timeout {
None => listen.await?,
Some(dur) => match timeout(dur, listen).await {
Err(elapsed) => {
tracing::warn!("RPC job has reached timeout after: {}", elapsed);
return Err(ClientError::TimeoutReached);
}
Ok(r) => match r {
Err(error) => return Err(error),
Ok(r) => r,
},
},
};
// TODO better implementation of this // TODO better implementation of this
tracing::debug!("Decoding headers"); tracing::debug!("Decoding headers");
let result_type = match del.properties.headers().to_owned() { let result_type = match del.properties.headers().to_owned() {
@ -169,17 +154,17 @@ impl Client {
tracing::error!( tracing::error!(
"Got a response with no headers, this might be an issue with version mismatch" "Got a response with no headers, this might be an issue with version mismatch"
); );
return Err(ClientError::InvalidResponse); return Err(RpcClientError::InvalidResponse);
} }
Some(headers) => match headers.inner().get("outcome") { Some(headers) => match headers.inner().get("outcome") {
None => { None => {
tracing::error!("Got a response with no outcome header"); tracing::error!("Got a response with no outcome header");
return Err(ClientError::InvalidResponse); return Err(RpcClientError::InvalidResponse);
} }
Some(res) => match res.as_long_string() { Some(res) => match res.as_long_string() {
None => { None => {
tracing::error!("Got a response with no headers"); tracing::error!("Got a response with no headers");
return Err(ClientError::InvalidResponse); return Err(RpcClientError::InvalidResponse);
} }
Some(outcome) => { Some(outcome) => {
match serde_json::from_str::<ResultHeader>(outcome.to_string().as_str()) { match serde_json::from_str::<ResultHeader>(outcome.to_string().as_str()) {
@ -189,53 +174,58 @@ impl Client {
} }
Err(_) => { Err(_) => {
tracing::warn!("Received a result header but it's not a type that can be deserailized "); tracing::warn!("Received a result header but it's not a type that can be deserailized ");
return Err(ClientError::InvalidResponse); return Err(RpcClientError::InvalidResponse);
} }
} }
} }
}, },
}, },
}; };
tracing::debug!("Result type is: {result_type}, decoding..."); tracing::debug!("Result type is: {result_type}, decoding...");
let utf8 = match from_utf8(&del.data) { let utf8 = match from_utf8(&del.data) {
Ok(r) => r, Ok(r) => r,
Err(error) => { Err(error) => {
tracing::error!("Failed to decode response message to utf8 {error}"); tracing::error!("Failed to decode response message to utf8 {error}");
return Err(ClientError::FailedDecode); return Err(RpcClientError::FailedDecode);
} }
}; };
let _ = channel.close(0, "byebye").await; let _ = channel.close(0, "byebye").await;
// acking for idk reason
// ack message
let _ = del.ack(BasicAckOptions::default()).await; let _ = del.ack(BasicAckOptions::default()).await;
match result_type { match result_type {
ResultHeader::Error => match serde_json::from_str::<T::ErroredResult>(utf8) { ResultHeader::Error => match serde_json::from_str::<T::ErroredResult>(utf8) {
// get result header // get result header
Err(_) => { Err(_) => {
tracing::error!("Failed to decode response message to E"); tracing::error!("Failed to decode response message to E");
return Err(ClientError::FailedDecode); return Err(RpcClientError::FailedDecode);
} }
Ok(res) => return Ok(Err(res)), Ok(res) => return Ok(Err(res)),
}, },
ResultHeader::Panic => return Err(ClientError::ServerPanicked), ResultHeader::Panic => return Err(RpcClientError::ServerPanicked),
ResultHeader::Ok => ResultHeader::Ok =>
// get result // get result
{ {
match serde_json::from_str::<T::Result>(utf8) { match serde_json::from_str::<T::Result>(utf8) {
Err(_) => { Err(_) => {
tracing::error!("Failed to decode response message to R"); tracing::error!("Failed to decode response message to R");
return Err(ClientError::FailedDecode); return Err(RpcClientError::FailedDecode);
} }
Ok(res) => return Ok(Ok(res)), Ok(res) => return Ok(Ok(res)),
} }
} }
} }
// ack message
} }
/// Sends a message to the queue
/// Sends a basic Task to the queue ///
pub async fn call<T>(&self, data: T, options: BasicCallOptions) -> Result<(), ClientError> /// # Examples
///
/// ```
/// use bunbun_worker::client::Client;
/// let client = Client::new("amqp://127.0.0.1:5672");
/// let result = client.call(EmailJob::new("someone@example.com", "Hello there"), "email-emailjob-v1.0.0");
/// ```
pub async fn call<T>(&self, data: T, queue_name: &str) -> Result<(), ClientError>
where where
T: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned,
{ {
@ -243,7 +233,7 @@ impl Client {
match channel match channel
.basic_publish( .basic_publish(
"", "",
format!("{}-{}", &options.queue_name, options.message_version).as_str(), queue_name,
BasicPublishOptions::default(), BasicPublishOptions::default(),
serde_json::to_string(&data).unwrap().as_bytes(), serde_json::to_string(&data).unwrap().as_bytes(),
BasicProperties::default(), BasicProperties::default(),
@ -264,11 +254,10 @@ impl Client {
Ok(confirmation) => { Ok(confirmation) => {
let _ = channel.close(0, "byebye").await; let _ = channel.close(0, "byebye").await;
tracing::info!( tracing::info!(
"Sent nonRPC job of type {} to channel {} Ack: {} Ver: {}", "Sent nonRPC job of type {} to channel {} Ack: {}",
std::any::type_name::<T>(), std::any::type_name::<T>(),
options.queue_name, queue_name,
confirmation.is_ack(), confirmation.is_ack()
options.message_version
); );
tracing::debug!( tracing::debug!(
"AMQP confirmed dispatch of job | Acknowledged? {}", "AMQP confirmed dispatch of job | Acknowledged? {}",
@ -280,39 +269,28 @@ impl Client {
Ok(()) Ok(())
} }
} }
/// A call option class that is used to control how calls are handled
/// You can define the timeout, and the message versions
pub struct BasicCallOptions {
timeout: Option<Duration>,
queue_name: String,
message_version: String,
}
impl BasicCallOptions {
pub fn default(queue_name: impl Into<String>) -> Self {
Self {
timeout: None,
queue_name: queue_name.into(),
message_version: "v1.0.0".into(),
}
}
pub fn message_version(mut self, message_version: impl Into<String>) -> Self {
self.message_version = message_version.into();
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
}
/// An error that the client returns /// An error that the client returns
#[derive(Debug)] #[derive(Debug)]
pub enum ClientError { pub enum RpcClientError {
NoReply, // TODO timeout
FailedDecode, FailedDecode,
FailedToSend, FailedToSend,
InvalidResponse, InvalidResponse,
ServerPanicked, ServerPanicked,
TimeoutReached, }
/// An error for normal calls
#[derive(Debug)]
pub enum ClientError {
FailedToSend,
}
impl Display for ClientError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::FailedToSend => write!(f, "Failed to send to queue"),
}
}
} }
/// A Client-side trait that needs to be implemented for a type in order for the client to know return types. /// A Client-side trait that needs to be implemented for a type in order for the client to know return types.

View file

@ -90,7 +90,7 @@ impl WorkerConfig {
/// ///
/// # Arguments /// # Arguments
/// * `custom_tls` - Optional TLSconfig (if none defaults to lapins choice) /// * `custom_tls` - Optional TLSconfig (if none defaults to lapins choice)
pub fn enable_tls(mut self, custom_tls: Option<TlsConfig>) -> Self { pub fn enable_tls(&mut self, custom_tls: Option<TlsConfig>) {
match custom_tls { match custom_tls {
Some(tls) => { Some(tls) => {
let tls = OwnedTLSConfig { let tls = OwnedTLSConfig {
@ -104,45 +104,6 @@ impl WorkerConfig {
} }
None => self.tls = OwnedTLSConfig::default().into(), None => self.tls = OwnedTLSConfig::default().into(),
} }
self
}
}
/// A worker configuration
pub struct ListenerConfig {
prefetch_count: u16,
queue_name: String,
consumer_tag: String,
message_version: String,
}
impl ListenerConfig {
/// Create a new listener config
/// # Arguments
/// * `queue_name` - The name of the queue to listen to (e.g. service-serviceJobName-v1.0.0)
pub fn default(queue_name: impl Into<String>) -> Self {
Self {
prefetch_count: 0,
queue_name: queue_name.into(),
consumer_tag: "".into(),
message_version: "v1.0.0".into(),
}
}
/// Set the prefetch count for the listener
/// This serves as a maximum job count that can be processed at a time. (0 is unlimited)
pub fn set_prefetch_count(mut self, prefetch_count: u16) -> Self {
self.prefetch_count = prefetch_count;
self
}
/// Set the consumer tag for the listener
pub fn set_consumer_tag(mut self, consumer_tag: impl Into<String>) -> Self {
self.consumer_tag = consumer_tag.into();
self
}
pub fn set_message_version(mut self, version: impl Into<String>) -> Self {
self.message_version = version.into();
self
} }
} }
@ -193,24 +154,20 @@ impl Worker {
/// Add a non-rpc listener to the worker object /// Add a non-rpc listener to the worker object
/// ///
/// # Arguments /// # Arguments
/// * `queue_name` - A string slice that holds the name of the queue to listen to (e.g. service-serviceJobName-v1.0.0)
/// * `state` - An Arc of the state object that will be passed to the listener /// * `state` - An Arc of the state object that will be passed to the listener
/// * `listener_config` - An Arc of the state object that will be passed to the listener
pub async fn add_non_rpc_consumer<J: Task + 'static + Send>( pub async fn add_non_rpc_consumer<J: Task + 'static + Send>(
&mut self, &mut self,
queue_name: &str,
state: Arc<J::State>, state: Arc<J::State>,
listener_config: ListenerConfig,
) where ) where
<J as Task>::State: std::marker::Send + Sync, <J as Task>::State: std::marker::Send + Sync,
{ {
let consumer = self let consumer = self
.channel .channel
.basic_consume( .basic_consume(
format!( queue_name,
"{}-{}", "",
listener_config.queue_name, listener_config.message_version
)
.as_str(),
&listener_config.consumer_tag,
BasicConsumeOptions::default(), BasicConsumeOptions::default(),
FieldTable::default(), FieldTable::default(),
) )
@ -274,33 +231,28 @@ impl Worker {
/// ///
/// ``` /// ```
/// let server = BunBunWorker::new("amqp://localhost:5672", None).await; /// let server = BunBunWorker::new("amqp://localhost:5672", None).await;
/// server.add_rpc_consumer::<MyRPCTask>(ListenerConfig::default("service-jobname-v1.0.0") )).await; /// server.add_rpc_consumer::<MyRPCTask>("service-serviceJobName-v1.0.0", SomeState{} )).await;
/// server.start_all_listeners().await; /// server.start_all_listeners().await;
/// ``` /// ```
pub async fn add_rpc_consumer<J: RPCTask + 'static + Send>( pub async fn add_rpc_consumer<J: RPCTask + 'static + Send>(
&mut self, &mut self,
queue_name: &str,
state: Arc<J::State>, state: Arc<J::State>,
listener_config: ListenerConfig,
) where ) where
<J as RPCTask>::State: std::marker::Send + Sync, <J as RPCTask>::State: std::marker::Send + Sync,
<J as RPCTask>::Result: std::marker::Send + Sync, <J as RPCTask>::Result: std::marker::Send + Sync,
<J as RPCTask>::ErroredResult: std::marker::Send + Sync, <J as RPCTask>::ErroredResult: std::marker::Send + Sync,
{ {
let consumer = create_consumer( let consumer = self
self.channel.clone(), .channel
format!( .basic_consume(
"{}-{}", queue_name,
listener_config.queue_name, listener_config.message_version "",
) BasicConsumeOptions::default(),
.as_str(), FieldTable::default(),
&listener_config.consumer_tag,
listener_config.prefetch_count,
) )
.await .await
.map_err(|e| { .expect("basic_consume error");
tracing::error!("Failed to create consumer: {}", e);
})
.expect("Failed to create consumer");
let channel = self.channel.clone(); let channel = self.channel.clone();
let handler: Arc< let handler: Arc<
@ -681,24 +633,3 @@ impl Display for ResultHeader {
} }
} }
} }
async fn create_consumer(
channel: Channel,
queue_name: &str,
consumer_tag: &str,
prefect_count: u16,
) -> Result<Consumer, lapin::Error> {
let channel = channel.clone();
channel
.basic_qos(prefect_count, BasicQosOptions::default())
.await?;
channel
.basic_consume(
queue_name,
consumer_tag,
BasicConsumeOptions::default(),
FieldTable::default(),
)
.await
}

View file

@ -15,8 +15,8 @@ mod test {
use tracing_test::traced_test; use tracing_test::traced_test;
use crate::{ use crate::{
client::{BasicCallOptions, Client, RPCClientTask}, client::{Client, RPCClientTask},
ListenerConfig, RPCTask, Worker, WorkerConfig, RPCTask, Worker, WorkerConfig,
}; };
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -108,10 +108,10 @@ mod test {
.await; .await;
listener listener
.add_rpc_consumer::<EmailJob>( .add_rpc_consumer::<EmailJob>(
"email-emailjob-v1.0.0",
Arc::new(State { Arc::new(State {
something: "test".into(), something: "test".into(),
}), }),
ListenerConfig::default("emailjob").set_prefetch_count(100),
) )
.await; .await;
tracing::debug!("Starting listener"); tracing::debug!("Starting listener");
@ -129,10 +129,10 @@ mod test {
.await; .await;
listener listener
.add_rpc_consumer::<PanickingEmailJob>( .add_rpc_consumer::<PanickingEmailJob>(
"email-emailjob-v1.0.0",
Arc::new(State { Arc::new(State {
something: "test".into(), something: "test".into(),
}), }),
ListenerConfig::default("emailjob").set_prefetch_count(100),
) )
.await; .await;
tracing::debug!("Starting listener"); tracing::debug!("Starting listener");
@ -143,7 +143,7 @@ mod test {
#[traced_test] #[traced_test]
async fn rpc_client() { async fn rpc_client() {
// //
let client = Client::new(env::var("AMQP_SERVER_URL").unwrap().as_str()) let mut client = Client::new(env::var("AMQP_SERVER_URL").unwrap().as_str())
.await .await
.unwrap(); .unwrap();
let result = client let result = client
@ -152,31 +152,7 @@ mod test {
send_to: "someone".into(), send_to: "someone".into(),
contents: "something".into(), contents: "something".into(),
}, },
BasicCallOptions::default("emailjob"), "email-emailjob-v1.0.0",
)
.await
.unwrap();
assert_eq!(
result,
Ok(EmailJobResult {
contents: "something".to_string()
})
)
}
#[test(tokio::test)]
#[traced_test]
async fn rpc_client_timeout() {
//
let client = Client::new(env::var("AMQP_SERVER_URL").unwrap().as_str())
.await
.unwrap();
let result = client
.rpc_call::<EmailJob>(
EmailJob {
send_to: "someone".into(),
contents: "something".into(),
},
BasicCallOptions::default("emailjob").timeout(Duration::from_secs(3)),
) )
.await .await
.unwrap(); .unwrap();
@ -208,7 +184,7 @@ mod test {
send_to: "someone".into(), send_to: "someone".into(),
contents: "something".into(), contents: "something".into(),
}, },
BasicCallOptions::default("emailjob"), "email-emailjob-v1.0.0",
) )
.await .await
})); }));