refactored the whole project, now its built up to event handlers with a commands folder for easy expendability
added a dm ignore feature
made a todo list
added a bit of logging
This commit is contained in:
2005 2024-05-08 22:37:29 +02:00
parent 9bf77013fd
commit fb251b1f3c
10 changed files with 319 additions and 115 deletions

154
Cargo.lock generated
View file

@ -380,6 +380,8 @@ name = "celestial"
version = "0.1.0"
dependencies = [
"anyhow",
"derive-getters",
"derive_builder",
"dotenv",
"env_logger",
"envy",
@ -557,6 +559,41 @@ dependencies = [
"syn 2.0.60",
]
[[package]]
name = "darling"
version = "0.20.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54e36fcd13ed84ffdfda6f5be89b31287cbb80c439841fe69e04841435464391"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.20.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c2cf1c23a687a1feeb728783b993c4e1ad83d99f351801977dd809b48d0a70f"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim",
"syn 2.0.60",
]
[[package]]
name = "darling_macro"
version = "0.20.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f"
dependencies = [
"darling_core",
"quote",
"syn 2.0.60",
]
[[package]]
name = "deadpool"
version = "0.10.0"
@ -633,6 +670,48 @@ dependencies = [
"syn 2.0.60",
]
[[package]]
name = "derive-getters"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a2c35ab6e03642397cdda1dd58abbc05d418aef8e36297f336d5aba060fe8df"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "derive_builder"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7"
dependencies = [
"derive_builder_macro",
]
[[package]]
name = "derive_builder_core"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn 2.0.60",
]
[[package]]
name = "derive_builder_macro"
version = "0.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b"
dependencies = [
"derive_builder_core",
"syn 2.0.60",
]
[[package]]
name = "digest"
version = "0.10.7"
@ -1222,6 +1301,12 @@ dependencies = [
"tracing",
]
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "idna"
version = "0.5.0"
@ -1435,6 +1520,16 @@ version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
[[package]]
name = "lock_api"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17"
dependencies = [
"autocfg",
"scopeguard",
]
[[package]]
name = "log"
version = "0.4.21"
@ -1823,6 +1918,8 @@ dependencies = [
"reqwest 0.12.4",
"serde",
"serde_json",
"tokio",
"tokio-stream",
]
[[package]]
@ -1893,6 +1990,29 @@ version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae"
[[package]]
name = "parking_lot"
version = "0.12.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e4af0ca4f6caed20e900d564c242b8e5d4903fdacf31d3daf527b66fe6f42fb"
dependencies = [
"lock_api",
"parking_lot_core",
]
[[package]]
name = "parking_lot_core"
version = "0.9.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8"
dependencies = [
"cfg-if",
"libc",
"redox_syscall",
"smallvec",
"windows-targets 0.52.5",
]
[[package]]
name = "paste"
version = "1.0.14"
@ -2127,6 +2247,15 @@ version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7b323e7196daa571c8584de958be19e92941c41f845776fe06babfe8fa280a2"
[[package]]
name = "redox_syscall"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e"
dependencies = [
"bitflags 2.5.0",
]
[[package]]
name = "regex"
version = "1.10.4"
@ -2229,10 +2358,12 @@ dependencies = [
"sync_wrapper",
"tokio",
"tokio-native-tls",
"tokio-util",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"winreg 0.52.0",
]
@ -2468,6 +2599,12 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "security-framework"
version = "2.10.0"
@ -2602,6 +2739,15 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "signal-hook-registry"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1"
dependencies = [
"libc",
]
[[package]]
name = "signature"
version = "2.2.0"
@ -2646,6 +2792,12 @@ dependencies = [
"der",
]
[[package]]
name = "strsim"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "subtle"
version = "2.5.0"
@ -2769,7 +2921,9 @@ dependencies = [
"libc",
"mio",
"num_cpus",
"parking_lot",
"pin-project-lite",
"signal-hook-registry",
"socket2",
"tokio-macros",
"windows-sys 0.48.0",

View file

@ -8,12 +8,14 @@ license = "MIT"
[dependencies]
anyhow = "1.0.82"
derive-getters = "0.3.0"
derive_builder = "0.20.0"
dotenv = "0.15.0"
env_logger = "0.11.3"
envy = "0.4.2"
log = "0.4.21"
matrix-sdk = "0.7.1"
ollama-rs = "0.1.9"
ollama-rs = {version="0.1.9", features = ["stream"]}
serde = "1.0.200"
tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] }
tracing-subscriber = "0.3.15"

View file

@ -2,3 +2,10 @@
This is just a funny project I slapped together in a span of around 30 minutes.
I plan on expanding it with other features in the future.
## Features
- [] History
- [] Per room history
- [] quality of life features
- [] Nix service and flake for ez install

59
src/commands/chatbot.rs Normal file
View file

@ -0,0 +1,59 @@
use std::time::Instant;
use anyhow::Ok;
use matrix_sdk::{
event_handler::Ctx,
ruma::events::room::message::{RoomMessageEventContent, SyncRoomMessageEvent},
Room,
};
use ollama_rs::generation::{
chat::{request::ChatMessageRequest, ChatMessage, ChatMessageResponseStream},
completion::request::GenerationRequest,
};
use crate::structs::BotContext;
pub async fn chat(ev: SyncRoomMessageEvent, room: Room, context: Ctx<BotContext>) {
log::debug!("Prompting ollama");
// Send seen
room.send_single_receipt(
matrix_sdk::ruma::api::client::receipt::create_receipt::v3::ReceiptType::Read,
matrix_sdk::ruma::events::receipt::ReceiptThread::Main,
ev.event_id().to_owned(),
)
.await;
room.typing_notice(true).await;
let prompt = format!(
"{} says: {}",
room.get_member(&ev.sender())
.await
.unwrap()
.unwrap()
.display_name()
.unwrap(),
ev.as_original()
.unwrap()
.content
.body()
.replace(context.config().prefix().as_str(), "")
);
let now = Instant::now();
let res = context
.ai()
.generate(GenerationRequest::new(
context.config().ollama_model().clone(),
prompt,
))
.await
.unwrap();
let mut msg = res.response.to_string();
msg.push_str(format!("\n prompt took {:?}", now.elapsed(),).as_str());
let content = RoomMessageEventContent::text_plain(msg);
room.send(content).await.unwrap();
room.typing_notice(false).await;
}

2
src/commands/mod.rs Normal file
View file

@ -0,0 +1,2 @@
pub mod chatbot;
pub mod model;

3
src/commands/model.rs Normal file
View file

@ -0,0 +1,3 @@
pub fn model() {
//
}

27
src/events/message.rs Normal file
View file

@ -0,0 +1,27 @@
use matrix_sdk::{event_handler::Ctx, ruma::events::room::message::SyncRoomMessageEvent, Room};
use crate::{commands::chatbot::chat, structs::BotContext};
pub async fn process_message(ev: SyncRoomMessageEvent, room: Room, context: Ctx<BotContext>) {
log::debug!("Processing message");
// Ignore if dm's are disabled
if room.is_direct().await.unwrap() && context.config().enable_dms() == &false {
log::debug!("Message is DM, ignoring due to configuration");
return;
}
// ignore message if its from the bot
if &ev.as_original().unwrap().sender.to_string() == context.bot_id() {
return;
}
let msg = ev.as_original().unwrap().content.body().to_string();
if msg.contains(context.config().prefix()) {
let parameters: Vec<&str> = msg.split(" ").collect();
match parameters[0] {
//"model" =>
_ => chat(ev, room, context).await,
};
}
}

1
src/events/mod.rs Normal file
View file

@ -0,0 +1 @@
pub mod message;

View file

@ -1,152 +1,76 @@
use std::time::Instant;
use events::message::process_message;
use matrix_sdk::{
config::SyncSettings,
event_handler::Ctx,
ruma::events::room::{
member::StrippedRoomMemberEvent,
message::{RoomMessageEventContent, SyncRoomMessageEvent},
},
ruma::events::room::{member::StrippedRoomMemberEvent, message::SyncRoomMessageEvent},
Client, Room,
};
use ollama_rs::{generation::completion::request::GenerationRequest, Ollama};
use serde::Deserialize;
use ollama_rs::Ollama;
use structs::BotContextBuilder;
use tokio::time::{sleep, Duration};
mod commands;
mod events;
mod structs;
use crate::structs::{BotContext, Config};
#[derive(Deserialize, Debug)]
struct Config {
homeserver_url: String,
username: String,
password: String,
ollama_host: String,
ollama_port: i16,
ollama_model: String,
prefix: String,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();
let _ = dotenv::dotenv();
env_logger::init();
log::info!("Bot has started");
let c = envy::from_env::<Config>().expect("Please provide variables");
login_and_sync(
c.homeserver_url.to_string(),
&c.username,
&c.password,
c.prefix,
c.ollama_host,
c.ollama_port,
c.ollama_model,
)
.await?;
login_and_sync(c).await?;
Ok(())
}
// The core sync loop we have running.
async fn login_and_sync(
homeserver_url: String,
username: &str,
password: &str,
prefix: String,
ollama_host: String,
ollama_port: i16,
ollama_model: String,
) -> anyhow::Result<()> {
async fn login_and_sync(config: Config) -> anyhow::Result<()> {
log::info!("client init");
let client = Client::builder()
.homeserver_url(homeserver_url)
.homeserver_url(config.homeserver_url())
.build()
.await?;
client
.matrix_auth()
.login_username(username, password)
.login_username(config.username(), config.password())
.initial_device_display_name("getting started bot")
.await?;
println!("logged in as {username}");
let ollama = Ollama::new(
config.ollama_host().clone().to_owned(),
config.ollama_port().clone().to_owned() as u16,
);
let ollama = Ollama::new(ollama_host, ollama_port as u16);
let ctx = BotContextBuilder::default()
.ai(ollama)
.bot_id(client.clone().user_id().unwrap().to_string())
.config(config)
.build()
.unwrap();
client.add_event_handler_context(ctx);
log::info!("adding handlers");
client.add_event_handler(on_stripped_state_member);
// TODO this is something funny, i dont get why it doenst work this way
//client.add_event_handler(process_message);
client.add_event_handler(
|ev: SyncRoomMessageEvent, room: Room, context: Ctx<BotContext>| async move {
process_message(ev.clone(), room, context).await;
log::debug!(
"received message: {}",
ev.as_original().unwrap().content.body().to_string()
);
},
);
let sync_token = client
.sync_once(SyncSettings::default())
.await
.unwrap()
.next_batch;
#[derive(Debug, Clone)]
struct MyContext {
botId: String,
ai: Ollama,
}
client.add_event_handler_context(MyContext {
botId: client.clone().user_id().unwrap().to_string(),
ai: ollama.clone(),
});
client.add_event_handler(
|ev: SyncRoomMessageEvent, room: Room, context: Ctx<MyContext>| async move {
if ev.as_original().unwrap().sender.to_string() == context.botId {
return;
};
// add . prefix
if !ev
.as_original()
.unwrap()
.content
.body()
.to_string()
.contains(prefix.as_str())
{
return;
}
// Send seen
room.send_single_receipt(
matrix_sdk::ruma::api::client::receipt::create_receipt::v3::ReceiptType::Read,
matrix_sdk::ruma::events::receipt::ReceiptThread::Main,
ev.event_id().to_owned(),
)
.await;
room.typing_notice(true).await;
let prompt = format!(
"{} says: {}",
room.get_member(&ev.sender())
.await
.unwrap()
.unwrap()
.display_name()
.unwrap(),
ev.as_original()
.unwrap()
.content
.body()
.replace(prefix.as_str(), "")
);
let now = Instant::now();
let res = ollama
.generate(GenerationRequest::new(ollama_model, prompt))
.await;
if let Ok(res) = res {
let mut asd = res.response.to_string();
asd.push_str(format!("\n prompt took {:?}", now.elapsed()).as_str());
let content = RoomMessageEventContent::text_plain(asd);
println!("Got res from ai: {}", res.clone().response);
room.send(content).await.unwrap();
room.typing_notice(false).await;
}
},
);
let settings = SyncSettings::default().token(sync_token);
client.sync(settings).await?;

25
src/structs/mod.rs Normal file
View file

@ -0,0 +1,25 @@
use derive_builder::Builder;
use derive_getters::Getters;
use ollama_rs::Ollama;
use serde::Deserialize;
#[derive(Deserialize, Debug, Clone, Getters)]
pub struct Config {
homeserver_url: String,
username: String,
password: String,
ollama_host: String,
ollama_port: i16,
ollama_model: String,
// bot settings
prefix: String,
enable_dms: bool,
}
#[derive(Debug, Clone, Getters, Builder)]
pub struct BotContext {
bot_id: String,
ai: Ollama,
config: Config,
}