Skip to content

Commit 842720a

Browse files
committed
change agent type
1 parent ab3cad2 commit 842720a

File tree

5 files changed

+98
-56
lines changed

5 files changed

+98
-56
lines changed

templates/components/agents/python/blog/app/agents/workflow.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33

44
from app.agents.publisher import create_publisher
55
from app.agents.researcher import create_researcher
6-
from app.workflows.single import AgentRunEvent, AgentRunResult, FunctionCallingAgent
6+
from app.workflows.single import (
7+
AgentRunEvent,
8+
AgentRunResult,
9+
AgentTask,
10+
FunctionCallingAgent,
11+
)
712
from llama_index.core.chat_engine.types import ChatMessage
813
from llama_index.core.prompts import PromptTemplate
914
from llama_index.core.settings import Settings
@@ -177,7 +182,9 @@ async def write(
177182
ctx.write_event_to_stream(
178183
AgentRunEvent(
179184
name=writer.name,
180-
msg=f"Too many attempts ({MAX_ATTEMPTS}) to write the blog post. Proceeding with the current version.",
185+
data=AgentTask(
186+
msg=f"Too many attempts ({MAX_ATTEMPTS}) to write the blog post. Proceeding with the current version.",
187+
),
181188
)
182189
)
183190
if ev.is_good or too_many_attempts:
@@ -204,7 +211,9 @@ async def review(
204211
ctx.write_event_to_stream(
205212
AgentRunEvent(
206213
name=reviewer.name,
207-
msg=f"The post is {'not ' if not post_is_good else ''}good enough for publishing. Sending back to the writer{' for publication.' if post_is_good else '.'}",
214+
data=AgentTask(
215+
msg=f"The post is {'not ' if not post_is_good else ''}good enough for publishing. Sending back to the writer{' for publication.' if post_is_good else '.'}",
216+
),
208217
)
209218
)
210219
if post_is_good:
@@ -246,7 +255,9 @@ async def publish(
246255
ctx.write_event_to_stream(
247256
AgentRunEvent(
248257
name=publisher.name,
249-
msg=f"Error publishing: {e}",
258+
data=AgentTask(
259+
msg=f"Error publishing: {e}",
260+
),
250261
)
251262
)
252263
return StopEvent(result=None)

templates/components/agents/python/financial_report/app/agents/workflow.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
from app.agents.analyst import create_analyst
55
from app.agents.reporter import create_reporter
66
from app.agents.researcher import create_researcher
7-
from app.workflows.single import AgentRunEvent, AgentRunResult, FunctionCallingAgent
7+
from app.workflows.single import (
8+
AgentRunEvent,
9+
AgentRunResult,
10+
AgentTask,
11+
FunctionCallingAgent,
12+
)
813
from llama_index.core.chat_engine.types import ChatMessage
914
from llama_index.core.prompts import PromptTemplate
1015
from llama_index.core.settings import Settings
@@ -156,7 +161,9 @@ async def report(
156161
ctx.write_event_to_stream(
157162
AgentRunEvent(
158163
name=reporter.name,
159-
msg=f"Error creating a report: {e}",
164+
data=AgentTask(
165+
msg=f"Error creating a report: {e}",
166+
),
160167
)
161168
)
162169
return StopEvent(result=None)

templates/components/agents/python/form_filling/app/engine/workflow.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
import json
21
import os
32
import uuid
4-
from enum import Enum
53
from typing import AsyncGenerator, List, Optional
64

75
from app.engine.tools.form_filling import CellValue, MissingCell
@@ -19,7 +17,7 @@
1917
step,
2018
)
2119
from llama_index.llms.openai import OpenAI
22-
from pydantic import Field
20+
from pydantic import BaseModel
2321

2422

2523
class InputEvent(Event):
@@ -39,25 +37,20 @@ class FillEvent(Event):
3937
tool_call: ToolSelection
4038

4139

42-
class AgentRunEventType(Enum):
43-
TEXT = "text"
44-
PROGRESS = "progress"
40+
class ProgressData(BaseModel):
41+
current: int
42+
total: int
43+
44+
45+
class AgentTask(BaseModel):
46+
id: str
47+
msg: str
48+
progress: Optional[ProgressData] = None
4549

4650

4751
class AgentRunEvent(Event):
4852
name: str
49-
msg: str
50-
event_type: AgentRunEventType = Field(default=AgentRunEventType.TEXT)
51-
52-
def to_response(self) -> dict:
53-
return {
54-
"type": "agent",
55-
"data": {
56-
"name": self.name,
57-
"event_type": self.event_type.value,
58-
"msg": self.msg,
59-
},
60-
}
53+
data: AgentTask
6154

6255

6356
class FormFillingWorkflow(Workflow):
@@ -156,7 +149,9 @@ async def extract_missing_cells(
156149
ctx.write_event_to_stream(
157150
AgentRunEvent(
158151
name="Extractor",
159-
msg="Extracting missing cells",
152+
data=AgentTask(
153+
msg="Extracting missing cells",
154+
),
160155
)
161156
)
162157
# Call the extract questions tool
@@ -192,7 +187,9 @@ async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
192187
ctx.write_event_to_stream(
193188
AgentRunEvent(
194189
name="Researcher",
195-
msg="Finding answers for missing cells",
190+
data=AgentTask(
191+
msg="Finding answers for missing cells",
192+
),
196193
)
197194
)
198195
missing_cells = ctx.data.get("missing_cells", None)
@@ -203,7 +200,9 @@ async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
203200
ctx.write_event_to_stream(
204201
AgentRunEvent(
205202
name="Researcher",
206-
msg="Error: Missing cells information not found. Fallback to other tools.",
203+
data=AgentTask(
204+
msg="Error: Missing cells information not found. Fallback to other tools.",
205+
),
207206
)
208207
)
209208
message = ChatMessage(
@@ -229,16 +228,14 @@ async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
229228
ctx.write_event_to_stream(
230229
AgentRunEvent(
231230
name="Researcher",
232-
# TODO: Add typing for the progress message
233-
msg=json.dumps(
234-
{
235-
"progress_id": progress_id,
236-
"total_steps": total_steps,
237-
"current_step": i,
238-
"step_msg": f"Querying for: {cell.question_to_answer}",
239-
}
231+
data=AgentTask(
232+
id=progress_id,
233+
msg=f"Querying for: {cell.question_to_answer}",
234+
progress=ProgressData(
235+
current=i,
236+
total=total_steps,
237+
),
240238
),
241-
event_type=AgentRunEventType.PROGRESS,
242239
)
243240
)
244241
# Call query engine tool directly
@@ -269,7 +266,9 @@ async def fill_cells(self, ctx: Context, ev: FillEvent) -> InputEvent:
269266
ctx.write_event_to_stream(
270267
AgentRunEvent(
271268
name="Processor",
272-
msg="Filling missing cells",
269+
data=AgentTask(
270+
msg="Filling missing cells",
271+
),
273272
)
274273
)
275274
# Call the fill cells tool
@@ -341,7 +340,9 @@ def _call_tool(
341340
ctx.write_event_to_stream(
342341
AgentRunEvent(
343342
name=agent_name,
344-
msg=f"Error: {str(e)}",
343+
data=AgentTask(
344+
msg=f"Error: {str(e)}",
345+
),
345346
)
346347
)
347348
message = ChatMessage(

templates/components/multiagent/python/app/workflows/single.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
import uuid
12
from abc import abstractmethod
2-
from enum import Enum
33
from typing import Any, AsyncGenerator, List, Optional
44

55
from llama_index.core.llms import ChatMessage, ChatResponse
@@ -27,25 +27,23 @@ class ToolCallEvent(Event):
2727
tool_calls: list[ToolSelection]
2828

2929

30-
class AgentRunEventType(Enum):
31-
TEXT = "text"
32-
PROGRESS = "progress"
30+
class ProgressData(BaseModel):
31+
current: int
32+
total: int
33+
34+
35+
class AgentTask(BaseModel):
36+
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
37+
msg: str
38+
progress: Optional[ProgressData] = None
3339

3440

3541
class AgentRunEvent(Event):
3642
name: str
37-
msg: str
38-
event_type: AgentRunEventType = Field(default=AgentRunEventType.TEXT)
43+
data: AgentTask
3944

4045
def to_response(self) -> dict:
41-
return {
42-
"type": "agent",
43-
"data": {
44-
"name": self.name,
45-
"event_type": self.event_type.value,
46-
"msg": self.msg,
47-
},
48-
}
46+
return self.model_dump()
4947

5048

5149
class AgentRunResult(BaseModel):
@@ -111,7 +109,12 @@ async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent
111109
self.memory.put(user_msg)
112110
if self.write_events:
113111
ctx.write_event_to_stream(
114-
AgentRunEvent(name=self.name, msg=f"Start to work on: {user_input}")
112+
AgentRunEvent(
113+
name=self.name,
114+
data=AgentTask(
115+
msg=f"Start to work on: {user_input}",
116+
),
117+
)
115118
)
116119

117120
# get chat history
@@ -139,7 +142,12 @@ async def handle_llm_input(
139142
if not tool_calls:
140143
if self.write_events:
141144
ctx.write_event_to_stream(
142-
AgentRunEvent(name=self.name, msg="Finished task")
145+
AgentRunEvent(
146+
name=self.name,
147+
data=AgentTask(
148+
msg="Finished task",
149+
),
150+
)
143151
)
144152
return StopEvent(
145153
result=AgentRunResult(response=response, sources=[*self.sources])
@@ -194,7 +202,12 @@ async def response_generator() -> AsyncGenerator:
194202
# If we've reached here, it's not an immediate tool call, so we return the generator
195203
if self.write_events:
196204
ctx.write_event_to_stream(
197-
AgentRunEvent(name=self.name, msg="Finished task")
205+
AgentRunEvent(
206+
name=self.name,
207+
data=AgentTask(
208+
msg="Finished task",
209+
),
210+
)
198211
)
199212
return StopEvent(result=generator)
200213

templates/types/streaming/nextjs/app/components/ui/chat/index.ts

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,18 @@ export type EventData = {
5858

5959
export type AgentEventData = {
6060
name: string;
61+
data: AgentTask;
62+
};
63+
64+
type AgentTask = {
65+
id: string;
6166
msg: string;
62-
event_type: "text" | "progress";
67+
progress?: ProgressData;
68+
};
69+
70+
type ProgressData = {
71+
current: number;
72+
total: number;
6373
};
6474

6575
export type ToolData = {

0 commit comments

Comments
 (0)