diff --git a/appflowy-local-ai/src/chat_ops.rs b/appflowy-local-ai/src/ai_ops.rs similarity index 72% rename from appflowy-local-ai/src/chat_ops.rs rename to appflowy-local-ai/src/ai_ops.rs index e46c466..1b62736 100644 --- a/appflowy-local-ai/src/chat_ops.rs +++ b/appflowy-local-ai/src/ai_ops.rs @@ -3,20 +3,22 @@ use appflowy_plugin::core::parser::{DefaultResponseParser, ResponseParser}; use appflowy_plugin::core::plugin::Plugin; use appflowy_plugin::error::{PluginError, RemoteError}; use bytes::Bytes; +use serde::{Deserialize, Serialize}; use serde_json::json; use serde_json::Value as JsonValue; +use std::collections::HashMap; use std::fmt::Debug; use std::sync::Weak; use tokio_stream::wrappers::ReceiverStream; use tracing::instrument; -pub struct ChatPluginOperation { +pub struct AIPluginOperation { plugin: Weak, } -impl ChatPluginOperation { +impl AIPluginOperation { pub fn new(plugin: Weak) -> Self { - ChatPluginOperation { plugin } + AIPluginOperation { plugin } } fn get_plugin(&self) -> Result, PluginError> { @@ -118,6 +120,43 @@ impl ChatPluginOperation { }); plugin.stream_request::("handle", ¶ms) } + + #[instrument(level = "debug", skip(self), err)] + pub async fn summary_row(&self, row: HashMap) -> Result { + let params = json!({"params": row }); + self + .send_request::("database_summary", params) + .await + } + + #[instrument(level = "debug", skip(self), err)] + pub async fn translate_row( + &self, + data: LocalAITranslateRowData, + ) -> Result { + let params = json!({"params": data }); + self + .send_request::("database_translate", params) + .await + } +} + +#[derive(Clone, Debug, Serialize)] +pub struct LocalAITranslateRowData { + pub cells: Vec, + pub language: String, + pub include_header: bool, +} + +#[derive(Clone, Debug, Serialize)] +pub struct LocalAITranslateItem { + pub title: String, + pub content: String, +} + +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct LocalAITranslateRowResponse { + pub items: Vec>, } pub struct ChatResponseParser; @@ -189,3 +228,28 @@ impl From for CompleteTextType { } } } + +pub struct DatabaseSummaryResponseParser; +impl ResponseParser for DatabaseSummaryResponseParser { + type ValueType = String; + + fn parse_json(json: JsonValue) -> Result { + json + .get("data") + .and_then(|data| data.as_str()) + .map(|s| s.to_string()) + .ok_or(RemoteError::ParseResponse(json)) + } +} + +pub struct DatabaseTranslateResponseParser; +impl ResponseParser for DatabaseTranslateResponseParser { + type ValueType = LocalAITranslateRowResponse; + + fn parse_json(json: JsonValue) -> Result { + json + .get("data") + .and_then(|data| LocalAITranslateRowResponse::deserialize(data).ok()) + .ok_or(RemoteError::ParseResponse(json)) + } +} diff --git a/appflowy-local-ai/src/chat_plugin.rs b/appflowy-local-ai/src/chat_plugin.rs index 6768e5b..5cd0a7d 100644 --- a/appflowy-local-ai/src/chat_plugin.rs +++ b/appflowy-local-ai/src/chat_plugin.rs @@ -1,4 +1,6 @@ -use crate::chat_ops::{ChatPluginOperation, CompleteTextType}; +use crate::ai_ops::{ + AIPluginOperation, CompleteTextType, LocalAITranslateRowData, LocalAITranslateRowResponse, +}; use anyhow::{anyhow, Result}; use appflowy_plugin::core::plugin::{ Plugin, PluginInfo, RunningState, RunningStateReceiver, RunningStateSender, @@ -8,6 +10,7 @@ use appflowy_plugin::manager::PluginManager; use appflowy_plugin::util::{get_operating_system, OperatingSystem}; use bytes::Bytes; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::fmt::Debug; use std::path::PathBuf; use std::sync::{Arc, Weak}; @@ -34,7 +37,7 @@ impl LocalLLMSetting { } } -pub struct LocalChatLLMChat { +pub struct AppFlowyLocalAI { plugin_manager: Arc, plugin_config: RwLock>, running_state: RunningStateSender, @@ -43,7 +46,7 @@ pub struct LocalChatLLMChat { running_state_rx: RunningStateReceiver, } -impl LocalChatLLMChat { +impl AppFlowyLocalAI { pub fn new(plugin_manager: Arc) -> Self { let (running_state, rx) = tokio::sync::watch::channel(RunningState::Connecting); Self { @@ -67,8 +70,8 @@ impl LocalChatLLMChat { trace!("[AI Plugin] create chat: {}", chat_id); self.wait_until_plugin_ready().await?; - let plugin = self.get_chat_plugin().await?; - let operation = ChatPluginOperation::new(plugin); + let plugin = self.get_ai_plugin().await?; + let operation = AIPluginOperation::new(plugin); operation.create_chat(chat_id, true).await?; Ok(()) } @@ -84,8 +87,8 @@ impl LocalChatLLMChat { /// A `Result<()>` indicating success or failure. pub async fn close_chat(&self, chat_id: &str) -> Result<()> { trace!("[AI Plugin] close chat: {}", chat_id); - let plugin = self.get_chat_plugin().await?; - let operation = ChatPluginOperation::new(plugin); + let plugin = self.get_ai_plugin().await?; + let operation = AIPluginOperation::new(plugin); operation.close_chat(chat_id).await?; Ok(()) } @@ -115,16 +118,16 @@ impl LocalChatLLMChat { ) -> Result>, PluginError> { trace!("[AI Plugin] ask question: {}", message); self.wait_until_plugin_ready().await?; - let plugin = self.get_chat_plugin().await?; - let operation = ChatPluginOperation::new(plugin); + let plugin = self.get_ai_plugin().await?; + let operation = AIPluginOperation::new(plugin); let stream = operation.stream_message(chat_id, message, true).await?; Ok(stream) } pub async fn get_related_question(&self, chat_id: &str) -> Result, PluginError> { self.wait_until_plugin_ready().await?; - let plugin = self.get_chat_plugin().await?; - let operation = ChatPluginOperation::new(plugin); + let plugin = self.get_ai_plugin().await?; + let operation = AIPluginOperation::new(plugin); let values = operation.get_related_questions(chat_id).await?; Ok(values) } @@ -143,8 +146,8 @@ impl LocalChatLLMChat { )))?; self.wait_until_plugin_ready().await?; - let plugin = self.get_chat_plugin().await?; - let operation = ChatPluginOperation::new(plugin); + let plugin = self.get_ai_plugin().await?; + let operation = AIPluginOperation::new(plugin); trace!("[AI Plugin] indexing file: {}", file_path); operation.index_file(chat_id, file_path).await?; Ok(()) @@ -162,8 +165,8 @@ impl LocalChatLLMChat { /// A `Result` containing the generated answer. pub async fn ask_question(&self, chat_id: &str, message: &str) -> Result { self.wait_until_plugin_ready().await?; - let plugin = self.get_chat_plugin().await?; - let operation = ChatPluginOperation::new(plugin); + let plugin = self.get_ai_plugin().await?; + let operation = AIPluginOperation::new(plugin); let answer = operation.send_message(chat_id, message, true).await?; Ok(answer) } @@ -187,12 +190,36 @@ impl LocalChatLLMChat { ) -> Result>, PluginError> { trace!("[AI Plugin] complete text: {}", message); self.wait_until_plugin_ready().await?; - let plugin = self.get_chat_plugin().await?; - let operation = ChatPluginOperation::new(plugin); + let plugin = self.get_ai_plugin().await?; + let operation = AIPluginOperation::new(plugin); let stream = operation.complete_text(message, complete_type).await?; Ok(stream) } + pub async fn summary_database_row( + &self, + row: HashMap, + ) -> Result { + trace!("[AI Plugin] summary database row: {:?}", row); + self.wait_until_plugin_ready().await?; + let plugin = self.get_ai_plugin().await?; + let operation = AIPluginOperation::new(plugin); + let text = operation.summary_row(row).await?; + Ok(text) + } + + pub async fn translate_database_row( + &self, + row: LocalAITranslateRowData, + ) -> Result { + trace!("[AI Plugin] summary database row: {:?}", row); + self.wait_until_plugin_ready().await?; + let plugin = self.get_ai_plugin().await?; + let operation = AIPluginOperation::new(plugin); + let resp = operation.translate_row(row).await?; + Ok(resp) + } + #[instrument(skip_all, err)] pub async fn init_chat_plugin(&self, config: AIPluginConfig) -> Result<()> { let state = self.running_state.borrow().clone(); @@ -320,7 +347,7 @@ impl LocalChatLLMChat { /// # Returns /// /// A `Result>` containing a weak reference to the plugin. - pub async fn get_chat_plugin(&self) -> Result, PluginError> { + pub async fn get_ai_plugin(&self) -> Result, PluginError> { let plugin_id = self .running_state .borrow() diff --git a/appflowy-local-ai/src/lib.rs b/appflowy-local-ai/src/lib.rs index f756478..42857f7 100644 --- a/appflowy-local-ai/src/lib.rs +++ b/appflowy-local-ai/src/lib.rs @@ -1,4 +1,4 @@ -pub mod chat_ops; +pub mod ai_ops; pub mod chat_plugin; pub mod embedding_ops; pub mod embedding_plugin; diff --git a/appflowy-local-ai/tests/chat_test/mod.rs b/appflowy-local-ai/tests/chat_test/mod.rs index 4636676..6f84c05 100644 --- a/appflowy-local-ai/tests/chat_test/mod.rs +++ b/appflowy-local-ai/tests/chat_test/mod.rs @@ -1,8 +1,9 @@ use crate::util::{get_asset_path, setup_log, LocalAITest}; -use appflowy_local_ai::chat_plugin::{AIPluginConfig, LocalChatLLMChat}; +use appflowy_local_ai::chat_plugin::{AIPluginConfig, AppFlowyLocalAI}; use appflowy_local_ai::plugin_request::download_plugin; +use std::collections::HashMap; -use appflowy_local_ai::chat_ops::CompleteTextType; +use appflowy_local_ai::ai_ops::{CompleteTextType, LocalAITranslateItem, LocalAITranslateRowData}; use appflowy_plugin::manager::PluginManager; use std::env::temp_dir; use std::path::PathBuf; @@ -33,8 +34,8 @@ async fn ci_chat_stream_test() { test.init_embedding_plugin().await; let chat_plugin = test - .chat_manager - .get_chat_plugin() + .local_ai + .get_ai_plugin() .await .unwrap() .upgrade() @@ -60,11 +61,7 @@ async fn ci_chat_stream_test() { let score = test.calculate_similarity(&answer, expected).await; assert!(score > 0.7, "score: {}", score); - let questions = test - .chat_manager - .get_related_question(&chat_id) - .await - .unwrap(); + let questions = test.local_ai.get_related_question(&chat_id).await.unwrap(); assert_eq!(questions.len(), 3); println!("related questions: {:?}", questions) } @@ -76,8 +73,8 @@ async fn ci_completion_text_test() { test.init_embedding_plugin().await; let chat_plugin = test - .chat_manager - .get_chat_plugin() + .local_ai + .get_ai_plugin() .await .unwrap() .upgrade() @@ -90,7 +87,7 @@ async fn ci_completion_text_test() { }); let mut resp = test - .chat_manager + .local_ai .complete_text("tell me the book, atomic habits", CompleteTextType::AskAI) .await .unwrap(); @@ -114,10 +111,10 @@ async fn ci_chat_with_pdf() { test.init_embedding_plugin().await; let chat_id = uuid::Uuid::new_v4().to_string(); let pdf = get_asset_path("AppFlowy_Values.pdf"); - test.chat_manager.index_file(&chat_id, pdf).await.unwrap(); + test.local_ai.index_file(&chat_id, pdf).await.unwrap(); let resp = test - .chat_manager + .local_ai .ask_question( &chat_id, // "what is the meaning of Aim High and Iterate in AppFlowy?", @@ -139,11 +136,74 @@ async fn ci_chat_with_pdf() { assert!(score > 0.6, "score: {}", score); } +#[tokio::test] +async fn ci_database_row_test() { + let test = LocalAITest::new().unwrap(); + test.init_chat_plugin().await; + test.init_embedding_plugin().await; + + // summary + let mut params = HashMap::new(); + params.insert("book name".to_string(), "Atomic Habits".to_string()); + params.insert("finish reading at".to_string(), "2023-02-10".to_string()); + params.insert( + "notes".to_string(), + "An atomic habit is a regular practice or routine that is not + only small and easy to do but is also the source of incredible power; a + component of the system of compound growth. Bad habits repeat themselves + again and again not because you don’t want to change, but because you + have the wrong system for change. Changes that seem small and + unimportant at first will compound into remarkable results if you’re + willing to stick with them for years" + .to_string(), + ); + let resp = test.local_ai.summary_database_row(params).await.unwrap(); + let expected = r#" + Finished reading "Atomic Habits" on 2023-02-10. The book emphasizes that + small, regular practices can lead to significant growth over time. Bad + habits persist due to flawed systems, and minor, consistent changes can + yield impressive results when maintained over the long term. + "#; + let score = test.calculate_similarity(&resp, expected).await; + assert!(score > 0.8, "score: {}", score); + + // translate + let data = LocalAITranslateRowData { + cells: vec![ + LocalAITranslateItem { + title: "book name".to_string(), + content: "Atomic Habits".to_string(), + }, + LocalAITranslateItem { + title: "score".to_string(), + content: "8".to_string(), + }, + LocalAITranslateItem { + title: "finish reading at".to_string(), + content: "2023-02-10".to_string(), + }, + ], + language: "chinese".to_string(), + include_header: false, + }; + let resp = test.local_ai.translate_database_row(data).await.unwrap(); + let resp_str: String = resp + .items + .into_iter() + .flat_map(|map| map.into_iter().map(|(k, v)| format!("{}:{}", k, v))) + .collect::>() + .join(","); + + let expected = r#"书名:原子习惯,评分:8,完成阅读日期:2023-02-10"#; + let score = test.calculate_similarity(&resp_str, expected).await; + assert!(score > 0.8, "score: {}, actural: {}", score, resp_str); +} + #[tokio::test] async fn load_aws_chat_bin_test() { setup_log(); let plugin_manager = PluginManager::new(); - let llm_chat = LocalChatLLMChat::new(Arc::new(plugin_manager)); + let llm_chat = AppFlowyLocalAI::new(Arc::new(plugin_manager)); let chat_bin = chat_bin_path().await; // clear_extended_attributes(&chat_bin).await.unwrap(); @@ -163,7 +223,6 @@ async fn load_aws_chat_bin_test() { async fn chat_bin_path() -> PathBuf { let url = "https://appflowy-local-ai.s3.amazonaws.com/macos-latest/AppFlowyAI_release.zip?AWSAccessKeyId=AKIAVQA4ULIFKSXHI6PI&Signature=p8evDjdypl58nbGK8qJ%2F1l0Zs%2FU%3D&Expires=1721044152"; - // let url = ""; let temp_dir = temp_dir().join("download_plugin"); if !temp_dir.exists() { std::fs::create_dir(&temp_dir).unwrap(); diff --git a/appflowy-local-ai/tests/util.rs b/appflowy-local-ai/tests/util.rs index 197df60..4af3d26 100644 --- a/appflowy-local-ai/tests/util.rs +++ b/appflowy-local-ai/tests/util.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use appflowy_local_ai::chat_plugin::{AIPluginConfig, LocalChatLLMChat}; +use appflowy_local_ai::chat_plugin::{AIPluginConfig, AppFlowyLocalAI}; use appflowy_local_ai::embedding_plugin::{EmbeddingPluginConfig, LocalEmbedding}; use appflowy_plugin::error::PluginError; use appflowy_plugin::manager::PluginManager; @@ -15,7 +15,7 @@ use tracing_subscriber::EnvFilter; pub struct LocalAITest { config: LocalAIConfiguration, - pub chat_manager: LocalChatLLMChat, + pub local_ai: AppFlowyLocalAI, pub embedding_manager: LocalEmbedding, } @@ -23,11 +23,11 @@ impl LocalAITest { pub fn new() -> Result { let config = LocalAIConfiguration::new()?; let sidecar = Arc::new(PluginManager::new()); - let chat_manager = LocalChatLLMChat::new(sidecar.clone()); + let chat_manager = AppFlowyLocalAI::new(sidecar.clone()); let embedding_manager = LocalEmbedding::new(sidecar); Ok(Self { config, - chat_manager, + local_ai: chat_manager, embedding_manager, }) } @@ -49,7 +49,7 @@ impl LocalAITest { .set_rag_enabled(&self.config.embedding_model_absolute_path(), &persist_dir) .unwrap(); - self.chat_manager.init_chat_plugin(config).await.unwrap(); + self.local_ai.init_chat_plugin(config).await.unwrap(); } pub async fn init_embedding_plugin(&self) { @@ -68,11 +68,7 @@ impl LocalAITest { } pub async fn send_chat_message(&self, chat_id: &str, message: &str) -> String { - self - .chat_manager - .ask_question(chat_id, message) - .await - .unwrap() + self.local_ai.ask_question(chat_id, message).await.unwrap() } pub async fn stream_chat_message( @@ -81,7 +77,7 @@ impl LocalAITest { message: &str, ) -> ReceiverStream> { self - .chat_manager + .local_ai .stream_question(chat_id, message) .await .unwrap()