Skip to content

Commit 4eaeb94

Browse files
committed
adapt to inplace update for context
Signed-off-by: Mingshi Liu <[email protected]>
1 parent e1bc0e0 commit 4eaeb94

File tree

9 files changed

+224
-131
lines changed

9 files changed

+224
-131
lines changed

common/src/main/java/org/opensearch/ml/common/contextmanager/ContextManagerContext.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public class ContextManagerContext {
5959
* The tool interactions/results from tool executions
6060
*/
6161
@Builder.Default
62-
private List<Map<String, Object>> toolInteractions = new ArrayList<>();
62+
private List<String> toolInteractions = new ArrayList<>();
6363

6464
/**
6565
* Additional parameters for context processing
@@ -96,11 +96,8 @@ public int getEstimatedTokenCount() {
9696
}
9797

9898
// Estimate tokens for tool interactions
99-
for (Map<String, Object> interaction : toolInteractions) {
100-
Object output = interaction.get("output");
101-
if (output instanceof String) {
102-
tokenCount += estimateTokens((String) output);
103-
}
99+
for (String interaction : toolInteractions) {
100+
tokenCount += estimateTokens(interaction);
104101
}
105102

106103
return tokenCount;
@@ -133,7 +130,7 @@ private int estimateTokens(String text) {
133130
* Add a tool interaction to the context.
134131
* @param interaction the tool interaction to add
135132
*/
136-
public void addToolInteraction(Map<String, Object> interaction) {
133+
public void addToolInteraction(String interaction) {
137134
if (toolInteractions == null) {
138135
toolInteractions = new ArrayList<>();
139136
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/agents/AgentContextUtil.java

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -95,15 +95,7 @@ public static ContextManagerContext buildContextManagerContext(
9595
builder.toolConfigs(toolSpecs);
9696
}
9797

98-
List<Map<String, Object>> toolInteractions = new ArrayList<>();
99-
if (interactions != null) {
100-
for (String interaction : interactions) {
101-
Map<String, Object> toolInteraction = new HashMap<>();
102-
toolInteraction.put("output", interaction);
103-
toolInteractions.add(toolInteraction);
104-
}
105-
}
106-
builder.toolInteractions(toolInteractions);
98+
builder.toolInteractions(interactions != null ? interactions : new ArrayList<>());
10799

108100
Map<String, String> contextParameters = new HashMap<>();
109101
contextParameters.putAll(parameters);
@@ -152,10 +144,10 @@ public static ContextManagerContext emitPreLLMHook(
152144
HookRegistry hookRegistry
153145
) {
154146
ContextManagerContext context = buildContextManagerContext(parameters, interactions, toolSpecs, memory);
147+
155148
try {
156149
PreLLMEvent event = new PreLLMEvent(context, new HashMap<>());
157150
hookRegistry.emit(event);
158-
log.debug("Emitted PRE_LLM hook event and updated context");
159151
return context;
160152

161153
} catch (Exception e) {
@@ -177,16 +169,7 @@ public static void updateParametersFromContext(Map<String, String> parameters, C
177169
}
178170

179171
if (context.getToolInteractions() != null && !context.getToolInteractions().isEmpty()) {
180-
List<String> updatedInteractions = new ArrayList<>();
181-
for (Map<String, Object> toolInteraction : context.getToolInteractions()) {
182-
Object output = toolInteraction.get("output");
183-
if (output instanceof String) {
184-
updatedInteractions.add((String) output);
185-
}
186-
}
187-
if (!updatedInteractions.isEmpty()) {
188-
parameters.put(INTERACTIONS, ", " + String.join(", ", updatedInteractions));
189-
}
172+
parameters.put(INTERACTIONS, ", " + String.join(", ", context.getToolInteractions()));
190173
}
191174

192175
if (context.getParameters() != null) {

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,13 @@ private void saveRootInteractionAndExecute(
477477
*/
478478
private void processContextManagement(MLAgent mlAgent, HookRegistry hookRegistry, RemoteInferenceInputDataSet inputDataSet) {
479479
try {
480+
// Check if context_management is already specified in runtime parameters
481+
String runtimeContextManagement = inputDataSet.getParameters().get("context_management");
482+
if (runtimeContextManagement != null && !runtimeContextManagement.trim().isEmpty()) {
483+
log.info("Using runtime context management parameter: {}", runtimeContextManagement);
484+
return; // Runtime parameter takes precedence, let MLExecuteTaskRunner handle it
485+
}
486+
480487
ContextManagementTemplate template = null;
481488
String templateName = null;
482489

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ private void runReAct(
340340
StepListener<?> lastStepListener = firstListener;
341341

342342
StringBuilder scratchpadBuilder = new StringBuilder();
343-
List<String> interactions = new CopyOnWriteArrayList<>();
343+
final List<String> interactions = new CopyOnWriteArrayList<>();
344344

345345
StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}");
346346
AtomicReference<String> newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt));
@@ -548,9 +548,18 @@ private void runReAct(
548548
ContextManagerContext contextAfterEvent = AgentContextUtil
549549
.emitPreLLMHook(tmpParameters, interactions, currentToolSpecs, memory, hookRegistry);
550550

551-
if (tmpParameters.get(INTERACTIONS) != null || tmpParameters.get(INTERACTIONS) != "") {
552-
tmpParameters.put(INTERACTIONS, StringUtils.toJson(contextAfterEvent.getParameters().get(INTERACTIONS)));
551+
// Check if context managers actually modified the interactions
552+
List<String> updatedInteractions = contextAfterEvent.getToolInteractions();
553553

554+
if (updatedInteractions != null && !updatedInteractions.equals(interactions)) {
555+
interactions.clear();
556+
interactions.addAll(updatedInteractions);
557+
558+
// Update parameters if context manager set INTERACTIONS
559+
String contextInteractions = contextAfterEvent.getParameters().get(INTERACTIONS);
560+
if (contextInteractions != null && !contextInteractions.isEmpty()) {
561+
tmpParameters.put(INTERACTIONS, contextInteractions);
562+
}
554563
}
555564
}
556565
ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId);
@@ -572,8 +581,17 @@ private void runReAct(
572581
ContextManagerContext contextAfterEvent = AgentContextUtil
573582
.emitPreLLMHook(tmpParameters, interactions, initialToolSpecs, memory, hookRegistry);
574583

575-
if (tmpParameters.get(INTERACTIONS) != null || tmpParameters.get(INTERACTIONS) != "") {
576-
tmpParameters.put(INTERACTIONS, StringUtils.toJson(contextAfterEvent.getParameters().get(INTERACTIONS)));
584+
// Check if context managers actually modified the interactions
585+
List<String> updatedInteractions = contextAfterEvent.getToolInteractions();
586+
if (updatedInteractions != null && !updatedInteractions.equals(interactions)) {
587+
interactions.clear();
588+
interactions.addAll(updatedInteractions);
589+
590+
// Update parameters if context manager set INTERACTIONS
591+
String contextInteractions = contextAfterEvent.getParameters().get(INTERACTIONS);
592+
if (contextInteractions != null && !contextInteractions.isEmpty()) {
593+
tmpParameters.put(INTERACTIONS, contextInteractions);
594+
}
577595
}
578596
}
579597
ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId);

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import org.opensearch.ml.common.agent.LLMSpec;
5454
import org.opensearch.ml.common.agent.MLAgent;
5555
import org.opensearch.ml.common.agent.MLToolSpec;
56+
import org.opensearch.ml.common.contextmanager.ContextManagerContext;
5657
import org.opensearch.ml.common.conversation.Interaction;
5758
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
5859
import org.opensearch.ml.common.exception.MLException;
@@ -297,7 +298,7 @@ public void run(MLAgent mlAgent, Map<String, String> apiParams, ActionListener<O
297298
conversationIndexMemoryFactory
298299
.create(allParams.get(USER_PROMPT_FIELD), memoryId, appType, ActionListener.<ConversationIndexMemory>wrap(memory -> {
299300
memory.getMessages(ActionListener.<List<Interaction>>wrap(interactions -> {
300-
List<String> completedSteps = new ArrayList<>();
301+
final List<String> completedSteps = new ArrayList<>();
301302
for (Interaction interaction : interactions) {
302303
String question = interaction.getInput();
303304
String response = interaction.getResponse();
@@ -398,14 +399,26 @@ private void executePlanningLoop(
398399

399400
allParams.put("_llm_model_id", llm.getModelId());
400401
if (hookRegistry != null && !completedSteps.isEmpty()) {
401-
allParams.put(INTERACTIONS, ", " + String.join(", ", completedSteps));
402+
402403
Map<String, String> requestParams = new HashMap<>(allParams);
404+
requestParams.put(INTERACTIONS, ", " + String.join(", ", completedSteps));
403405
try {
404-
AgentContextUtil.emitPreLLMHook(requestParams, completedSteps, null, memory, hookRegistry);
405-
406-
if (requestParams.get(INTERACTIONS) != null || requestParams.get(INTERACTIONS) != "") {
407-
allParams.put(COMPLETED_STEPS_FIELD, StringUtils.toJson(requestParams.get(INTERACTIONS)));
408-
allParams.put(INTERACTIONS, "");
406+
ContextManagerContext contextAfterEvent = AgentContextUtil
407+
.emitPreLLMHook(requestParams, completedSteps, null, memory, hookRegistry);
408+
409+
// Check if context managers actually modified the interactions
410+
List<String> updatedSteps = contextAfterEvent.getToolInteractions();
411+
if (updatedSteps != null && !updatedSteps.equals(completedSteps)) {
412+
completedSteps.clear();
413+
completedSteps.addAll(updatedSteps);
414+
415+
// Update parameters if context manager set INTERACTIONS
416+
String contextInteractions = contextAfterEvent.getParameters().get(INTERACTIONS);
417+
if (contextInteractions != null && !contextInteractions.isEmpty()) {
418+
allParams.put(COMPLETED_STEPS_FIELD, contextInteractions);
419+
// TODO should I always clear interactions after update the completed steps?
420+
allParams.put(INTERACTIONS, "");
421+
}
409422
}
410423
} catch (Exception e) {
411424
log.error("Failed to emit pre-LLM hook", e);

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/contextmanager/SlidingWindowManager.java

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,13 @@ public boolean shouldActivate(ContextManagerContext context) {
7878

7979
@Override
8080
public void execute(ContextManagerContext context) {
81-
List<Map<String, Object>> toolInteractions = context.getToolInteractions();
81+
List<String> interactions = context.getToolInteractions();
8282

83-
if (toolInteractions == null || toolInteractions.isEmpty()) {
83+
if (interactions == null || interactions.isEmpty()) {
8484
log.debug("No tool interactions to process");
8585
return;
8686
}
8787

88-
// Extract interactions from tool interactions
89-
List<String> interactions = new ArrayList<>();
90-
for (Map<String, Object> toolInteraction : toolInteractions) {
91-
Object output = toolInteraction.get("output");
92-
if (output instanceof String) {
93-
interactions.add((String) output);
94-
}
95-
}
96-
9788
if (interactions.isEmpty()) {
9889
log.debug("No string interactions found in tool interactions");
9990
return;
@@ -106,14 +97,14 @@ public void execute(ContextManagerContext context) {
10697
return;
10798
}
10899

109-
// Keep the most recent interactions
110-
List<String> updatedInteractions = new ArrayList<>(interactions.subList(originalSize - maxMessages, originalSize));
100+
// Find safe start point to avoid breaking tool pairs
101+
int startIndex = findSafeStartPoint(interactions, originalSize - maxMessages);
102+
103+
// Keep the most recent interactions from safe start point
104+
List<String> updatedInteractions = new ArrayList<>(interactions.subList(startIndex, originalSize));
111105

112106
// Update toolInteractions in context to keep only the most recent ones
113-
List<Map<String, Object>> updatedToolInteractions = new ArrayList<>(
114-
toolInteractions.subList(originalSize - maxMessages, originalSize)
115-
);
116-
context.setToolInteractions(updatedToolInteractions);
107+
context.setToolInteractions(updatedInteractions);
117108

118109
// Update the _interactions parameter with smaller size of updated interactions
119110
Map<String, String> parameters = context.getParameters();
@@ -123,8 +114,13 @@ public void execute(ContextManagerContext context) {
123114
}
124115
parameters.put("_interactions", ", " + String.join(", ", updatedInteractions));
125116

126-
int removedMessages = originalSize - maxMessages;
127-
log.info("Applied sliding window: kept {} most recent interactions, removed {} older interactions", maxMessages, removedMessages);
117+
int removedMessages = originalSize - updatedInteractions.size();
118+
log
119+
.info(
120+
"Applied sliding window: kept {} most recent interactions, removed {} older interactions",
121+
updatedInteractions.size(),
122+
removedMessages
123+
);
128124
}
129125

130126
private int parseIntegerConfig(Map<String, Object> config, String key, int defaultValue) {
@@ -149,4 +145,47 @@ private int parseIntegerConfig(Map<String, Object> config, String key, int defau
149145
return defaultValue;
150146
}
151147
}
148+
149+
/**
150+
* Find a safe start point that doesn't break assistant-tool message pairs
151+
* Same logic as SummarizationManager but for finding start point
152+
*/
153+
private int findSafeStartPoint(List<String> interactions, int targetStartPoint) {
154+
if (targetStartPoint <= 0) {
155+
return 0;
156+
}
157+
if (targetStartPoint >= interactions.size()) {
158+
return interactions.size();
159+
}
160+
161+
int startPoint = targetStartPoint;
162+
163+
while (startPoint < interactions.size()) {
164+
try {
165+
String messageAtStart = interactions.get(startPoint);
166+
167+
// Oldest message cannot be a toolResult because it needs a toolUse preceding it
168+
boolean hasToolResult = messageAtStart.contains("toolResult");
169+
170+
// Oldest message can be a toolUse only if a toolResult immediately follows it
171+
boolean hasToolUse = messageAtStart.contains("toolUse");
172+
boolean nextHasToolResult = false;
173+
if (startPoint + 1 < interactions.size()) {
174+
nextHasToolResult = interactions.get(startPoint + 1).contains("toolResult");
175+
}
176+
177+
if (hasToolResult || (hasToolUse && startPoint + 1 < interactions.size() && !nextHasToolResult)) {
178+
startPoint++;
179+
} else {
180+
break;
181+
}
182+
183+
} catch (Exception e) {
184+
log.warn("Error checking message at index {}: {}", startPoint, e.getMessage());
185+
startPoint++;
186+
}
187+
}
188+
189+
return startPoint;
190+
}
152191
}

0 commit comments

Comments
 (0)