diff --git a/src/lib.rs b/src/lib.rs index 766cd4c..603d71d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,7 @@ -use futures::{future::BoxFuture, FutureExt, StreamExt}; +use futures::{ + future::{join_all, BoxFuture}, + FutureExt, StreamExt, +}; use lapin::{ options::{ BasicAckOptions, BasicConsumeOptions, BasicNackOptions, BasicPublishOptions, @@ -21,6 +24,16 @@ mod test; pub struct BunBunWorker { channel: Channel, + rpc_consumers: Vec, + rpc_handlers: Vec< + Arc< + dyn Fn( + lapin::message::Delivery, + ) -> Pin + Send>> + + Send + + Sync, + >, + >, consumers: Vec, handlers: Vec< Arc< @@ -40,6 +53,9 @@ impl BunBunWorker { channel, handlers: Vec::new(), consumers: Vec::new(), + + rpc_handlers: Vec::new(), + rpc_consumers: Vec::new(), } } @@ -59,7 +75,53 @@ impl BunBunWorker { None => conn.create_channel().await.expect("create channel error"), } } + pub async fn add_non_rpc_consumer( + &mut self, + queue_name: &str, + consumer_tag: &str, + state: Arc>, + ) where + ::State: std::marker::Send, + { + let consumer = self + .channel + .basic_consume( + queue_name, + consumer_tag, + BasicConsumeOptions::default(), + FieldTable::default(), + ) + .await + .expect("basic_consume error"); + let handler: Arc< + dyn Fn( + lapin::message::Delivery, + ) -> Pin + Send>> + + Send + + Sync, + > = Arc::new(move |delivery: lapin::message::Delivery| { + let state = Arc::clone(&state); + Box::pin(async move { + if let Ok(job) = J::decode(delivery.data.clone()) { + // Running job + match tokio::task::spawn(async move { job.run(state).await }).await { + Err(error) => { + tracing::error!("Failed to run non-rpc job: {}", error) + } + Ok(_) => { + tracing::info!("Non-rpc job has finished.") + } + }; + } else { + delivery.nack(BasicNackOptions::default()).await.unwrap(); + } + }) + }); + + self.handlers.push(handler); + self.consumers.push(consumer); + } pub async fn add_rpc_consumer( &mut self, queue_name: &str, @@ -163,11 +225,12 @@ impl BunBunWorker { }) }); - self.handlers.push(handler); - self.consumers.push(consumer); + self.rpc_handlers.push(handler); + self.rpc_consumers.push(consumer); } pub async fn start_all_listeners(&self) { + let mut listeners = vec![]; for (handler, consumer) in self.handlers.iter().zip(self.consumers.iter()) { let consumer = consumer.clone(); let handler = Arc::clone(handler); @@ -176,7 +239,7 @@ impl BunBunWorker { "Listening for incoming messages for queue: {}", consumer.queue().as_str() ); - tokio::spawn(async move { + listeners.push(tokio::spawn(async move { consumer .for_each_concurrent(None, move |delivery| { let handler = Arc::clone(&handler); @@ -187,11 +250,31 @@ impl BunBunWorker { } }) .await; - }); + })); } - signal::ctrl_c().await.expect("failed to listen for event"); - // TODO hand the program + for (handler, consumer) in self.rpc_handlers.iter().zip(self.rpc_consumers.iter()) { + let consumer = consumer.clone(); + let handler = Arc::clone(handler); + + tracing::debug!( + "Listening for incoming messages for queue: {}", + consumer.queue().as_str() + ); + listeners.push(tokio::spawn(async move { + consumer + .for_each_concurrent(None, move |delivery| { + let handler = Arc::clone(&handler); + // TODO handle unwrap of delivery + let delivery = delivery.unwrap(); + async move { + handler(delivery).await; + } + }) + .await; + })); + } + join_all(listeners).await; } } @@ -223,6 +306,29 @@ pub trait RPCServerTask: Sized + Debug + DeserializeOwned { format!("{:?}", self) } } +pub trait NonRPCServerTask: Sized + Debug + DeserializeOwned { + type State: Clone + Debug; + + fn decode(data: Vec) -> Result { + let job = match from_utf8(&data) { + Err(_) => { + return Err(RabbitDecodeError::NotUtf8); + } + Ok(data) => match serde_json::from_str::(data) { + Err(_) => return Err(RabbitDecodeError::NotJson), + Ok(data) => data, + }, + }; + Ok(job) + } + + fn run(self, state: Arc>) -> BoxFuture<'static, Result<(), ()>>; + + /// A function to display the task + fn display(&self) -> String { + format!("{:?}", self) + } +} #[derive(Debug)] pub enum RabbitDecodeError {