diff --git a/Cargo.lock b/Cargo.lock index 7a664b9..cd2abf0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -380,6 +380,8 @@ name = "celestial" version = "0.1.0" dependencies = [ "anyhow", + "derive-getters", + "derive_builder", "dotenv", "env_logger", "envy", @@ -557,6 +559,41 @@ dependencies = [ "syn 2.0.60", ] +[[package]] +name = "darling" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54e36fcd13ed84ffdfda6f5be89b31287cbb80c439841fe69e04841435464391" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c2cf1c23a687a1feeb728783b993c4e1ad83d99f351801977dd809b48d0a70f" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.60", +] + +[[package]] +name = "darling_macro" +version = "0.20.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.60", +] + [[package]] name = "deadpool" version = "0.10.0" @@ -633,6 +670,48 @@ dependencies = [ "syn 2.0.60", ] +[[package]] +name = "derive-getters" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2c35ab6e03642397cdda1dd58abbc05d418aef8e36297f336d5aba060fe8df" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_builder" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.60", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" +dependencies = [ + "derive_builder_core", + "syn 2.0.60", +] + [[package]] name = "digest" version = "0.10.7" @@ -1222,6 +1301,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "idna" version = "0.5.0" @@ -1435,6 +1520,16 @@ version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" +[[package]] +name = "lock_api" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +dependencies = [ + "autocfg", + "scopeguard", +] + [[package]] name = "log" version = "0.4.21" @@ -1823,6 +1918,8 @@ dependencies = [ "reqwest 0.12.4", "serde", "serde_json", + "tokio", + "tokio-stream", ] [[package]] @@ -1893,6 +1990,29 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" +[[package]] +name = "parking_lot" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e4af0ca4f6caed20e900d564c242b8e5d4903fdacf31d3daf527b66fe6f42fb" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets 0.52.5", +] + [[package]] name = "paste" version = "1.0.14" @@ -2127,6 +2247,15 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7b323e7196daa571c8584de958be19e92941c41f845776fe06babfe8fa280a2" +[[package]] +name = "redox_syscall" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" +dependencies = [ + "bitflags 2.5.0", +] + [[package]] name = "regex" version = "1.10.4" @@ -2229,10 +2358,12 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-native-tls", + "tokio-util", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "winreg 0.52.0", ] @@ -2468,6 +2599,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + [[package]] name = "security-framework" version = "2.10.0" @@ -2602,6 +2739,15 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + [[package]] name = "signature" version = "2.2.0" @@ -2646,6 +2792,12 @@ dependencies = [ "der", ] +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + [[package]] name = "subtle" version = "2.5.0" @@ -2769,7 +2921,9 @@ dependencies = [ "libc", "mio", "num_cpus", + "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.48.0", diff --git a/Cargo.toml b/Cargo.toml index 6311663..0553b5e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,12 +8,14 @@ license = "MIT" [dependencies] anyhow = "1.0.82" +derive-getters = "0.3.0" +derive_builder = "0.20.0" dotenv = "0.15.0" env_logger = "0.11.3" envy = "0.4.2" log = "0.4.21" matrix-sdk = "0.7.1" -ollama-rs = "0.1.9" +ollama-rs = {version="0.1.9", features = ["stream"]} serde = "1.0.200" tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] } tracing-subscriber = "0.3.15" diff --git a/readme.md b/readme.md index b3bf890..08e4ddb 100644 --- a/readme.md +++ b/readme.md @@ -2,3 +2,10 @@ This is just a funny project I slapped together in a span of around 30 minutes. I plan on expanding it with other features in the future. + +## Features + +- [] History + - [] Per room history +- [] quality of life features +- [] Nix service and flake for ez install diff --git a/src/commands/chatbot.rs b/src/commands/chatbot.rs new file mode 100644 index 0000000..38da8ac --- /dev/null +++ b/src/commands/chatbot.rs @@ -0,0 +1,59 @@ +use std::time::Instant; + +use anyhow::Ok; +use matrix_sdk::{ + event_handler::Ctx, + ruma::events::room::message::{RoomMessageEventContent, SyncRoomMessageEvent}, + Room, +}; +use ollama_rs::generation::{ + chat::{request::ChatMessageRequest, ChatMessage, ChatMessageResponseStream}, + completion::request::GenerationRequest, +}; + +use crate::structs::BotContext; + +pub async fn chat(ev: SyncRoomMessageEvent, room: Room, context: Ctx) { + log::debug!("Prompting ollama"); + // Send seen + room.send_single_receipt( + matrix_sdk::ruma::api::client::receipt::create_receipt::v3::ReceiptType::Read, + matrix_sdk::ruma::events::receipt::ReceiptThread::Main, + ev.event_id().to_owned(), + ) + .await; + + room.typing_notice(true).await; + + let prompt = format!( + "{} says: {}", + room.get_member(&ev.sender()) + .await + .unwrap() + .unwrap() + .display_name() + .unwrap(), + ev.as_original() + .unwrap() + .content + .body() + .replace(context.config().prefix().as_str(), "") + ); + let now = Instant::now(); + + let res = context + .ai() + .generate(GenerationRequest::new( + context.config().ollama_model().clone(), + prompt, + )) + .await + .unwrap(); + + let mut msg = res.response.to_string(); + msg.push_str(format!("\n prompt took {:?}", now.elapsed(),).as_str()); + + let content = RoomMessageEventContent::text_plain(msg); + room.send(content).await.unwrap(); + room.typing_notice(false).await; +} diff --git a/src/commands/mod.rs b/src/commands/mod.rs new file mode 100644 index 0000000..85f1d58 --- /dev/null +++ b/src/commands/mod.rs @@ -0,0 +1,2 @@ +pub mod chatbot; +pub mod model; diff --git a/src/commands/model.rs b/src/commands/model.rs new file mode 100644 index 0000000..7eb97ab --- /dev/null +++ b/src/commands/model.rs @@ -0,0 +1,3 @@ +pub fn model() { + // +} diff --git a/src/events/message.rs b/src/events/message.rs new file mode 100644 index 0000000..6f8009d --- /dev/null +++ b/src/events/message.rs @@ -0,0 +1,27 @@ +use matrix_sdk::{event_handler::Ctx, ruma::events::room::message::SyncRoomMessageEvent, Room}; + +use crate::{commands::chatbot::chat, structs::BotContext}; + +pub async fn process_message(ev: SyncRoomMessageEvent, room: Room, context: Ctx) { + log::debug!("Processing message"); + + // Ignore if dm's are disabled + if room.is_direct().await.unwrap() && context.config().enable_dms() == &false { + log::debug!("Message is DM, ignoring due to configuration"); + return; + } + + // ignore message if its from the bot + if &ev.as_original().unwrap().sender.to_string() == context.bot_id() { + return; + } + + let msg = ev.as_original().unwrap().content.body().to_string(); + if msg.contains(context.config().prefix()) { + let parameters: Vec<&str> = msg.split(" ").collect(); + match parameters[0] { + //"model" => + _ => chat(ev, room, context).await, + }; + } +} diff --git a/src/events/mod.rs b/src/events/mod.rs new file mode 100644 index 0000000..e216a50 --- /dev/null +++ b/src/events/mod.rs @@ -0,0 +1 @@ +pub mod message; diff --git a/src/main.rs b/src/main.rs index a695645..70e05ef 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,152 +1,76 @@ -use std::time::Instant; - +use events::message::process_message; use matrix_sdk::{ config::SyncSettings, event_handler::Ctx, - ruma::events::room::{ - member::StrippedRoomMemberEvent, - message::{RoomMessageEventContent, SyncRoomMessageEvent}, - }, + ruma::events::room::{member::StrippedRoomMemberEvent, message::SyncRoomMessageEvent}, Client, Room, }; -use ollama_rs::{generation::completion::request::GenerationRequest, Ollama}; -use serde::Deserialize; +use ollama_rs::Ollama; +use structs::BotContextBuilder; use tokio::time::{sleep, Duration}; +mod commands; +mod events; +mod structs; +use crate::structs::{BotContext, Config}; -#[derive(Deserialize, Debug)] -struct Config { - homeserver_url: String, - username: String, - password: String, - - ollama_host: String, - ollama_port: i16, - ollama_model: String, - - prefix: String, -} #[tokio::main] async fn main() -> anyhow::Result<()> { - tracing_subscriber::fmt::init(); let _ = dotenv::dotenv(); + env_logger::init(); log::info!("Bot has started"); let c = envy::from_env::().expect("Please provide variables"); - login_and_sync( - c.homeserver_url.to_string(), - &c.username, - &c.password, - c.prefix, - c.ollama_host, - c.ollama_port, - c.ollama_model, - ) - .await?; + login_and_sync(c).await?; Ok(()) } // The core sync loop we have running. -async fn login_and_sync( - homeserver_url: String, - username: &str, - password: &str, - prefix: String, - ollama_host: String, - ollama_port: i16, - ollama_model: String, -) -> anyhow::Result<()> { +async fn login_and_sync(config: Config) -> anyhow::Result<()> { + log::info!("client init"); let client = Client::builder() - .homeserver_url(homeserver_url) + .homeserver_url(config.homeserver_url()) .build() .await?; client .matrix_auth() - .login_username(username, password) + .login_username(config.username(), config.password()) .initial_device_display_name("getting started bot") .await?; - println!("logged in as {username}"); + let ollama = Ollama::new( + config.ollama_host().clone().to_owned(), + config.ollama_port().clone().to_owned() as u16, + ); - let ollama = Ollama::new(ollama_host, ollama_port as u16); + let ctx = BotContextBuilder::default() + .ai(ollama) + .bot_id(client.clone().user_id().unwrap().to_string()) + .config(config) + .build() + .unwrap(); + client.add_event_handler_context(ctx); + log::info!("adding handlers"); client.add_event_handler(on_stripped_state_member); + // TODO this is something funny, i dont get why it doenst work this way + //client.add_event_handler(process_message); + client.add_event_handler( + |ev: SyncRoomMessageEvent, room: Room, context: Ctx| async move { + process_message(ev.clone(), room, context).await; + log::debug!( + "received message: {}", + ev.as_original().unwrap().content.body().to_string() + ); + }, + ); + let sync_token = client .sync_once(SyncSettings::default()) .await .unwrap() .next_batch; - #[derive(Debug, Clone)] - struct MyContext { - botId: String, - ai: Ollama, - } - client.add_event_handler_context(MyContext { - botId: client.clone().user_id().unwrap().to_string(), - ai: ollama.clone(), - }); - - client.add_event_handler( - |ev: SyncRoomMessageEvent, room: Room, context: Ctx| async move { - if ev.as_original().unwrap().sender.to_string() == context.botId { - return; - }; - - // add . prefix - if !ev - .as_original() - .unwrap() - .content - .body() - .to_string() - .contains(prefix.as_str()) - { - return; - } - - // Send seen - room.send_single_receipt( - matrix_sdk::ruma::api::client::receipt::create_receipt::v3::ReceiptType::Read, - matrix_sdk::ruma::events::receipt::ReceiptThread::Main, - ev.event_id().to_owned(), - ) - .await; - - room.typing_notice(true).await; - - let prompt = format!( - "{} says: {}", - room.get_member(&ev.sender()) - .await - .unwrap() - .unwrap() - .display_name() - .unwrap(), - ev.as_original() - .unwrap() - .content - .body() - .replace(prefix.as_str(), "") - ); - let now = Instant::now(); - let res = ollama - .generate(GenerationRequest::new(ollama_model, prompt)) - .await; - - if let Ok(res) = res { - let mut asd = res.response.to_string(); - asd.push_str(format!("\n prompt took {:?}", now.elapsed()).as_str()); - - let content = RoomMessageEventContent::text_plain(asd); - println!("Got res from ai: {}", res.clone().response); - room.send(content).await.unwrap(); - - room.typing_notice(false).await; - } - }, - ); - let settings = SyncSettings::default().token(sync_token); client.sync(settings).await?; diff --git a/src/structs/mod.rs b/src/structs/mod.rs new file mode 100644 index 0000000..2f88cac --- /dev/null +++ b/src/structs/mod.rs @@ -0,0 +1,25 @@ +use derive_builder::Builder; +use derive_getters::Getters; +use ollama_rs::Ollama; +use serde::Deserialize; + +#[derive(Deserialize, Debug, Clone, Getters)] +pub struct Config { + homeserver_url: String, + username: String, + password: String, + + ollama_host: String, + ollama_port: i16, + ollama_model: String, + // bot settings + prefix: String, + enable_dms: bool, +} + +#[derive(Debug, Clone, Getters, Builder)] +pub struct BotContext { + bot_id: String, + ai: Ollama, + config: Config, +}