-made the add methods blocking,
- added back prefetch setter
- bit of docs and other shenanigans
This commit is contained in:
2005 2024-11-15 23:10:03 +01:00
parent 4012f84886
commit e68c88ad60
3 changed files with 97 additions and 51 deletions

View file

@ -43,7 +43,7 @@ impl Client {
/// ///
/// Arguments /// Arguments
/// * `data` The job that will be sent to the queue, must implement Deserialize and Serialize /// * `data` The job that will be sent to the queue, must implement Deserialize and Serialize
/// * `queue_name` The name of the queue to be sent to /// * `options` BasicCallOptions, used to control the timeout and message version
// TODO if the queue is nonexistent return error // TODO if the queue is nonexistent return error
pub async fn rpc_call<T: RPCClientTask + Send + Debug>( pub async fn rpc_call<T: RPCClientTask + Send + Debug>(
&self, &self,
@ -235,6 +235,10 @@ impl Client {
} }
/// Sends a basic Task to the queue /// Sends a basic Task to the queue
///
/// Arguments
/// * `data` The job that will be sent to the queue, must implement Deserialize and Serialize
/// * `options` BasicCallOptions, used to control the timeout and message version
pub async fn call<T>(&self, data: T, options: BasicCallOptions) -> Result<(), ClientError> pub async fn call<T>(&self, data: T, options: BasicCallOptions) -> Result<(), ClientError>
where where
T: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned,

View file

@ -1,6 +1,7 @@
// TODO clean up code...
use futures::{ use futures::{
future::{join_all, BoxFuture}, future::{join_all, BoxFuture},
StreamExt, StreamExt, TryStreamExt,
}; };
use lapin::{ use lapin::{
options::{ options::{
@ -29,7 +30,7 @@ mod test;
pub struct Worker { pub struct Worker {
channel: Channel, channel: Channel,
/// A consumer for each rpc handler /// A consumer for each rpc handler
rpc_consumers: Vec<Consumer>, rpc_consumers: Vec<ListenerConfig>,
rpc_handlers: Vec< rpc_handlers: Vec<
Arc< Arc<
dyn Fn( dyn Fn(
@ -40,7 +41,7 @@ pub struct Worker {
>, >,
>, >,
/// A consumer for each non-rpc handler /// A consumer for each non-rpc handler
consumers: Vec<Consumer>, consumers: Vec<ListenerConfig>,
handlers: Vec< handlers: Vec<
Arc< Arc<
dyn Fn( dyn Fn(
@ -59,6 +60,7 @@ pub struct TlsConfig {
client_cert_and_key: String, client_cert_and_key: String,
client_cert_and_key_password: String, client_cert_and_key_password: String,
} }
impl TlsConfig { impl TlsConfig {
/// Create a custom TLS config /// Create a custom TLS config
pub fn new( pub fn new(
@ -76,7 +78,6 @@ impl TlsConfig {
#[derive(Debug)] #[derive(Debug)]
/// A worker configuration /// A worker configuration
/// Enable tls here
pub struct WorkerConfig { pub struct WorkerConfig {
tls: Option<OwnedTLSConfig>, tls: Option<OwnedTLSConfig>,
} }
@ -186,8 +187,6 @@ impl Worker {
.unwrap(), .unwrap(),
}; };
channel channel
// TODO set qos for channel
} }
/// Add a non-rpc listener to the worker object /// Add a non-rpc listener to the worker object
@ -195,28 +194,19 @@ impl Worker {
/// # Arguments /// # Arguments
/// * `state` - An Arc of the state object that will be passed to the listener /// * `state` - An Arc of the state object that will be passed to the listener
/// * `listener_config` - An Arc of the state object that will be passed to the listener /// * `listener_config` - An Arc of the state object that will be passed to the listener
pub async fn add_non_rpc_consumer<J: Task + 'static + Send>( /// ```
/// use bunbun_worker::{Worker, ListenerConfig, WorkerConfig};
/// let server = Worker::new("amqp://localhost:5672", Workerconfig::default()).await;
/// server.add_non_rpc_consumer::<MyTask>(ListenerConfig::default("service-jobname").set_message_version("v2.0.0") ));
/// server.start_all_listeners().await;
/// ```
pub fn add_non_rpc_consumer<J: Task + 'static + Send>(
&mut self, &mut self,
state: Arc<J::State>, state: Arc<J::State>,
listener_config: ListenerConfig, listener_config: ListenerConfig,
) where ) where
<J as Task>::State: std::marker::Send + Sync, <J as Task>::State: std::marker::Send + Sync,
{ {
let consumer = self
.channel
.basic_consume(
format!(
"{}-{}",
listener_config.queue_name, listener_config.message_version
)
.as_str(),
&listener_config.consumer_tag,
BasicConsumeOptions::default(),
FieldTable::default(),
)
.await
.expect("basic_consume error");
let handler: Arc< let handler: Arc<
dyn Fn( dyn Fn(
lapin::message::Delivery, lapin::message::Delivery,
@ -261,7 +251,7 @@ impl Worker {
}); });
self.handlers.push(handler); self.handlers.push(handler);
self.consumers.push(consumer); self.consumers.push(listener_config);
} }
/// Add an rpc job listener to the worker object /// Add an rpc job listener to the worker object
/// Make sure the type you pass in implements RPCTask /// Make sure the type you pass in implements RPCTask
@ -273,8 +263,9 @@ impl Worker {
/// # Examples /// # Examples
/// ///
/// ``` /// ```
/// let server = BunBunWorker::new("amqp://localhost:5672", None).await; /// use bunbun_worker::{Worker, ListenerConfig, WorkerConfig};
/// server.add_rpc_consumer::<MyRPCTask>(ListenerConfig::default("service-jobname-v1.0.0") )).await; /// let server = Worker::new("amqp://localhost:5672", Workerconfig::default()).await;
/// server.add_rpc_consumer::<MyRPCTask>(ListenerConfig::default("service-jobname").set_message_version("v2.0.0") ));
/// server.start_all_listeners().await; /// server.start_all_listeners().await;
/// ``` /// ```
pub async fn add_rpc_consumer<J: RPCTask + 'static + Send>( pub async fn add_rpc_consumer<J: RPCTask + 'static + Send>(
@ -286,22 +277,6 @@ impl Worker {
<J as RPCTask>::Result: std::marker::Send + Sync, <J as RPCTask>::Result: std::marker::Send + Sync,
<J as RPCTask>::ErroredResult: std::marker::Send + Sync, <J as RPCTask>::ErroredResult: std::marker::Send + Sync,
{ {
let consumer = create_consumer(
self.channel.clone(),
format!(
"{}-{}",
listener_config.queue_name, listener_config.message_version
)
.as_str(),
&listener_config.consumer_tag,
listener_config.prefetch_count,
)
.await
.map_err(|e| {
tracing::error!("Failed to create consumer: {}", e);
})
.expect("Failed to create consumer");
let channel = self.channel.clone(); let channel = self.channel.clone();
let handler: Arc< let handler: Arc<
dyn Fn( dyn Fn(
@ -388,21 +363,53 @@ impl Worker {
}); });
self.rpc_handlers.push(handler); self.rpc_handlers.push(handler);
self.rpc_consumers.push(consumer); self.rpc_consumers.push(listener_config);
} }
/// Start all the listeners added to the worker object /// Start all the listeners added to the worker object
// TODO implement reconnect // TODO implement reconnect
pub async fn start_all_listeners(&self) { // TODO better error handling
pub async fn start_all_listeners(&self) -> Result<(), String> {
let mut listeners = vec![]; let mut listeners = vec![];
for (handler, consumer) in self.handlers.iter().zip(self.consumers.iter()) {
let consumer = consumer.clone(); // Start all the non-rpc listeners
for (handler, consumer_config) in self.handlers.iter().zip(self.consumers.iter()) {
// Clone channel
let mut channel = self.channel.clone();
// Set prefetch count
set_consumer_qos(&mut channel, consumer_config.prefetch_count)
.await
.map_err(|e| {
tracing::error!("Failed to set qos: {}", e);
"Failed to set qos".to_string()
})?;
// Create a consumer with modified channel
let consumer = channel
.basic_consume(
format!(
"{}-{}",
consumer_config.queue_name, consumer_config.message_version
)
.as_str(),
&consumer_config.consumer_tag,
BasicConsumeOptions::default(),
FieldTable::default(),
)
.await
.map_err(|e| {
tracing::error!("Failed to start consumer: {}", e);
"Failed to start consumer".to_string()
})?;
let handler = Arc::clone(handler); let handler = Arc::clone(handler);
tracing::info!( tracing::info!(
"Started listening for incoming messages on queue: {} | Non-rpc", "Started listening for incoming messages on queue: {} | Non-rpc",
consumer.queue().as_str() consumer.queue().as_str()
); );
listeners.push(tokio::spawn(async move { listeners.push(tokio::spawn(async move {
consumer consumer
.for_each_concurrent(None, move |delivery| { .for_each_concurrent(None, move |delivery| {
@ -422,8 +429,33 @@ impl Worker {
})); }));
} }
for (handler, consumer) in self.rpc_handlers.iter().zip(self.rpc_consumers.iter()) { for (handler, consumer_config) in self.rpc_handlers.iter().zip(self.rpc_consumers.iter()) {
let consumer = consumer.clone(); let mut channel = self.channel.clone();
// Set prefetch count
set_consumer_qos(&mut channel, consumer_config.prefetch_count)
.await
.map_err(|e| {
tracing::error!("Failed to set qos: {}", e);
"Failed to set qos".to_string()
})?;
// Create a consumer with modified channel
let consumer = channel
.basic_consume(
format!(
"{}-{}",
consumer_config.queue_name, consumer_config.message_version
)
.as_str(),
&consumer_config.consumer_tag,
BasicConsumeOptions::default(),
FieldTable::default(),
)
.await
.map_err(|e| {
tracing::error!("Failed to start consumer: {}", e);
"Failed to start consumer".to_string()
})?;
let handler = Arc::clone(handler); let handler = Arc::clone(handler);
tracing::debug!( tracing::debug!(
@ -448,7 +480,9 @@ impl Worker {
.await; .await;
})); }));
} }
join_all(listeners).await; join_all(listeners).await;
Ok(())
} }
} }
@ -570,6 +604,7 @@ pub trait Task: Sized + Debug + DeserializeOwned {
Ok(job) Ok(job)
} }
/// The method that will be run by the worker
fn run(self, state: Arc<Self::State>) -> BoxFuture<'static, Result<(), ()>>; fn run(self, state: Arc<Self::State>) -> BoxFuture<'static, Result<(), ()>>;
/// A function to display the task /// A function to display the task
@ -702,3 +737,9 @@ async fn create_consumer(
) )
.await .await
} }
async fn set_consumer_qos(channel: &mut Channel, prefetch_count: u16) -> Result<(), lapin::Error> {
channel
.basic_qos(prefetch_count, BasicQosOptions::default())
.await
}

View file

@ -3,7 +3,7 @@ mod test {
use std::{env, os::unix::thread, sync::Arc, time::Duration}; use std::{env, os::unix::thread, sync::Arc, time::Duration};
use futures::future::join_all; use futures::{future::join_all, FutureExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use test_log::test; use test_log::test;
use tokio::{ use tokio::{
@ -52,13 +52,14 @@ mod test {
state: Arc<Self::State>, state: Arc<Self::State>,
) -> futures::prelude::future::BoxFuture<'static, Result<Self::Result, Self::ErroredResult>> ) -> futures::prelude::future::BoxFuture<'static, Result<Self::Result, Self::ErroredResult>>
{ {
Box::pin(async move { async move {
tracing::info!("Sent email to {}", self.send_to); tracing::info!("Sent email to {}", self.send_to);
tokio::time::sleep(Duration::from_secs(2)).await; tokio::time::sleep(Duration::from_secs(2)).await;
return Ok(EmailJobResult { return Ok(EmailJobResult {
contents: self.contents.clone(), contents: self.contents.clone(),
}); });
}) }
.boxed()
} }
async fn before_job(self) -> Self async fn before_job(self) -> Self
where where