Skip to content

Commit effc31d

Browse files
committed
Fix workflow issue in tool calling
1 parent 7788c89 commit effc31d

File tree

2 files changed

+30
-30
lines changed

2 files changed

+30
-30
lines changed

templates/components/agents/python/form_filling/app/engine/engine.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,17 @@ def get_chat_engine(
2222
query_engine = index.as_query_engine(similarity_top_k=top_k)
2323
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
2424

25-
configured_tools = ToolFactory.from_env(map_result=True) # type: dict[str, list[Any]]
25+
configured_tools = ToolFactory.from_env(map_result=True)
26+
extractor_tool = configured_tools.get("extract_questions")
27+
filling_tool = configured_tools.get("fill_form")
28+
29+
if extractor_tool is None or filling_tool is None:
30+
raise ValueError("Extractor or filling tool is not found!")
31+
2632
workflow = FormFillingWorkflow(
2733
query_engine_tool=query_engine_tool,
28-
extractor_tool=configured_tools.get("extract_questions"),
29-
filling_tool=configured_tools.get("fill_form"),
34+
extractor_tool=extractor_tool,
35+
filling_tool=filling_tool,
3036
chat_history=chat_history,
3137
)
3238

templates/components/agents/python/form_filling/app/engine/workflow.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class ExtractMissingCellsEvent(Event):
3232

3333

3434
class FindAnswersEvent(Event):
35-
tool_call: ToolSelection
35+
missing_cells: list[MissingCell]
3636

3737

3838
class FillEvent(Event):
@@ -90,7 +90,6 @@ def __init__(
9090
self.query_engine_tool = query_engine_tool
9191
self.extractor_tool = extractor_tool
9292
self.filling_tool = filling_tool
93-
self.tools = [self.query_engine_tool, self.extractor_tool, self.filling_tool]
9493
self.llm: FunctionCallingLLM = llm or Settings.llm
9594
if not isinstance(self.llm, FunctionCallingLLM):
9695
raise ValueError("FormFillingWorkflow only supports FunctionCallingLLM.")
@@ -117,11 +116,11 @@ async def start(self, ctx: Context, ev: StartEvent) -> InputEvent:
117116
return InputEvent(input=chat_history)
118117

119118
@step(pass_context=True)
120-
async def handle_llm_input(
119+
async def handle_llm_input( # type: ignore
121120
self,
122121
ctx: Context,
123122
ev: InputEvent,
124-
) -> ExtractMissingCellsEvent | FindAnswersEvent | FillEvent | StopEvent:
123+
) -> ExtractMissingCellsEvent | FillEvent | StopEvent:
125124
"""
126125
Handle an LLM input and decide the next step.
127126
"""
@@ -133,21 +132,20 @@ async def handle_llm_input(
133132
is_tool_call = await generator.__anext__()
134133
if is_tool_call:
135134
full_response = await generator.__anext__()
136-
tool_calls = self.llm.get_tool_calls_from_response(full_response)
135+
tool_calls = self.llm.get_tool_calls_from_response(full_response) # type: ignore
137136
for tool_call in tool_calls:
138137
if tool_call.tool_name == self.extractor_tool.metadata.get_name():
139-
return ExtractMissingCellsEvent(tool_call=tool_call)
140-
elif tool_call.tool_name == self.query_engine_tool.metadata.get_name():
141-
return FindAnswersEvent(tool_call=tool_call)
138+
ctx.send_event(ExtractMissingCellsEvent(tool_call=tool_call))
142139
elif tool_call.tool_name == self.filling_tool.metadata.get_name():
143-
return FillEvent(tool_call=tool_call)
144-
# If no tool call, return the generator
145-
return StopEvent(result=generator)
140+
ctx.send_event(FillEvent(tool_call=tool_call))
141+
else:
142+
# If no tool call, return the generator
143+
return StopEvent(result=generator)
146144

147145
@step()
148146
async def extract_missing_cells(
149147
self, ctx: Context, ev: ExtractMissingCellsEvent
150-
) -> InputEvent:
148+
) -> InputEvent | FindAnswersEvent:
151149
"""
152150
Extract missing cells in a CSV file and generate questions to fill them.
153151
"""
@@ -168,7 +166,6 @@ async def extract_missing_cells(
168166
return InputEvent(input=self.memory.get())
169167

170168
missing_cells = response.raw_output.get("missing_cells", [])
171-
ctx.data["missing_cells"] = missing_cells
172169
message = ChatMessage(
173170
role=MessageRole.TOOL,
174171
content=str(missing_cells),
@@ -179,8 +176,8 @@ async def extract_missing_cells(
179176
)
180177
self.memory.put(message)
181178

182-
# send input event back with updated chat history
183-
return InputEvent(input=self.memory.get())
179+
# Forward missing cells information to find answers step
180+
return FindAnswersEvent(missing_cells=missing_cells)
184181

185182
@step()
186183
async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
@@ -193,7 +190,7 @@ async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
193190
msg="Finding answers for missing cells",
194191
)
195192
)
196-
missing_cells = ctx.data.get("missing_cells", None)
193+
missing_cells = ev.missing_cells
197194
# If missing cells information is not found, fallback to other tools
198195
# It means that the extractor tool has not been called yet
199196
# Fallback to input
@@ -220,8 +217,7 @@ async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
220217
# and stream the progress
221218
progress_id = str(uuid.uuid4())
222219
total_steps = len(missing_cells)
223-
for i, missing_cell in enumerate(missing_cells):
224-
cell = MissingCell(**missing_cell)
220+
for i, cell in enumerate(missing_cells):
225221
if cell.question_to_answer is None:
226222
continue
227223
ctx.write_event_to_stream(
@@ -248,15 +244,12 @@ async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
248244
value=str(answer),
249245
)
250246
)
251-
message = ChatMessage(
252-
role=MessageRole.TOOL,
253-
content=str(cell_values),
254-
additional_kwargs={
255-
"tool_call_id": ev.tool_call.tool_id,
256-
"name": ev.tool_call.tool_name,
257-
},
247+
self.memory.put(
248+
ChatMessage(
249+
role=MessageRole.ASSISTANT,
250+
content=str(cell_values),
251+
)
258252
)
259-
self.memory.put(message)
260253
return InputEvent(input=self.memory.get())
261254

262255
@step()
@@ -295,7 +288,8 @@ async def _tool_call_generator(
295288
self, chat_history: list[ChatMessage]
296289
) -> AsyncGenerator[ChatMessage | bool, None]:
297290
response_stream = await self.llm.astream_chat_with_tools(
298-
self.tools, chat_history=chat_history
291+
[self.extractor_tool, self.filling_tool],
292+
chat_history=chat_history,
299293
)
300294

301295
full_response = None

0 commit comments

Comments
 (0)