Skip to content

Commit 8a98e33

Browse files
authored
Struct version of tool & knowledge (#212)
1 parent 6b5b098 commit 8a98e33

File tree

9 files changed

+268
-189
lines changed

9 files changed

+268
-189
lines changed

v2/src/agent/agent.rs

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ use futures::{Stream, StreamExt, lock::Mutex};
55

66
use crate::{
77
agent::SystemMessageRenderer,
8-
knowledge::Knowledge,
8+
knowledge::{Knowledge, KnowledgeBehavior as _},
99
model::{InferenceConfig, LangModel, LangModelInference as _},
10-
tool::Tool,
10+
tool::{Tool, ToolBehavior as _},
1111
utils::log,
1212
value::{Delta, FinishReason, Message, MessageDelta, Part, PartDelta, Role},
1313
warn,
@@ -16,9 +16,9 @@ use crate::{
1616
#[derive(Clone)]
1717
pub struct Agent {
1818
lm: LangModel,
19-
tools: Vec<Arc<dyn Tool>>,
19+
tools: Vec<Tool>,
2020
messages: Arc<Mutex<Vec<Message>>>,
21-
knowledge: Option<Arc<dyn Knowledge>>,
21+
knowledge: Option<Knowledge>,
2222
system_message_renderer: Arc<SystemMessageRenderer>,
2323
}
2424

@@ -34,7 +34,7 @@ pub struct AgentResponse {
3434
}
3535

3636
impl Agent {
37-
pub fn new(lm: LangModel, tools: impl IntoIterator<Item = Arc<dyn Tool>>) -> Self {
37+
pub fn new(lm: LangModel, tools: impl IntoIterator<Item = Tool>) -> Self {
3838
Self {
3939
lm,
4040
tools: tools.into_iter().collect(),
@@ -55,11 +55,11 @@ impl Agent {
5555
self.lm.clone()
5656
}
5757

58-
pub fn get_tools(&self) -> Vec<Arc<dyn Tool>> {
58+
pub fn get_tools(&self) -> Vec<Tool> {
5959
self.tools.clone()
6060
}
6161

62-
pub fn knowledge(&self) -> Option<Arc<dyn Knowledge>> {
62+
pub fn knowledge(&self) -> Option<Knowledge> {
6363
self.knowledge.clone()
6464
}
6565

@@ -68,7 +68,7 @@ impl Agent {
6868
Ok(())
6969
}
7070

71-
pub async fn add_tools(&mut self, tools: Vec<Arc<dyn Tool>>) -> anyhow::Result<()> {
71+
pub async fn add_tools(&mut self, tools: Vec<Tool>) -> anyhow::Result<()> {
7272
for tool in tools.iter() {
7373
let tool_name = tool.get_description().name;
7474

@@ -92,7 +92,7 @@ impl Agent {
9292
Ok(())
9393
}
9494

95-
pub async fn add_tool(&mut self, tool: Arc<dyn Tool>) -> anyhow::Result<()> {
95+
pub async fn add_tool(&mut self, tool: Tool) -> anyhow::Result<()> {
9696
self.add_tools(vec![tool]).await
9797
}
9898

@@ -109,17 +109,8 @@ impl Agent {
109109
self.remove_tools(vec![tool_name]).await
110110
}
111111

112-
pub async fn remove_mcp_tools(&mut self, client_name: String) -> anyhow::Result<()> {
113-
self.tools.retain(|t| {
114-
let tool_name = t.get_description().name;
115-
// Remove the MCP tool if its description name is prefixed with the provided client name.
116-
!tool_name.starts_with(format!("{}--", client_name).as_str())
117-
});
118-
Ok(())
119-
}
120-
121-
pub fn set_knowledge(&mut self, knowledge: impl Knowledge + 'static) {
122-
self.knowledge = Some(Arc::new(knowledge));
112+
pub fn set_knowledge(&mut self, knowledge: Knowledge) {
113+
self.knowledge = Some(knowledge);
123114
}
124115

125116
pub fn remove_knowledge(&mut self) {
@@ -140,10 +131,10 @@ impl Agent {
140131
let system_message_content = "You are helpful assistant.".to_string();
141132
let knowledge_results = if let Some(knowledge) = &self.knowledge {
142133
let query = contents.iter().filter(|&c| matches!(c, Part::Text{..})).map(|c| c.as_text().unwrap()).collect::<Vec<_>>().join("\n");
143-
let retrieved = match knowledge.retrieve(query.clone()).await {
134+
let retrieved = match knowledge.retrieve(query.clone(), 1).await {
144135
Ok(retrieved) => retrieved,
145136
Err(e) => {
146-
warn!("Failed to retrieve from knowledge {}: {}", knowledge.name(), e.to_string());
137+
warn!("Failed to retrieve from knowledge: {}", e.to_string());
147138
vec![]
148139
}
149140
};
@@ -319,7 +310,7 @@ mod tests {
319310
// Part::Text("104".to_owned())
320311
// }
321312
// }),
322-
// )) as Arc<dyn Tool>];
313+
// )) as Tool];
323314
// let mut agent = Agent::new(model, tools);
324315

325316
// let mut agg = MessageAggregator::new();

v2/src/knowledge/base.rs

Lines changed: 46 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@ use ailoy_macros::{maybe_send_sync, multi_platform_async_trait};
44
use serde::{Deserialize, Serialize};
55

66
use crate::{
7+
knowledge::{CustomKnowledge, VectorStoreKnowledge},
8+
model::EmbeddingModel,
79
to_value,
8-
tool::Tool,
10+
tool::ToolBehavior,
911
value::{ToolDesc, Value},
12+
vector_store::VectorStore,
1013
};
1114

1215
type Metadata = serde_json::Map<String, serde_json::Value>;
@@ -19,15 +22,17 @@ pub struct KnowledgeRetrieveResult {
1922

2023
#[maybe_send_sync]
2124
#[multi_platform_async_trait]
22-
pub trait Knowledge: std::fmt::Debug {
23-
fn name(&self) -> String;
24-
25-
async fn retrieve(&self, query: String) -> anyhow::Result<Vec<KnowledgeRetrieveResult>>;
25+
pub trait KnowledgeBehavior: std::fmt::Debug {
26+
async fn retrieve(
27+
&self,
28+
query: String,
29+
top_k: u32,
30+
) -> anyhow::Result<Vec<KnowledgeRetrieveResult>>;
2631
}
2732

2833
#[derive(Clone)]
2934
pub struct KnowledgeTool {
30-
inner: Arc<dyn Knowledge>,
35+
inner: Arc<dyn KnowledgeBehavior>,
3136
desc: ToolDesc,
3237
}
3338

@@ -36,15 +41,14 @@ impl std::fmt::Debug for KnowledgeTool {
3641
f.debug_struct("KnowledgeTool")
3742
.field("desc", &self.desc)
3843
.field("inner", &self.inner)
39-
.field("stringify", &"(Function)")
4044
.finish()
4145
}
4246
}
4347

4448
impl KnowledgeTool {
45-
pub fn from(knowledge: impl Knowledge + 'static) -> Self {
49+
pub fn from(knowledge: impl KnowledgeBehavior + 'static) -> Self {
4650
let default_desc = ToolDesc {
47-
name: format!("retrieve-{}", knowledge.name()),
51+
name: "retrieve-from-knowledge".into(),
4852
description: Some("Retrieve the relevant context from knowledge base.".into()),
4953
parameters: to_value!({
5054
"type": "object",
@@ -71,7 +75,7 @@ impl KnowledgeTool {
7175
}
7276

7377
#[multi_platform_async_trait]
74-
impl Tool for KnowledgeTool {
78+
impl ToolBehavior for KnowledgeTool {
7579
fn get_description(&self) -> ToolDesc {
7680
self.desc.clone()
7781
}
@@ -94,7 +98,7 @@ impl Tool for KnowledgeTool {
9498
}
9599
};
96100

97-
let results = match self.inner.retrieve(query.into()).await {
101+
let results = match self.inner.retrieve(query.into(), 1).await {
98102
Ok(results) => results,
99103
Err(e) => {
100104
return Ok(e.to_string().into());
@@ -106,58 +110,44 @@ impl Tool for KnowledgeTool {
106110
}
107111
}
108112

109-
#[cfg(test)]
110-
mod tests {
111-
use ailoy_macros::multi_platform_test;
112-
use futures::stream::StreamExt;
113-
114-
use super::*;
115-
use crate::{agent::Agent, model::LangModel, value::Part};
113+
#[derive(Debug, Clone)]
114+
pub enum KnowledgeInner {
115+
VectorStore(VectorStoreKnowledge),
116+
Custom(CustomKnowledge),
117+
}
116118

117-
#[derive(Debug)]
118-
struct CustomKnowledge {}
119+
#[derive(Debug, Clone)]
120+
pub struct Knowledge {
121+
inner: KnowledgeInner,
122+
}
119123

120-
#[multi_platform_async_trait]
121-
impl Knowledge for CustomKnowledge {
122-
fn name(&self) -> String {
123-
"about-ailoy".into()
124+
impl Knowledge {
125+
pub fn new_vector_store(
126+
store: impl VectorStore + 'static,
127+
embedding_model: EmbeddingModel,
128+
) -> Self {
129+
Self {
130+
inner: KnowledgeInner::VectorStore(VectorStoreKnowledge::new(store, embedding_model)),
124131
}
132+
}
125133

126-
async fn retrieve(&self, _query: String) -> anyhow::Result<Vec<KnowledgeRetrieveResult>> {
127-
let documents = vec![
128-
KnowledgeRetrieveResult {
129-
document: "Ailoy is an awesome AI agent framework.".into(),
130-
metadata: None,
131-
},
132-
KnowledgeRetrieveResult {
133-
document: "Ailoy supports Python, Javascript and Rust.".into(),
134-
metadata: None,
135-
},
136-
KnowledgeRetrieveResult {
137-
document: "Ailoy enables running LLMs in local environment easily.".into(),
138-
metadata: None,
139-
},
140-
];
141-
Ok(documents)
134+
pub fn new_custom(knowledge: CustomKnowledge) -> Self {
135+
Self {
136+
inner: KnowledgeInner::Custom(knowledge),
142137
}
143138
}
139+
}
144140

145-
#[multi_platform_test]
146-
async fn test_custom_knowledge_with_agent() -> anyhow::Result<()> {
147-
let knowledge = CustomKnowledge {};
148-
let model = LangModel::try_new_local("Qwen/Qwen3-0.6B").await.unwrap();
149-
let mut agent = Agent::new(model, vec![]);
150-
151-
agent.set_knowledge(knowledge);
152-
153-
let mut strm = Box::pin(agent.run(vec![Part::Text {
154-
text: "What is Ailoy?".into(),
155-
}]));
156-
while let Some(out) = strm.next().await {
157-
let out = out.unwrap();
158-
println!("{:?}", out);
141+
#[multi_platform_async_trait]
142+
impl KnowledgeBehavior for Knowledge {
143+
async fn retrieve(
144+
&self,
145+
query: String,
146+
top_k: u32,
147+
) -> anyhow::Result<Vec<KnowledgeRetrieveResult>> {
148+
match &self.inner {
149+
KnowledgeInner::VectorStore(knowledge) => knowledge.retrieve(query, top_k).await,
150+
KnowledgeInner::Custom(knowledge) => knowledge.retrieve(query, top_k).await,
159151
}
160-
161-
Ok(())
162152
}
163153
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
use std::{fmt::Debug, sync::Arc};
2+
3+
use ailoy_macros::multi_platform_async_trait;
4+
use futures::future::BoxFuture;
5+
6+
use crate::{
7+
knowledge::{KnowledgeBehavior, KnowledgeRetrieveResult},
8+
utils::{MaybeSend, MaybeSync},
9+
};
10+
11+
#[derive(Clone)]
12+
pub struct CustomKnowledge {
13+
f: Arc<
14+
dyn Fn(String, u32) -> BoxFuture<'static, anyhow::Result<Vec<KnowledgeRetrieveResult>>>
15+
+ MaybeSend
16+
+ MaybeSync,
17+
>,
18+
}
19+
20+
impl Debug for CustomKnowledge {
21+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22+
f.debug_struct("CustomKnowledge")
23+
.field("f", &"function")
24+
.finish()
25+
}
26+
}
27+
28+
impl CustomKnowledge {
29+
pub fn new(
30+
f: Arc<
31+
dyn Fn(String, u32) -> BoxFuture<'static, anyhow::Result<Vec<KnowledgeRetrieveResult>>>
32+
+ MaybeSend
33+
+ MaybeSync,
34+
>,
35+
) -> Self {
36+
Self { f }
37+
}
38+
}
39+
40+
#[multi_platform_async_trait]
41+
impl KnowledgeBehavior for CustomKnowledge {
42+
async fn retrieve(
43+
&self,
44+
query: String,
45+
top_k: u32,
46+
) -> anyhow::Result<Vec<KnowledgeRetrieveResult>> {
47+
(self.f)(query, top_k).await
48+
}
49+
}
50+
51+
#[cfg(test)]
52+
mod tests {
53+
use ailoy_macros::multi_platform_test;
54+
use futures::{FutureExt, stream::StreamExt};
55+
56+
use super::*;
57+
use crate::{agent::Agent, knowledge::Knowledge, model::LangModel, value::Part};
58+
59+
#[multi_platform_test]
60+
async fn test_custom_knowledge_with_agent() -> anyhow::Result<()> {
61+
let knowledge = Knowledge::new_custom(CustomKnowledge::new(Arc::new(|_, _| {
62+
async {
63+
let documents = vec![
64+
KnowledgeRetrieveResult {
65+
document: "Ailoy is an awesome AI agent framework.".into(),
66+
metadata: None,
67+
},
68+
KnowledgeRetrieveResult {
69+
document: "Ailoy supports Python, Javascript and Rust.".into(),
70+
metadata: None,
71+
},
72+
KnowledgeRetrieveResult {
73+
document: "Ailoy enables running LLMs in local environment easily.".into(),
74+
metadata: None,
75+
},
76+
];
77+
Ok(documents)
78+
}
79+
.boxed()
80+
})));
81+
let model = LangModel::try_new_local("Qwen/Qwen3-0.6B").await.unwrap();
82+
let mut agent = Agent::new(model, vec![]);
83+
84+
agent.set_knowledge(knowledge);
85+
86+
let mut strm = Box::pin(agent.run(vec![Part::Text {
87+
text: "What is Ailoy?".into(),
88+
}]));
89+
while let Some(out) = strm.next().await {
90+
let out = out.unwrap();
91+
println!("{:?}", out);
92+
}
93+
94+
Ok(())
95+
}
96+
}

v2/src/knowledge/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
mod base;
2+
mod custom_knowledge;
23
mod vector_store_knowledge;
34

45
pub use base::*;
6+
pub use custom_knowledge::*;
57
pub use vector_store_knowledge::*;

0 commit comments

Comments
 (0)