Skip to content

Commit

Permalink
chore: add test
Browse files Browse the repository at this point in the history
  • Loading branch information
appflowy committed Jul 31, 2024
1 parent e31a029 commit 7dd879a
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@ use appflowy_plugin::core::parser::{DefaultResponseParser, ResponseParser};
use appflowy_plugin::core::plugin::Plugin;
use appflowy_plugin::error::{PluginError, RemoteError};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Weak;
use tokio_stream::wrappers::ReceiverStream;
use tracing::instrument;

pub struct ChatPluginOperation {
pub struct AIPluginOperation {
plugin: Weak<Plugin>,
}

impl ChatPluginOperation {
impl AIPluginOperation {
pub fn new(plugin: Weak<Plugin>) -> Self {
ChatPluginOperation { plugin }
AIPluginOperation { plugin }
}

fn get_plugin(&self) -> Result<std::sync::Arc<Plugin>, PluginError> {
Expand Down Expand Up @@ -118,6 +120,43 @@ impl ChatPluginOperation {
});
plugin.stream_request::<ChatStreamResponseParser>("handle", &params)
}

#[instrument(level = "debug", skip(self), err)]
pub async fn summary_row(&self, row: HashMap<String, String>) -> Result<String, PluginError> {
let params = json!({"params": row });
self
.send_request::<DatabaseSummaryResponseParser>("database_summary", params)
.await
}

#[instrument(level = "debug", skip(self), err)]
pub async fn translate_row(
&self,
data: LocalAITranslateRowData,
) -> Result<LocalAITranslateRowResponse, PluginError> {
let params = json!({"params": data });
self
.send_request::<DatabaseTranslateResponseParser>("database_translate", params)
.await
}
}

#[derive(Clone, Debug, Serialize)]
pub struct LocalAITranslateRowData {
pub cells: Vec<LocalAITranslateItem>,
pub language: String,
pub include_header: bool,
}

#[derive(Clone, Debug, Serialize)]
pub struct LocalAITranslateItem {
pub title: String,
pub content: String,
}

#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct LocalAITranslateRowResponse {
pub items: Vec<HashMap<String, String>>,
}

pub struct ChatResponseParser;
Expand Down Expand Up @@ -189,3 +228,28 @@ impl From<u8> for CompleteTextType {
}
}
}

pub struct DatabaseSummaryResponseParser;
impl ResponseParser for DatabaseSummaryResponseParser {
type ValueType = String;

fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
json
.get("data")
.and_then(|data| data.as_str())
.map(|s| s.to_string())
.ok_or(RemoteError::ParseResponse(json))
}
}

pub struct DatabaseTranslateResponseParser;
impl ResponseParser for DatabaseTranslateResponseParser {
type ValueType = LocalAITranslateRowResponse;

fn parse_json(json: JsonValue) -> Result<Self::ValueType, RemoteError> {
json
.get("data")
.and_then(|data| LocalAITranslateRowResponse::deserialize(data).ok())
.ok_or(RemoteError::ParseResponse(json))
}
}
63 changes: 45 additions & 18 deletions appflowy-local-ai/src/chat_plugin.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::chat_ops::{ChatPluginOperation, CompleteTextType};
use crate::ai_ops::{
AIPluginOperation, CompleteTextType, LocalAITranslateRowData, LocalAITranslateRowResponse,
};
use anyhow::{anyhow, Result};
use appflowy_plugin::core::plugin::{
Plugin, PluginInfo, RunningState, RunningStateReceiver, RunningStateSender,
Expand All @@ -8,6 +10,7 @@ use appflowy_plugin::manager::PluginManager;
use appflowy_plugin::util::{get_operating_system, OperatingSystem};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Debug;
use std::path::PathBuf;
use std::sync::{Arc, Weak};
Expand All @@ -34,7 +37,7 @@ impl LocalLLMSetting {
}
}

pub struct LocalChatLLMChat {
pub struct AppFlowyLocalAI {
plugin_manager: Arc<PluginManager>,
plugin_config: RwLock<Option<AIPluginConfig>>,
running_state: RunningStateSender,
Expand All @@ -43,7 +46,7 @@ pub struct LocalChatLLMChat {
running_state_rx: RunningStateReceiver,
}

impl LocalChatLLMChat {
impl AppFlowyLocalAI {
pub fn new(plugin_manager: Arc<PluginManager>) -> Self {
let (running_state, rx) = tokio::sync::watch::channel(RunningState::Connecting);
Self {
Expand All @@ -67,8 +70,8 @@ impl LocalChatLLMChat {
trace!("[AI Plugin] create chat: {}", chat_id);
self.wait_until_plugin_ready().await?;

let plugin = self.get_chat_plugin().await?;
let operation = ChatPluginOperation::new(plugin);
let plugin = self.get_ai_plugin().await?;
let operation = AIPluginOperation::new(plugin);
operation.create_chat(chat_id, true).await?;
Ok(())
}
Expand All @@ -84,8 +87,8 @@ impl LocalChatLLMChat {
/// A `Result<()>` indicating success or failure.
pub async fn close_chat(&self, chat_id: &str) -> Result<()> {
trace!("[AI Plugin] close chat: {}", chat_id);
let plugin = self.get_chat_plugin().await?;
let operation = ChatPluginOperation::new(plugin);
let plugin = self.get_ai_plugin().await?;
let operation = AIPluginOperation::new(plugin);
operation.close_chat(chat_id).await?;
Ok(())
}
Expand Down Expand Up @@ -115,16 +118,16 @@ impl LocalChatLLMChat {
) -> Result<ReceiverStream<anyhow::Result<Bytes, PluginError>>, PluginError> {
trace!("[AI Plugin] ask question: {}", message);
self.wait_until_plugin_ready().await?;
let plugin = self.get_chat_plugin().await?;
let operation = ChatPluginOperation::new(plugin);
let plugin = self.get_ai_plugin().await?;
let operation = AIPluginOperation::new(plugin);
let stream = operation.stream_message(chat_id, message, true).await?;
Ok(stream)
}

pub async fn get_related_question(&self, chat_id: &str) -> Result<Vec<String>, PluginError> {
self.wait_until_plugin_ready().await?;
let plugin = self.get_chat_plugin().await?;
let operation = ChatPluginOperation::new(plugin);
let plugin = self.get_ai_plugin().await?;
let operation = AIPluginOperation::new(plugin);
let values = operation.get_related_questions(chat_id).await?;
Ok(values)
}
Expand All @@ -143,8 +146,8 @@ impl LocalChatLLMChat {
)))?;

self.wait_until_plugin_ready().await?;
let plugin = self.get_chat_plugin().await?;
let operation = ChatPluginOperation::new(plugin);
let plugin = self.get_ai_plugin().await?;
let operation = AIPluginOperation::new(plugin);
trace!("[AI Plugin] indexing file: {}", file_path);
operation.index_file(chat_id, file_path).await?;
Ok(())
Expand All @@ -162,8 +165,8 @@ impl LocalChatLLMChat {
/// A `Result<String>` containing the generated answer.
pub async fn ask_question(&self, chat_id: &str, message: &str) -> Result<String, PluginError> {
self.wait_until_plugin_ready().await?;
let plugin = self.get_chat_plugin().await?;
let operation = ChatPluginOperation::new(plugin);
let plugin = self.get_ai_plugin().await?;
let operation = AIPluginOperation::new(plugin);
let answer = operation.send_message(chat_id, message, true).await?;
Ok(answer)
}
Expand All @@ -187,12 +190,36 @@ impl LocalChatLLMChat {
) -> Result<ReceiverStream<anyhow::Result<Bytes, PluginError>>, PluginError> {
trace!("[AI Plugin] complete text: {}", message);
self.wait_until_plugin_ready().await?;
let plugin = self.get_chat_plugin().await?;
let operation = ChatPluginOperation::new(plugin);
let plugin = self.get_ai_plugin().await?;
let operation = AIPluginOperation::new(plugin);
let stream = operation.complete_text(message, complete_type).await?;
Ok(stream)
}

pub async fn summary_database_row(
&self,
row: HashMap<String, String>,
) -> Result<String, PluginError> {
trace!("[AI Plugin] summary database row: {:?}", row);
self.wait_until_plugin_ready().await?;
let plugin = self.get_ai_plugin().await?;
let operation = AIPluginOperation::new(plugin);
let text = operation.summary_row(row).await?;
Ok(text)
}

pub async fn translate_database_row(
&self,
row: LocalAITranslateRowData,
) -> Result<LocalAITranslateRowResponse, PluginError> {
trace!("[AI Plugin] summary database row: {:?}", row);
self.wait_until_plugin_ready().await?;
let plugin = self.get_ai_plugin().await?;
let operation = AIPluginOperation::new(plugin);
let resp = operation.translate_row(row).await?;
Ok(resp)
}

#[instrument(skip_all, err)]
pub async fn init_chat_plugin(&self, config: AIPluginConfig) -> Result<()> {
let state = self.running_state.borrow().clone();
Expand Down Expand Up @@ -320,7 +347,7 @@ impl LocalChatLLMChat {
/// # Returns
///
/// A `Result<Weak<Plugin>>` containing a weak reference to the plugin.
pub async fn get_chat_plugin(&self) -> Result<Weak<Plugin>, PluginError> {
pub async fn get_ai_plugin(&self) -> Result<Weak<Plugin>, PluginError> {
let plugin_id = self
.running_state
.borrow()
Expand Down
2 changes: 1 addition & 1 deletion appflowy-local-ai/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod chat_ops;
pub mod ai_ops;
pub mod chat_plugin;
pub mod embedding_ops;
pub mod embedding_plugin;
Expand Down
Loading

0 comments on commit 7dd879a

Please sign in to comment.