@@ -32,7 +32,7 @@ class ExtractMissingCellsEvent(Event):
32
32
33
33
34
34
class FindAnswersEvent (Event ):
35
- tool_call : ToolSelection
35
+ missing_cells : list [ MissingCell ]
36
36
37
37
38
38
class FillEvent (Event ):
@@ -90,8 +90,7 @@ def __init__(
90
90
self .query_engine_tool = query_engine_tool
91
91
self .extractor_tool = extractor_tool
92
92
self .filling_tool = filling_tool
93
- self .tools = [self .query_engine_tool , self .extractor_tool , self .filling_tool ]
94
- self .llm : FunctionCallingLLM = llm or Settings .llm
93
+ self .llm = llm or Settings .llm
95
94
if not isinstance (self .llm , FunctionCallingLLM ):
96
95
raise ValueError ("FormFillingWorkflow only supports FunctionCallingLLM." )
97
96
self .memory = ChatMemoryBuffer .from_defaults (
@@ -121,7 +120,7 @@ async def handle_llm_input(
121
120
self ,
122
121
ctx : Context ,
123
122
ev : InputEvent ,
124
- ) -> ExtractMissingCellsEvent | FindAnswersEvent | FillEvent | StopEvent :
123
+ ) -> ExtractMissingCellsEvent | FillEvent | StopEvent :
125
124
"""
126
125
Handle an LLM input and decide the next step.
127
126
"""
@@ -137,8 +136,6 @@ async def handle_llm_input(
137
136
for tool_call in tool_calls :
138
137
if tool_call .tool_name == self .extractor_tool .metadata .get_name ():
139
138
return ExtractMissingCellsEvent (tool_call = tool_call )
140
- elif tool_call .tool_name == self .query_engine_tool .metadata .get_name ():
141
- return FindAnswersEvent (tool_call = tool_call )
142
139
elif tool_call .tool_name == self .filling_tool .metadata .get_name ():
143
140
return FillEvent (tool_call = tool_call )
144
141
# If no tool call, return the generator
@@ -147,7 +144,7 @@ async def handle_llm_input(
147
144
@step ()
148
145
async def extract_missing_cells (
149
146
self , ctx : Context , ev : ExtractMissingCellsEvent
150
- ) -> InputEvent :
147
+ ) -> InputEvent | FindAnswersEvent :
151
148
"""
152
149
Extract missing cells in a CSV file and generate questions to fill them.
153
150
"""
@@ -168,7 +165,6 @@ async def extract_missing_cells(
168
165
return InputEvent (input = self .memory .get ())
169
166
170
167
missing_cells = response .raw_output .get ("missing_cells" , [])
171
- ctx .data ["missing_cells" ] = missing_cells
172
168
message = ChatMessage (
173
169
role = MessageRole .TOOL ,
174
170
content = str (missing_cells ),
@@ -179,8 +175,8 @@ async def extract_missing_cells(
179
175
)
180
176
self .memory .put (message )
181
177
182
- # send input event back with updated chat history
183
- return InputEvent ( input = self . memory . get () )
178
+ # Forward missing cells information to find answers step
179
+ return FindAnswersEvent ( missing_cells = missing_cells )
184
180
185
181
@step ()
186
182
async def find_answers (self , ctx : Context , ev : FindAnswersEvent ) -> InputEvent :
@@ -193,7 +189,7 @@ async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
193
189
msg = "Finding answers for missing cells" ,
194
190
)
195
191
)
196
- missing_cells = ctx . data . get ( " missing_cells" , None )
192
+ missing_cells = ev . missing_cells
197
193
# If missing cells information is not found, fallback to other tools
198
194
# It means that the extractor tool has not been called yet
199
195
# Fallback to input
@@ -220,8 +216,7 @@ async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
220
216
# and stream the progress
221
217
progress_id = str (uuid .uuid4 ())
222
218
total_steps = len (missing_cells )
223
- for i , missing_cell in enumerate (missing_cells ):
224
- cell = MissingCell (** missing_cell )
219
+ for i , cell in enumerate (missing_cells ):
225
220
if cell .question_to_answer is None :
226
221
continue
227
222
ctx .write_event_to_stream (
@@ -248,15 +243,12 @@ async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
248
243
value = str (answer ),
249
244
)
250
245
)
251
- message = ChatMessage (
252
- role = MessageRole .TOOL ,
253
- content = str (cell_values ),
254
- additional_kwargs = {
255
- "tool_call_id" : ev .tool_call .tool_id ,
256
- "name" : ev .tool_call .tool_name ,
257
- },
246
+ self .memory .put (
247
+ ChatMessage (
248
+ role = MessageRole .ASSISTANT ,
249
+ content = str (cell_values ),
250
+ )
258
251
)
259
- self .memory .put (message )
260
252
return InputEvent (input = self .memory .get ())
261
253
262
254
@step ()
@@ -295,7 +287,8 @@ async def _tool_call_generator(
295
287
self , chat_history : list [ChatMessage ]
296
288
) -> AsyncGenerator [ChatMessage | bool , None ]:
297
289
response_stream = await self .llm .astream_chat_with_tools (
298
- self .tools , chat_history = chat_history
290
+ [self .extractor_tool , self .filling_tool ],
291
+ chat_history = chat_history ,
299
292
)
300
293
301
294
full_response = None
@@ -321,7 +314,6 @@ async def _tool_call_generator(
321
314
self .memory .put (full_response .message )
322
315
yield full_response
323
316
324
- # TODO: Implement a _acall_tool method
325
317
def _call_tool (
326
318
self ,
327
319
ctx : Context ,
0 commit comments