refactor
refactored the whole project, now its built up to event handlers with a commands folder for easy expendability added a dm ignore feature made a todo list added a bit of logging
This commit is contained in:
parent
9bf77013fd
commit
fb251b1f3c
154
Cargo.lock
generated
154
Cargo.lock
generated
|
@ -380,6 +380,8 @@ name = "celestial"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"derive-getters",
|
||||
"derive_builder",
|
||||
"dotenv",
|
||||
"env_logger",
|
||||
"envy",
|
||||
|
@ -557,6 +559,41 @@ dependencies = [
|
|||
"syn 2.0.60",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.20.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "54e36fcd13ed84ffdfda6f5be89b31287cbb80c439841fe69e04841435464391"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"darling_macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_core"
|
||||
version = "0.20.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c2cf1c23a687a1feeb728783b993c4e1ad83d99f351801977dd809b48d0a70f"
|
||||
dependencies = [
|
||||
"fnv",
|
||||
"ident_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"strsim",
|
||||
"syn 2.0.60",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_macro"
|
||||
version = "0.20.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"quote",
|
||||
"syn 2.0.60",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deadpool"
|
||||
version = "0.10.0"
|
||||
|
@ -633,6 +670,48 @@ dependencies = [
|
|||
"syn 2.0.60",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive-getters"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a2c35ab6e03642397cdda1dd58abbc05d418aef8e36297f336d5aba060fe8df"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_builder"
|
||||
version = "0.20.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7"
|
||||
dependencies = [
|
||||
"derive_builder_macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_builder_core"
|
||||
version = "0.20.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d"
|
||||
dependencies = [
|
||||
"darling",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.60",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_builder_macro"
|
||||
version = "0.20.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b"
|
||||
dependencies = [
|
||||
"derive_builder_core",
|
||||
"syn 2.0.60",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
|
@ -1222,6 +1301,12 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ident_case"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "0.5.0"
|
||||
|
@ -1435,6 +1520,16 @@ version = "0.4.13"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"scopeguard",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.21"
|
||||
|
@ -1823,6 +1918,8 @@ dependencies = [
|
|||
"reqwest 0.12.4",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"tokio",
|
||||
"tokio-stream",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1893,6 +1990,29 @@ version = "2.2.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae"
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot"
|
||||
version = "0.12.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7e4af0ca4f6caed20e900d564c242b8e5d4903fdacf31d3daf527b66fe6f42fb"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
"parking_lot_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "parking_lot_core"
|
||||
version = "0.9.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"redox_syscall",
|
||||
"smallvec",
|
||||
"windows-targets 0.52.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "paste"
|
||||
version = "1.0.14"
|
||||
|
@ -2127,6 +2247,15 @@ version = "0.1.7"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d7b323e7196daa571c8584de958be19e92941c41f845776fe06babfe8fa280a2"
|
||||
|
||||
[[package]]
|
||||
name = "redox_syscall"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e"
|
||||
dependencies = [
|
||||
"bitflags 2.5.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.10.4"
|
||||
|
@ -2229,10 +2358,12 @@ dependencies = [
|
|||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-native-tls",
|
||||
"tokio-util",
|
||||
"tower-service",
|
||||
"url",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"wasm-streams",
|
||||
"web-sys",
|
||||
"winreg 0.52.0",
|
||||
]
|
||||
|
@ -2468,6 +2599,12 @@ dependencies = [
|
|||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||
|
||||
[[package]]
|
||||
name = "security-framework"
|
||||
version = "2.10.0"
|
||||
|
@ -2602,6 +2739,15 @@ dependencies = [
|
|||
"lazy_static",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-registry"
|
||||
version = "1.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "signature"
|
||||
version = "2.2.0"
|
||||
|
@ -2646,6 +2792,12 @@ dependencies = [
|
|||
"der",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||
|
||||
[[package]]
|
||||
name = "subtle"
|
||||
version = "2.5.0"
|
||||
|
@ -2769,7 +2921,9 @@ dependencies = [
|
|||
"libc",
|
||||
"mio",
|
||||
"num_cpus",
|
||||
"parking_lot",
|
||||
"pin-project-lite",
|
||||
"signal-hook-registry",
|
||||
"socket2",
|
||||
"tokio-macros",
|
||||
"windows-sys 0.48.0",
|
||||
|
|
|
@ -8,12 +8,14 @@ license = "MIT"
|
|||
|
||||
[dependencies]
|
||||
anyhow = "1.0.82"
|
||||
derive-getters = "0.3.0"
|
||||
derive_builder = "0.20.0"
|
||||
dotenv = "0.15.0"
|
||||
env_logger = "0.11.3"
|
||||
envy = "0.4.2"
|
||||
log = "0.4.21"
|
||||
matrix-sdk = "0.7.1"
|
||||
ollama-rs = "0.1.9"
|
||||
ollama-rs = {version="0.1.9", features = ["stream"]}
|
||||
serde = "1.0.200"
|
||||
tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] }
|
||||
tracing-subscriber = "0.3.15"
|
||||
|
|
|
@ -2,3 +2,10 @@
|
|||
|
||||
This is just a funny project I slapped together in a span of around 30 minutes.
|
||||
I plan on expanding it with other features in the future.
|
||||
|
||||
## Features
|
||||
|
||||
- [] History
|
||||
- [] Per room history
|
||||
- [] quality of life features
|
||||
- [] Nix service and flake for ez install
|
||||
|
|
59
src/commands/chatbot.rs
Normal file
59
src/commands/chatbot.rs
Normal file
|
@ -0,0 +1,59 @@
|
|||
use std::time::Instant;
|
||||
|
||||
use anyhow::Ok;
|
||||
use matrix_sdk::{
|
||||
event_handler::Ctx,
|
||||
ruma::events::room::message::{RoomMessageEventContent, SyncRoomMessageEvent},
|
||||
Room,
|
||||
};
|
||||
use ollama_rs::generation::{
|
||||
chat::{request::ChatMessageRequest, ChatMessage, ChatMessageResponseStream},
|
||||
completion::request::GenerationRequest,
|
||||
};
|
||||
|
||||
use crate::structs::BotContext;
|
||||
|
||||
pub async fn chat(ev: SyncRoomMessageEvent, room: Room, context: Ctx<BotContext>) {
|
||||
log::debug!("Prompting ollama");
|
||||
// Send seen
|
||||
room.send_single_receipt(
|
||||
matrix_sdk::ruma::api::client::receipt::create_receipt::v3::ReceiptType::Read,
|
||||
matrix_sdk::ruma::events::receipt::ReceiptThread::Main,
|
||||
ev.event_id().to_owned(),
|
||||
)
|
||||
.await;
|
||||
|
||||
room.typing_notice(true).await;
|
||||
|
||||
let prompt = format!(
|
||||
"{} says: {}",
|
||||
room.get_member(&ev.sender())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.display_name()
|
||||
.unwrap(),
|
||||
ev.as_original()
|
||||
.unwrap()
|
||||
.content
|
||||
.body()
|
||||
.replace(context.config().prefix().as_str(), "")
|
||||
);
|
||||
let now = Instant::now();
|
||||
|
||||
let res = context
|
||||
.ai()
|
||||
.generate(GenerationRequest::new(
|
||||
context.config().ollama_model().clone(),
|
||||
prompt,
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut msg = res.response.to_string();
|
||||
msg.push_str(format!("\n prompt took {:?}", now.elapsed(),).as_str());
|
||||
|
||||
let content = RoomMessageEventContent::text_plain(msg);
|
||||
room.send(content).await.unwrap();
|
||||
room.typing_notice(false).await;
|
||||
}
|
2
src/commands/mod.rs
Normal file
2
src/commands/mod.rs
Normal file
|
@ -0,0 +1,2 @@
|
|||
pub mod chatbot;
|
||||
pub mod model;
|
3
src/commands/model.rs
Normal file
3
src/commands/model.rs
Normal file
|
@ -0,0 +1,3 @@
|
|||
pub fn model() {
|
||||
//
|
||||
}
|
27
src/events/message.rs
Normal file
27
src/events/message.rs
Normal file
|
@ -0,0 +1,27 @@
|
|||
use matrix_sdk::{event_handler::Ctx, ruma::events::room::message::SyncRoomMessageEvent, Room};
|
||||
|
||||
use crate::{commands::chatbot::chat, structs::BotContext};
|
||||
|
||||
pub async fn process_message(ev: SyncRoomMessageEvent, room: Room, context: Ctx<BotContext>) {
|
||||
log::debug!("Processing message");
|
||||
|
||||
// Ignore if dm's are disabled
|
||||
if room.is_direct().await.unwrap() && context.config().enable_dms() == &false {
|
||||
log::debug!("Message is DM, ignoring due to configuration");
|
||||
return;
|
||||
}
|
||||
|
||||
// ignore message if its from the bot
|
||||
if &ev.as_original().unwrap().sender.to_string() == context.bot_id() {
|
||||
return;
|
||||
}
|
||||
|
||||
let msg = ev.as_original().unwrap().content.body().to_string();
|
||||
if msg.contains(context.config().prefix()) {
|
||||
let parameters: Vec<&str> = msg.split(" ").collect();
|
||||
match parameters[0] {
|
||||
//"model" =>
|
||||
_ => chat(ev, room, context).await,
|
||||
};
|
||||
}
|
||||
}
|
1
src/events/mod.rs
Normal file
1
src/events/mod.rs
Normal file
|
@ -0,0 +1 @@
|
|||
pub mod message;
|
152
src/main.rs
152
src/main.rs
|
@ -1,152 +1,76 @@
|
|||
use std::time::Instant;
|
||||
|
||||
use events::message::process_message;
|
||||
use matrix_sdk::{
|
||||
config::SyncSettings,
|
||||
event_handler::Ctx,
|
||||
ruma::events::room::{
|
||||
member::StrippedRoomMemberEvent,
|
||||
message::{RoomMessageEventContent, SyncRoomMessageEvent},
|
||||
},
|
||||
ruma::events::room::{member::StrippedRoomMemberEvent, message::SyncRoomMessageEvent},
|
||||
Client, Room,
|
||||
};
|
||||
use ollama_rs::{generation::completion::request::GenerationRequest, Ollama};
|
||||
use serde::Deserialize;
|
||||
use ollama_rs::Ollama;
|
||||
use structs::BotContextBuilder;
|
||||
use tokio::time::{sleep, Duration};
|
||||
mod commands;
|
||||
mod events;
|
||||
mod structs;
|
||||
use crate::structs::{BotContext, Config};
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct Config {
|
||||
homeserver_url: String,
|
||||
username: String,
|
||||
password: String,
|
||||
|
||||
ollama_host: String,
|
||||
ollama_port: i16,
|
||||
ollama_model: String,
|
||||
|
||||
prefix: String,
|
||||
}
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
let _ = dotenv::dotenv();
|
||||
env_logger::init();
|
||||
log::info!("Bot has started");
|
||||
let c = envy::from_env::<Config>().expect("Please provide variables");
|
||||
|
||||
login_and_sync(
|
||||
c.homeserver_url.to_string(),
|
||||
&c.username,
|
||||
&c.password,
|
||||
c.prefix,
|
||||
c.ollama_host,
|
||||
c.ollama_port,
|
||||
c.ollama_model,
|
||||
)
|
||||
.await?;
|
||||
login_and_sync(c).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// The core sync loop we have running.
|
||||
async fn login_and_sync(
|
||||
homeserver_url: String,
|
||||
username: &str,
|
||||
password: &str,
|
||||
prefix: String,
|
||||
ollama_host: String,
|
||||
ollama_port: i16,
|
||||
ollama_model: String,
|
||||
) -> anyhow::Result<()> {
|
||||
async fn login_and_sync(config: Config) -> anyhow::Result<()> {
|
||||
log::info!("client init");
|
||||
let client = Client::builder()
|
||||
.homeserver_url(homeserver_url)
|
||||
.homeserver_url(config.homeserver_url())
|
||||
.build()
|
||||
.await?;
|
||||
|
||||
client
|
||||
.matrix_auth()
|
||||
.login_username(username, password)
|
||||
.login_username(config.username(), config.password())
|
||||
.initial_device_display_name("getting started bot")
|
||||
.await?;
|
||||
|
||||
println!("logged in as {username}");
|
||||
let ollama = Ollama::new(
|
||||
config.ollama_host().clone().to_owned(),
|
||||
config.ollama_port().clone().to_owned() as u16,
|
||||
);
|
||||
|
||||
let ollama = Ollama::new(ollama_host, ollama_port as u16);
|
||||
let ctx = BotContextBuilder::default()
|
||||
.ai(ollama)
|
||||
.bot_id(client.clone().user_id().unwrap().to_string())
|
||||
.config(config)
|
||||
.build()
|
||||
.unwrap();
|
||||
client.add_event_handler_context(ctx);
|
||||
|
||||
log::info!("adding handlers");
|
||||
client.add_event_handler(on_stripped_state_member);
|
||||
// TODO this is something funny, i dont get why it doenst work this way
|
||||
//client.add_event_handler(process_message);
|
||||
client.add_event_handler(
|
||||
|ev: SyncRoomMessageEvent, room: Room, context: Ctx<BotContext>| async move {
|
||||
process_message(ev.clone(), room, context).await;
|
||||
log::debug!(
|
||||
"received message: {}",
|
||||
ev.as_original().unwrap().content.body().to_string()
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
let sync_token = client
|
||||
.sync_once(SyncSettings::default())
|
||||
.await
|
||||
.unwrap()
|
||||
.next_batch;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct MyContext {
|
||||
botId: String,
|
||||
ai: Ollama,
|
||||
}
|
||||
client.add_event_handler_context(MyContext {
|
||||
botId: client.clone().user_id().unwrap().to_string(),
|
||||
ai: ollama.clone(),
|
||||
});
|
||||
|
||||
client.add_event_handler(
|
||||
|ev: SyncRoomMessageEvent, room: Room, context: Ctx<MyContext>| async move {
|
||||
if ev.as_original().unwrap().sender.to_string() == context.botId {
|
||||
return;
|
||||
};
|
||||
|
||||
// add . prefix
|
||||
if !ev
|
||||
.as_original()
|
||||
.unwrap()
|
||||
.content
|
||||
.body()
|
||||
.to_string()
|
||||
.contains(prefix.as_str())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Send seen
|
||||
room.send_single_receipt(
|
||||
matrix_sdk::ruma::api::client::receipt::create_receipt::v3::ReceiptType::Read,
|
||||
matrix_sdk::ruma::events::receipt::ReceiptThread::Main,
|
||||
ev.event_id().to_owned(),
|
||||
)
|
||||
.await;
|
||||
|
||||
room.typing_notice(true).await;
|
||||
|
||||
let prompt = format!(
|
||||
"{} says: {}",
|
||||
room.get_member(&ev.sender())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.display_name()
|
||||
.unwrap(),
|
||||
ev.as_original()
|
||||
.unwrap()
|
||||
.content
|
||||
.body()
|
||||
.replace(prefix.as_str(), "")
|
||||
);
|
||||
let now = Instant::now();
|
||||
let res = ollama
|
||||
.generate(GenerationRequest::new(ollama_model, prompt))
|
||||
.await;
|
||||
|
||||
if let Ok(res) = res {
|
||||
let mut asd = res.response.to_string();
|
||||
asd.push_str(format!("\n prompt took {:?}", now.elapsed()).as_str());
|
||||
|
||||
let content = RoomMessageEventContent::text_plain(asd);
|
||||
println!("Got res from ai: {}", res.clone().response);
|
||||
room.send(content).await.unwrap();
|
||||
|
||||
room.typing_notice(false).await;
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let settings = SyncSettings::default().token(sync_token);
|
||||
client.sync(settings).await?;
|
||||
|
||||
|
|
25
src/structs/mod.rs
Normal file
25
src/structs/mod.rs
Normal file
|
@ -0,0 +1,25 @@
|
|||
use derive_builder::Builder;
|
||||
use derive_getters::Getters;
|
||||
use ollama_rs::Ollama;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize, Debug, Clone, Getters)]
|
||||
pub struct Config {
|
||||
homeserver_url: String,
|
||||
username: String,
|
||||
password: String,
|
||||
|
||||
ollama_host: String,
|
||||
ollama_port: i16,
|
||||
ollama_model: String,
|
||||
// bot settings
|
||||
prefix: String,
|
||||
enable_dms: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Getters, Builder)]
|
||||
pub struct BotContext {
|
||||
bot_id: String,
|
||||
ai: Ollama,
|
||||
config: Config,
|
||||
}
|
Loading…
Reference in a new issue