diff --git a/appflowy-local-ai/src/chat_plugin.rs b/appflowy-local-ai/src/chat_plugin.rs index c38c4ea..1f47b47 100644 --- a/appflowy-local-ai/src/chat_plugin.rs +++ b/appflowy-local-ai/src/chat_plugin.rs @@ -83,6 +83,16 @@ impl ChatPluginOperation { ) .await } + + pub async fn index_file(&self, chat_id: &str, file_path: &str) -> Result<(), PluginError> { + let params = json!({ "file_path": file_path, "metadatas": [{"chat_id": chat_id}] }); + self + .send_request::( + "index_file", + json!({ "chat_id": chat_id, "params": params }), + ) + .await + } } pub struct ChatResponseParser; diff --git a/appflowy-local-ai/src/llm_chat.rs b/appflowy-local-ai/src/llm_chat.rs index 4c7ba66..907629f 100644 --- a/appflowy-local-ai/src/llm_chat.rs +++ b/appflowy-local-ai/src/llm_chat.rs @@ -10,6 +10,7 @@ use serde::{Deserialize, Serialize}; use std::path::PathBuf; use std::sync::{Arc, Weak}; use std::time::Duration; +use tokio::io; use tokio::sync::RwLock; use tokio::time::timeout; use tokio_stream::wrappers::ReceiverStream; @@ -104,7 +105,7 @@ impl LocalChatLLMChat { /// # Returns /// /// A `Result>>` containing a stream of responses. - pub async fn ask_question( + pub async fn stream_question( &self, chat_id: &str, message: &str, @@ -125,6 +126,27 @@ impl LocalChatLLMChat { Ok(values) } + pub async fn index_file(&self, chat_id: &str, file_path: PathBuf) -> Result<(), PluginError> { + if !file_path.exists() { + return Err(PluginError::Io(io::Error::new( + io::ErrorKind::NotFound, + "file not found", + ))); + } + + let file_path = file_path.to_str().ok_or(PluginError::Io(io::Error::new( + io::ErrorKind::NotFound, + "file path invalid", + )))?; + + self.wait_until_plugin_ready().await?; + let plugin = self.get_chat_plugin().await?; + let operation = ChatPluginOperation::new(plugin); + trace!("[Chat Plugin] indexing file: {}", file_path); + operation.index_file(chat_id, file_path).await?; + Ok(()) + } + /// Generates a complete answer for a given message. /// /// # Arguments @@ -135,7 +157,7 @@ impl LocalChatLLMChat { /// # Returns /// /// A `Result` containing the generated answer. - pub async fn generate_answer(&self, chat_id: &str, message: &str) -> Result { + pub async fn ask_question(&self, chat_id: &str, message: &str) -> Result { self.wait_until_plugin_ready().await?; let plugin = self.get_chat_plugin().await?; let operation = ChatPluginOperation::new(plugin); @@ -226,10 +248,13 @@ impl LocalChatLLMChat { params["absolute_related_model_path"] = serde_json::json!(related_model_path); } - if let Some(embedding_model_path) = config.embedding_model_path.clone() { + if let (Some(embedding_model_path), Some(persist_directory)) = ( + config.embedding_model_path.clone(), + config.persist_directory.clone(), + ) { params["vectorstore_config"] = serde_json::json!({ "absolute_model_path": embedding_model_path, - "persist_directory": "./", + "persist_directory": persist_directory, }); } @@ -303,9 +328,9 @@ pub struct ChatPluginConfig { chat_model_path: PathBuf, related_model_path: Option, embedding_model_path: Option, + persist_directory: Option, device: String, verbose: bool, - rag_enabled: bool, } impl ChatPluginConfig { @@ -338,9 +363,9 @@ impl ChatPluginConfig { chat_model_path, related_model_path: None, embedding_model_path: None, + persist_directory: None, device: "cpu".to_string(), verbose: false, - rag_enabled: false, }) } @@ -353,18 +378,35 @@ impl ChatPluginConfig { self.verbose = verbose; self } - pub fn with_rag_enabled(mut self, rag_enabled: bool) -> Self { - self.rag_enabled = rag_enabled; - self + pub fn with_rag_enabled( + mut self, + embedding_model_path: PathBuf, + persist_directory: PathBuf, + ) -> Result { + if !embedding_model_path.exists() { + return Err(anyhow!( + "embedding model path does not exist: {:?}", + embedding_model_path + )); + } + if !embedding_model_path.is_file() { + return Err(anyhow!( + "embedding model is not a file: {:?}", + embedding_model_path + )); + } + + if !persist_directory.exists() { + std::fs::create_dir_all(&persist_directory)?; + } + + self.embedding_model_path = Some(embedding_model_path); + self.persist_directory = Some(persist_directory); + Ok(self) } pub fn with_related_model_path>(mut self, related_model_path: T) -> Self { self.related_model_path = Some(related_model_path.into()); self } - - pub fn with_embedding_model_path>(mut self, embedding_model_path: T) -> Self { - self.embedding_model_path = Some(embedding_model_path.into()); - self - } } diff --git a/appflowy-local-ai/tests/asset/AppFlowy_Values.pdf b/appflowy-local-ai/tests/asset/AppFlowy_Values.pdf new file mode 100644 index 0000000..7030336 Binary files /dev/null and b/appflowy-local-ai/tests/asset/AppFlowy_Values.pdf differ diff --git a/appflowy-local-ai/tests/chat_test/mod.rs b/appflowy-local-ai/tests/chat_test/mod.rs index 25d6693..0543156 100644 --- a/appflowy-local-ai/tests/chat_test/mod.rs +++ b/appflowy-local-ai/tests/chat_test/mod.rs @@ -1,5 +1,4 @@ -use crate::util::LocalAITest; -use std::time::Duration; +use crate::util::{get_asset_path, LocalAITest}; use tokio_stream::StreamExt; #[tokio::test] @@ -37,13 +36,36 @@ async fn ci_chat_stream_test() { let expected = r#"banana is a fruit that belongs to the genus _______, which also includes other fruits such as apple and pear. It has several varieties with different shapes, colors, and flavors depending on where it grows. Bananas are typically green or yellow in color and have smooth skin that peels off easily when ripe. They are sweet and juicy, often eaten raw or roasted, and can also be used for cooking and baking. In some cultures, banana is considered a symbol of good luck, fertility, and prosperity. Bananas originated in Southeast Asia, where they were cultivated by early humans thousands of years ago. They are now grown around the world as a major crop, with significant production in many countries including the United States, Brazil, India, and China#"#; let score = test.calculate_similarity(&answer, expected).await; assert!(score > 0.7, "score: {}", score); +} + +#[tokio::test] +async fn ci_chat_with_pdf() { + let test = LocalAITest::new().unwrap(); + test.init_chat_plugin().await; + test.init_embedding_plugin().await; + let chat_id = uuid::Uuid::new_v4().to_string(); + let pdf = get_asset_path("AppFlowy_Values.pdf"); + test.chat_manager.index_file(&chat_id, pdf).await.unwrap(); + + let resp = test + .chat_manager + .ask_question( + &chat_id, + // "what is the meaning of Aim High and Iterate in AppFlowy?", + "what is AppFlowy Values?", + ) + .await + .unwrap(); - // let questions = test - // .chat_manager - // .get_related_question(&chat_id) - // .await - // .unwrap(); - // println!("related questions: {:?}", questions); + println!("chat with pdf response: {}", resp); - tokio::time::sleep(Duration::from_secs(5)).await; + let expected = r#" +1. **Mission Driven**: Our mission is to enable everyone to unleash their potential and achieve more with secure workplace tools. +2. **Collaboration**: We pride ourselves on being a great team. We foster collaboration, value diversity and inclusion, and encourage sharing. +3. **Honesty**: We are honest with ourselves. We admit mistakes freely and openly. We provide candid, helpful, timely feedback to colleagues with respect, regardless of their status or whether they disagree with us. +4. **Aim High and Iterate**: We strive for excellence with a growth mindset. We dream big, start small, and move fast. We take smaller steps and ship smaller, simpler features. +5. **Transparency**: We make information about AppFlowy public by default unless there is a compelling reason not to. We are straightforward and kind with ourselves and each other. +"#; + let score = test.calculate_similarity(&resp, expected).await; + assert!(score > 0.8, "score: {}", score); } diff --git a/appflowy-local-ai/tests/util.rs b/appflowy-local-ai/tests/util.rs index d780339..de18d0f 100644 --- a/appflowy-local-ai/tests/util.rs +++ b/appflowy-local-ai/tests/util.rs @@ -6,7 +6,7 @@ use appflowy_plugin::manager::PluginManager; use bytes::Bytes; use simsimd::SpatialSimilarity; use std::f64; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::sync::{Arc, Once}; use tokio_stream::wrappers::ReceiverStream; use tracing_subscriber::fmt::Subscriber; @@ -44,7 +44,10 @@ impl LocalAITest { config = config.with_related_model_path(related_question_model); } - config = config.with_embedding_model_path(self.config.embedding_model_absolute_path()); + let persist_dir = tempfile::tempdir().unwrap().path().to_path_buf(); + config = config + .with_rag_enabled(self.config.embedding_model_absolute_path(), persist_dir) + .unwrap(); self.chat_manager.init_chat_plugin(config).await.unwrap(); } @@ -67,7 +70,7 @@ impl LocalAITest { pub async fn send_chat_message(&self, chat_id: &str, message: &str) -> String { self .chat_manager - .generate_answer(chat_id, message) + .ask_question(chat_id, message) .await .unwrap() } @@ -79,7 +82,7 @@ impl LocalAITest { ) -> ReceiverStream> { self .chat_manager - .ask_question(chat_id, message) + .stream_question(chat_id, message) .await .unwrap() } @@ -92,15 +95,15 @@ impl LocalAITest { .unwrap() } - pub async fn calculate_similarity(&self, message1: &str, message2: &str) -> f64 { + pub async fn calculate_similarity(&self, input: &str, expected: &str) -> f64 { let left = self .embedding_manager - .generate_embedding(message1) + .generate_embedding(input) .await .unwrap(); let right = self .embedding_manager - .generate_embedding(message2) + .generate_embedding(expected) .await .unwrap(); @@ -185,3 +188,9 @@ pub fn setup_log() { subscriber.try_init().unwrap(); }); } + +pub fn get_asset_path(name: &str) -> PathBuf { + let file = format!("tests/asset/{name}"); + let absolute_path = std::env::current_dir().unwrap().join(Path::new(&file)); + absolute_path +} diff --git a/dev.env b/dev.env index 8857630..4aa93bc 100644 --- a/dev.env +++ b/dev.env @@ -1,8 +1,7 @@ CHAT_BIN_PATH= EMBEDDING_BIN_PATH= -LOCAL_AI_MODEL_DIR= -LOCAL_AI_CHAT_MODEL_NAME= -LOCAL_AI_EMBEDDING_MODEL_NAME= -RUN_ALL_TEST=false +LOCAL_AI_MODEL_DIR='The parent directory of the model files' +LOCAL_AI_CHAT_MODEL_NAME='Meta-Llama-3-8B-Instruct.Q4_0.gguf' +LOCAL_AI_EMBEDDING_MODEL_NAME='all-MiniLM-L12-v2.Q4_0.gguf' LOCAL_AI_APPLE_DEVICE=cpu \ No newline at end of file