55
66package org .opensearch .ml .common .agui ;
77
8+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_FIELD_CONTENT ;
9+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_FIELD_CONTEXT ;
10+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_FIELD_FORWARDED_PROPS ;
11+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_FIELD_MESSAGES ;
12+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_FIELD_ROLE ;
13+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_FIELD_RUN_ID ;
14+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_FIELD_STATE ;
15+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_FIELD_THREAD_ID ;
16+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_FIELD_TOOLS ;
17+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_FIELD_TOOL_CALL_ID ;
18+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_PARAM_CONTEXT ;
19+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_PARAM_FORWARDED_PROPS ;
20+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_PARAM_MESSAGES ;
21+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_PARAM_RUN_ID ;
22+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_PARAM_STATE ;
23+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_PARAM_THREAD_ID ;
24+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_PARAM_TOOLS ;
25+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_PARAM_TOOL_CALL_RESULTS ;
26+ import static org .opensearch .ml .common .agui .AGUIConstants .AGUI_ROLE_USER ;
27+
828import java .util .HashMap ;
929import java .util .List ;
1030import java .util .Map ;
@@ -27,7 +47,28 @@ public class AGUIInputConverter {
2747 public static boolean isAGUIInput (String inputJson ) {
2848 try {
2949 JsonObject jsonObj = JsonParser .parseString (inputJson ).getAsJsonObject ();
30- return jsonObj .has ("threadId" ) && jsonObj .has ("runId" ) && jsonObj .has ("messages" ) && jsonObj .has ("tools" );
50+
51+ // Check required fields exist
52+ if (!jsonObj .has (AGUI_FIELD_THREAD_ID )
53+ || !jsonObj .has (AGUI_FIELD_RUN_ID )
54+ || !jsonObj .has (AGUI_FIELD_MESSAGES )
55+ || !jsonObj .has (AGUI_FIELD_TOOLS )) {
56+ return false ;
57+ }
58+
59+ // Validate messages is an array
60+ JsonElement messages = jsonObj .get (AGUI_FIELD_MESSAGES );
61+ if (!messages .isJsonArray ()) {
62+ return false ;
63+ }
64+
65+ // Validate tools is an array
66+ JsonElement tools = jsonObj .get (AGUI_FIELD_TOOLS );
67+ if (!tools .isJsonArray ()) {
68+ return false ;
69+ }
70+
71+ return true ;
3172 } catch (Exception e ) {
3273 log .debug ("Failed to parse input as JSON for AG-UI detection" , e );
3374 return false ;
@@ -38,37 +79,37 @@ public static AgentMLInput convertFromAGUIInput(String aguiInputJson, String age
3879 try {
3980 JsonObject aguiInput = JsonParser .parseString (aguiInputJson ).getAsJsonObject ();
4081
41- String threadId = getStringField (aguiInput , "threadId" );
42- String runId = getStringField (aguiInput , "runId" );
43- JsonElement state = aguiInput .get ("state" );
44- JsonElement messages = aguiInput .get ("messages" );
45- JsonElement tools = aguiInput .get ("tools" );
46- JsonElement context = aguiInput .get ("context" );
47- JsonElement forwardedProps = aguiInput .get ("forwardedProps" );
82+ String threadId = getStringField (aguiInput , AGUI_FIELD_THREAD_ID );
83+ String runId = getStringField (aguiInput , AGUI_FIELD_RUN_ID );
84+ JsonElement state = aguiInput .get (AGUI_FIELD_STATE );
85+ JsonElement messages = aguiInput .get (AGUI_FIELD_MESSAGES );
86+ JsonElement tools = aguiInput .get (AGUI_FIELD_TOOLS );
87+ JsonElement context = aguiInput .get (AGUI_FIELD_CONTEXT );
88+ JsonElement forwardedProps = aguiInput .get (AGUI_FIELD_FORWARDED_PROPS );
4889
4990 Map <String , String > parameters = new HashMap <>();
50- parameters .put ("agui_thread_id" , threadId );
51- parameters .put ("agui_run_id" , runId );
91+ parameters .put (AGUI_PARAM_THREAD_ID , threadId );
92+ parameters .put (AGUI_PARAM_RUN_ID , runId );
5293
5394 if (state != null ) {
54- parameters .put ("agui_state" , gson .toJson (state ));
95+ parameters .put (AGUI_PARAM_STATE , gson .toJson (state ));
5596 }
5697
5798 if (messages != null ) {
58- parameters .put ("agui_messages" , gson .toJson (messages ));
99+ parameters .put (AGUI_PARAM_MESSAGES , gson .toJson (messages ));
59100 extractUserQuestion (messages , parameters );
60101 }
61102
62103 if (tools != null ) {
63- parameters .put ("agui_tools" , gson .toJson (tools ));
104+ parameters .put (AGUI_PARAM_TOOLS , gson .toJson (tools ));
64105 }
65106
66107 if (context != null ) {
67- parameters .put ("agui_context" , gson .toJson (context ));
108+ parameters .put (AGUI_PARAM_CONTEXT , gson .toJson (context ));
68109 }
69110
70111 if (forwardedProps != null ) {
71- parameters .put ("agui_forwarded_props" , gson .toJson (forwardedProps ));
112+ parameters .put (AGUI_PARAM_FORWARDED_PROPS , gson .toJson (forwardedProps ));
72113 }
73114 RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet .builder ().parameters (parameters ).build ();
74115 AgentMLInput agentMLInput = new AgentMLInput (
@@ -93,40 +134,6 @@ private static String getStringField(JsonObject obj, String fieldName) {
93134 return element != null && !element .isJsonNull () ? element .getAsString () : null ;
94135 }
95136
96- public static JsonObject reconstructAGUIInput (Map <String , String > parameters ) {
97- JsonObject aguiInput = new JsonObject ();
98-
99- try {
100- String threadId = parameters .get ("agui_thread_id" );
101- String runId = parameters .get ("agui_run_id" );
102- String stateJson = parameters .get ("agui_state" );
103- String messagesJson = parameters .get ("agui_messages" );
104- String toolsJson = parameters .get ("agui_tools" );
105- String contextJson = parameters .get ("agui_context" );
106- String forwardedPropsJson = parameters .get ("agui_forwarded_props" );
107-
108- if (threadId != null )
109- aguiInput .addProperty ("threadId" , threadId );
110- if (runId != null )
111- aguiInput .addProperty ("runId" , runId );
112- if (stateJson != null )
113- aguiInput .add ("state" , JsonParser .parseString (stateJson ));
114- if (messagesJson != null )
115- aguiInput .add ("messages" , JsonParser .parseString (messagesJson ));
116- if (toolsJson != null )
117- aguiInput .add ("tools" , JsonParser .parseString (toolsJson ));
118- if (contextJson != null )
119- aguiInput .add ("context" , JsonParser .parseString (contextJson ));
120- if (forwardedPropsJson != null )
121- aguiInput .add ("forwardedProps" , JsonParser .parseString (forwardedPropsJson ));
122-
123- } catch (Exception e ) {
124- log .error ("Failed to reconstruct AG-UI input from parameters" , e );
125- }
126-
127- return aguiInput ;
128- }
129-
130137 private static void extractUserQuestion (JsonElement messages , Map <String , String > parameters ) {
131138 if (messages == null || !messages .isJsonArray ()) {
132139 throw new IllegalArgumentException ("Invalid AG-UI messages" );
@@ -140,12 +147,12 @@ private static void extractUserQuestion(JsonElement messages, Map<String, String
140147 for (JsonElement messageElement : messages .getAsJsonArray ()) {
141148 if (messageElement .isJsonObject ()) {
142149 JsonObject message = messageElement .getAsJsonObject ();
143- JsonElement roleElement = message .get ("role" );
144- JsonElement contentElement = message .get ("content" );
145- JsonElement toolCallIdElement = message .get ("toolCallId" );
150+ JsonElement roleElement = message .get (AGUI_FIELD_ROLE );
151+ JsonElement contentElement = message .get (AGUI_FIELD_CONTENT );
152+ JsonElement toolCallIdElement = message .get (AGUI_FIELD_TOOL_CALL_ID );
146153
147154 if (roleElement != null
148- && "user" .equals (roleElement .getAsString ())
155+ && AGUI_ROLE_USER .equals (roleElement .getAsString ())
149156 && contentElement != null
150157 && !contentElement .isJsonNull ()) {
151158
@@ -173,7 +180,7 @@ private static void extractUserQuestion(JsonElement messages, Map<String, String
173180
174181 // Set appropriate parameters based on what was found
175182 if (toolCallResults != null ) {
176- parameters .put ("agui_tool_call_results" , toolCallResults );
183+ parameters .put (AGUI_PARAM_TOOL_CALL_RESULTS , toolCallResults );
177184 log .debug ("Detected AG-UI tool call results: {}" , toolCallResults );
178185 } else if (lastUserMessage != null ) {
179186 parameters .put ("question" , lastUserMessage );
0 commit comments