|
5 | 5 |
|
6 | 6 | package org.opensearch.ml.engine.algorithms.agent; |
7 | 7 |
|
8 | | -import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_ARGUMENTS; |
9 | 8 | import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_CONTENT; |
10 | | -import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_FUNCTION; |
11 | 9 | import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_ID; |
12 | | -import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_NAME; |
13 | 10 | import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_ROLE; |
14 | 11 | import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TOOL_CALLS; |
15 | 12 | import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TOOL_CALL_ID; |
16 | | -import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TYPE; |
17 | 13 | import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_ASSISTANT_TOOL_CALL_MESSAGES; |
18 | 14 | import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_CONTEXT; |
19 | 15 | import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_MESSAGES; |
@@ -160,16 +156,8 @@ private void processAgentResult(Object result, AGUIEventCollector eventCollector |
160 | 156 | ModelTensorOutput tensorOutput = (ModelTensorOutput) result; |
161 | 157 | // Extract tool calls and text responses from the tensor output |
162 | 158 | processTensorOutput(tensorOutput, eventCollector); |
163 | | - } else if (result instanceof String) { |
164 | | - String resultString = (String) result; |
165 | | - // Check if this is a frontend tool call response |
166 | | - if (resultString.startsWith("FRONTEND_TOOL_CALL: ")) { |
167 | | - log.debug("AG-UI: Detected frontend tool call response, processing..."); |
168 | | - processFrontendToolCall(resultString, eventCollector); |
169 | | - } else { |
170 | | - log.debug("AG-UI: String result is not a frontend tool call"); |
171 | | - } |
172 | 159 | } |
| 160 | + |
173 | 161 | List<Object> messages = new ArrayList<>(); |
174 | 162 | String responseText = extractResponseText(result); |
175 | 163 | messages.add(Map.of(AGUI_FIELD_ID, messageId, AGUI_FIELD_ROLE, AGUI_ROLE_ASSISTANT, AGUI_FIELD_CONTENT, responseText)); |
@@ -226,10 +214,6 @@ private void processModelTensor(ModelTensor tensor, AGUIEventCollector eventColl |
226 | 214 | Map<String, ?> dataMap = tensor.getDataAsMap(); |
227 | 215 | if (dataMap != null) { |
228 | 216 | processToolCallsFromDataMap(dataMap, eventCollector); |
229 | | - } else if (tensor.getResult() != null) { |
230 | | - // Handle text result that might contain tool call information |
231 | | - String result = tensor.getResult(); |
232 | | - processTextResponseForToolCalls(result, eventCollector); |
233 | 217 | } |
234 | 218 | } |
235 | 219 | } |
@@ -283,80 +267,6 @@ private void processToolCallsFromDataMap(Map<String, ?> dataMap, AGUIEventCollec |
283 | 267 | } |
284 | 268 | } |
285 | 269 |
|
286 | | - private void processFrontendToolCall(String frontendToolCallResponse, AGUIEventCollector eventCollector) { |
287 | | - log.debug("AG-UI: Processing frontend tool call response: {}", frontendToolCallResponse); |
288 | | - try { |
289 | | - // Extract the JSON part after "FRONTEND_TOOL_CALL: " |
290 | | - String jsonPart = frontendToolCallResponse.substring("FRONTEND_TOOL_CALL: ".length()); |
291 | | - log.debug("AG-UI: Extracted JSON part: {}", jsonPart); |
292 | | - |
293 | | - JsonElement element = gson.fromJson(jsonPart, JsonElement.class); |
294 | | - |
295 | | - if (element.isJsonObject()) { |
296 | | - JsonObject toolCallObj = element.getAsJsonObject(); |
297 | | - String toolName = toolCallObj.get("tool").getAsString(); |
298 | | - String toolInput = toolCallObj.get("input").getAsString(); |
299 | | - |
300 | | - log.debug("AG-UI: Processing frontend tool call - tool: {}, input: {}", toolName, toolInput); |
301 | | - |
302 | | - // Generate AG-UI events for the frontend tool call |
303 | | - String toolCallId = eventCollector.startToolCall(toolName, null); |
304 | | - eventCollector.addToolCallArgs(toolCallId, toolInput); |
305 | | - eventCollector.endToolCall(toolCallId); |
306 | | - } else { |
307 | | - log.warn("AG-UI: JSON element is not an object: {}", element); |
308 | | - } |
309 | | - } catch (Exception e) { |
310 | | - log.error("Failed to process frontend tool call response: {}", frontendToolCallResponse, e); |
311 | | - } |
312 | | - } |
313 | | - |
314 | | - private void processTextResponseForToolCalls(String result, AGUIEventCollector eventCollector) { |
315 | | - // Try to parse JSON response that might contain tool calls |
316 | | - try { |
317 | | - JsonElement element = gson.fromJson(result, JsonElement.class); |
318 | | - if (element.isJsonObject()) { |
319 | | - JsonObject obj = element.getAsJsonObject(); |
320 | | - if (obj.has("tool_calls")) { |
321 | | - JsonElement toolCallsElement = obj.get("tool_calls"); |
322 | | - if (toolCallsElement.isJsonArray()) { |
323 | | - for (JsonElement toolCallElement : toolCallsElement.getAsJsonArray()) { |
324 | | - if (toolCallElement.isJsonObject()) { |
325 | | - JsonObject toolCall = toolCallElement.getAsJsonObject(); |
326 | | - String toolCallId = getStringField(toolCall, "id"); |
327 | | - JsonElement functionElement = toolCall.get("function"); |
328 | | - |
329 | | - if (functionElement != null && functionElement.isJsonObject()) { |
330 | | - JsonObject function = functionElement.getAsJsonObject(); |
331 | | - String toolName = getStringField(function, "name"); |
332 | | - String arguments = getStringField(function, "arguments"); |
333 | | - |
334 | | - if (toolCallId != null && toolName != null) { |
335 | | - eventCollector.startToolCall(toolName, null); |
336 | | - |
337 | | - if (arguments != null && !arguments.isEmpty()) { |
338 | | - eventCollector.addToolCallArgs(toolCallId, arguments); |
339 | | - } |
340 | | - |
341 | | - eventCollector.endToolCall(toolCallId); |
342 | | - log |
343 | | - .debug( |
344 | | - "AG-UI: Generated tool call events from text response for tool={}, id={}", |
345 | | - toolName, |
346 | | - toolCallId |
347 | | - ); |
348 | | - } |
349 | | - } |
350 | | - } |
351 | | - } |
352 | | - } |
353 | | - } |
354 | | - } |
355 | | - } catch (Exception e) { |
356 | | - // Not a tool call response, just regular text - no special processing needed |
357 | | - } |
358 | | - } |
359 | | - |
360 | 270 | private void processAGUIMessages(MLAgent mlAgent, Map<String, String> params, String llmInterface) { |
361 | 271 | String aguiMessagesJson = params.get(AGUI_PARAM_MESSAGES); |
362 | 272 | if (aguiMessagesJson == null || aguiMessagesJson.isEmpty()) { |
@@ -415,78 +325,27 @@ private void processAGUIMessages(MLAgent mlAgent, Map<String, String> params, St |
415 | 325 | if (AGUI_ROLE_ASSISTANT.equals(role) && message.has(AGUI_FIELD_TOOL_CALLS)) { |
416 | 326 | toolCallMessageIndices.add(i); |
417 | 327 |
|
418 | | - // Convert to OpenAI format for interactions |
| 328 | + // Extract tool calls from AG-UI message (AG-UI uses OpenAI-compatible format) |
419 | 329 | JsonElement toolCallsElement = message.get(AGUI_FIELD_TOOL_CALLS); |
420 | 330 | if (toolCallsElement != null && toolCallsElement.isJsonArray()) { |
421 | | - List<Map<String, Object>> toolCalls = new ArrayList<>(); |
422 | | - for (JsonElement tcElement : toolCallsElement.getAsJsonArray()) { |
423 | | - if (tcElement.isJsonObject()) { |
424 | | - JsonObject tc = tcElement.getAsJsonObject(); |
425 | | - Map<String, Object> toolCall = new HashMap<>(); |
426 | | - |
427 | | - // OpenAI format: id, type, and function at the same level |
428 | | - String toolCallId = getStringField(tc, AGUI_FIELD_ID); |
429 | | - String toolCallType = getStringField(tc, AGUI_FIELD_TYPE); |
430 | | - |
431 | | - toolCall.put(AGUI_FIELD_ID, toolCallId); |
432 | | - toolCall.put(AGUI_FIELD_TYPE, toolCallType != null ? toolCallType : "function"); |
433 | | - |
434 | | - JsonElement functionElement = tc.get(AGUI_FIELD_FUNCTION); |
435 | | - if (functionElement != null && functionElement.isJsonObject()) { |
436 | | - JsonObject func = functionElement.getAsJsonObject(); |
437 | | - Map<String, String> function = new HashMap<>(); |
438 | | - function.put(AGUI_FIELD_NAME, getStringField(func, AGUI_FIELD_NAME)); |
439 | | - function.put(AGUI_FIELD_ARGUMENTS, getStringField(func, AGUI_FIELD_ARGUMENTS)); |
440 | | - toolCall.put(AGUI_FIELD_FUNCTION, function); |
441 | | - } |
442 | | - toolCalls.add(toolCall); |
443 | | - } |
444 | | - } |
| 331 | + // Pass the JSON array directly to FunctionCalling for format conversion |
| 332 | + String toolCallsJson = gson.toJson(toolCallsElement); |
445 | 333 |
|
446 | | - // Create assistant message in the appropriate format based on LLM interface |
| 334 | + FunctionCalling functionCalling = FunctionCallingFactory.create(llmInterface); |
447 | 335 | String assistantMessage; |
448 | | - boolean isBedrockConverse = llmInterface != null && llmInterface.toLowerCase().contains("bedrock"); |
449 | | - |
450 | | - if (isBedrockConverse) { |
451 | | - // Bedrock format: {"role": "assistant", "content": [{"toolUse": {...}}]} |
452 | | - List<Map<String, Object>> contentBlocks = new ArrayList<>(); |
453 | | - for (Map<String, Object> toolCall : toolCalls) { |
454 | | - Map<String, Object> toolUse = new HashMap<>(); |
455 | | - toolUse.put("toolUseId", toolCall.get("id")); |
456 | | - |
457 | | - Map<String, Object> function = (Map<String, Object>) toolCall.get("function"); |
458 | | - if (function != null) { |
459 | | - toolUse.put("name", function.get("name")); |
460 | | - |
461 | | - // Parse arguments JSON string to object |
462 | | - String argumentsJson = (String) function.get("arguments"); |
463 | | - try { |
464 | | - Object argumentsObj = gson.fromJson(argumentsJson, Object.class); |
465 | | - toolUse.put("input", argumentsObj); |
466 | | - } catch (Exception e) { |
467 | | - log.warn("AG-UI: Failed to parse tool arguments as JSON: {}", argumentsJson, e); |
468 | | - toolUse.put("input", Map.of()); |
469 | | - } |
470 | | - } |
471 | | - |
472 | | - contentBlocks.add(Map.of("toolUse", toolUse)); |
473 | | - } |
474 | 336 |
|
475 | | - Map<String, Object> bedrockMsg = new HashMap<>(); |
476 | | - bedrockMsg.put(AGUI_FIELD_ROLE, AGUI_ROLE_ASSISTANT); |
477 | | - bedrockMsg.put(AGUI_FIELD_CONTENT, contentBlocks); |
478 | | - assistantMessage = gson.toJson(bedrockMsg); |
| 337 | + if (functionCalling != null) { |
| 338 | + // Use FunctionCalling to format the message in the correct LLM format |
| 339 | + assistantMessage = functionCalling.formatAGUIToolCalls(toolCallsJson); |
| 340 | + log.debug("AG-UI: Formatted assistant message using {}", functionCalling.getClass().getSimpleName()); |
479 | 341 | } else { |
480 | | - // OpenAI format: {"role": "assistant", "tool_calls": [...]} |
481 | | - Map<String, Object> assistantMsg = new HashMap<>(); |
482 | | - assistantMsg.put(AGUI_FIELD_ROLE, AGUI_ROLE_ASSISTANT); |
483 | | - assistantMsg.put("tool_calls", toolCalls); |
484 | | - assistantMessage = gson.toJson(assistantMsg); |
485 | | - log.debug("AG-UI: Created OpenAI-format assistant message with {} tool calls", toolCalls.size()); |
| 342 | + // Fallback to OpenAI format if no FunctionCalling available |
| 343 | + assistantMessage = "{\"role\":\"assistant\",\"tool_calls\":" + toolCallsJson + "}"; |
| 344 | + log.debug("AG-UI: Created OpenAI-format assistant message (fallback)"); |
486 | 345 | } |
487 | 346 |
|
488 | 347 | assistantToolCallMessages.add(assistantMessage); |
489 | | - log.debug("AG-UI: Extracted assistant message with {} tool calls at index {}", toolCalls.size(), i); |
| 348 | + log.debug("AG-UI: Extracted assistant message at index {}", i); |
490 | 349 | log.debug("AG-UI: Assistant message JSON: {}", assistantMessage); |
491 | 350 | } |
492 | 351 | } |
|
0 commit comments