diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c7bc3322e7f9..2b240ea0dec59 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Fix SearchPhaseExecutionException to properly initCause ([#20320](https://github.com/opensearch-project/OpenSearch/pull/20320)) - Fix `cluster.remote..server_name` setting no populating SNI ([#20321](https://github.com/opensearch-project/OpenSearch/pull/20321)) - Fix X-Opaque-Id header propagation (along with other response headers) for streaming Reactor Netty 4 transport ([#20371](https://github.com/opensearch-project/OpenSearch/pull/20371)) +- Fix Thread context is not well restored for streaming ([#20361](https://github.com/opensearch-project/OpenSearch/pull/20361)) ### Dependencies - Bump `com.google.auth:google-auth-library-oauth2-http` from 1.38.0 to 1.41.0 ([#20183](https://github.com/opensearch-project/OpenSearch/pull/20183)) diff --git a/plugins/transport-reactor-netty4/src/internalClusterTest/java/org/opensearch/http/reactor/netty4/ReactorNetty4ThreadContextRestorationIT.java b/plugins/transport-reactor-netty4/src/internalClusterTest/java/org/opensearch/http/reactor/netty4/ReactorNetty4ThreadContextRestorationIT.java new file mode 100644 index 0000000000000..a47e5f05cd88d --- /dev/null +++ b/plugins/transport-reactor-netty4/src/internalClusterTest/java/org/opensearch/http/reactor/netty4/ReactorNetty4ThreadContextRestorationIT.java @@ -0,0 +1,241 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.http.reactor.netty4; + +import io.netty.util.NettyRuntime; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchReactorNetty4IntegTestCase; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.common.collect.Tuple; +import org.opensearch.common.lease.Releasable; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.FeatureFlagSettings; +import org.opensearch.common.settings.IndexScopedSettings; +import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.settings.SettingsFilter; +import org.opensearch.common.util.FeatureFlags; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.transport.TransportAddress; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.http.HttpChunk; +import org.opensearch.http.HttpServerTransport; +import org.opensearch.plugins.ActionPlugin; +import org.opensearch.plugins.Plugin; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestController; +import org.opensearch.rest.RestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.StreamingRestChannel; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchIntegTestCase; +import org.opensearch.transport.Netty4ModulePlugin; +import org.opensearch.transport.client.node.NodeClient; +import org.opensearch.transport.reactor.ReactorNetty4Plugin; +import org.junit.Assert; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; + +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCounted; +import org.opensearch.transport.reactor.netty4.Netty4Utils; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import static org.opensearch.rest.RestRequest.Method.POST; + +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, supportsDedicatedMasters = false, numDataNodes = 1) +public class ReactorNetty4ThreadContextRestorationIT extends OpenSearchReactorNetty4IntegTestCase { + + @Override + protected boolean addMockHttpTransport() { + return false; // enable http + } + + protected Settings featureFlagSettings() { + Settings.Builder featureSettings = Settings.builder(); + for (Setting builtInFlag : FeatureFlagSettings.BUILT_IN_FEATURE_FLAGS) { + featureSettings.put(builtInFlag.getKey(), builtInFlag.getDefaultRaw(Settings.EMPTY)); + } + // disable Telemetry setting + featureSettings.put(FeatureFlags.TELEMETRY_SETTING.getKey(), false); + return featureSettings.build(); + } + + @Override + protected Collection> nodePlugins() { + return List.of(ReactorNetty4Plugin.class, Netty4ModulePlugin.class, MockStreamingPlugin.class); + } + + public static class MockStreamingPlugin extends Plugin implements ActionPlugin { + @Override + public List getRestHandlers( + Settings settings, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster + ) { + return List.of(new MockStreamingRestHandler()); + } + } + + public static class MockStreamingRestHandler extends BaseRestHandler { + @Override + public String getName() { + return "mock_ml_streaming_handler"; + } + + private final Logger logger = LogManager.getLogger(MockStreamingRestHandler.class); + + @Override + public List routes() { + return List.of(new Route(POST, "/_plugins/_ml/models/_stream")); + } + + @Override + public boolean supportsStreaming() { + return true; + } + + @Override + public boolean supportsContentStream() { + return true; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + return channel -> { + if (channel instanceof StreamingRestChannel) { + StreamingRestChannel streamingChannel = (StreamingRestChannel) channel; + + // Mimic ML streaming behavior - store thread context + final ThreadContext.StoredContext storedContext = client.threadPool().getThreadContext().newStoredContext(true); + + // Set streaming headers like ML action + Map> headers = Map.of( + "Content-Type", + List.of("text/event-stream"), + "Cache-Control", + List.of("no-cache"), + "Connection", + List.of("keep-alive") + ); + streamingChannel.prepareResponse(RestStatus.OK, headers); + + // Use Flux pattern like ML streaming action + Flux.from(streamingChannel).ofType(HttpChunk.class).collectList().flatMap(chunks -> { + try (ThreadContext.StoredContext context = storedContext) { + context.restore(); + + String opaqueId = request.header(Task.X_OPAQUE_ID); + streamingChannel.sendChunk( + createHttpChunk("data: {\"status\":\"streaming\",\"opaque_id\":\"" + opaqueId + "\"}\n\n", false) + ); + + final CompletableFuture future = new CompletableFuture<>(); + + Flux.just( + createHttpChunk("data: {\"content\":\"test chunk 1\"}\n\n", false), + createHttpChunk("data: {\"content\":\"test chunk 2\"}\n\n", false), + createHttpChunk("data: {\"content\":\"final chunk\",\"is_last\":true}\n\n", true) + ) + .delayElements(Duration.ofMillis(100)) + .doOnNext(streamingChannel::sendChunk) + .doOnComplete(() -> future.complete(createHttpChunk("", true))) + .doOnError(future::completeExceptionally) + .subscribe(); // Simulate streaming delay + + return Mono.fromCompletionStage(future); + } catch (Exception e) { + return Mono.error(e); + } + }) + .doOnNext(streamingChannel::sendChunk) + .doFinally(signalType -> storedContext.close()) // Restore context when done + .onErrorResume(ex -> { + try { + HttpChunk errorChunk = createHttpChunk("data: {\"error\":\"" + ex.getMessage() + "\"}\n\n", true); + streamingChannel.sendChunk(errorChunk); + } catch (Exception e) { + // Log error + } + return Mono.empty(); + }) + .subscribe(); + } + }; + } + + private HttpChunk createHttpChunk(String sseData, boolean isLast) { + BytesReference bytesRef = BytesReference.fromByteBuffer(ByteBuffer.wrap(sseData.getBytes())); + return new HttpChunk() { + @Override + public void close() { + if (bytesRef instanceof Releasable) { + ((Releasable) bytesRef).close(); + } + } + + @Override + public boolean isLast() { + return isLast; + } + + @Override + public BytesReference content() { + return bytesRef; + } + }; + } + } + + public void testThreadContextRestorationWithDuplicateOpaqueId() throws Exception { + ensureGreen(); + + HttpServerTransport httpServerTransport = internalCluster().getInstance(HttpServerTransport.class); + TransportAddress[] boundAddresses = httpServerTransport.boundAddress().boundAddresses(); + TransportAddress transportAddress = randomFrom(boundAddresses); + + List> requests = new ArrayList<>(); + int cores = NettyRuntime.availableProcessors(); + for (int i = 0; i < cores + 1; i++) { + requests.add(Tuple.tuple("/_plugins/_ml/models/_stream", "dummy request body")); + } + int i = 0; + while (i++ < requests.size()) { + try (ReactorHttpClient nettyHttpClient = ReactorHttpClient.create(false)) { + Collection singleResponse = nettyHttpClient.post(transportAddress.address(), requests.subList(0, 1)); + try { + Assert.assertEquals(1, singleResponse.size()); + FullHttpResponse response = singleResponse.iterator().next(); + String responseBody = response.content().toString(CharsetUtil.UTF_8); + Assert.assertEquals(200, response.status().code()); + Assert.assertFalse(responseBody.contains("value for key [X-Opaque-Id] already present")); + } finally { + singleResponse.forEach(ReferenceCounted::release); + } + } + } + } +} diff --git a/plugins/transport-reactor-netty4/src/main/java/org/opensearch/http/reactor/netty4/ReactorNetty4HttpServerTransport.java b/plugins/transport-reactor-netty4/src/main/java/org/opensearch/http/reactor/netty4/ReactorNetty4HttpServerTransport.java index 78e5edf48cc57..284c460431c35 100644 --- a/plugins/transport-reactor-netty4/src/main/java/org/opensearch/http/reactor/netty4/ReactorNetty4HttpServerTransport.java +++ b/plugins/transport-reactor-netty4/src/main/java/org/opensearch/http/reactor/netty4/ReactorNetty4HttpServerTransport.java @@ -17,6 +17,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.BigArrays; import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.util.io.IOUtils; import org.opensearch.common.util.net.NetUtils; import org.opensearch.core.common.unit.ByteSizeUnit; @@ -381,6 +382,8 @@ protected Publisher incomingRequest(HttpServerRequest request, HttpServerR request.params() ); if (dispatchHandlerOpt.map(RestHandler::supportsStreaming).orElse(false)) { + final ThreadContext.StoredContext storedContext = threadPool.getThreadContext().newStoredContext(false); + final ReactorNetty4StreamingRequestConsumer consumer = new ReactorNetty4StreamingRequestConsumer<>( this, request, @@ -392,7 +395,9 @@ protected Publisher incomingRequest(HttpServerRequest request, HttpServerR .subscribe(consumer, error -> {}, () -> consumer.accept(DefaultLastHttpContent.EMPTY_LAST_CONTENT)); incomingStream(new ReactorNetty4HttpRequest(request), consumer.httpChannel()); - return response.sendObject(consumer); + + // restore ThreadContext when streaming is finished + return response.sendObject(consumer).then(Mono.fromRunnable(storedContext::close)); } else { final ReactorNetty4NonStreamingRequestConsumer consumer = new ReactorNetty4NonStreamingRequestConsumer<>( this, diff --git a/server/src/main/java/org/opensearch/rest/RestController.java b/server/src/main/java/org/opensearch/rest/RestController.java index 5b0e45bdc2d27..b85d3b031f7d0 100644 --- a/server/src/main/java/org/opensearch/rest/RestController.java +++ b/server/src/main/java/org/opensearch/rest/RestController.java @@ -54,6 +54,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.http.HttpChunk; import org.opensearch.http.HttpServerTransport; +import org.opensearch.telemetry.tracing.channels.TraceableRestChannel; import org.opensearch.transport.client.node.NodeClient; import org.opensearch.usage.UsageService; @@ -345,6 +346,10 @@ private void dispatchRequest(RestRequest request, RestChannel channel, RestHandl } if (handler.supportsStreaming()) { + // for the case TraceableRestChannel wrapped steamingRestChannel, we need to unwrap it + if (channel instanceof TraceableRestChannel traceableRestChannel) { + channel = traceableRestChannel.unwrap(); + } // The handler may support streaming but not the engine, in this case we fail with the bad request if (channel instanceof StreamingRestChannel streamingRestChannel) { responseChannel = new StreamHandlingHttpChannel(streamingRestChannel, circuitBreakerService, contentLength); diff --git a/server/src/main/java/org/opensearch/telemetry/tracing/channels/TraceableRestChannel.java b/server/src/main/java/org/opensearch/telemetry/tracing/channels/TraceableRestChannel.java index e9bc230fb24b8..a0032a6960a3b 100644 --- a/server/src/main/java/org/opensearch/telemetry/tracing/channels/TraceableRestChannel.java +++ b/server/src/main/java/org/opensearch/telemetry/tracing/channels/TraceableRestChannel.java @@ -59,6 +59,14 @@ public static RestChannel create(RestChannel delegate, Span span, Tracer tracer) } } + /** + * Returns delegate. + * @return delegate + */ + public RestChannel unwrap() { + return delegate; + } + @Override public XContentBuilder newBuilder() throws IOException { return delegate.newBuilder();