From 84850fc1ea0511bb13d1ae3fdc1aa70825138bc3 Mon Sep 17 00:00:00 2001 From: 4o1x5 <4o1x5@4o1x5.dev> Date: Fri, 10 May 2024 03:50:26 +0200 Subject: [PATCH] new features: streaming is now an available option streaming made a command template wanted to implement history but there is sadly no streaming for it. --- .env.example | 9 +++- Cargo.lock | 1 + Cargo.toml | 3 +- readme.md | 6 ++- src/commands/chatbot.rs | 84 +++++++++++++++++++++++++++----- src/commands/command.rs.template | 7 +++ src/commands/mod.rs | 1 + src/commands/pfp.rs | 19 ++++++++ src/events/message.rs | 7 ++- src/main.rs | 13 ++--- src/structs/mod.rs | 2 + 11 files changed, 125 insertions(+), 27 deletions(-) create mode 100644 src/commands/command.rs.template create mode 100644 src/commands/pfp.rs diff --git a/.env.example b/.env.example index 931426e..7776e00 100644 --- a/.env.example +++ b/.env.example @@ -6,4 +6,11 @@ ollama_host="http://localhost" ollama_port=1111 ollama_model="neural-chat:latest" -prefix=".ask" \ No newline at end of file +prefix=".ask" +enable_dms=false + +# streaming means that the bot will edit the message once a token arrives. +#this results excessive editing and if you are not running the bot on +# your own server I recommend turning this off. +# also it procudes lots of notifications on some clients +enable_streaming=true \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index cd2abf0..b91c992 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -390,6 +390,7 @@ dependencies = [ "ollama-rs", "serde", "tokio", + "tokio-stream", "tracing-subscriber", ] diff --git a/Cargo.toml b/Cargo.toml index 0553b5e..6186c8d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,8 @@ env_logger = "0.11.3" envy = "0.4.2" log = "0.4.21" matrix-sdk = "0.7.1" -ollama-rs = {version="0.1.9", features = ["stream"]} +ollama-rs = {version="0.1.9", features = ["stream", "chat-history"]} serde = "1.0.200" tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] } +tokio-stream = "0.1.15" tracing-subscriber = "0.3.15" diff --git a/readme.md b/readme.md index bf57f66..9026b64 100644 --- a/readme.md +++ b/readme.md @@ -6,6 +6,10 @@ I plan on expanding it with other features in the future. ## Features - [ ] History + - [ ] Per room history -- [ ] quality of life features + - [ ] Ollama feedback (like, dislike) + +- [x] Streaming text (via replacing messages) + - [x] Option to toggle - [ ] Nix service and flake for ez install diff --git a/src/commands/chatbot.rs b/src/commands/chatbot.rs index 38da8ac..2818f85 100644 --- a/src/commands/chatbot.rs +++ b/src/commands/chatbot.rs @@ -1,30 +1,49 @@ -use std::time::Instant; +use std::{ + any::Any, + borrow::Borrow, + thread, + time::{Duration, Instant}, +}; use anyhow::Ok; use matrix_sdk::{ event_handler::Ctx, - ruma::events::room::message::{RoomMessageEventContent, SyncRoomMessageEvent}, + room::MessagesOptions, + ruma::{ + events::{ + relation::Replacement, + room::message::{ + ReplacementMetadata, RoomMessageEventContent, + RoomMessageEventContentWithoutRelation, SyncRoomMessageEvent, + }, + StateEventContent, + }, + OwnedEventId, TransactionId, + }, Room, }; use ollama_rs::generation::{ chat::{request::ChatMessageRequest, ChatMessage, ChatMessageResponseStream}, completion::request::GenerationRequest, }; +use tokio_stream::StreamExt; use crate::structs::BotContext; pub async fn chat(ev: SyncRoomMessageEvent, room: Room, context: Ctx) { log::debug!("Prompting ollama"); - // Send seen + // Send seen and start typing room.send_single_receipt( matrix_sdk::ruma::api::client::receipt::create_receipt::v3::ReceiptType::Read, matrix_sdk::ruma::events::receipt::ReceiptThread::Main, ev.event_id().to_owned(), ) .await; - room.typing_notice(true).await; + // Staring timer + let now = Instant::now(); + // Formatting the prompt so the AI know who's prompting (by username) let prompt = format!( "{} says: {}", room.get_member(&ev.sender()) @@ -33,27 +52,66 @@ pub async fn chat(ev: SyncRoomMessageEvent, room: Room, context: Ctx .unwrap() .display_name() .unwrap(), - ev.as_original() + ev.clone() + .as_original() .unwrap() .content .body() .replace(context.config().prefix().as_str(), "") ); - let now = Instant::now(); - let res = context + // Sending initial msg which then the bot will edit as the tokens arrive + let init_msg = RoomMessageEventContent::text_plain("...").make_reply_to( + ev.clone() + .borrow() + .as_original() + .unwrap() + .clone() + .into_full_event(room.room_id().to_owned()) + .borrow(), + matrix_sdk::ruma::events::room::message::ForwardThread::No, + matrix_sdk::ruma::events::room::message::AddMentions::No, + ); + let init_id = room.send(init_msg).await.unwrap(); + + // Ollama stream + // https://github.com/pepperoni21/ollama-rs/issues/41 no history yet + let mut stream = context .ai() - .generate(GenerationRequest::new( - context.config().ollama_model().clone(), + .generate_stream(GenerationRequest::new( + context.config().ollama_model().clone().to_owned(), prompt, )) .await .unwrap(); - let mut msg = res.response.to_string(); - msg.push_str(format!("\n prompt took {:?}", now.elapsed(),).as_str()); + // Constructing the tokens into a whole string + let mut response = String::new(); + while let Some(res) = stream.next().await { + let responses = res.unwrap(); + for resp in responses { + response += &resp.response; + + // Replacing old msg + if context.config().enable_streaming().to_owned() { + let replacement_msg = RoomMessageEventContent::text_plain(response.clone()) + .make_replacement( + ReplacementMetadata::new(init_id.event_id.clone(), None), + None, + ); + room.send(replacement_msg).await; + } + } + } + + if !context.config().enable_streaming().to_owned() { + let replacement_msg = RoomMessageEventContent::text_plain(response.clone()) + .make_replacement( + ReplacementMetadata::new(init_id.event_id.clone(), None), + None, + ); + room.send(replacement_msg).await; + } - let content = RoomMessageEventContent::text_plain(msg); - room.send(content).await.unwrap(); room.typing_notice(false).await; } diff --git a/src/commands/command.rs.template b/src/commands/command.rs.template new file mode 100644 index 0000000..af2fc83 --- /dev/null +++ b/src/commands/command.rs.template @@ -0,0 +1,7 @@ +use matrix_sdk::{event_handler::Ctx, events::room::message::SyncRoomMessageEvent, Room}; + +use crate::structs::BotContext; + +pub async fn get_pfp(ev: SyncRoomMessageEvent, room: Room, context: Ctx) { + log::debug!("something"); +} diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 85f1d58..125b128 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -1,2 +1,3 @@ pub mod chatbot; pub mod model; +pub mod pfp; diff --git a/src/commands/pfp.rs b/src/commands/pfp.rs new file mode 100644 index 0000000..9b7d439 --- /dev/null +++ b/src/commands/pfp.rs @@ -0,0 +1,19 @@ +use matrix_sdk::{ + event_handler::Ctx, + ruma::{ + api::client::{profile, search::search_events::v3::UserProfile}, + events::room::message::SyncRoomMessageEvent, + }, + Room, +}; + +use crate::structs::BotContext; + +pub async fn get_pfp(ev: SyncRoomMessageEvent, room: Room, context: Ctx) { + let asd = ev.as_original().unwrap(); + let mentions = asd.clone().content.mentions.unwrap().user_ids; + // TODO get pfp from client + let user = profile::get_profile::v3::Request::new(mentions.first().unwrap().to_owned()); + + log::debug!("stuff: {:?}", mentions); +} diff --git a/src/events/message.rs b/src/events/message.rs index 6f8009d..cd79da4 100644 --- a/src/events/message.rs +++ b/src/events/message.rs @@ -1,6 +1,9 @@ use matrix_sdk::{event_handler::Ctx, ruma::events::room::message::SyncRoomMessageEvent, Room}; -use crate::{commands::chatbot::chat, structs::BotContext}; +use crate::{ + commands::{chatbot::chat, pfp::get_pfp}, + structs::BotContext, +}; pub async fn process_message(ev: SyncRoomMessageEvent, room: Room, context: Ctx) { log::debug!("Processing message"); @@ -20,7 +23,7 @@ pub async fn process_message(ev: SyncRoomMessageEvent, room: Room, context: Ctx< if msg.contains(context.config().prefix()) { let parameters: Vec<&str> = msg.split(" ").collect(); match parameters[0] { - //"model" => + "pfp" => get_pfp(ev, room, context).await, _ => chat(ev, room, context).await, }; } diff --git a/src/main.rs b/src/main.rs index 70e05ef..f20cf6b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -35,16 +35,17 @@ async fn login_and_sync(config: Config) -> anyhow::Result<()> { client .matrix_auth() .login_username(config.username(), config.password()) - .initial_device_display_name("getting started bot") + .initial_device_display_name("celestial") .await?; - let ollama = Ollama::new( + let ollama_history = Ollama::new_with_history( config.ollama_host().clone().to_owned(), config.ollama_port().clone().to_owned() as u16, + 99, ); let ctx = BotContextBuilder::default() - .ai(ollama) + .ai(ollama_history) .bot_id(client.clone().user_id().unwrap().to_string()) .config(config) .build() @@ -53,15 +54,9 @@ async fn login_and_sync(config: Config) -> anyhow::Result<()> { log::info!("adding handlers"); client.add_event_handler(on_stripped_state_member); - // TODO this is something funny, i dont get why it doenst work this way - //client.add_event_handler(process_message); client.add_event_handler( |ev: SyncRoomMessageEvent, room: Room, context: Ctx| async move { process_message(ev.clone(), room, context).await; - log::debug!( - "received message: {}", - ev.as_original().unwrap().content.body().to_string() - ); }, ); diff --git a/src/structs/mod.rs b/src/structs/mod.rs index 2f88cac..c6e3d2f 100644 --- a/src/structs/mod.rs +++ b/src/structs/mod.rs @@ -1,5 +1,6 @@ use derive_builder::Builder; use derive_getters::Getters; +use matrix_sdk::Client; use ollama_rs::Ollama; use serde::Deserialize; @@ -15,6 +16,7 @@ pub struct Config { // bot settings prefix: String, enable_dms: bool, + enable_streaming: bool, } #[derive(Debug, Clone, Getters, Builder)]