Skip to content

Commit 9a7f887

Browse files
committed
add to function calling
Signed-off-by: Jiaping Zeng <[email protected]>
1 parent 8bd5e51 commit 9a7f887

File tree

8 files changed

+228
-396
lines changed

8 files changed

+228
-396
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.engine.algorithms.agent;
7+
8+
import java.util.Map;
9+
10+
import org.opensearch.core.action.ActionListener;
11+
import org.opensearch.ml.common.spi.tools.Tool;
12+
13+
import lombok.extern.log4j.Log4j2;
14+
15+
/**
16+
* Placeholder tool for AG-UI frontend tools.
17+
* Frontend tools are not executed on the backend - they are executed in the browser.
18+
* This placeholder allows the LLM to see frontend tools in the unified tool list.
19+
*/
20+
@Log4j2
21+
public class AGUIFrontendTool implements Tool {
22+
private final String toolName;
23+
private final String toolDescription;
24+
private final Map<String, Object> toolAttributes;
25+
26+
public AGUIFrontendTool(String toolName, String toolDescription, Map<String, Object> toolAttributes) {
27+
this.toolName = toolName;
28+
this.toolDescription = toolDescription;
29+
this.toolAttributes = toolAttributes;
30+
}
31+
32+
@Override
33+
public String getName() {
34+
return toolName;
35+
}
36+
37+
@Override
38+
public void setName(String name) {}
39+
40+
@Override
41+
public String getDescription() {
42+
return toolDescription;
43+
}
44+
45+
@Override
46+
public void setDescription(String description) {}
47+
48+
@Override
49+
public Map<String, Object> getAttributes() {
50+
return toolAttributes;
51+
}
52+
53+
@Override
54+
public void setAttributes(Map<String, Object> attributes) {}
55+
56+
@Override
57+
@SuppressWarnings("unchecked")
58+
public <T> void run(Map<String, String> parameters, ActionListener<T> listener) {
59+
log.debug("AG-UI: Frontend tool {} executed with parameters: {}", toolName, parameters);
60+
String errorResult = String
61+
.format(
62+
"Error: Tool '%s' is a frontend tool and should be called via function calling in the final response, "
63+
+ "not during ReAct execution.",
64+
toolName
65+
);
66+
listener.onResponse((T) errorResult);
67+
}
68+
69+
@Override
70+
public boolean validate(Map<String, String> parameters) {
71+
return true;
72+
}
73+
74+
@Override
75+
public String getType() {
76+
return "AGUIFrontendTool";
77+
}
78+
79+
@Override
80+
public String getVersion() {
81+
return "1.0.0";
82+
}
83+
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,4 +1014,65 @@ public static Tool createTool(Map<String, Tool.Factory> toolFactories, Map<Strin
10141014

10151015
return tool;
10161016
}
1017+
1018+
public static List<Map<String, Object>> parseFrontendTools(String aguiTools) {
1019+
List<Map<String, Object>> frontendTools = new ArrayList<>();
1020+
if (aguiTools != null && !aguiTools.isEmpty() && !aguiTools.trim().equals("[]")) {
1021+
try {
1022+
Type listType = new TypeToken<List<Map<String, Object>>>() {
1023+
}.getType();
1024+
List<Map<String, Object>> parsed = gson.fromJson(aguiTools, listType);
1025+
if (parsed != null) {
1026+
frontendTools.addAll(parsed);
1027+
}
1028+
} catch (Exception e) {
1029+
log.warn("Failed to parse frontend tools: {}", e.getMessage());
1030+
}
1031+
}
1032+
return frontendTools;
1033+
}
1034+
1035+
public static ModelTensorOutput createFrontendToolCallResponse(String toolCallId, String action, String actionInput) {
1036+
Map<String, Object> toolCallData = Map
1037+
.of(
1038+
"tool_calls",
1039+
List.of(Map.of("id", toolCallId, "type", "function", "function", Map.of("name", action, "arguments", actionInput)))
1040+
);
1041+
1042+
ModelTensor responseTensor = ModelTensor.builder().name("response").dataAsMap(toolCallData).build();
1043+
1044+
org.opensearch.ml.common.output.model.ModelTensors modelTensors = org.opensearch.ml.common.output.model.ModelTensors
1045+
.builder()
1046+
.mlModelTensors(List.of(responseTensor))
1047+
.build();
1048+
1049+
return ModelTensorOutput.builder().mlModelOutputs(List.of(modelTensors)).build();
1050+
}
1051+
1052+
public static Map<String, Tool> wrapFrontendToolsAsToolObjects(List<Map<String, Object>> frontendTools) {
1053+
Map<String, Tool> wrappedTools = new HashMap<>();
1054+
1055+
for (Map<String, Object> frontendTool : frontendTools) {
1056+
String toolName = (String) frontendTool.get("name");
1057+
String toolDescription = (String) frontendTool.get("description");
1058+
1059+
// Create frontend tool object with source marker
1060+
Map<String, Object> toolAttributes = new HashMap<>();
1061+
toolAttributes.put("source", "frontend");
1062+
toolAttributes.put("tool_definition", frontendTool);
1063+
1064+
Object parameters = frontendTool.get("parameters");
1065+
if (parameters != null) {
1066+
toolAttributes.put("input_schema", gson.toJson(parameters));
1067+
} else {
1068+
Map<String, Object> emptySchema = Map.of("type", "object", "properties", Map.of());
1069+
toolAttributes.put("input_schema", gson.toJson(emptySchema));
1070+
}
1071+
1072+
Tool frontendToolObj = new AGUIFrontendTool(toolName, toolDescription, toolAttributes);
1073+
wrappedTools.put(toolName, frontendToolObj);
1074+
}
1075+
1076+
return wrappedTools;
1077+
}
10171078
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAGUIAgentRunner.java

Lines changed: 13 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,11 @@
55

66
package org.opensearch.ml.engine.algorithms.agent;
77

8-
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_ARGUMENTS;
98
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_CONTENT;
10-
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_FUNCTION;
119
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_ID;
12-
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_NAME;
1310
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_ROLE;
1411
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TOOL_CALLS;
1512
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TOOL_CALL_ID;
16-
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_FIELD_TYPE;
1713
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_ASSISTANT_TOOL_CALL_MESSAGES;
1814
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_CONTEXT;
1915
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_MESSAGES;
@@ -160,16 +156,8 @@ private void processAgentResult(Object result, AGUIEventCollector eventCollector
160156
ModelTensorOutput tensorOutput = (ModelTensorOutput) result;
161157
// Extract tool calls and text responses from the tensor output
162158
processTensorOutput(tensorOutput, eventCollector);
163-
} else if (result instanceof String) {
164-
String resultString = (String) result;
165-
// Check if this is a frontend tool call response
166-
if (resultString.startsWith("FRONTEND_TOOL_CALL: ")) {
167-
log.debug("AG-UI: Detected frontend tool call response, processing...");
168-
processFrontendToolCall(resultString, eventCollector);
169-
} else {
170-
log.debug("AG-UI: String result is not a frontend tool call");
171-
}
172159
}
160+
173161
List<Object> messages = new ArrayList<>();
174162
String responseText = extractResponseText(result);
175163
messages.add(Map.of(AGUI_FIELD_ID, messageId, AGUI_FIELD_ROLE, AGUI_ROLE_ASSISTANT, AGUI_FIELD_CONTENT, responseText));
@@ -226,10 +214,6 @@ private void processModelTensor(ModelTensor tensor, AGUIEventCollector eventColl
226214
Map<String, ?> dataMap = tensor.getDataAsMap();
227215
if (dataMap != null) {
228216
processToolCallsFromDataMap(dataMap, eventCollector);
229-
} else if (tensor.getResult() != null) {
230-
// Handle text result that might contain tool call information
231-
String result = tensor.getResult();
232-
processTextResponseForToolCalls(result, eventCollector);
233217
}
234218
}
235219
}
@@ -283,80 +267,6 @@ private void processToolCallsFromDataMap(Map<String, ?> dataMap, AGUIEventCollec
283267
}
284268
}
285269

286-
private void processFrontendToolCall(String frontendToolCallResponse, AGUIEventCollector eventCollector) {
287-
log.debug("AG-UI: Processing frontend tool call response: {}", frontendToolCallResponse);
288-
try {
289-
// Extract the JSON part after "FRONTEND_TOOL_CALL: "
290-
String jsonPart = frontendToolCallResponse.substring("FRONTEND_TOOL_CALL: ".length());
291-
log.debug("AG-UI: Extracted JSON part: {}", jsonPart);
292-
293-
JsonElement element = gson.fromJson(jsonPart, JsonElement.class);
294-
295-
if (element.isJsonObject()) {
296-
JsonObject toolCallObj = element.getAsJsonObject();
297-
String toolName = toolCallObj.get("tool").getAsString();
298-
String toolInput = toolCallObj.get("input").getAsString();
299-
300-
log.debug("AG-UI: Processing frontend tool call - tool: {}, input: {}", toolName, toolInput);
301-
302-
// Generate AG-UI events for the frontend tool call
303-
String toolCallId = eventCollector.startToolCall(toolName, null);
304-
eventCollector.addToolCallArgs(toolCallId, toolInput);
305-
eventCollector.endToolCall(toolCallId);
306-
} else {
307-
log.warn("AG-UI: JSON element is not an object: {}", element);
308-
}
309-
} catch (Exception e) {
310-
log.error("Failed to process frontend tool call response: {}", frontendToolCallResponse, e);
311-
}
312-
}
313-
314-
private void processTextResponseForToolCalls(String result, AGUIEventCollector eventCollector) {
315-
// Try to parse JSON response that might contain tool calls
316-
try {
317-
JsonElement element = gson.fromJson(result, JsonElement.class);
318-
if (element.isJsonObject()) {
319-
JsonObject obj = element.getAsJsonObject();
320-
if (obj.has("tool_calls")) {
321-
JsonElement toolCallsElement = obj.get("tool_calls");
322-
if (toolCallsElement.isJsonArray()) {
323-
for (JsonElement toolCallElement : toolCallsElement.getAsJsonArray()) {
324-
if (toolCallElement.isJsonObject()) {
325-
JsonObject toolCall = toolCallElement.getAsJsonObject();
326-
String toolCallId = getStringField(toolCall, "id");
327-
JsonElement functionElement = toolCall.get("function");
328-
329-
if (functionElement != null && functionElement.isJsonObject()) {
330-
JsonObject function = functionElement.getAsJsonObject();
331-
String toolName = getStringField(function, "name");
332-
String arguments = getStringField(function, "arguments");
333-
334-
if (toolCallId != null && toolName != null) {
335-
eventCollector.startToolCall(toolName, null);
336-
337-
if (arguments != null && !arguments.isEmpty()) {
338-
eventCollector.addToolCallArgs(toolCallId, arguments);
339-
}
340-
341-
eventCollector.endToolCall(toolCallId);
342-
log
343-
.debug(
344-
"AG-UI: Generated tool call events from text response for tool={}, id={}",
345-
toolName,
346-
toolCallId
347-
);
348-
}
349-
}
350-
}
351-
}
352-
}
353-
}
354-
}
355-
} catch (Exception e) {
356-
// Not a tool call response, just regular text - no special processing needed
357-
}
358-
}
359-
360270
private void processAGUIMessages(MLAgent mlAgent, Map<String, String> params, String llmInterface) {
361271
String aguiMessagesJson = params.get(AGUI_PARAM_MESSAGES);
362272
if (aguiMessagesJson == null || aguiMessagesJson.isEmpty()) {
@@ -415,78 +325,27 @@ private void processAGUIMessages(MLAgent mlAgent, Map<String, String> params, St
415325
if (AGUI_ROLE_ASSISTANT.equals(role) && message.has(AGUI_FIELD_TOOL_CALLS)) {
416326
toolCallMessageIndices.add(i);
417327

418-
// Convert to OpenAI format for interactions
328+
// Extract tool calls from AG-UI message (AG-UI uses OpenAI-compatible format)
419329
JsonElement toolCallsElement = message.get(AGUI_FIELD_TOOL_CALLS);
420330
if (toolCallsElement != null && toolCallsElement.isJsonArray()) {
421-
List<Map<String, Object>> toolCalls = new ArrayList<>();
422-
for (JsonElement tcElement : toolCallsElement.getAsJsonArray()) {
423-
if (tcElement.isJsonObject()) {
424-
JsonObject tc = tcElement.getAsJsonObject();
425-
Map<String, Object> toolCall = new HashMap<>();
426-
427-
// OpenAI format: id, type, and function at the same level
428-
String toolCallId = getStringField(tc, AGUI_FIELD_ID);
429-
String toolCallType = getStringField(tc, AGUI_FIELD_TYPE);
430-
431-
toolCall.put(AGUI_FIELD_ID, toolCallId);
432-
toolCall.put(AGUI_FIELD_TYPE, toolCallType != null ? toolCallType : "function");
433-
434-
JsonElement functionElement = tc.get(AGUI_FIELD_FUNCTION);
435-
if (functionElement != null && functionElement.isJsonObject()) {
436-
JsonObject func = functionElement.getAsJsonObject();
437-
Map<String, String> function = new HashMap<>();
438-
function.put(AGUI_FIELD_NAME, getStringField(func, AGUI_FIELD_NAME));
439-
function.put(AGUI_FIELD_ARGUMENTS, getStringField(func, AGUI_FIELD_ARGUMENTS));
440-
toolCall.put(AGUI_FIELD_FUNCTION, function);
441-
}
442-
toolCalls.add(toolCall);
443-
}
444-
}
331+
// Pass the JSON array directly to FunctionCalling for format conversion
332+
String toolCallsJson = gson.toJson(toolCallsElement);
445333

446-
// Create assistant message in the appropriate format based on LLM interface
334+
FunctionCalling functionCalling = FunctionCallingFactory.create(llmInterface);
447335
String assistantMessage;
448-
boolean isBedrockConverse = llmInterface != null && llmInterface.toLowerCase().contains("bedrock");
449-
450-
if (isBedrockConverse) {
451-
// Bedrock format: {"role": "assistant", "content": [{"toolUse": {...}}]}
452-
List<Map<String, Object>> contentBlocks = new ArrayList<>();
453-
for (Map<String, Object> toolCall : toolCalls) {
454-
Map<String, Object> toolUse = new HashMap<>();
455-
toolUse.put("toolUseId", toolCall.get("id"));
456-
457-
Map<String, Object> function = (Map<String, Object>) toolCall.get("function");
458-
if (function != null) {
459-
toolUse.put("name", function.get("name"));
460-
461-
// Parse arguments JSON string to object
462-
String argumentsJson = (String) function.get("arguments");
463-
try {
464-
Object argumentsObj = gson.fromJson(argumentsJson, Object.class);
465-
toolUse.put("input", argumentsObj);
466-
} catch (Exception e) {
467-
log.warn("AG-UI: Failed to parse tool arguments as JSON: {}", argumentsJson, e);
468-
toolUse.put("input", Map.of());
469-
}
470-
}
471-
472-
contentBlocks.add(Map.of("toolUse", toolUse));
473-
}
474336

475-
Map<String, Object> bedrockMsg = new HashMap<>();
476-
bedrockMsg.put(AGUI_FIELD_ROLE, AGUI_ROLE_ASSISTANT);
477-
bedrockMsg.put(AGUI_FIELD_CONTENT, contentBlocks);
478-
assistantMessage = gson.toJson(bedrockMsg);
337+
if (functionCalling != null) {
338+
// Use FunctionCalling to format the message in the correct LLM format
339+
assistantMessage = functionCalling.formatAGUIToolCalls(toolCallsJson);
340+
log.debug("AG-UI: Formatted assistant message using {}", functionCalling.getClass().getSimpleName());
479341
} else {
480-
// OpenAI format: {"role": "assistant", "tool_calls": [...]}
481-
Map<String, Object> assistantMsg = new HashMap<>();
482-
assistantMsg.put(AGUI_FIELD_ROLE, AGUI_ROLE_ASSISTANT);
483-
assistantMsg.put("tool_calls", toolCalls);
484-
assistantMessage = gson.toJson(assistantMsg);
485-
log.debug("AG-UI: Created OpenAI-format assistant message with {} tool calls", toolCalls.size());
342+
// Fallback to OpenAI format if no FunctionCalling available
343+
assistantMessage = "{\"role\":\"assistant\",\"tool_calls\":" + toolCallsJson + "}";
344+
log.debug("AG-UI: Created OpenAI-format assistant message (fallback)");
486345
}
487346

488347
assistantToolCallMessages.add(assistantMessage);
489-
log.debug("AG-UI: Extracted assistant message with {} tool calls at index {}", toolCalls.size(), i);
348+
log.debug("AG-UI: Extracted assistant message at index {}", i);
490349
log.debug("AG-UI: Assistant message JSON: {}", assistantMessage);
491350
}
492351
}

0 commit comments

Comments
 (0)