refactor
refactored the whole project, now its built up to event handlers with a commands folder for easy expendability added a dm ignore feature made a todo list added a bit of logging
This commit is contained in:
parent
9bf77013fd
commit
fb251b1f3c
154
Cargo.lock
generated
154
Cargo.lock
generated
|
@ -380,6 +380,8 @@ name = "celestial"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"derive-getters",
|
||||||
|
"derive_builder",
|
||||||
"dotenv",
|
"dotenv",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
"envy",
|
"envy",
|
||||||
|
@ -557,6 +559,41 @@ dependencies = [
|
||||||
"syn 2.0.60",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "darling"
|
||||||
|
version = "0.20.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "54e36fcd13ed84ffdfda6f5be89b31287cbb80c439841fe69e04841435464391"
|
||||||
|
dependencies = [
|
||||||
|
"darling_core",
|
||||||
|
"darling_macro",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "darling_core"
|
||||||
|
version = "0.20.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9c2cf1c23a687a1feeb728783b993c4e1ad83d99f351801977dd809b48d0a70f"
|
||||||
|
dependencies = [
|
||||||
|
"fnv",
|
||||||
|
"ident_case",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"strsim",
|
||||||
|
"syn 2.0.60",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "darling_macro"
|
||||||
|
version = "0.20.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f"
|
||||||
|
dependencies = [
|
||||||
|
"darling_core",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.60",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "deadpool"
|
name = "deadpool"
|
||||||
version = "0.10.0"
|
version = "0.10.0"
|
||||||
|
@ -633,6 +670,48 @@ dependencies = [
|
||||||
"syn 2.0.60",
|
"syn 2.0.60",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "derive-getters"
|
||||||
|
version = "0.3.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7a2c35ab6e03642397cdda1dd58abbc05d418aef8e36297f336d5aba060fe8df"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 1.0.109",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "derive_builder"
|
||||||
|
version = "0.20.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7"
|
||||||
|
dependencies = [
|
||||||
|
"derive_builder_macro",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "derive_builder_core"
|
||||||
|
version = "0.20.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d"
|
||||||
|
dependencies = [
|
||||||
|
"darling",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.60",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "derive_builder_macro"
|
||||||
|
version = "0.20.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b"
|
||||||
|
dependencies = [
|
||||||
|
"derive_builder_core",
|
||||||
|
"syn 2.0.60",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "digest"
|
name = "digest"
|
||||||
version = "0.10.7"
|
version = "0.10.7"
|
||||||
|
@ -1222,6 +1301,12 @@ dependencies = [
|
||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ident_case"
|
||||||
|
version = "1.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "idna"
|
name = "idna"
|
||||||
version = "0.5.0"
|
version = "0.5.0"
|
||||||
|
@ -1435,6 +1520,16 @@ version = "0.4.13"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
|
checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "lock_api"
|
||||||
|
version = "0.4.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"scopeguard",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "log"
|
name = "log"
|
||||||
version = "0.4.21"
|
version = "0.4.21"
|
||||||
|
@ -1823,6 +1918,8 @@ dependencies = [
|
||||||
"reqwest 0.12.4",
|
"reqwest 0.12.4",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"tokio",
|
||||||
|
"tokio-stream",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -1893,6 +1990,29 @@ version = "2.2.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae"
|
checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "parking_lot"
|
||||||
|
version = "0.12.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7e4af0ca4f6caed20e900d564c242b8e5d4903fdacf31d3daf527b66fe6f42fb"
|
||||||
|
dependencies = [
|
||||||
|
"lock_api",
|
||||||
|
"parking_lot_core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "parking_lot_core"
|
||||||
|
version = "0.9.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"libc",
|
||||||
|
"redox_syscall",
|
||||||
|
"smallvec",
|
||||||
|
"windows-targets 0.52.5",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "paste"
|
name = "paste"
|
||||||
version = "1.0.14"
|
version = "1.0.14"
|
||||||
|
@ -2127,6 +2247,15 @@ version = "0.1.7"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d7b323e7196daa571c8584de958be19e92941c41f845776fe06babfe8fa280a2"
|
checksum = "d7b323e7196daa571c8584de958be19e92941c41f845776fe06babfe8fa280a2"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "redox_syscall"
|
||||||
|
version = "0.5.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.5.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "regex"
|
name = "regex"
|
||||||
version = "1.10.4"
|
version = "1.10.4"
|
||||||
|
@ -2229,10 +2358,12 @@ dependencies = [
|
||||||
"sync_wrapper",
|
"sync_wrapper",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-native-tls",
|
"tokio-native-tls",
|
||||||
|
"tokio-util",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"url",
|
"url",
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
"wasm-bindgen-futures",
|
"wasm-bindgen-futures",
|
||||||
|
"wasm-streams",
|
||||||
"web-sys",
|
"web-sys",
|
||||||
"winreg 0.52.0",
|
"winreg 0.52.0",
|
||||||
]
|
]
|
||||||
|
@ -2468,6 +2599,12 @@ dependencies = [
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.52.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "scopeguard"
|
||||||
|
version = "1.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "security-framework"
|
name = "security-framework"
|
||||||
version = "2.10.0"
|
version = "2.10.0"
|
||||||
|
@ -2602,6 +2739,15 @@ dependencies = [
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "signal-hook-registry"
|
||||||
|
version = "1.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "signature"
|
name = "signature"
|
||||||
version = "2.2.0"
|
version = "2.2.0"
|
||||||
|
@ -2646,6 +2792,12 @@ dependencies = [
|
||||||
"der",
|
"der",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "strsim"
|
||||||
|
version = "0.10.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "subtle"
|
name = "subtle"
|
||||||
version = "2.5.0"
|
version = "2.5.0"
|
||||||
|
@ -2769,7 +2921,9 @@ dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"mio",
|
"mio",
|
||||||
"num_cpus",
|
"num_cpus",
|
||||||
|
"parking_lot",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
|
"signal-hook-registry",
|
||||||
"socket2",
|
"socket2",
|
||||||
"tokio-macros",
|
"tokio-macros",
|
||||||
"windows-sys 0.48.0",
|
"windows-sys 0.48.0",
|
||||||
|
|
|
@ -8,12 +8,14 @@ license = "MIT"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.82"
|
anyhow = "1.0.82"
|
||||||
|
derive-getters = "0.3.0"
|
||||||
|
derive_builder = "0.20.0"
|
||||||
dotenv = "0.15.0"
|
dotenv = "0.15.0"
|
||||||
env_logger = "0.11.3"
|
env_logger = "0.11.3"
|
||||||
envy = "0.4.2"
|
envy = "0.4.2"
|
||||||
log = "0.4.21"
|
log = "0.4.21"
|
||||||
matrix-sdk = "0.7.1"
|
matrix-sdk = "0.7.1"
|
||||||
ollama-rs = "0.1.9"
|
ollama-rs = {version="0.1.9", features = ["stream"]}
|
||||||
serde = "1.0.200"
|
serde = "1.0.200"
|
||||||
tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] }
|
tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] }
|
||||||
tracing-subscriber = "0.3.15"
|
tracing-subscriber = "0.3.15"
|
||||||
|
|
|
@ -2,3 +2,10 @@
|
||||||
|
|
||||||
This is just a funny project I slapped together in a span of around 30 minutes.
|
This is just a funny project I slapped together in a span of around 30 minutes.
|
||||||
I plan on expanding it with other features in the future.
|
I plan on expanding it with other features in the future.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- [] History
|
||||||
|
- [] Per room history
|
||||||
|
- [] quality of life features
|
||||||
|
- [] Nix service and flake for ez install
|
||||||
|
|
59
src/commands/chatbot.rs
Normal file
59
src/commands/chatbot.rs
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use anyhow::Ok;
|
||||||
|
use matrix_sdk::{
|
||||||
|
event_handler::Ctx,
|
||||||
|
ruma::events::room::message::{RoomMessageEventContent, SyncRoomMessageEvent},
|
||||||
|
Room,
|
||||||
|
};
|
||||||
|
use ollama_rs::generation::{
|
||||||
|
chat::{request::ChatMessageRequest, ChatMessage, ChatMessageResponseStream},
|
||||||
|
completion::request::GenerationRequest,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::structs::BotContext;
|
||||||
|
|
||||||
|
pub async fn chat(ev: SyncRoomMessageEvent, room: Room, context: Ctx<BotContext>) {
|
||||||
|
log::debug!("Prompting ollama");
|
||||||
|
// Send seen
|
||||||
|
room.send_single_receipt(
|
||||||
|
matrix_sdk::ruma::api::client::receipt::create_receipt::v3::ReceiptType::Read,
|
||||||
|
matrix_sdk::ruma::events::receipt::ReceiptThread::Main,
|
||||||
|
ev.event_id().to_owned(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
room.typing_notice(true).await;
|
||||||
|
|
||||||
|
let prompt = format!(
|
||||||
|
"{} says: {}",
|
||||||
|
room.get_member(&ev.sender())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap()
|
||||||
|
.display_name()
|
||||||
|
.unwrap(),
|
||||||
|
ev.as_original()
|
||||||
|
.unwrap()
|
||||||
|
.content
|
||||||
|
.body()
|
||||||
|
.replace(context.config().prefix().as_str(), "")
|
||||||
|
);
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
let res = context
|
||||||
|
.ai()
|
||||||
|
.generate(GenerationRequest::new(
|
||||||
|
context.config().ollama_model().clone(),
|
||||||
|
prompt,
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut msg = res.response.to_string();
|
||||||
|
msg.push_str(format!("\n prompt took {:?}", now.elapsed(),).as_str());
|
||||||
|
|
||||||
|
let content = RoomMessageEventContent::text_plain(msg);
|
||||||
|
room.send(content).await.unwrap();
|
||||||
|
room.typing_notice(false).await;
|
||||||
|
}
|
2
src/commands/mod.rs
Normal file
2
src/commands/mod.rs
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
pub mod chatbot;
|
||||||
|
pub mod model;
|
3
src/commands/model.rs
Normal file
3
src/commands/model.rs
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
pub fn model() {
|
||||||
|
//
|
||||||
|
}
|
27
src/events/message.rs
Normal file
27
src/events/message.rs
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
use matrix_sdk::{event_handler::Ctx, ruma::events::room::message::SyncRoomMessageEvent, Room};
|
||||||
|
|
||||||
|
use crate::{commands::chatbot::chat, structs::BotContext};
|
||||||
|
|
||||||
|
pub async fn process_message(ev: SyncRoomMessageEvent, room: Room, context: Ctx<BotContext>) {
|
||||||
|
log::debug!("Processing message");
|
||||||
|
|
||||||
|
// Ignore if dm's are disabled
|
||||||
|
if room.is_direct().await.unwrap() && context.config().enable_dms() == &false {
|
||||||
|
log::debug!("Message is DM, ignoring due to configuration");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ignore message if its from the bot
|
||||||
|
if &ev.as_original().unwrap().sender.to_string() == context.bot_id() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let msg = ev.as_original().unwrap().content.body().to_string();
|
||||||
|
if msg.contains(context.config().prefix()) {
|
||||||
|
let parameters: Vec<&str> = msg.split(" ").collect();
|
||||||
|
match parameters[0] {
|
||||||
|
//"model" =>
|
||||||
|
_ => chat(ev, room, context).await,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
1
src/events/mod.rs
Normal file
1
src/events/mod.rs
Normal file
|
@ -0,0 +1 @@
|
||||||
|
pub mod message;
|
152
src/main.rs
152
src/main.rs
|
@ -1,152 +1,76 @@
|
||||||
use std::time::Instant;
|
use events::message::process_message;
|
||||||
|
|
||||||
use matrix_sdk::{
|
use matrix_sdk::{
|
||||||
config::SyncSettings,
|
config::SyncSettings,
|
||||||
event_handler::Ctx,
|
event_handler::Ctx,
|
||||||
ruma::events::room::{
|
ruma::events::room::{member::StrippedRoomMemberEvent, message::SyncRoomMessageEvent},
|
||||||
member::StrippedRoomMemberEvent,
|
|
||||||
message::{RoomMessageEventContent, SyncRoomMessageEvent},
|
|
||||||
},
|
|
||||||
Client, Room,
|
Client, Room,
|
||||||
};
|
};
|
||||||
use ollama_rs::{generation::completion::request::GenerationRequest, Ollama};
|
use ollama_rs::Ollama;
|
||||||
use serde::Deserialize;
|
use structs::BotContextBuilder;
|
||||||
use tokio::time::{sleep, Duration};
|
use tokio::time::{sleep, Duration};
|
||||||
|
mod commands;
|
||||||
|
mod events;
|
||||||
|
mod structs;
|
||||||
|
use crate::structs::{BotContext, Config};
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
|
||||||
struct Config {
|
|
||||||
homeserver_url: String,
|
|
||||||
username: String,
|
|
||||||
password: String,
|
|
||||||
|
|
||||||
ollama_host: String,
|
|
||||||
ollama_port: i16,
|
|
||||||
ollama_model: String,
|
|
||||||
|
|
||||||
prefix: String,
|
|
||||||
}
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
tracing_subscriber::fmt::init();
|
|
||||||
let _ = dotenv::dotenv();
|
let _ = dotenv::dotenv();
|
||||||
|
env_logger::init();
|
||||||
log::info!("Bot has started");
|
log::info!("Bot has started");
|
||||||
let c = envy::from_env::<Config>().expect("Please provide variables");
|
let c = envy::from_env::<Config>().expect("Please provide variables");
|
||||||
|
|
||||||
login_and_sync(
|
login_and_sync(c).await?;
|
||||||
c.homeserver_url.to_string(),
|
|
||||||
&c.username,
|
|
||||||
&c.password,
|
|
||||||
c.prefix,
|
|
||||||
c.ollama_host,
|
|
||||||
c.ollama_port,
|
|
||||||
c.ollama_model,
|
|
||||||
)
|
|
||||||
.await?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// The core sync loop we have running.
|
// The core sync loop we have running.
|
||||||
async fn login_and_sync(
|
async fn login_and_sync(config: Config) -> anyhow::Result<()> {
|
||||||
homeserver_url: String,
|
log::info!("client init");
|
||||||
username: &str,
|
|
||||||
password: &str,
|
|
||||||
prefix: String,
|
|
||||||
ollama_host: String,
|
|
||||||
ollama_port: i16,
|
|
||||||
ollama_model: String,
|
|
||||||
) -> anyhow::Result<()> {
|
|
||||||
let client = Client::builder()
|
let client = Client::builder()
|
||||||
.homeserver_url(homeserver_url)
|
.homeserver_url(config.homeserver_url())
|
||||||
.build()
|
.build()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
client
|
client
|
||||||
.matrix_auth()
|
.matrix_auth()
|
||||||
.login_username(username, password)
|
.login_username(config.username(), config.password())
|
||||||
.initial_device_display_name("getting started bot")
|
.initial_device_display_name("getting started bot")
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
println!("logged in as {username}");
|
let ollama = Ollama::new(
|
||||||
|
config.ollama_host().clone().to_owned(),
|
||||||
|
config.ollama_port().clone().to_owned() as u16,
|
||||||
|
);
|
||||||
|
|
||||||
let ollama = Ollama::new(ollama_host, ollama_port as u16);
|
let ctx = BotContextBuilder::default()
|
||||||
|
.ai(ollama)
|
||||||
|
.bot_id(client.clone().user_id().unwrap().to_string())
|
||||||
|
.config(config)
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
client.add_event_handler_context(ctx);
|
||||||
|
|
||||||
|
log::info!("adding handlers");
|
||||||
client.add_event_handler(on_stripped_state_member);
|
client.add_event_handler(on_stripped_state_member);
|
||||||
|
// TODO this is something funny, i dont get why it doenst work this way
|
||||||
|
//client.add_event_handler(process_message);
|
||||||
|
client.add_event_handler(
|
||||||
|
|ev: SyncRoomMessageEvent, room: Room, context: Ctx<BotContext>| async move {
|
||||||
|
process_message(ev.clone(), room, context).await;
|
||||||
|
log::debug!(
|
||||||
|
"received message: {}",
|
||||||
|
ev.as_original().unwrap().content.body().to_string()
|
||||||
|
);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
let sync_token = client
|
let sync_token = client
|
||||||
.sync_once(SyncSettings::default())
|
.sync_once(SyncSettings::default())
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.next_batch;
|
.next_batch;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct MyContext {
|
|
||||||
botId: String,
|
|
||||||
ai: Ollama,
|
|
||||||
}
|
|
||||||
client.add_event_handler_context(MyContext {
|
|
||||||
botId: client.clone().user_id().unwrap().to_string(),
|
|
||||||
ai: ollama.clone(),
|
|
||||||
});
|
|
||||||
|
|
||||||
client.add_event_handler(
|
|
||||||
|ev: SyncRoomMessageEvent, room: Room, context: Ctx<MyContext>| async move {
|
|
||||||
if ev.as_original().unwrap().sender.to_string() == context.botId {
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
// add . prefix
|
|
||||||
if !ev
|
|
||||||
.as_original()
|
|
||||||
.unwrap()
|
|
||||||
.content
|
|
||||||
.body()
|
|
||||||
.to_string()
|
|
||||||
.contains(prefix.as_str())
|
|
||||||
{
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send seen
|
|
||||||
room.send_single_receipt(
|
|
||||||
matrix_sdk::ruma::api::client::receipt::create_receipt::v3::ReceiptType::Read,
|
|
||||||
matrix_sdk::ruma::events::receipt::ReceiptThread::Main,
|
|
||||||
ev.event_id().to_owned(),
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
room.typing_notice(true).await;
|
|
||||||
|
|
||||||
let prompt = format!(
|
|
||||||
"{} says: {}",
|
|
||||||
room.get_member(&ev.sender())
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.unwrap()
|
|
||||||
.display_name()
|
|
||||||
.unwrap(),
|
|
||||||
ev.as_original()
|
|
||||||
.unwrap()
|
|
||||||
.content
|
|
||||||
.body()
|
|
||||||
.replace(prefix.as_str(), "")
|
|
||||||
);
|
|
||||||
let now = Instant::now();
|
|
||||||
let res = ollama
|
|
||||||
.generate(GenerationRequest::new(ollama_model, prompt))
|
|
||||||
.await;
|
|
||||||
|
|
||||||
if let Ok(res) = res {
|
|
||||||
let mut asd = res.response.to_string();
|
|
||||||
asd.push_str(format!("\n prompt took {:?}", now.elapsed()).as_str());
|
|
||||||
|
|
||||||
let content = RoomMessageEventContent::text_plain(asd);
|
|
||||||
println!("Got res from ai: {}", res.clone().response);
|
|
||||||
room.send(content).await.unwrap();
|
|
||||||
|
|
||||||
room.typing_notice(false).await;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
let settings = SyncSettings::default().token(sync_token);
|
let settings = SyncSettings::default().token(sync_token);
|
||||||
client.sync(settings).await?;
|
client.sync(settings).await?;
|
||||||
|
|
||||||
|
|
25
src/structs/mod.rs
Normal file
25
src/structs/mod.rs
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
use derive_builder::Builder;
|
||||||
|
use derive_getters::Getters;
|
||||||
|
use ollama_rs::Ollama;
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
#[derive(Deserialize, Debug, Clone, Getters)]
|
||||||
|
pub struct Config {
|
||||||
|
homeserver_url: String,
|
||||||
|
username: String,
|
||||||
|
password: String,
|
||||||
|
|
||||||
|
ollama_host: String,
|
||||||
|
ollama_port: i16,
|
||||||
|
ollama_model: String,
|
||||||
|
// bot settings
|
||||||
|
prefix: String,
|
||||||
|
enable_dms: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Getters, Builder)]
|
||||||
|
pub struct BotContext {
|
||||||
|
bot_id: String,
|
||||||
|
ai: Ollama,
|
||||||
|
config: Config,
|
||||||
|
}
|
Loading…
Reference in a new issue