From ee1bb1abdfbe7370e2a13cdddae779ff4db77e5d Mon Sep 17 00:00:00 2001 From: 4o1x5 <4o1x5@4o1x5.dev> Date: Thu, 14 Nov 2024 19:28:10 +0100 Subject: [PATCH] client: added timeout - also created a new basiccalloptions that includes message versioning and such server: now message versioning works. --- src/client.rs | 85 ++++++++++++++++++++++++++++------------ src/lib.rs | 101 ++++++++++++++++++++++++++++++++++++++++-------- src/test/mod.rs | 38 ++++++++++++++---- 3 files changed, 176 insertions(+), 48 deletions(-) diff --git a/src/client.rs b/src/client.rs index ff7d722..8e52d8d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,7 +1,7 @@ use std::{ fmt::{Debug, Display}, str::from_utf8, - time::Instant, + time::{Duration, Instant}, }; use futures::StreamExt; @@ -10,6 +10,7 @@ use lapin::{ ConnectionProperties, }; use serde::{de::DeserializeOwned, Serialize}; +use tokio::time::timeout; use uuid::Uuid; use crate::ResultHeader; @@ -47,7 +48,7 @@ impl Client { pub async fn rpc_call( &self, data: T, - queue_name: &str, + options: BasicCallOptions, ) -> Result, RpcClientError> where T: Serialize + DeserializeOwned, @@ -94,7 +95,7 @@ impl Client { match channel .basic_publish( "", - queue_name, + format!("{}-{}", &options.queue_name, options.message_version).as_str(), BasicPublishOptions::default(), serde_json::to_string(&data).unwrap().as_bytes(), BasicProperties::default() @@ -112,18 +113,18 @@ impl Client { } Ok(confirmation) => { tracing::info!( - "Sent RPC job of type {} to channel {} Ack: {}", + "Sent RPC job of type {} to channel {} Ack: {} Ver: {}", std::any::type_name::(), - queue_name, - confirmation.is_ack() + options.queue_name, + confirmation.is_ack(), + options.message_version ); } }, } - // TODO implement timeout tracing::debug!("Awaiting response from callback queue"); - let del = loop { + let listen = async move { match consumer.next().await { None => { tracing::error!("Received empty data after {:?}", now.elapsed()); @@ -141,12 +142,26 @@ impl Client { } Ok(del) => { tracing::debug!("Received response after {:?}", now.elapsed()); - break del; + return Ok(del); } }, }; }; + let del = match options.timeout { + None => listen.await?, + Some(dur) => match timeout(dur, listen).await { + Err(elapsed) => { + tracing::warn!("RPC job has reached timeout after: {}", elapsed); + return Err(RpcClientError::TimeoutReached); + } + Ok(r) => match r { + Err(error) => return Err(error), + Ok(r) => r, + }, + }, + }; + // TODO better implementation of this tracing::debug!("Decoding headers"); let result_type = match del.properties.headers().to_owned() { @@ -181,6 +196,7 @@ impl Client { }, }, }; + tracing::debug!("Result type is: {result_type}, decoding..."); let utf8 = match from_utf8(&del.data) { Ok(r) => r, @@ -190,7 +206,7 @@ impl Client { } }; let _ = channel.close(0, "byebye").await; - // acking for idk reason + let _ = del.ack(BasicAckOptions::default()).await; match result_type { ResultHeader::Error => match serde_json::from_str::(utf8) { @@ -216,16 +232,9 @@ impl Client { } // ack message } - /// Sends a message to the queue - /// - /// # Examples - /// - /// ``` - /// use bunbun_worker::client::Client; - /// let client = Client::new("amqp://127.0.0.1:5672"); - /// let result = client.call(EmailJob::new("someone@example.com", "Hello there"), "email-emailjob-v1.0.0"); - /// ``` - pub async fn call(&self, data: T, queue_name: &str) -> Result<(), ClientError> + + /// Sends a basic Task to the queue + pub async fn call(&self, data: T, options: BasicCallOptions) -> Result<(), ClientError> where T: Serialize + DeserializeOwned, { @@ -233,7 +242,7 @@ impl Client { match channel .basic_publish( "", - queue_name, + format!("{}-{}", &options.queue_name, options.message_version).as_str(), BasicPublishOptions::default(), serde_json::to_string(&data).unwrap().as_bytes(), BasicProperties::default(), @@ -254,10 +263,11 @@ impl Client { Ok(confirmation) => { let _ = channel.close(0, "byebye").await; tracing::info!( - "Sent nonRPC job of type {} to channel {} Ack: {}", + "Sent nonRPC job of type {} to channel {} Ack: {} Ver: {}", std::any::type_name::(), - queue_name, - confirmation.is_ack() + options.queue_name, + confirmation.is_ack(), + options.message_version ); tracing::debug!( "AMQP confirmed dispatch of job | Acknowledged? {}", @@ -269,15 +279,40 @@ impl Client { Ok(()) } } +/// A call option class that is used to control how calls are handled +/// You can define the timeout, and the message versions +pub struct BasicCallOptions { + timeout: Option, + queue_name: String, + message_version: String, +} +impl BasicCallOptions { + pub fn default(queue_name: impl Into) -> Self { + Self { + timeout: None, + queue_name: queue_name.into(), + message_version: "v1.0.0".into(), + } + } + pub fn message_version(mut self, message_version: impl Into) -> Self { + self.message_version = message_version.into(); + self + } + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + self + } +} /// An error that the client returns #[derive(Debug)] pub enum RpcClientError { - NoReply, // TODO timeout + NoReply, FailedDecode, FailedToSend, InvalidResponse, ServerPanicked, + TimeoutReached, } /// An error for normal calls #[derive(Debug)] diff --git a/src/lib.rs b/src/lib.rs index 11c227a..9caa8b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -90,7 +90,7 @@ impl WorkerConfig { /// /// # Arguments /// * `custom_tls` - Optional TLSconfig (if none defaults to lapins choice) - pub fn enable_tls(&mut self, custom_tls: Option) { + pub fn enable_tls(mut self, custom_tls: Option) -> Self { match custom_tls { Some(tls) => { let tls = OwnedTLSConfig { @@ -104,6 +104,45 @@ impl WorkerConfig { } None => self.tls = OwnedTLSConfig::default().into(), } + self + } +} + +/// A worker configuration + +pub struct ListenerConfig { + prefetch_count: u16, + queue_name: String, + consumer_tag: String, + message_version: String, +} + +impl ListenerConfig { + /// Create a new listener config + /// # Arguments + /// * `queue_name` - The name of the queue to listen to (e.g. service-serviceJobName-v1.0.0) + pub fn default(queue_name: impl Into) -> Self { + Self { + prefetch_count: 0, + queue_name: queue_name.into(), + consumer_tag: "".into(), + message_version: "v1.0.0".into(), + } + } + /// Set the prefetch count for the listener + /// This serves as a maximum job count that can be processed at a time. (0 is unlimited) + pub fn set_prefetch_count(mut self, prefetch_count: u16) -> Self { + self.prefetch_count = prefetch_count; + self + } + /// Set the consumer tag for the listener + pub fn set_consumer_tag(mut self, consumer_tag: impl Into) -> Self { + self.consumer_tag = consumer_tag.into(); + self + } + pub fn set_message_version(mut self, version: impl Into) -> Self { + self.message_version = version.into(); + self } } @@ -154,20 +193,24 @@ impl Worker { /// Add a non-rpc listener to the worker object /// /// # Arguments - /// * `queue_name` - A string slice that holds the name of the queue to listen to (e.g. service-serviceJobName-v1.0.0) /// * `state` - An Arc of the state object that will be passed to the listener + /// * `listener_config` - An Arc of the state object that will be passed to the listener pub async fn add_non_rpc_consumer( &mut self, - queue_name: &str, state: Arc, + listener_config: ListenerConfig, ) where ::State: std::marker::Send + Sync, { let consumer = self .channel .basic_consume( - queue_name, - "", + format!( + "{}-{}", + listener_config.queue_name, listener_config.message_version + ) + .as_str(), + &listener_config.consumer_tag, BasicConsumeOptions::default(), FieldTable::default(), ) @@ -231,28 +274,33 @@ impl Worker { /// /// ``` /// let server = BunBunWorker::new("amqp://localhost:5672", None).await; - /// server.add_rpc_consumer::("service-serviceJobName-v1.0.0", SomeState{} )).await; + /// server.add_rpc_consumer::(ListenerConfig::default("service-jobname-v1.0.0") )).await; /// server.start_all_listeners().await; /// ``` pub async fn add_rpc_consumer( &mut self, - queue_name: &str, state: Arc, + listener_config: ListenerConfig, ) where ::State: std::marker::Send + Sync, ::Result: std::marker::Send + Sync, ::ErroredResult: std::marker::Send + Sync, { - let consumer = self - .channel - .basic_consume( - queue_name, - "", - BasicConsumeOptions::default(), - FieldTable::default(), + let consumer = create_consumer( + self.channel.clone(), + format!( + "{}-{}", + listener_config.queue_name, listener_config.message_version ) - .await - .expect("basic_consume error"); + .as_str(), + &listener_config.consumer_tag, + listener_config.prefetch_count, + ) + .await + .map_err(|e| { + tracing::error!("Failed to create consumer: {}", e); + }) + .expect("Failed to create consumer"); let channel = self.channel.clone(); let handler: Arc< @@ -633,3 +681,24 @@ impl Display for ResultHeader { } } } + +async fn create_consumer( + channel: Channel, + queue_name: &str, + consumer_tag: &str, + prefect_count: u16, +) -> Result { + let channel = channel.clone(); + channel + .basic_qos(prefect_count, BasicQosOptions::default()) + .await?; + + channel + .basic_consume( + queue_name, + consumer_tag, + BasicConsumeOptions::default(), + FieldTable::default(), + ) + .await +} diff --git a/src/test/mod.rs b/src/test/mod.rs index 3abea87..35f367f 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -15,8 +15,8 @@ mod test { use tracing_test::traced_test; use crate::{ - client::{Client, RPCClientTask}, - RPCTask, Worker, WorkerConfig, + client::{BasicCallOptions, Client, RPCClientTask}, + ListenerConfig, RPCTask, Worker, WorkerConfig, }; #[derive(Clone, Debug)] @@ -108,10 +108,10 @@ mod test { .await; listener .add_rpc_consumer::( - "email-emailjob-v1.0.0", Arc::new(State { something: "test".into(), }), + ListenerConfig::default("emailjob").set_prefetch_count(100), ) .await; tracing::debug!("Starting listener"); @@ -129,10 +129,10 @@ mod test { .await; listener .add_rpc_consumer::( - "email-emailjob-v1.0.0", Arc::new(State { something: "test".into(), }), + ListenerConfig::default("emailjob").set_prefetch_count(100), ) .await; tracing::debug!("Starting listener"); @@ -143,7 +143,7 @@ mod test { #[traced_test] async fn rpc_client() { // - let mut client = Client::new(env::var("AMQP_SERVER_URL").unwrap().as_str()) + let client = Client::new(env::var("AMQP_SERVER_URL").unwrap().as_str()) .await .unwrap(); let result = client @@ -152,7 +152,31 @@ mod test { send_to: "someone".into(), contents: "something".into(), }, - "email-emailjob-v1.0.0", + BasicCallOptions::default("emailjob"), + ) + .await + .unwrap(); + assert_eq!( + result, + Ok(EmailJobResult { + contents: "something".to_string() + }) + ) + } + #[test(tokio::test)] + #[traced_test] + async fn rpc_client_timeout() { + // + let client = Client::new(env::var("AMQP_SERVER_URL").unwrap().as_str()) + .await + .unwrap(); + let result = client + .rpc_call::( + EmailJob { + send_to: "someone".into(), + contents: "something".into(), + }, + BasicCallOptions::default("emailjob").timeout(Duration::from_secs(3)), ) .await .unwrap(); @@ -184,7 +208,7 @@ mod test { send_to: "someone".into(), contents: "something".into(), }, - "email-emailjob-v1.0.0", + BasicCallOptions::default("emailjob"), ) .await }));