Skip to content

Commit 7b6cb32

Browse files
committed
refactor tool use formatting
Signed-off-by: Jiaping Zeng <[email protected]>
1 parent 9a7f887 commit 7b6cb32

File tree

16 files changed

+338
-937
lines changed

16 files changed

+338
-937
lines changed

common/src/main/java/org/opensearch/ml/common/agui/AGUIInputConverter.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_TOOLS;
2525
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_PARAM_TOOL_CALL_RESULTS;
2626
import static org.opensearch.ml.common.agui.AGUIConstants.AGUI_ROLE_USER;
27+
import static org.opensearch.ml.common.utils.StringUtils.getStringField;
2728

2829
import java.util.HashMap;
2930
import java.util.List;
@@ -124,11 +125,6 @@ public static AgentMLInput convertFromAGUIInput(String aguiInputJson, String age
124125
}
125126
}
126127

127-
private static String getStringField(JsonObject obj, String fieldName) {
128-
JsonElement element = obj.get(fieldName);
129-
return element != null && !element.isJsonNull() ? element.getAsString() : null;
130-
}
131-
132128
private static void extractUserQuestion(JsonElement messages, Map<String, String> parameters) {
133129
if (messages == null || !messages.isJsonArray()) {
134130
throw new IllegalArgumentException("Invalid AG-UI messages");

common/src/main/java/org/opensearch/ml/common/output/AGUIOutput.java

Lines changed: 0 additions & 72 deletions
This file was deleted.

common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskResponse.java

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
import java.io.ByteArrayOutputStream;
1010
import java.io.IOException;
1111
import java.io.UncheckedIOException;
12-
import java.util.ArrayList;
13-
import java.util.List;
1412

1513
import org.opensearch.core.action.ActionResponse;
1614
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
@@ -21,12 +19,7 @@
2119
import org.opensearch.core.xcontent.XContentBuilder;
2220
import org.opensearch.ml.common.FunctionName;
2321
import org.opensearch.ml.common.MLCommonsClassLoader;
24-
import org.opensearch.ml.common.output.AGUIOutput;
2522
import org.opensearch.ml.common.output.Output;
26-
import org.opensearch.ml.common.output.model.ModelTensorOutput;
27-
import org.opensearch.ml.common.output.model.ModelTensors;
28-
29-
import com.google.gson.Gson;
3023

3124
import lombok.Builder;
3225
import lombok.Getter;
@@ -82,34 +75,6 @@ public static MLExecuteTaskResponse fromActionResponse(ActionResponse actionResp
8275

8376
@Override
8477
public XContentBuilder toXContent(final XContentBuilder builder, final Params params) throws IOException {
85-
if (functionName == FunctionName.AGENT && output instanceof ModelTensorOutput) {
86-
ModelTensorOutput modelOutput = (ModelTensorOutput) output;
87-
if (isAGUIOutput(modelOutput)) {
88-
return extractAGUIOutput(modelOutput).toXContent(builder, params);
89-
}
90-
}
9178
return output.toXContent(builder, params);
9279
}
93-
94-
private boolean isAGUIOutput(ModelTensorOutput modelOutput) {
95-
if (modelOutput.getMlModelOutputs() != null && modelOutput.getMlModelOutputs().size() == 1) {
96-
ModelTensors modelTensors = modelOutput.getMlModelOutputs().get(0);
97-
if (modelTensors.getMlModelTensors() != null && modelTensors.getMlModelTensors().size() == 1) {
98-
return "ag_ui_events".equals(modelTensors.getMlModelTensors().get(0).getName());
99-
}
100-
}
101-
return false;
102-
}
103-
104-
private AGUIOutput extractAGUIOutput(ModelTensorOutput modelOutput) {
105-
try {
106-
ModelTensors modelTensors = modelOutput.getMlModelOutputs().get(0);
107-
String eventsJson = modelTensors.getMlModelTensors().get(0).getResult();
108-
Gson gson = new Gson();
109-
List<Object> events = gson.fromJson(eventsJson, List.class);
110-
return AGUIOutput.builder().events(events).build();
111-
} catch (Exception e) {
112-
return AGUIOutput.builder().events(new ArrayList<>()).build();
113-
}
114-
}
11580
}

common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,4 +726,9 @@ public Float read(JsonReader in) throws IOException {
726726
return f;
727727
}
728728
}
729+
730+
public static String getStringField(JsonObject obj, String fieldName) {
731+
JsonElement element = obj.get(fieldName);
732+
return element != null && !element.isJsonNull() ? element.getAsString() : null;
733+
}
729734
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AGUIEventCollector.java

Lines changed: 0 additions & 124 deletions
This file was deleted.

0 commit comments

Comments
 (0)