Skip to content

Commit 305cb8e

Browse files
authored
combine json chunks from requests (#4317)
1 parent f8b403b commit 305cb8e

File tree

2 files changed

+103
-15
lines changed

2 files changed

+103
-15
lines changed

plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import static org.opensearch.ml.utils.RestActionUtils.isAsync;
1818
import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID;
1919

20+
import java.io.ByteArrayOutputStream;
2021
import java.io.IOException;
21-
import java.io.UncheckedIOException;
2222
import java.nio.ByteBuffer;
2323
import java.util.LinkedHashMap;
2424
import java.util.List;
@@ -48,7 +48,6 @@
4848
import org.opensearch.ml.common.MLModel;
4949
import org.opensearch.ml.common.agent.MLAgent;
5050
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
51-
import org.opensearch.ml.common.exception.MLException;
5251
import org.opensearch.ml.common.input.Input;
5352
import org.opensearch.ml.common.input.MLInput;
5453
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
@@ -158,10 +157,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
158157
);
159158
channel.prepareResponse(RestStatus.OK, headers);
160159

161-
Flux.from(channel).ofType(HttpChunk.class).concatMap(chunk -> {
162-
final CompletableFuture<HttpChunk> future = new CompletableFuture<>();
160+
Flux.from(channel).ofType(HttpChunk.class).collectList().flatMap(chunks -> {
163161
try {
164-
MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, chunk.content());
162+
BytesReference completeContent = combineChunks(chunks);
163+
MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, completeContent);
164+
165+
final CompletableFuture<HttpChunk> future = new CompletableFuture<>();
165166
StreamTransportResponseHandler<MLTaskResponse> handler = new StreamTransportResponseHandler<MLTaskResponse>() {
166167
@Override
167168
public void handleStreamResponse(StreamTransportResponse<MLTaskResponse> streamResponse) {
@@ -214,19 +215,23 @@ public MLTaskResponse read(StreamInput in) throws IOException {
214215
handler
215216
);
216217

217-
} catch (IOException e) {
218-
throw new MLException("Got an exception in flux.", e);
218+
return Mono.fromCompletionStage(future);
219+
} catch (Exception e) {
220+
log.error("Failed to parse or process request", e);
221+
return Mono.error(e);
219222
}
220-
221-
return Mono.fromCompletionStage(future);
222-
}).doOnNext(channel::sendChunk).onErrorComplete(ex -> {
223-
// Error handling
223+
}).doOnNext(channel::sendChunk).onErrorResume(ex -> {
224+
log.error("Error occurred", ex);
224225
try {
225-
channel.sendResponse(new BytesRestResponse(channel, (Exception) ex));
226-
return true;
227-
} catch (final IOException e) {
228-
throw new UncheckedIOException(e);
226+
String errorMessage = ex instanceof IOException
227+
? "Failed to parse request: " + ex.getMessage()
228+
: "Error processing request: " + ex.getMessage();
229+
HttpChunk errorChunk = createHttpChunk("data: {\"error\": \"" + errorMessage.replace("\"", "\\\"") + "\"}\n\n", true);
230+
channel.sendChunk(errorChunk);
231+
} catch (Exception e) {
232+
log.error("Failed to send error chunk", e);
229233
}
234+
return Mono.empty();
230235
}).subscribe();
231236
};
232237

@@ -402,6 +407,20 @@ private String extractTensorResult(MLTaskResponse response, String tensorName) {
402407
return Map.of();
403408
}
404409

410+
@VisibleForTesting
411+
BytesReference combineChunks(List<HttpChunk> chunks) {
412+
try {
413+
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
414+
for (HttpChunk chunk : chunks) {
415+
chunk.content().writeTo(buffer);
416+
}
417+
return BytesReference.fromByteBuffer(ByteBuffer.wrap(buffer.toByteArray()));
418+
} catch (IOException e) {
419+
log.error("Failed to combine chunks", e);
420+
throw new OpenSearchStatusException("Failed to combine request chunks", RestStatus.INTERNAL_SERVER_ERROR, e);
421+
}
422+
}
423+
405424
private HttpChunk createHttpChunk(String sseData, boolean isLast) {
406425
BytesReference bytesRef = BytesReference.fromByteBuffer(ByteBuffer.wrap(sseData.getBytes()));
407426
return new HttpChunk() {

plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteStreamActionTests.java

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
import org.opensearch.common.xcontent.XContentType;
3434
import org.opensearch.core.action.ActionListener;
3535
import org.opensearch.core.common.bytes.BytesArray;
36+
import org.opensearch.core.common.bytes.BytesReference;
3637
import org.opensearch.core.rest.RestStatus;
38+
import org.opensearch.http.HttpChunk;
3739
import org.opensearch.ml.common.FunctionName;
3840
import org.opensearch.ml.common.MLModel;
3941
import org.opensearch.ml.common.agent.LLMSpec;
@@ -302,4 +304,71 @@ public void testGetRequestAgentFrameworkDisabled() {
302304
when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false);
303305
assertThrows(IllegalStateException.class, () -> restAction.handleRequest(request, channel, client));
304306
}
307+
308+
@Test
309+
public void testCombineChunksWithSingleChunk() {
310+
String testContent = "{\"parameters\":{\"question\":\"test\"}}";
311+
BytesArray bytesArray = new BytesArray(testContent);
312+
313+
HttpChunk mockChunk = mock(HttpChunk.class);
314+
when(mockChunk.content()).thenReturn(bytesArray);
315+
316+
BytesReference result = restAction.combineChunks(List.of(mockChunk));
317+
318+
assertNotNull(result);
319+
assertEquals(testContent, result.utf8ToString());
320+
}
321+
322+
@Test
323+
public void testCombineChunksWithMultipleChunks() {
324+
String chunk1Content = "{\"parameters\":";
325+
String chunk2Content = "{\"question\":";
326+
String chunk3Content = "\"test\"}}";
327+
328+
BytesArray bytes1 = new BytesArray(chunk1Content);
329+
BytesArray bytes2 = new BytesArray(chunk2Content);
330+
BytesArray bytes3 = new BytesArray(chunk3Content);
331+
332+
HttpChunk mockChunk1 = mock(HttpChunk.class);
333+
HttpChunk mockChunk2 = mock(HttpChunk.class);
334+
HttpChunk mockChunk3 = mock(HttpChunk.class);
335+
336+
when(mockChunk1.content()).thenReturn(bytes1);
337+
when(mockChunk2.content()).thenReturn(bytes2);
338+
when(mockChunk3.content()).thenReturn(bytes3);
339+
340+
BytesReference result = restAction.combineChunks(List.of(mockChunk1, mockChunk2, mockChunk3));
341+
342+
assertNotNull(result);
343+
String expectedContent = chunk1Content + chunk2Content + chunk3Content;
344+
assertEquals(expectedContent, result.utf8ToString());
345+
}
346+
347+
@Test
348+
public void testCombineChunksWithEmptyList() {
349+
BytesReference result = restAction.combineChunks(List.of());
350+
351+
assertNotNull(result);
352+
assertEquals(0, result.length());
353+
}
354+
355+
@Test
356+
public void testCombineChunksWithLargeContent() {
357+
StringBuilder largeContent = new StringBuilder();
358+
for (int i = 0; i < 1000; i++) {
359+
largeContent.append("chunk").append(i).append(",");
360+
}
361+
String content = largeContent.toString();
362+
363+
BytesArray bytesArray = new BytesArray(content);
364+
365+
HttpChunk mockChunk = mock(HttpChunk.class);
366+
when(mockChunk.content()).thenReturn(bytesArray);
367+
368+
BytesReference result = restAction.combineChunks(List.of(mockChunk));
369+
370+
assertNotNull(result);
371+
assertEquals(content.length(), result.length());
372+
assertEquals(content, result.utf8ToString());
373+
}
305374
}

0 commit comments

Comments
 (0)