|
17 | 17 | import static org.opensearch.ml.utils.RestActionUtils.isAsync; |
18 | 18 | import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; |
19 | 19 |
|
| 20 | +import java.io.ByteArrayOutputStream; |
20 | 21 | import java.io.IOException; |
21 | | -import java.io.UncheckedIOException; |
22 | 22 | import java.nio.ByteBuffer; |
23 | 23 | import java.util.LinkedHashMap; |
24 | 24 | import java.util.List; |
|
48 | 48 | import org.opensearch.ml.common.MLModel; |
49 | 49 | import org.opensearch.ml.common.agent.MLAgent; |
50 | 50 | import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; |
51 | | -import org.opensearch.ml.common.exception.MLException; |
52 | 51 | import org.opensearch.ml.common.input.Input; |
53 | 52 | import org.opensearch.ml.common.input.MLInput; |
54 | 53 | import org.opensearch.ml.common.input.execute.agent.AgentMLInput; |
@@ -158,10 +157,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client |
158 | 157 | ); |
159 | 158 | channel.prepareResponse(RestStatus.OK, headers); |
160 | 159 |
|
161 | | - Flux.from(channel).ofType(HttpChunk.class).concatMap(chunk -> { |
162 | | - final CompletableFuture<HttpChunk> future = new CompletableFuture<>(); |
| 160 | + Flux.from(channel).ofType(HttpChunk.class).collectList().flatMap(chunks -> { |
163 | 161 | try { |
164 | | - MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, chunk.content()); |
| 162 | + BytesReference completeContent = combineChunks(chunks); |
| 163 | + MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, completeContent); |
| 164 | + |
| 165 | + final CompletableFuture<HttpChunk> future = new CompletableFuture<>(); |
165 | 166 | StreamTransportResponseHandler<MLTaskResponse> handler = new StreamTransportResponseHandler<MLTaskResponse>() { |
166 | 167 | @Override |
167 | 168 | public void handleStreamResponse(StreamTransportResponse<MLTaskResponse> streamResponse) { |
@@ -214,19 +215,23 @@ public MLTaskResponse read(StreamInput in) throws IOException { |
214 | 215 | handler |
215 | 216 | ); |
216 | 217 |
|
217 | | - } catch (IOException e) { |
218 | | - throw new MLException("Got an exception in flux.", e); |
| 218 | + return Mono.fromCompletionStage(future); |
| 219 | + } catch (Exception e) { |
| 220 | + log.error("Failed to parse or process request", e); |
| 221 | + return Mono.error(e); |
219 | 222 | } |
220 | | - |
221 | | - return Mono.fromCompletionStage(future); |
222 | | - }).doOnNext(channel::sendChunk).onErrorComplete(ex -> { |
223 | | - // Error handling |
| 223 | + }).doOnNext(channel::sendChunk).onErrorResume(ex -> { |
| 224 | + log.error("Error occurred", ex); |
224 | 225 | try { |
225 | | - channel.sendResponse(new BytesRestResponse(channel, (Exception) ex)); |
226 | | - return true; |
227 | | - } catch (final IOException e) { |
228 | | - throw new UncheckedIOException(e); |
| 226 | + String errorMessage = ex instanceof IOException |
| 227 | + ? "Failed to parse request: " + ex.getMessage() |
| 228 | + : "Error processing request: " + ex.getMessage(); |
| 229 | + HttpChunk errorChunk = createHttpChunk("data: {\"error\": \"" + errorMessage.replace("\"", "\\\"") + "\"}\n\n", true); |
| 230 | + channel.sendChunk(errorChunk); |
| 231 | + } catch (Exception e) { |
| 232 | + log.error("Failed to send error chunk", e); |
229 | 233 | } |
| 234 | + return Mono.empty(); |
230 | 235 | }).subscribe(); |
231 | 236 | }; |
232 | 237 |
|
@@ -402,6 +407,20 @@ private String extractTensorResult(MLTaskResponse response, String tensorName) { |
402 | 407 | return Map.of(); |
403 | 408 | } |
404 | 409 |
|
| 410 | + @VisibleForTesting |
| 411 | + BytesReference combineChunks(List<HttpChunk> chunks) { |
| 412 | + try { |
| 413 | + ByteArrayOutputStream buffer = new ByteArrayOutputStream(); |
| 414 | + for (HttpChunk chunk : chunks) { |
| 415 | + chunk.content().writeTo(buffer); |
| 416 | + } |
| 417 | + return BytesReference.fromByteBuffer(ByteBuffer.wrap(buffer.toByteArray())); |
| 418 | + } catch (IOException e) { |
| 419 | + log.error("Failed to combine chunks", e); |
| 420 | + throw new OpenSearchStatusException("Failed to combine request chunks", RestStatus.INTERNAL_SERVER_ERROR, e); |
| 421 | + } |
| 422 | + } |
| 423 | + |
405 | 424 | private HttpChunk createHttpChunk(String sseData, boolean isLast) { |
406 | 425 | BytesReference bytesRef = BytesReference.fromByteBuffer(ByteBuffer.wrap(sseData.getBytes())); |
407 | 426 | return new HttpChunk() { |
|
0 commit comments