diff --git a/Cargo.toml b/Cargo.toml index fe9a797..fb4c6e7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bunbun-worker" -version = "0.1.1" +version = "0.2.0" description = "An rpc/non-rpc rabbitmq worker library" edition = "2021" license = "AGPL-3.0" @@ -11,8 +11,9 @@ keywords = ["worker", "rpc", "rabbitmq"] [dependencies] async-trait = "0.1.83" +derive_setters = "0.1.6" futures = "0.3.31" -lapin = "2.5.0" +lapin = { version = "2.5.0", features = ["rustls"] } serde = "1.0.213" serde_json = "1.0.132" tokio = { version = "1.41.0", features = ["full"] } diff --git a/README.md b/README.md index cbd7f12..9cd1e30 100644 --- a/README.md +++ b/README.md @@ -49,130 +49,11 @@ bunbun-worker = { git = "https://git.4o1x5.dev/4o1x5/bunbun-worker", branch = "m Here is a basic implementation of an RPC job in bunbun-worker -```rust -// server - -// First let's create a state that will be used inside a job. -// Imagine this holding a database connection, some context that may need to be changed. Anything really -#[derive(Clone, Debug)] -pub struct State { - pub something: String, -} - -/// Second, let's create a job, with field that can be serialized/deserialized into JSON -/// This is what the server will receive from a client and will do the job based on these properties -#[derive(Deserialize, Serialize, Clone, Debug)] -pub struct EmailJob { - send_to: String, - contents: String, -} -// We also create a result for it, since it's an RPC job -#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] -pub struct EmailJobResult { - contents: String, -} -// And an error type so know if the other end errored out, what to do. -#[derive(Deserialize, Serialize, Clone, Debug)] -pub enum EmailJobResultError { - Errored, -} - -/// After all this we implement a Jobrunner/Taskrunner to the type, so when the listener receives it, it can run this piece of code. -impl RPCServerTask for EmailJob { - type ErroredResult = EmailJobResultError; - type Result = EmailJobResult; - type State = State; - fn run( - self, - state: Self::State, - ) -> futures::prelude::future::BoxFuture<'static, Result> - { - Box::pin(async move { - tracing::info!("Sent email to {}", self.send_to); - tokio::time::sleep(Duration::from_secs(2)).await; - return Ok(EmailJobResult { - contents: self.contents.clone(), - }); - }) - } -} -// Finally, we can define an async main function to run the listener in. -#[tokio::main] -async fn main(){ - // Define a listener, with a hard-limit of 100 jobs at once. - let mut listener = - BunBunWorker::new(env::var("AMQP_SERVER_URL").unwrap(), 100.into()).await; - // Add the defined sturct to the worker - listener - .add_rpc_consumer::( - "email-emailjob-v1.0.0", // queue name - State { - something: "test".into(), // putting our state into a Arc> for thread safety - }, - ) - .await; - tracing::debug!("Starting listener"); - // Starting the listener - listener.start_all_listeners().await; -} -``` - -Instal `futures_util` for the client - -``` -cargo add futures_util -``` - -```rust -// client - -// Define the same structs we did. These are DTO's after all.. -#[derive(Deserialize, Serialize, Clone, Debug)] -pub struct EmailJob { - send_to: String, - contents: String, -} -#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] -pub struct EmailJobResult { - contents: String, -} -#[derive(Deserialize, Serialize, Clone, Debug)] -pub enum EmailJobResultError { - Errored, -} -// Now we implement the clientside task for it. This reduces generics when defining the calling.. -impl RPCClientTask for EmailJob { - type ErroredResult = EmailJobResultError; - type Result = EmailJobResult; -} - -#[tokio::main] -async fn main(){ -// Define a client - let mut client = BunBunClient::new(env::var("AMQP_SERVER_URL").unwrap().as_str()) - .await - .unwrap(); - // Make a call - let result = client - .rpc_call::( - // Define the job - EmailJob { - send_to: "someone".into(), - contents: "something".into(), - }, - "email-emailjob-v1.0.0", // the queue name - ) - .await - .unwrap(); -} -``` - # Limitations 1. Currently some `unwrap()`'s are called inside the code and may results in panics (not in the job-runner). -2. No TLS support -3. No settings, and very limited API -4. The rabbitmq RPC logic is very basic with no message-versioning (aside using different queue names (see [usage](#usage)) ) +2. No settings, and very limited API +3. The rabbitmq RPC logic is very basic with no message-versioning (aside using different queue names (eg service-class-v1.0.0) ) # Bugs department diff --git a/src/client.rs b/src/client.rs index 88a9684..ff7d722 100644 --- a/src/client.rs +++ b/src/client.rs @@ -16,22 +16,21 @@ use crate::ResultHeader; /// A client for the server part of `bunbun-worker` #[derive(Debug)] -pub struct BunBunClient { +pub struct Client { conn: Connection, } -// TODO implement reconnect // TODO implement tls -impl BunBunClient { +impl Client { /// Creates an rpc client /// /// # Examples /// /// ``` /// // Create a client and send a message - /// use bunbun_worker::client::BunBunClient; - /// let client = BunBunClient::new("amqp://127.0.0.1:5672"); + /// use bunbun_worker::client::Client; + /// let client = Client::new("amqp://127.0.0.1:5672"); /// ``` pub async fn new(address: &str) -> Result { let conn = Connection::connect(address, ConnectionProperties::default()).await?; @@ -39,6 +38,12 @@ impl BunBunClient { Ok(Self { conn }) } + /// A method to call for a RPC + /// + /// Arguments + /// * `data` The job that will be sent to the queue, must implement Deserialize and Serialize + /// * `queue_name` The name of the queue to be sent to + // TODO if the queue is nonexistent return error pub async fn rpc_call( &self, data: T, @@ -265,7 +270,7 @@ impl BunBunClient { } } -/// An error that the bunbunclient returns +/// An error that the client returns #[derive(Debug)] pub enum RpcClientError { NoReply, // TODO timeout @@ -288,6 +293,28 @@ impl Display for ClientError { } } +/// A Client-side trait that needs to be implemented for a type in order for the client to know return types. +/// +/// Examples +/// ``` +/// #[derive(Deserialize, Serialize, Clone, Debug)] +/// pub struct EmailJob { +/// send_to: String, +/// contents: String, +/// } +/// #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] +/// pub struct EmailJobResult { +/// contents: String, +/// } +/// #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] +/// pub enum EmailJobResultError { +/// Errored, +/// } +/// impl RPCClientTask for EmailJob { +/// type ErroredResult = EmailJobResultError; +/// type Result = EmailJobResult; +/// } +/// ``` pub trait RPCClientTask: Sized + Debug + DeserializeOwned { type Result: Serialize + DeserializeOwned + Debug; type ErroredResult: Serialize + DeserializeOwned + Debug; diff --git a/src/lib.rs b/src/lib.rs index b197f24..11c227a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ use lapin::{ BasicAckOptions, BasicConsumeOptions, BasicNackOptions, BasicPublishOptions, BasicQosOptions, }, + tcp::{OwnedIdentity, OwnedTLSConfig}, types::{DeliveryTag, FieldTable, ShortString}, BasicProperties, Channel, Connection, ConnectionProperties, Consumer, }; @@ -14,16 +15,18 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::{ fmt::{Debug, Display}, pin::Pin, - str::from_utf8, + str::{from_utf8, FromStr}, sync::Arc, }; +use tokio::task::JoinError; /// The client module that interacts with the server part of `bunbun-worker` pub mod client; mod test; -pub struct BunBunWorker { +/// The worker object that contains all the threads and runners. +pub struct Worker { channel: Channel, /// A consumer for each rpc handler rpc_consumers: Vec, @@ -49,16 +52,71 @@ pub struct BunBunWorker { >, } +// Example taken from https://github.com/amqp-rs/lapin/blob/main/examples/client-certificate.rs +/// Custom certificate configuration +pub struct TlsConfig { + cert_chain: String, + client_cert_and_key: String, + client_cert_and_key_password: String, +} +impl TlsConfig { + /// Create a custom TLS config + pub fn new( + cert_chain: String, + client_cert_and_key: String, + client_cert_and_key_password: String, + ) -> Self { + Self { + cert_chain, + client_cert_and_key, + client_cert_and_key_password, + } + } +} + +#[derive(Debug)] +/// A worker configuration +/// Enable tls here +pub struct WorkerConfig { + tls: Option, +} +impl WorkerConfig { + /// Creates a new worker config + pub fn default() -> Self { + Self { tls: None } + } + + /// Enable secure connection to amqp. + /// + /// # Arguments + /// * `custom_tls` - Optional TLSconfig (if none defaults to lapins choice) + pub fn enable_tls(&mut self, custom_tls: Option) { + match custom_tls { + Some(tls) => { + let tls = OwnedTLSConfig { + identity: Some(OwnedIdentity { + der: tls.client_cert_and_key.as_bytes().to_vec(), + password: tls.client_cert_and_key_password, + }), + cert_chain: Some(tls.cert_chain.to_string()), + }; + self.tls = tls.into(); + } + None => self.tls = OwnedTLSConfig::default().into(), + } + } +} + // TODO implement reconnect -// TODO implement tls -impl BunBunWorker { +impl Worker { /// Create a new instance of `bunbun-worker` /// # Arguments /// * `amqp_server_url` - A string slice that holds the url of the amqp server (e.g. amqp://localhost:5672) - /// * `limit` - An optional u16 that holds the limit of the number of messages to prefetch 0 by default - pub async fn new(amqp_server_url: impl Into, limit: Option) -> Self { - let channel = Self::create_channel(amqp_server_url.into(), limit).await; - BunBunWorker { + /// * `config` - A worker config, containing the TLS config for now. + pub async fn new(amqp_server_url: impl Into, config: WorkerConfig) -> Self { + let channel = Self::create_channel(amqp_server_url.into(), config).await; + + Worker { channel, handlers: Vec::new(), consumers: Vec::new(), @@ -68,33 +126,42 @@ impl BunBunWorker { } } - async fn create_channel(amqp_server_url: String, limit: Option) -> Channel { - let conn = Connection::connect(&amqp_server_url, ConnectionProperties::default()) + async fn create_channel(amqp_server_url: String, config: WorkerConfig) -> Channel { + // TODO handle unwraps + let channel = match config.tls { + None => Connection::connect(&amqp_server_url, ConnectionProperties::default()) + .await + .expect("connection error") + .create_channel() + .await + .unwrap(), + Some(tls) => Connection::connect_uri_with_config( + lapin::uri::AMQPUri::from_str(&amqp_server_url).unwrap(), + ConnectionProperties::default(), + tls, + ) .await - .expect("connection error"); - match limit { - Some(limit) => { - let channel = conn.create_channel().await.expect("create channel error"); - channel - .basic_qos(limit, BasicQosOptions::default()) - .await - .expect("Failed to set prefetch amount"); - channel - } - None => conn.create_channel().await.expect("create channel error"), - } + .unwrap() + .create_channel() + .await + .unwrap(), + }; + channel + + // TODO set qos for channel } + /// Add a non-rpc listener to the worker object /// /// # Arguments /// * `queue_name` - A string slice that holds the name of the queue to listen to (e.g. service-serviceJobName-v1.0.0) /// * `state` - An Arc of the state object that will be passed to the listener - pub async fn add_non_rpc_consumer( + pub async fn add_non_rpc_consumer( &mut self, queue_name: &str, state: Arc, ) where - ::State: std::marker::Send + Sync, + ::State: std::marker::Send + Sync, { let consumer = self .channel @@ -107,7 +174,6 @@ impl BunBunWorker { .await .expect("basic_consume error"); - let channel = self.channel.clone(); let handler: Arc< dyn Fn( lapin::message::Delivery, @@ -116,14 +182,27 @@ impl BunBunWorker { + Sync, > = Arc::new(move |delivery: lapin::message::Delivery| { let state = Arc::clone(&state); - let channel = channel.clone(); Box::pin(async move { if let Ok(job) = J::decode(delivery.data.clone()) { // Running job + + tracing::debug!("Running before job"); + let job = match tokio::task::spawn(async move { job.before_job().await }).await + { + Err(error) => { + tracing::error!( + "The before_job function has failed for a job of type: {}, {}", + std::any::type_name::(), + error + ); + return (); + } + Ok(j) => j, + }; + match tokio::task::spawn(async move { job.run(state).await }).await { Err(error) => { - tracing::error!("Failed to run non-rpc job: {}", error); - + tracing::error!("Failed to run task job: {}", error); let _ = delivery.nack(BasicNackOptions::default()).await; } Ok(_) => { @@ -131,6 +210,7 @@ impl BunBunWorker { let _ = delivery.ack(BasicAckOptions::default()).await; } }; + // TODO run afterjob } else { delivery.nack(BasicNackOptions::default()).await.unwrap(); } @@ -141,27 +221,27 @@ impl BunBunWorker { self.consumers.push(consumer); } /// Add an rpc job listener to the worker object - /// + /// Make sure the type you pass in implements RPCTask /// # Arguments /// * `queue_name` - A string slice that holds the name of the queue to listen to (e.g. service-serviceJobName-v1.0.0) /// * `state` - An Arc of the state object that will be passed to the listener /// + /// /// # Examples /// /// ``` - /// /// let server = BunBunWorker::new("amqp://localhost:5672", None).await; - /// server.add_rpc_consumer::("service-serviceJobName-v1.0.0", SomeState{} )).await; + /// server.add_rpc_consumer::("service-serviceJobName-v1.0.0", SomeState{} )).await; /// server.start_all_listeners().await; /// ``` - pub async fn add_rpc_consumer( + pub async fn add_rpc_consumer( &mut self, queue_name: &str, state: Arc, ) where - ::State: std::marker::Send + Sync, - ::Result: std::marker::Send + Sync, - ::ErroredResult: std::marker::Send + Sync, + ::State: std::marker::Send + Sync, + ::Result: std::marker::Send + Sync, + ::ErroredResult: std::marker::Send + Sync, { let consumer = self .channel @@ -195,7 +275,6 @@ impl BunBunWorker { } }; - // TODO handle errors let correlation_id = match delivery.properties.correlation_id().clone() { None => { tracing::warn!("received a job with no correlation id"); @@ -207,10 +286,13 @@ impl BunBunWorker { }; if let Ok(job) = J::decode(delivery.data.clone()) { - // Catching panics + let job = tokio::task::spawn(async move { job.before_job().await }) + .await + .unwrap(); let outcome = tokio::task::spawn(async move { job.run(state).await }).await; + match outcome { - Err(error) => { + Err(ref error) => { tracing::error!("Failed to start thread for worker {}", error); let headers = create_header(ResultHeader::Panic); let _ = delivery.ack(BasicAckOptions::default()).await; // acking the delivery @@ -223,7 +305,7 @@ impl BunBunWorker { ) .await } - Ok(Ok(res)) => { + Ok(Ok(ref res)) => { let headers = create_header(ResultHeader::Ok); let _ = delivery.ack(BasicAckOptions::default()).await; // acking the delivery respond_to_rpc_queue( @@ -231,11 +313,11 @@ impl BunBunWorker { routing_key, headers, correlation_id, - Some(res), + Some(res.clone()), ) .await } - Ok(Err(err)) => { + Ok(Err(ref err)) => { // let headers = create_header(ResultHeader::Ok); let _ = delivery.ack(BasicAckOptions::default()).await; // acking the delivery @@ -244,11 +326,13 @@ impl BunBunWorker { routing_key, headers, correlation_id, - Some(err), + Some(err.clone()), ) .await } - } + }; + tracing::debug!("Running after job"); + let _ = tokio::task::spawn(async move { J::after_job(outcome).await }).await; } else { delivery.nack(BasicNackOptions::default()).await.unwrap(); } @@ -260,7 +344,7 @@ impl BunBunWorker { } /// Start all the listeners added to the worker object - + // TODO implement reconnect pub async fn start_all_listeners(&self) { let mut listeners = vec![]; for (handler, consumer) in self.handlers.iter().zip(self.consumers.iter()) { @@ -320,28 +404,36 @@ impl BunBunWorker { } } -/// A trait that defines the structure of a task that can be run by the worker -/// The task must be deserializable and serializable -/// The task must have a result and an errored result +/// A trait that defines the structure of a task that can be run by the worker +/// BoxFuture is from a crate called `futures_util` +/// /// # Examples /// ``` /// #[derive(Debug, Serialize, Deserialize)] -/// struct MyRPCServerTask { +/// struct MyRPCTask { /// pub name: String, /// } -/// impl RPCServerTask for MyRPCServerTask { -/// type Result = String; -/// type ErroredResult = String; +/// #[derive(Debug, Serialize, Deserialize)] +/// struct MyRPCTaskResult { +/// pub something: String, +/// } +/// #[derive(Debug, Serialize, Deserialize)] +/// struct MyRPCTaskErroredResult { +/// pub what_failed: String, +/// } +/// impl RPCTask for MyRPCTask { +/// type Result = MyRPCTaskResult; +/// type Error = MyRPCTaskErroredResult; /// type State = SomeState; /// -/// fn run(self, state: Arc) -> BoxFuture<'static, Result> { +/// fn run(self, state: Arc) -> BoxFuture<'static, Result> { /// async move { -/// Ok("Hello".to_string()) +/// Ok(MyRPCTaskResult{ something: "Hello I ran ok!".into() }) /// }.boxed() /// } -pub trait RPCServerTask: Sized + Debug + DeserializeOwned { - type Result: Serialize + DeserializeOwned + Debug; - type ErroredResult: Serialize + DeserializeOwned + Debug; +pub trait RPCTask: Sized + Debug + DeserializeOwned { + type Result: Serialize + DeserializeOwned + Debug + Clone; + type ErroredResult: Serialize + DeserializeOwned + Debug + Clone; type State: Clone + Debug; /// Decoding for the message. Overriding is possible. @@ -368,11 +460,25 @@ pub trait RPCServerTask: Sized + Debug + DeserializeOwned { fn display(&self) -> String { format!("{:?}", self) } - // TODO add a function that runs after a job is ran - // TODO add a function that runs before a job is about to run + /// A function that runs before the job is ran by the worker. This allows you to modify any values inside it, add tracing ect. + fn before_job(self) -> impl std::future::Future + Send + where + Self: Send, + { + async move { self } + } + /// A function that runs after the job has ran and the worker has responded to the callback queue. Any modifications made will not be reflected on the client side. + fn after_job( + res: Result, JoinError>, + ) -> impl std::future::Future + Send + where + Self: Send, + { + async move {} + } } -/// A NonRPCServer task +/// A regular task /// Implement this trait to any struct to make it a runnable `non-rpc` job. /// /// Examples @@ -384,7 +490,7 @@ pub trait RPCServerTask: Sized + Debug + DeserializeOwned { /// send_to: String, /// contents: String, /// } -/// impl NonRPCServerTask for EmailJob { +/// impl Task for EmailJob { /// type State = State; /// fn run( /// self, @@ -397,7 +503,7 @@ pub trait RPCServerTask: Sized + Debug + DeserializeOwned { /// } /// } /// ``` -pub trait NonRPCServerTask: Sized + Debug + DeserializeOwned { +pub trait Task: Sized + Debug + DeserializeOwned { type State: Clone + Debug; fn decode(data: Vec) -> Result { @@ -423,8 +529,20 @@ pub trait NonRPCServerTask: Sized + Debug + DeserializeOwned { format!("{:?}", self) } - // TODO add a function that runs after a job is ran - // TODO add a function that runs before a job is about to run + /// A function that runs before the job is ran by the worker. This allows you to modify any values inside it, add tracing ect. + fn before_job(self) -> impl std::future::Future + Send + where + Self: Send, + { + async move { self } + } + /// A function that runs after the job has finished and the worker has acked the request. + fn after_job(self) -> impl std::future::Future + Send + where + Self: Sync + Send, + { + async move {} + } } #[derive(Debug)] @@ -501,7 +619,7 @@ fn create_header(header: ResultHeader) -> FieldTable { /// A result header that is included in the header of the AMQP message. /// It indicates the status of the returned message #[derive(Debug, Serialize, Deserialize)] -pub enum ResultHeader { +enum ResultHeader { Ok, Error, Panic, diff --git a/src/test/mod.rs b/src/test/mod.rs index 56ad479..3abea87 100644 --- a/src/test/mod.rs +++ b/src/test/mod.rs @@ -15,8 +15,8 @@ mod test { use tracing_test::traced_test; use crate::{ - client::{BunBunClient, RPCClientTask}, - BunBunWorker, RPCServerTask, + client::{Client, RPCClientTask}, + RPCTask, Worker, WorkerConfig, }; #[derive(Clone, Debug)] @@ -43,7 +43,7 @@ mod test { contents: String, } - impl RPCServerTask for EmailJob { + impl RPCTask for EmailJob { type ErroredResult = EmailJobResultError; type Result = EmailJobResult; type State = State; @@ -60,8 +60,26 @@ mod test { }); }) } + async fn before_job(self) -> Self + where + Self: Send, + { + tracing::info!( + "hello i received a job of type: {}", + std::any::type_name::() + ); + self + } + fn after_job( + res: Result, tokio::task::JoinError>, + ) -> impl std::future::Future + Send + where + Self: Send, + { + async move { tracing::info!("Job has ran with result: {:?}", res) } + } } - impl RPCServerTask for PanickingEmailJob { + impl RPCTask for PanickingEmailJob { type ErroredResult = EmailJobResultError; type Result = EmailJobResult; type State = State; @@ -81,10 +99,13 @@ mod test { #[test(tokio::test)] #[traced_test] - async fn rpc() { + async fn rpc_server() { // - let mut listener = - BunBunWorker::new(env::var("AMQP_SERVER_URL").unwrap(), 100.into()).await; + let mut listener = Worker::new( + env::var("AMQP_SERVER_URL").unwrap(), + WorkerConfig::default(), + ) + .await; listener .add_rpc_consumer::( "email-emailjob-v1.0.0", @@ -101,8 +122,11 @@ mod test { #[traced_test] async fn rpc_that_will_panic() { // - let mut listener = - BunBunWorker::new(env::var("AMQP_SERVER_URL").unwrap(), 100.into()).await; + let mut listener = Worker::new( + env::var("AMQP_SERVER_URL").unwrap(), + WorkerConfig::default(), + ) + .await; listener .add_rpc_consumer::( "email-emailjob-v1.0.0", @@ -119,7 +143,7 @@ mod test { #[traced_test] async fn rpc_client() { // - let mut client = BunBunClient::new(env::var("AMQP_SERVER_URL").unwrap().as_str()) + let mut client = Client::new(env::var("AMQP_SERVER_URL").unwrap().as_str()) .await .unwrap(); let result = client @@ -143,7 +167,7 @@ mod test { #[test(tokio::test)] #[traced_test] async fn rpc_client_spam_multithread() { - let client = BunBunClient::new(env::var("AMQP_SERVER_URL").unwrap().as_str()) + let client = Client::new(env::var("AMQP_SERVER_URL").unwrap().as_str()) .await .unwrap(); let client = Arc::new(Mutex::new(client));