Skip to content

Commit e9f36fe

Browse files
committed
Fix workflow issue in tool calling
1 parent 7788c89 commit e9f36fe

File tree

1 file changed

+15
-23
lines changed
  • templates/components/agents/python/form_filling/app/engine

1 file changed

+15
-23
lines changed

templates/components/agents/python/form_filling/app/engine/workflow.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class ExtractMissingCellsEvent(Event):
3232

3333

3434
class FindAnswersEvent(Event):
35-
tool_call: ToolSelection
35+
missing_cells: list[MissingCell]
3636

3737

3838
class FillEvent(Event):
@@ -90,8 +90,7 @@ def __init__(
9090
self.query_engine_tool = query_engine_tool
9191
self.extractor_tool = extractor_tool
9292
self.filling_tool = filling_tool
93-
self.tools = [self.query_engine_tool, self.extractor_tool, self.filling_tool]
94-
self.llm: FunctionCallingLLM = llm or Settings.llm
93+
self.llm = llm or Settings.llm
9594
if not isinstance(self.llm, FunctionCallingLLM):
9695
raise ValueError("FormFillingWorkflow only supports FunctionCallingLLM.")
9796
self.memory = ChatMemoryBuffer.from_defaults(
@@ -121,7 +120,7 @@ async def handle_llm_input(
121120
self,
122121
ctx: Context,
123122
ev: InputEvent,
124-
) -> ExtractMissingCellsEvent | FindAnswersEvent | FillEvent | StopEvent:
123+
) -> ExtractMissingCellsEvent | FillEvent | StopEvent:
125124
"""
126125
Handle an LLM input and decide the next step.
127126
"""
@@ -137,8 +136,6 @@ async def handle_llm_input(
137136
for tool_call in tool_calls:
138137
if tool_call.tool_name == self.extractor_tool.metadata.get_name():
139138
return ExtractMissingCellsEvent(tool_call=tool_call)
140-
elif tool_call.tool_name == self.query_engine_tool.metadata.get_name():
141-
return FindAnswersEvent(tool_call=tool_call)
142139
elif tool_call.tool_name == self.filling_tool.metadata.get_name():
143140
return FillEvent(tool_call=tool_call)
144141
# If no tool call, return the generator
@@ -147,7 +144,7 @@ async def handle_llm_input(
147144
@step()
148145
async def extract_missing_cells(
149146
self, ctx: Context, ev: ExtractMissingCellsEvent
150-
) -> InputEvent:
147+
) -> InputEvent | FindAnswersEvent:
151148
"""
152149
Extract missing cells in a CSV file and generate questions to fill them.
153150
"""
@@ -168,7 +165,6 @@ async def extract_missing_cells(
168165
return InputEvent(input=self.memory.get())
169166

170167
missing_cells = response.raw_output.get("missing_cells", [])
171-
ctx.data["missing_cells"] = missing_cells
172168
message = ChatMessage(
173169
role=MessageRole.TOOL,
174170
content=str(missing_cells),
@@ -179,8 +175,8 @@ async def extract_missing_cells(
179175
)
180176
self.memory.put(message)
181177

182-
# send input event back with updated chat history
183-
return InputEvent(input=self.memory.get())
178+
# Forward missing cells information to find answers step
179+
return FindAnswersEvent(missing_cells=missing_cells)
184180

185181
@step()
186182
async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
@@ -193,7 +189,7 @@ async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
193189
msg="Finding answers for missing cells",
194190
)
195191
)
196-
missing_cells = ctx.data.get("missing_cells", None)
192+
missing_cells = ev.missing_cells
197193
# If missing cells information is not found, fallback to other tools
198194
# It means that the extractor tool has not been called yet
199195
# Fallback to input
@@ -220,8 +216,7 @@ async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
220216
# and stream the progress
221217
progress_id = str(uuid.uuid4())
222218
total_steps = len(missing_cells)
223-
for i, missing_cell in enumerate(missing_cells):
224-
cell = MissingCell(**missing_cell)
219+
for i, cell in enumerate(missing_cells):
225220
if cell.question_to_answer is None:
226221
continue
227222
ctx.write_event_to_stream(
@@ -248,15 +243,12 @@ async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
248243
value=str(answer),
249244
)
250245
)
251-
message = ChatMessage(
252-
role=MessageRole.TOOL,
253-
content=str(cell_values),
254-
additional_kwargs={
255-
"tool_call_id": ev.tool_call.tool_id,
256-
"name": ev.tool_call.tool_name,
257-
},
246+
self.memory.put(
247+
ChatMessage(
248+
role=MessageRole.ASSISTANT,
249+
content=str(cell_values),
250+
)
258251
)
259-
self.memory.put(message)
260252
return InputEvent(input=self.memory.get())
261253

262254
@step()
@@ -295,7 +287,8 @@ async def _tool_call_generator(
295287
self, chat_history: list[ChatMessage]
296288
) -> AsyncGenerator[ChatMessage | bool, None]:
297289
response_stream = await self.llm.astream_chat_with_tools(
298-
self.tools, chat_history=chat_history
290+
[self.extractor_tool, self.filling_tool],
291+
chat_history=chat_history,
299292
)
300293

301294
full_response = None
@@ -321,7 +314,6 @@ async def _tool_call_generator(
321314
self.memory.put(full_response.message)
322315
yield full_response
323316

324-
# TODO: Implement a _acall_tool method
325317
def _call_tool(
326318
self,
327319
ctx: Context,

0 commit comments

Comments
 (0)