client: removed the bunbun part

worker: removed the bunbun part
- extended with a struct for configuration
- added tls support (have not tested)
- removed the prefetch temporary
test: rewritten according to new api
This commit is contained in:
Barna Máté 2024-11-14 17:59:36 +01:00
parent de6c0619b3
commit 755cf6467d
5 changed files with 255 additions and 204 deletions

View file

@ -1,6 +1,6 @@
[package]
name = "bunbun-worker"
version = "0.1.1"
version = "0.2.0"
description = "An rpc/non-rpc rabbitmq worker library"
edition = "2021"
license = "AGPL-3.0"
@ -11,8 +11,9 @@ keywords = ["worker", "rpc", "rabbitmq"]
[dependencies]
async-trait = "0.1.83"
derive_setters = "0.1.6"
futures = "0.3.31"
lapin = "2.5.0"
lapin = { version = "2.5.0", features = ["rustls"] }
serde = "1.0.213"
serde_json = "1.0.132"
tokio = { version = "1.41.0", features = ["full"] }

123
README.md
View file

@ -49,130 +49,11 @@ bunbun-worker = { git = "https://git.4o1x5.dev/4o1x5/bunbun-worker", branch = "m
Here is a basic implementation of an RPC job in bunbun-worker
```rust
// server
// First let's create a state that will be used inside a job.
// Imagine this holding a database connection, some context that may need to be changed. Anything really
#[derive(Clone, Debug)]
pub struct State {
pub something: String,
}
/// Second, let's create a job, with field that can be serialized/deserialized into JSON
/// This is what the server will receive from a client and will do the job based on these properties
#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct EmailJob {
send_to: String,
contents: String,
}
// We also create a result for it, since it's an RPC job
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
pub struct EmailJobResult {
contents: String,
}
// And an error type so know if the other end errored out, what to do.
#[derive(Deserialize, Serialize, Clone, Debug)]
pub enum EmailJobResultError {
Errored,
}
/// After all this we implement a Jobrunner/Taskrunner to the type, so when the listener receives it, it can run this piece of code.
impl RPCServerTask for EmailJob {
type ErroredResult = EmailJobResultError;
type Result = EmailJobResult;
type State = State;
fn run(
self,
state: Self::State,
) -> futures::prelude::future::BoxFuture<'static, Result<Self::Result, Self::ErroredResult>>
{
Box::pin(async move {
tracing::info!("Sent email to {}", self.send_to);
tokio::time::sleep(Duration::from_secs(2)).await;
return Ok(EmailJobResult {
contents: self.contents.clone(),
});
})
}
}
// Finally, we can define an async main function to run the listener in.
#[tokio::main]
async fn main(){
// Define a listener, with a hard-limit of 100 jobs at once.
let mut listener =
BunBunWorker::new(env::var("AMQP_SERVER_URL").unwrap(), 100.into()).await;
// Add the defined sturct to the worker
listener
.add_rpc_consumer::<EmailJob>(
"email-emailjob-v1.0.0", // queue name
State {
something: "test".into(), // putting our state into a Arc<Mutex<S>> for thread safety
},
)
.await;
tracing::debug!("Starting listener");
// Starting the listener
listener.start_all_listeners().await;
}
```
Instal `futures_util` for the client
```
cargo add futures_util
```
```rust
// client
// Define the same structs we did. These are DTO's after all..
#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct EmailJob {
send_to: String,
contents: String,
}
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
pub struct EmailJobResult {
contents: String,
}
#[derive(Deserialize, Serialize, Clone, Debug)]
pub enum EmailJobResultError {
Errored,
}
// Now we implement the clientside task for it. This reduces generics when defining the calling..
impl RPCClientTask for EmailJob {
type ErroredResult = EmailJobResultError;
type Result = EmailJobResult;
}
#[tokio::main]
async fn main(){
// Define a client
let mut client = BunBunClient::new(env::var("AMQP_SERVER_URL").unwrap().as_str())
.await
.unwrap();
// Make a call
let result = client
.rpc_call::<EmailJob>(
// Define the job
EmailJob {
send_to: "someone".into(),
contents: "something".into(),
},
"email-emailjob-v1.0.0", // the queue name
)
.await
.unwrap();
}
```
# Limitations
1. Currently some `unwrap()`'s are called inside the code and may results in panics (not in the job-runner).
2. No TLS support
3. No settings, and very limited API
4. The rabbitmq RPC logic is very basic with no message-versioning (aside using different queue names (see [usage](#usage)) )
2. No settings, and very limited API
3. The rabbitmq RPC logic is very basic with no message-versioning (aside using different queue names (eg service-class-v1.0.0) )
# Bugs department

View file

@ -16,22 +16,21 @@ use crate::ResultHeader;
/// A client for the server part of `bunbun-worker`
#[derive(Debug)]
pub struct BunBunClient {
pub struct Client {
conn: Connection,
}
// TODO implement reconnect
// TODO implement tls
impl BunBunClient {
impl Client {
/// Creates an rpc client
///
/// # Examples
///
/// ```
/// // Create a client and send a message
/// use bunbun_worker::client::BunBunClient;
/// let client = BunBunClient::new("amqp://127.0.0.1:5672");
/// use bunbun_worker::client::Client;
/// let client = Client::new("amqp://127.0.0.1:5672");
/// ```
pub async fn new(address: &str) -> Result<Self, lapin::Error> {
let conn = Connection::connect(address, ConnectionProperties::default()).await?;
@ -39,6 +38,12 @@ impl BunBunClient {
Ok(Self { conn })
}
/// A method to call for a RPC
///
/// Arguments
/// * `data` The job that will be sent to the queue, must implement Deserialize and Serialize
/// * `queue_name` The name of the queue to be sent to
// TODO if the queue is nonexistent return error
pub async fn rpc_call<T: RPCClientTask + Send + Debug>(
&self,
data: T,
@ -265,7 +270,7 @@ impl BunBunClient {
}
}
/// An error that the bunbunclient returns
/// An error that the client returns
#[derive(Debug)]
pub enum RpcClientError {
NoReply, // TODO timeout
@ -288,6 +293,28 @@ impl Display for ClientError {
}
}
/// A Client-side trait that needs to be implemented for a type in order for the client to know return types.
///
/// Examples
/// ```
/// #[derive(Deserialize, Serialize, Clone, Debug)]
/// pub struct EmailJob {
/// send_to: String,
/// contents: String,
/// }
/// #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
/// pub struct EmailJobResult {
/// contents: String,
/// }
/// #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
/// pub enum EmailJobResultError {
/// Errored,
/// }
/// impl RPCClientTask for EmailJob {
/// type ErroredResult = EmailJobResultError;
/// type Result = EmailJobResult;
/// }
/// ```
pub trait RPCClientTask: Sized + Debug + DeserializeOwned {
type Result: Serialize + DeserializeOwned + Debug;
type ErroredResult: Serialize + DeserializeOwned + Debug;

View file

@ -7,6 +7,7 @@ use lapin::{
BasicAckOptions, BasicConsumeOptions, BasicNackOptions, BasicPublishOptions,
BasicQosOptions,
},
tcp::{OwnedIdentity, OwnedTLSConfig},
types::{DeliveryTag, FieldTable, ShortString},
BasicProperties, Channel, Connection, ConnectionProperties, Consumer,
};
@ -14,16 +15,18 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{
fmt::{Debug, Display},
pin::Pin,
str::from_utf8,
str::{from_utf8, FromStr},
sync::Arc,
};
use tokio::task::JoinError;
/// The client module that interacts with the server part of `bunbun-worker`
pub mod client;
mod test;
pub struct BunBunWorker {
/// The worker object that contains all the threads and runners.
pub struct Worker {
channel: Channel,
/// A consumer for each rpc handler
rpc_consumers: Vec<Consumer>,
@ -49,16 +52,71 @@ pub struct BunBunWorker {
>,
}
// Example taken from https://github.com/amqp-rs/lapin/blob/main/examples/client-certificate.rs
/// Custom certificate configuration
pub struct TlsConfig {
cert_chain: String,
client_cert_and_key: String,
client_cert_and_key_password: String,
}
impl TlsConfig {
/// Create a custom TLS config
pub fn new(
cert_chain: String,
client_cert_and_key: String,
client_cert_and_key_password: String,
) -> Self {
Self {
cert_chain,
client_cert_and_key,
client_cert_and_key_password,
}
}
}
#[derive(Debug)]
/// A worker configuration
/// Enable tls here
pub struct WorkerConfig {
tls: Option<OwnedTLSConfig>,
}
impl WorkerConfig {
/// Creates a new worker config
pub fn default() -> Self {
Self { tls: None }
}
/// Enable secure connection to amqp.
///
/// # Arguments
/// * `custom_tls` - Optional TLSconfig (if none defaults to lapins choice)
pub fn enable_tls(&mut self, custom_tls: Option<TlsConfig>) {
match custom_tls {
Some(tls) => {
let tls = OwnedTLSConfig {
identity: Some(OwnedIdentity {
der: tls.client_cert_and_key.as_bytes().to_vec(),
password: tls.client_cert_and_key_password,
}),
cert_chain: Some(tls.cert_chain.to_string()),
};
self.tls = tls.into();
}
None => self.tls = OwnedTLSConfig::default().into(),
}
}
}
// TODO implement reconnect
// TODO implement tls
impl BunBunWorker {
impl Worker {
/// Create a new instance of `bunbun-worker`
/// # Arguments
/// * `amqp_server_url` - A string slice that holds the url of the amqp server (e.g. amqp://localhost:5672)
/// * `limit` - An optional u16 that holds the limit of the number of messages to prefetch 0 by default
pub async fn new(amqp_server_url: impl Into<String>, limit: Option<u16>) -> Self {
let channel = Self::create_channel(amqp_server_url.into(), limit).await;
BunBunWorker {
/// * `config` - A worker config, containing the TLS config for now.
pub async fn new(amqp_server_url: impl Into<String>, config: WorkerConfig) -> Self {
let channel = Self::create_channel(amqp_server_url.into(), config).await;
Worker {
channel,
handlers: Vec::new(),
consumers: Vec::new(),
@ -68,33 +126,42 @@ impl BunBunWorker {
}
}
async fn create_channel(amqp_server_url: String, limit: Option<u16>) -> Channel {
let conn = Connection::connect(&amqp_server_url, ConnectionProperties::default())
async fn create_channel(amqp_server_url: String, config: WorkerConfig) -> Channel {
// TODO handle unwraps
let channel = match config.tls {
None => Connection::connect(&amqp_server_url, ConnectionProperties::default())
.await
.expect("connection error")
.create_channel()
.await
.unwrap(),
Some(tls) => Connection::connect_uri_with_config(
lapin::uri::AMQPUri::from_str(&amqp_server_url).unwrap(),
ConnectionProperties::default(),
tls,
)
.await
.expect("connection error");
match limit {
Some(limit) => {
let channel = conn.create_channel().await.expect("create channel error");
channel
.basic_qos(limit, BasicQosOptions::default())
.await
.expect("Failed to set prefetch amount");
channel
}
None => conn.create_channel().await.expect("create channel error"),
}
.unwrap()
.create_channel()
.await
.unwrap(),
};
channel
// TODO set qos for channel
}
/// Add a non-rpc listener to the worker object
///
/// # Arguments
/// * `queue_name` - A string slice that holds the name of the queue to listen to (e.g. service-serviceJobName-v1.0.0)
/// * `state` - An Arc of the state object that will be passed to the listener
pub async fn add_non_rpc_consumer<J: NonRPCServerTask + 'static + Send>(
pub async fn add_non_rpc_consumer<J: Task + 'static + Send>(
&mut self,
queue_name: &str,
state: Arc<J::State>,
) where
<J as NonRPCServerTask>::State: std::marker::Send + Sync,
<J as Task>::State: std::marker::Send + Sync,
{
let consumer = self
.channel
@ -107,7 +174,6 @@ impl BunBunWorker {
.await
.expect("basic_consume error");
let channel = self.channel.clone();
let handler: Arc<
dyn Fn(
lapin::message::Delivery,
@ -116,14 +182,27 @@ impl BunBunWorker {
+ Sync,
> = Arc::new(move |delivery: lapin::message::Delivery| {
let state = Arc::clone(&state);
let channel = channel.clone();
Box::pin(async move {
if let Ok(job) = J::decode(delivery.data.clone()) {
// Running job
tracing::debug!("Running before job");
let job = match tokio::task::spawn(async move { job.before_job().await }).await
{
Err(error) => {
tracing::error!(
"The before_job function has failed for a job of type: {}, {}",
std::any::type_name::<J>(),
error
);
return ();
}
Ok(j) => j,
};
match tokio::task::spawn(async move { job.run(state).await }).await {
Err(error) => {
tracing::error!("Failed to run non-rpc job: {}", error);
tracing::error!("Failed to run task job: {}", error);
let _ = delivery.nack(BasicNackOptions::default()).await;
}
Ok(_) => {
@ -131,6 +210,7 @@ impl BunBunWorker {
let _ = delivery.ack(BasicAckOptions::default()).await;
}
};
// TODO run afterjob
} else {
delivery.nack(BasicNackOptions::default()).await.unwrap();
}
@ -141,27 +221,27 @@ impl BunBunWorker {
self.consumers.push(consumer);
}
/// Add an rpc job listener to the worker object
///
/// Make sure the type you pass in implements RPCTask
/// # Arguments
/// * `queue_name` - A string slice that holds the name of the queue to listen to (e.g. service-serviceJobName-v1.0.0)
/// * `state` - An Arc of the state object that will be passed to the listener
///
///
/// # Examples
///
/// ```
///
/// let server = BunBunWorker::new("amqp://localhost:5672", None).await;
/// server.add_rpc_consumer::<MyRPCServerTask>("service-serviceJobName-v1.0.0", SomeState{} )).await;
/// server.add_rpc_consumer::<MyRPCTask>("service-serviceJobName-v1.0.0", SomeState{} )).await;
/// server.start_all_listeners().await;
/// ```
pub async fn add_rpc_consumer<J: RPCServerTask + 'static + Send>(
pub async fn add_rpc_consumer<J: RPCTask + 'static + Send>(
&mut self,
queue_name: &str,
state: Arc<J::State>,
) where
<J as RPCServerTask>::State: std::marker::Send + Sync,
<J as RPCServerTask>::Result: std::marker::Send + Sync,
<J as RPCServerTask>::ErroredResult: std::marker::Send + Sync,
<J as RPCTask>::State: std::marker::Send + Sync,
<J as RPCTask>::Result: std::marker::Send + Sync,
<J as RPCTask>::ErroredResult: std::marker::Send + Sync,
{
let consumer = self
.channel
@ -195,7 +275,6 @@ impl BunBunWorker {
}
};
// TODO handle errors
let correlation_id = match delivery.properties.correlation_id().clone() {
None => {
tracing::warn!("received a job with no correlation id");
@ -207,10 +286,13 @@ impl BunBunWorker {
};
if let Ok(job) = J::decode(delivery.data.clone()) {
// Catching panics
let job = tokio::task::spawn(async move { job.before_job().await })
.await
.unwrap();
let outcome = tokio::task::spawn(async move { job.run(state).await }).await;
match outcome {
Err(error) => {
Err(ref error) => {
tracing::error!("Failed to start thread for worker {}", error);
let headers = create_header(ResultHeader::Panic);
let _ = delivery.ack(BasicAckOptions::default()).await; // acking the delivery
@ -223,7 +305,7 @@ impl BunBunWorker {
)
.await
}
Ok(Ok(res)) => {
Ok(Ok(ref res)) => {
let headers = create_header(ResultHeader::Ok);
let _ = delivery.ack(BasicAckOptions::default()).await; // acking the delivery
respond_to_rpc_queue(
@ -231,11 +313,11 @@ impl BunBunWorker {
routing_key,
headers,
correlation_id,
Some(res),
Some(res.clone()),
)
.await
}
Ok(Err(err)) => {
Ok(Err(ref err)) => {
//
let headers = create_header(ResultHeader::Ok);
let _ = delivery.ack(BasicAckOptions::default()).await; // acking the delivery
@ -244,11 +326,13 @@ impl BunBunWorker {
routing_key,
headers,
correlation_id,
Some(err),
Some(err.clone()),
)
.await
}
}
};
tracing::debug!("Running after job");
let _ = tokio::task::spawn(async move { J::after_job(outcome).await }).await;
} else {
delivery.nack(BasicNackOptions::default()).await.unwrap();
}
@ -260,7 +344,7 @@ impl BunBunWorker {
}
/// Start all the listeners added to the worker object
// TODO implement reconnect
pub async fn start_all_listeners(&self) {
let mut listeners = vec![];
for (handler, consumer) in self.handlers.iter().zip(self.consumers.iter()) {
@ -320,28 +404,36 @@ impl BunBunWorker {
}
}
/// A trait that defines the structure of a task that can be run by the worker
/// The task must be deserializable and serializable
/// The task must have a result and an errored result
/// A trait that defines the structure of a task that can be run by the worker
/// BoxFuture is from a crate called `futures_util`
///
/// # Examples
/// ```
/// #[derive(Debug, Serialize, Deserialize)]
/// struct MyRPCServerTask {
/// struct MyRPCTask {
/// pub name: String,
/// }
/// impl RPCServerTask for MyRPCServerTask {
/// type Result = String;
/// type ErroredResult = String;
/// #[derive(Debug, Serialize, Deserialize)]
/// struct MyRPCTaskResult {
/// pub something: String,
/// }
/// #[derive(Debug, Serialize, Deserialize)]
/// struct MyRPCTaskErroredResult {
/// pub what_failed: String,
/// }
/// impl RPCTask for MyRPCTask {
/// type Result = MyRPCTaskResult;
/// type Error = MyRPCTaskErroredResult;
/// type State = SomeState;
///
/// fn run(self, state: Arc<Self::State>) -> BoxFuture<'static, Result<Self::Result, Self::ErroredResult>> {
/// fn run(self, state: Arc<Self::State>) -> BoxFuture<'static, Result<Self::Result, Self::Error>> {
/// async move {
/// Ok("Hello".to_string())
/// Ok(MyRPCTaskResult{ something: "Hello I ran ok!".into() })
/// }.boxed()
/// }
pub trait RPCServerTask: Sized + Debug + DeserializeOwned {
type Result: Serialize + DeserializeOwned + Debug;
type ErroredResult: Serialize + DeserializeOwned + Debug;
pub trait RPCTask: Sized + Debug + DeserializeOwned {
type Result: Serialize + DeserializeOwned + Debug + Clone;
type ErroredResult: Serialize + DeserializeOwned + Debug + Clone;
type State: Clone + Debug;
/// Decoding for the message. Overriding is possible.
@ -368,11 +460,25 @@ pub trait RPCServerTask: Sized + Debug + DeserializeOwned {
fn display(&self) -> String {
format!("{:?}", self)
}
// TODO add a function that runs after a job is ran
// TODO add a function that runs before a job is about to run
/// A function that runs before the job is ran by the worker. This allows you to modify any values inside it, add tracing ect.
fn before_job(self) -> impl std::future::Future<Output = Self> + Send
where
Self: Send,
{
async move { self }
}
/// A function that runs after the job has ran and the worker has responded to the callback queue. Any modifications made will not be reflected on the client side.
fn after_job(
res: Result<Result<Self::Result, Self::ErroredResult>, JoinError>,
) -> impl std::future::Future<Output = ()> + Send
where
Self: Send,
{
async move {}
}
}
/// A NonRPCServer task
/// A regular task
/// Implement this trait to any struct to make it a runnable `non-rpc` job.
///
/// Examples
@ -384,7 +490,7 @@ pub trait RPCServerTask: Sized + Debug + DeserializeOwned {
/// send_to: String,
/// contents: String,
/// }
/// impl NonRPCServerTask for EmailJob {
/// impl Task for EmailJob {
/// type State = State;
/// fn run(
/// self,
@ -397,7 +503,7 @@ pub trait RPCServerTask: Sized + Debug + DeserializeOwned {
/// }
/// }
/// ```
pub trait NonRPCServerTask: Sized + Debug + DeserializeOwned {
pub trait Task: Sized + Debug + DeserializeOwned {
type State: Clone + Debug;
fn decode(data: Vec<u8>) -> Result<Self, RabbitDecodeError> {
@ -423,8 +529,20 @@ pub trait NonRPCServerTask: Sized + Debug + DeserializeOwned {
format!("{:?}", self)
}
// TODO add a function that runs after a job is ran
// TODO add a function that runs before a job is about to run
/// A function that runs before the job is ran by the worker. This allows you to modify any values inside it, add tracing ect.
fn before_job(self) -> impl std::future::Future<Output = Self> + Send
where
Self: Send,
{
async move { self }
}
/// A function that runs after the job has finished and the worker has acked the request.
fn after_job(self) -> impl std::future::Future<Output = ()> + Send
where
Self: Sync + Send,
{
async move {}
}
}
#[derive(Debug)]
@ -501,7 +619,7 @@ fn create_header(header: ResultHeader) -> FieldTable {
/// A result header that is included in the header of the AMQP message.
/// It indicates the status of the returned message
#[derive(Debug, Serialize, Deserialize)]
pub enum ResultHeader {
enum ResultHeader {
Ok,
Error,
Panic,

View file

@ -15,8 +15,8 @@ mod test {
use tracing_test::traced_test;
use crate::{
client::{BunBunClient, RPCClientTask},
BunBunWorker, RPCServerTask,
client::{Client, RPCClientTask},
RPCTask, Worker, WorkerConfig,
};
#[derive(Clone, Debug)]
@ -43,7 +43,7 @@ mod test {
contents: String,
}
impl RPCServerTask for EmailJob {
impl RPCTask for EmailJob {
type ErroredResult = EmailJobResultError;
type Result = EmailJobResult;
type State = State;
@ -60,8 +60,26 @@ mod test {
});
})
}
async fn before_job(self) -> Self
where
Self: Send,
{
tracing::info!(
"hello i received a job of type: {}",
std::any::type_name::<Self>()
);
self
}
fn after_job(
res: Result<Result<Self::Result, Self::ErroredResult>, tokio::task::JoinError>,
) -> impl std::future::Future<Output = ()> + Send
where
Self: Send,
{
async move { tracing::info!("Job has ran with result: {:?}", res) }
}
}
impl RPCServerTask for PanickingEmailJob {
impl RPCTask for PanickingEmailJob {
type ErroredResult = EmailJobResultError;
type Result = EmailJobResult;
type State = State;
@ -81,10 +99,13 @@ mod test {
#[test(tokio::test)]
#[traced_test]
async fn rpc() {
async fn rpc_server() {
//
let mut listener =
BunBunWorker::new(env::var("AMQP_SERVER_URL").unwrap(), 100.into()).await;
let mut listener = Worker::new(
env::var("AMQP_SERVER_URL").unwrap(),
WorkerConfig::default(),
)
.await;
listener
.add_rpc_consumer::<EmailJob>(
"email-emailjob-v1.0.0",
@ -101,8 +122,11 @@ mod test {
#[traced_test]
async fn rpc_that_will_panic() {
//
let mut listener =
BunBunWorker::new(env::var("AMQP_SERVER_URL").unwrap(), 100.into()).await;
let mut listener = Worker::new(
env::var("AMQP_SERVER_URL").unwrap(),
WorkerConfig::default(),
)
.await;
listener
.add_rpc_consumer::<PanickingEmailJob>(
"email-emailjob-v1.0.0",
@ -119,7 +143,7 @@ mod test {
#[traced_test]
async fn rpc_client() {
//
let mut client = BunBunClient::new(env::var("AMQP_SERVER_URL").unwrap().as_str())
let mut client = Client::new(env::var("AMQP_SERVER_URL").unwrap().as_str())
.await
.unwrap();
let result = client
@ -143,7 +167,7 @@ mod test {
#[test(tokio::test)]
#[traced_test]
async fn rpc_client_spam_multithread() {
let client = BunBunClient::new(env::var("AMQP_SERVER_URL").unwrap().as_str())
let client = Client::new(env::var("AMQP_SERVER_URL").unwrap().as_str())
.await
.unwrap();
let client = Arc::new(Mutex::new(client));