@@ -32,7 +32,7 @@ class ExtractMissingCellsEvent(Event):
32
32
33
33
34
34
class FindAnswersEvent (Event ):
35
- tool_call : ToolSelection
35
+ missing_cells : list [ MissingCell ]
36
36
37
37
38
38
class FillEvent (Event ):
@@ -90,7 +90,6 @@ def __init__(
90
90
self .query_engine_tool = query_engine_tool
91
91
self .extractor_tool = extractor_tool
92
92
self .filling_tool = filling_tool
93
- self .tools = [self .query_engine_tool , self .extractor_tool , self .filling_tool ]
94
93
self .llm : FunctionCallingLLM = llm or Settings .llm
95
94
if not isinstance (self .llm , FunctionCallingLLM ):
96
95
raise ValueError ("FormFillingWorkflow only supports FunctionCallingLLM." )
@@ -117,11 +116,11 @@ async def start(self, ctx: Context, ev: StartEvent) -> InputEvent:
117
116
return InputEvent (input = chat_history )
118
117
119
118
@step (pass_context = True )
120
- async def handle_llm_input (
119
+ async def handle_llm_input ( # type: ignore
121
120
self ,
122
121
ctx : Context ,
123
122
ev : InputEvent ,
124
- ) -> ExtractMissingCellsEvent | FindAnswersEvent | FillEvent | StopEvent :
123
+ ) -> ExtractMissingCellsEvent | FillEvent | StopEvent :
125
124
"""
126
125
Handle an LLM input and decide the next step.
127
126
"""
@@ -133,21 +132,20 @@ async def handle_llm_input(
133
132
is_tool_call = await generator .__anext__ ()
134
133
if is_tool_call :
135
134
full_response = await generator .__anext__ ()
136
- tool_calls = self .llm .get_tool_calls_from_response (full_response )
135
+ tool_calls = self .llm .get_tool_calls_from_response (full_response ) # type: ignore
137
136
for tool_call in tool_calls :
138
137
if tool_call .tool_name == self .extractor_tool .metadata .get_name ():
139
- return ExtractMissingCellsEvent (tool_call = tool_call )
140
- elif tool_call .tool_name == self .query_engine_tool .metadata .get_name ():
141
- return FindAnswersEvent (tool_call = tool_call )
138
+ ctx .send_event (ExtractMissingCellsEvent (tool_call = tool_call ))
142
139
elif tool_call .tool_name == self .filling_tool .metadata .get_name ():
143
- return FillEvent (tool_call = tool_call )
144
- # If no tool call, return the generator
145
- return StopEvent (result = generator )
140
+ ctx .send_event (FillEvent (tool_call = tool_call ))
141
+ else :
142
+ # If no tool call, return the generator
143
+ return StopEvent (result = generator )
146
144
147
145
@step ()
148
146
async def extract_missing_cells (
149
147
self , ctx : Context , ev : ExtractMissingCellsEvent
150
- ) -> InputEvent :
148
+ ) -> InputEvent | FindAnswersEvent :
151
149
"""
152
150
Extract missing cells in a CSV file and generate questions to fill them.
153
151
"""
@@ -168,7 +166,6 @@ async def extract_missing_cells(
168
166
return InputEvent (input = self .memory .get ())
169
167
170
168
missing_cells = response .raw_output .get ("missing_cells" , [])
171
- ctx .data ["missing_cells" ] = missing_cells
172
169
message = ChatMessage (
173
170
role = MessageRole .TOOL ,
174
171
content = str (missing_cells ),
@@ -179,8 +176,8 @@ async def extract_missing_cells(
179
176
)
180
177
self .memory .put (message )
181
178
182
- # send input event back with updated chat history
183
- return InputEvent ( input = self . memory . get () )
179
+ # Forward missing cells information to find answers step
180
+ return FindAnswersEvent ( missing_cells = missing_cells )
184
181
185
182
@step ()
186
183
async def find_answers (self , ctx : Context , ev : FindAnswersEvent ) -> InputEvent :
@@ -193,7 +190,7 @@ async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
193
190
msg = "Finding answers for missing cells" ,
194
191
)
195
192
)
196
- missing_cells = ctx . data . get ( " missing_cells" , None )
193
+ missing_cells = ev . missing_cells
197
194
# If missing cells information is not found, fallback to other tools
198
195
# It means that the extractor tool has not been called yet
199
196
# Fallback to input
@@ -220,8 +217,7 @@ async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
220
217
# and stream the progress
221
218
progress_id = str (uuid .uuid4 ())
222
219
total_steps = len (missing_cells )
223
- for i , missing_cell in enumerate (missing_cells ):
224
- cell = MissingCell (** missing_cell )
220
+ for i , cell in enumerate (missing_cells ):
225
221
if cell .question_to_answer is None :
226
222
continue
227
223
ctx .write_event_to_stream (
@@ -248,15 +244,12 @@ async def find_answers(self, ctx: Context, ev: FindAnswersEvent) -> InputEvent:
248
244
value = str (answer ),
249
245
)
250
246
)
251
- message = ChatMessage (
252
- role = MessageRole .TOOL ,
253
- content = str (cell_values ),
254
- additional_kwargs = {
255
- "tool_call_id" : ev .tool_call .tool_id ,
256
- "name" : ev .tool_call .tool_name ,
257
- },
247
+ self .memory .put (
248
+ ChatMessage (
249
+ role = MessageRole .ASSISTANT ,
250
+ content = str (cell_values ),
251
+ )
258
252
)
259
- self .memory .put (message )
260
253
return InputEvent (input = self .memory .get ())
261
254
262
255
@step ()
@@ -295,7 +288,8 @@ async def _tool_call_generator(
295
288
self , chat_history : list [ChatMessage ]
296
289
) -> AsyncGenerator [ChatMessage | bool , None ]:
297
290
response_stream = await self .llm .astream_chat_with_tools (
298
- self .tools , chat_history = chat_history
291
+ [self .extractor_tool , self .filling_tool ],
292
+ chat_history = chat_history ,
299
293
)
300
294
301
295
full_response = None
0 commit comments