Skip to content

Commit

Permalink
chore: support chat with pdf
Browse files Browse the repository at this point in the history
  • Loading branch information
appflowy committed Jul 8, 2024
1 parent 05749c1 commit 0820a0d
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 34 deletions.
10 changes: 10 additions & 0 deletions appflowy-local-ai/src/chat_plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ impl ChatPluginOperation {
)
.await
}

pub async fn index_file(&self, chat_id: &str, file_path: &str) -> Result<(), PluginError> {
let params = json!({ "file_path": file_path, "metadatas": [{"chat_id": chat_id}] });
self
.send_request::<DefaultResponseParser>(
"index_file",
json!({ "chat_id": chat_id, "params": params }),
)
.await
}
}

pub struct ChatResponseParser;
Expand Down
70 changes: 56 additions & 14 deletions appflowy-local-ai/src/llm_chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::{Arc, Weak};
use std::time::Duration;
use tokio::io;
use tokio::sync::RwLock;
use tokio::time::timeout;
use tokio_stream::wrappers::ReceiverStream;
Expand Down Expand Up @@ -104,7 +105,7 @@ impl LocalChatLLMChat {
/// # Returns
///
/// A `Result<ReceiverStream<anyhow::Result<Bytes, SidecarError>>>` containing a stream of responses.
pub async fn ask_question(
pub async fn stream_question(
&self,
chat_id: &str,
message: &str,
Expand All @@ -125,6 +126,27 @@ impl LocalChatLLMChat {
Ok(values)
}

pub async fn index_file(&self, chat_id: &str, file_path: PathBuf) -> Result<(), PluginError> {
if !file_path.exists() {
return Err(PluginError::Io(io::Error::new(
io::ErrorKind::NotFound,
"file not found",
)));
}

let file_path = file_path.to_str().ok_or(PluginError::Io(io::Error::new(
io::ErrorKind::NotFound,
"file path invalid",
)))?;

self.wait_until_plugin_ready().await?;
let plugin = self.get_chat_plugin().await?;
let operation = ChatPluginOperation::new(plugin);
trace!("[Chat Plugin] indexing file: {}", file_path);
operation.index_file(chat_id, file_path).await?;
Ok(())
}

/// Generates a complete answer for a given message.
///
/// # Arguments
Expand All @@ -135,7 +157,7 @@ impl LocalChatLLMChat {
/// # Returns
///
/// A `Result<String>` containing the generated answer.
pub async fn generate_answer(&self, chat_id: &str, message: &str) -> Result<String, PluginError> {
pub async fn ask_question(&self, chat_id: &str, message: &str) -> Result<String, PluginError> {
self.wait_until_plugin_ready().await?;
let plugin = self.get_chat_plugin().await?;
let operation = ChatPluginOperation::new(plugin);
Expand Down Expand Up @@ -226,10 +248,13 @@ impl LocalChatLLMChat {
params["absolute_related_model_path"] = serde_json::json!(related_model_path);
}

if let Some(embedding_model_path) = config.embedding_model_path.clone() {
if let (Some(embedding_model_path), Some(persist_directory)) = (
config.embedding_model_path.clone(),
config.persist_directory.clone(),
) {
params["vectorstore_config"] = serde_json::json!({
"absolute_model_path": embedding_model_path,
"persist_directory": "./",
"persist_directory": persist_directory,
});
}

Expand Down Expand Up @@ -303,9 +328,9 @@ pub struct ChatPluginConfig {
chat_model_path: PathBuf,
related_model_path: Option<PathBuf>,
embedding_model_path: Option<PathBuf>,
persist_directory: Option<PathBuf>,
device: String,
verbose: bool,
rag_enabled: bool,
}

impl ChatPluginConfig {
Expand Down Expand Up @@ -338,9 +363,9 @@ impl ChatPluginConfig {
chat_model_path,
related_model_path: None,
embedding_model_path: None,
persist_directory: None,
device: "cpu".to_string(),
verbose: false,
rag_enabled: false,
})
}

Expand All @@ -353,18 +378,35 @@ impl ChatPluginConfig {
self.verbose = verbose;
self
}
pub fn with_rag_enabled(mut self, rag_enabled: bool) -> Self {
self.rag_enabled = rag_enabled;
self
pub fn with_rag_enabled(
mut self,
embedding_model_path: PathBuf,
persist_directory: PathBuf,
) -> Result<Self> {
if !embedding_model_path.exists() {
return Err(anyhow!(
"embedding model path does not exist: {:?}",
embedding_model_path
));
}
if !embedding_model_path.is_file() {
return Err(anyhow!(
"embedding model is not a file: {:?}",
embedding_model_path
));
}

if !persist_directory.exists() {
std::fs::create_dir_all(&persist_directory)?;
}

self.embedding_model_path = Some(embedding_model_path);
self.persist_directory = Some(persist_directory);
Ok(self)
}

pub fn with_related_model_path<T: Into<PathBuf>>(mut self, related_model_path: T) -> Self {
self.related_model_path = Some(related_model_path.into());
self
}

pub fn with_embedding_model_path<T: Into<PathBuf>>(mut self, embedding_model_path: T) -> Self {
self.embedding_model_path = Some(embedding_model_path.into());
self
}
}
Binary file added appflowy-local-ai/tests/asset/AppFlowy_Values.pdf
Binary file not shown.
40 changes: 31 additions & 9 deletions appflowy-local-ai/tests/chat_test/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::util::LocalAITest;
use std::time::Duration;
use crate::util::{get_asset_path, LocalAITest};
use tokio_stream::StreamExt;

#[tokio::test]
Expand Down Expand Up @@ -37,13 +36,36 @@ async fn ci_chat_stream_test() {
let expected = r#"banana is a fruit that belongs to the genus _______, which also includes other fruits such as apple and pear. It has several varieties with different shapes, colors, and flavors depending on where it grows. Bananas are typically green or yellow in color and have smooth skin that peels off easily when ripe. They are sweet and juicy, often eaten raw or roasted, and can also be used for cooking and baking. In some cultures, banana is considered a symbol of good luck, fertility, and prosperity. Bananas originated in Southeast Asia, where they were cultivated by early humans thousands of years ago. They are now grown around the world as a major crop, with significant production in many countries including the United States, Brazil, India, and China#"#;
let score = test.calculate_similarity(&answer, expected).await;
assert!(score > 0.7, "score: {}", score);
}

#[tokio::test]
async fn ci_chat_with_pdf() {
let test = LocalAITest::new().unwrap();
test.init_chat_plugin().await;
test.init_embedding_plugin().await;
let chat_id = uuid::Uuid::new_v4().to_string();
let pdf = get_asset_path("AppFlowy_Values.pdf");
test.chat_manager.index_file(&chat_id, pdf).await.unwrap();

let resp = test
.chat_manager
.ask_question(
&chat_id,
// "what is the meaning of Aim High and Iterate in AppFlowy?",
"what is AppFlowy Values?",
)
.await
.unwrap();

// let questions = test
// .chat_manager
// .get_related_question(&chat_id)
// .await
// .unwrap();
// println!("related questions: {:?}", questions);
println!("chat with pdf response: {}", resp);

tokio::time::sleep(Duration::from_secs(5)).await;
let expected = r#"
1. **Mission Driven**: Our mission is to enable everyone to unleash their potential and achieve more with secure workplace tools.
2. **Collaboration**: We pride ourselves on being a great team. We foster collaboration, value diversity and inclusion, and encourage sharing.
3. **Honesty**: We are honest with ourselves. We admit mistakes freely and openly. We provide candid, helpful, timely feedback to colleagues with respect, regardless of their status or whether they disagree with us.
4. **Aim High and Iterate**: We strive for excellence with a growth mindset. We dream big, start small, and move fast. We take smaller steps and ship smaller, simpler features.
5. **Transparency**: We make information about AppFlowy public by default unless there is a compelling reason not to. We are straightforward and kind with ourselves and each other.
"#;
let score = test.calculate_similarity(&resp, expected).await;
assert!(score > 0.8, "score: {}", score);
}
23 changes: 16 additions & 7 deletions appflowy-local-ai/tests/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use appflowy_plugin::manager::PluginManager;
use bytes::Bytes;
use simsimd::SpatialSimilarity;
use std::f64;
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Once};
use tokio_stream::wrappers::ReceiverStream;
use tracing_subscriber::fmt::Subscriber;
Expand Down Expand Up @@ -44,7 +44,10 @@ impl LocalAITest {
config = config.with_related_model_path(related_question_model);
}

config = config.with_embedding_model_path(self.config.embedding_model_absolute_path());
let persist_dir = tempfile::tempdir().unwrap().path().to_path_buf();
config = config
.with_rag_enabled(self.config.embedding_model_absolute_path(), persist_dir)
.unwrap();

self.chat_manager.init_chat_plugin(config).await.unwrap();
}
Expand All @@ -67,7 +70,7 @@ impl LocalAITest {
pub async fn send_chat_message(&self, chat_id: &str, message: &str) -> String {
self
.chat_manager
.generate_answer(chat_id, message)
.ask_question(chat_id, message)
.await
.unwrap()
}
Expand All @@ -79,7 +82,7 @@ impl LocalAITest {
) -> ReceiverStream<Result<Bytes, PluginError>> {
self
.chat_manager
.ask_question(chat_id, message)
.stream_question(chat_id, message)
.await
.unwrap()
}
Expand All @@ -92,15 +95,15 @@ impl LocalAITest {
.unwrap()
}

pub async fn calculate_similarity(&self, message1: &str, message2: &str) -> f64 {
pub async fn calculate_similarity(&self, input: &str, expected: &str) -> f64 {
let left = self
.embedding_manager
.generate_embedding(message1)
.generate_embedding(input)
.await
.unwrap();
let right = self
.embedding_manager
.generate_embedding(message2)
.generate_embedding(expected)
.await
.unwrap();

Expand Down Expand Up @@ -185,3 +188,9 @@ pub fn setup_log() {
subscriber.try_init().unwrap();
});
}

pub fn get_asset_path(name: &str) -> PathBuf {
let file = format!("tests/asset/{name}");
let absolute_path = std::env::current_dir().unwrap().join(Path::new(&file));
absolute_path
}
7 changes: 3 additions & 4 deletions dev.env
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@

CHAT_BIN_PATH=
EMBEDDING_BIN_PATH=
LOCAL_AI_MODEL_DIR=
LOCAL_AI_CHAT_MODEL_NAME=
LOCAL_AI_EMBEDDING_MODEL_NAME=
RUN_ALL_TEST=false
LOCAL_AI_MODEL_DIR='The parent directory of the model files'
LOCAL_AI_CHAT_MODEL_NAME='Meta-Llama-3-8B-Instruct.Q4_0.gguf'
LOCAL_AI_EMBEDDING_MODEL_NAME='all-MiniLM-L12-v2.Q4_0.gguf'
LOCAL_AI_APPLE_DEVICE=cpu

0 comments on commit 0820a0d

Please sign in to comment.