diff --git a/src/client.rs b/src/client.rs index a2134eb..5ebc02a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -43,7 +43,7 @@ impl Client { /// /// Arguments /// * `data` The job that will be sent to the queue, must implement Deserialize and Serialize - /// * `queue_name` The name of the queue to be sent to + /// * `options` BasicCallOptions, used to control the timeout and message version // TODO if the queue is nonexistent return error pub async fn rpc_call( &self, @@ -235,6 +235,10 @@ impl Client { } /// Sends a basic Task to the queue + /// + /// Arguments + /// * `data` The job that will be sent to the queue, must implement Deserialize and Serialize + /// * `options` BasicCallOptions, used to control the timeout and message version pub async fn call(&self, data: T, options: BasicCallOptions) -> Result<(), ClientError> where T: Serialize + DeserializeOwned, diff --git a/src/lib.rs b/src/lib.rs index 9caa8b0..8a96f33 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ +// TODO clean up code... use futures::{ future::{join_all, BoxFuture}, - StreamExt, + StreamExt, TryStreamExt, }; use lapin::{ options::{ @@ -29,7 +30,7 @@ mod test; pub struct Worker { channel: Channel, /// A consumer for each rpc handler - rpc_consumers: Vec, + rpc_consumers: Vec, rpc_handlers: Vec< Arc< dyn Fn( @@ -40,7 +41,7 @@ pub struct Worker { >, >, /// A consumer for each non-rpc handler - consumers: Vec, + consumers: Vec, handlers: Vec< Arc< dyn Fn( @@ -59,6 +60,7 @@ pub struct TlsConfig { client_cert_and_key: String, client_cert_and_key_password: String, } + impl TlsConfig { /// Create a custom TLS config pub fn new( @@ -76,7 +78,6 @@ impl TlsConfig { #[derive(Debug)] /// A worker configuration -/// Enable tls here pub struct WorkerConfig { tls: Option, } @@ -186,8 +187,6 @@ impl Worker { .unwrap(), }; channel - - // TODO set qos for channel } /// Add a non-rpc listener to the worker object @@ -195,28 +194,19 @@ impl Worker { /// # Arguments /// * `state` - An Arc of the state object that will be passed to the listener /// * `listener_config` - An Arc of the state object that will be passed to the listener - pub async fn add_non_rpc_consumer( + /// ``` + /// use bunbun_worker::{Worker, ListenerConfig, WorkerConfig}; + /// let server = Worker::new("amqp://localhost:5672", Workerconfig::default()).await; + /// server.add_non_rpc_consumer::(ListenerConfig::default("service-jobname").set_message_version("v2.0.0") )); + /// server.start_all_listeners().await; + /// ``` + pub fn add_non_rpc_consumer( &mut self, state: Arc, listener_config: ListenerConfig, ) where ::State: std::marker::Send + Sync, { - let consumer = self - .channel - .basic_consume( - format!( - "{}-{}", - listener_config.queue_name, listener_config.message_version - ) - .as_str(), - &listener_config.consumer_tag, - BasicConsumeOptions::default(), - FieldTable::default(), - ) - .await - .expect("basic_consume error"); - let handler: Arc< dyn Fn( lapin::message::Delivery, @@ -261,7 +251,7 @@ impl Worker { }); self.handlers.push(handler); - self.consumers.push(consumer); + self.consumers.push(listener_config); } /// Add an rpc job listener to the worker object /// Make sure the type you pass in implements RPCTask @@ -273,8 +263,9 @@ impl Worker { /// # Examples /// /// ``` - /// let server = BunBunWorker::new("amqp://localhost:5672", None).await; - /// server.add_rpc_consumer::(ListenerConfig::default("service-jobname-v1.0.0") )).await; + /// use bunbun_worker::{Worker, ListenerConfig, WorkerConfig}; + /// let server = Worker::new("amqp://localhost:5672", Workerconfig::default()).await; + /// server.add_rpc_consumer::(ListenerConfig::default("service-jobname").set_message_version("v2.0.0") )); /// server.start_all_listeners().await; /// ``` pub async fn add_rpc_consumer( @@ -286,22 +277,6 @@ impl Worker { ::Result: std::marker::Send + Sync, ::ErroredResult: std::marker::Send + Sync, { - let consumer = create_consumer( - self.channel.clone(), - format!( - "{}-{}", - listener_config.queue_name, listener_config.message_version - ) - .as_str(), - &listener_config.consumer_tag, - listener_config.prefetch_count, - ) - .await - .map_err(|e| { - tracing::error!("Failed to create consumer: {}", e); - }) - .expect("Failed to create consumer"); - let channel = self.channel.clone(); let handler: Arc< dyn Fn( @@ -388,21 +363,53 @@ impl Worker { }); self.rpc_handlers.push(handler); - self.rpc_consumers.push(consumer); + self.rpc_consumers.push(listener_config); } /// Start all the listeners added to the worker object // TODO implement reconnect - pub async fn start_all_listeners(&self) { + // TODO better error handling + pub async fn start_all_listeners(&self) -> Result<(), String> { let mut listeners = vec![]; - for (handler, consumer) in self.handlers.iter().zip(self.consumers.iter()) { - let consumer = consumer.clone(); + + // Start all the non-rpc listeners + for (handler, consumer_config) in self.handlers.iter().zip(self.consumers.iter()) { + // Clone channel + let mut channel = self.channel.clone(); + + // Set prefetch count + set_consumer_qos(&mut channel, consumer_config.prefetch_count) + .await + .map_err(|e| { + tracing::error!("Failed to set qos: {}", e); + "Failed to set qos".to_string() + })?; + + // Create a consumer with modified channel + let consumer = channel + .basic_consume( + format!( + "{}-{}", + consumer_config.queue_name, consumer_config.message_version + ) + .as_str(), + &consumer_config.consumer_tag, + BasicConsumeOptions::default(), + FieldTable::default(), + ) + .await + .map_err(|e| { + tracing::error!("Failed to start consumer: {}", e); + "Failed to start consumer".to_string() + })?; + let handler = Arc::clone(handler); tracing::info!( "Started listening for incoming messages on queue: {} | Non-rpc", consumer.queue().as_str() ); + listeners.push(tokio::spawn(async move { consumer .for_each_concurrent(None, move |delivery| { @@ -422,8 +429,33 @@ impl Worker { })); } - for (handler, consumer) in self.rpc_handlers.iter().zip(self.rpc_consumers.iter()) { - let consumer = consumer.clone(); + for (handler, consumer_config) in self.rpc_handlers.iter().zip(self.rpc_consumers.iter()) { + let mut channel = self.channel.clone(); + // Set prefetch count + set_consumer_qos(&mut channel, consumer_config.prefetch_count) + .await + .map_err(|e| { + tracing::error!("Failed to set qos: {}", e); + "Failed to set qos".to_string() + })?; + + // Create a consumer with modified channel + let consumer = channel + .basic_consume( + format!( + "{}-{}", + consumer_config.queue_name, consumer_config.message_version + ) + .as_str(), + &consumer_config.consumer_tag, + BasicConsumeOptions::default(), + FieldTable::default(), + ) + .await + .map_err(|e| { + tracing::error!("Failed to start consumer: {}", e); + "Failed to start consumer".to_string() + })?; let handler = Arc::clone(handler); tracing::debug!( @@ -448,7 +480,9 @@ impl Worker { .await; })); } + join_all(listeners).await; + Ok(()) } } @@ -570,6 +604,7 @@ pub trait Task: Sized + Debug + DeserializeOwned { Ok(job) } + /// The method that will be run by the worker fn run(self, state: Arc) -> BoxFuture<'static, Result<(), ()>>; /// A function to display the task @@ -702,3 +737,9 @@ async fn create_consumer( ) .await } + +async fn set_consumer_qos(channel: &mut Channel, prefetch_count: u16) -> Result<(), lapin::Error> { + channel + .basic_qos(prefetch_count, BasicQosOptions::default()) + .await +} diff --git a/src/test/mod.rs b/src/test/mod.rs index 35f367f..3f01bec 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -3,7 +3,7 @@ mod test { use std::{env, os::unix::thread, sync::Arc, time::Duration}; - use futures::future::join_all; + use futures::{future::join_all, FutureExt}; use serde::{Deserialize, Serialize}; use test_log::test; use tokio::{ @@ -52,13 +52,14 @@ mod test { state: Arc, ) -> futures::prelude::future::BoxFuture<'static, Result> { - Box::pin(async move { + async move { tracing::info!("Sent email to {}", self.send_to); tokio::time::sleep(Duration::from_secs(2)).await; return Ok(EmailJobResult { contents: self.contents.clone(), }); - }) + } + .boxed() } async fn before_job(self) -> Self where