diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/AsyncContext.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/AsyncContext.java index 62a522afc5..f1c5b495f4 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/AsyncContext.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/AsyncContext.java @@ -98,6 +98,10 @@ public static ContextMap context() { return provider().context(); } + public static CapturedContext capturedContext() { + return provider().captureContext(); + } + /** * Convenience method to put a new entry to the current context. * diff --git a/servicetalk-opentelemetry-http/src/test/java/io/servicetalk/opentelemetry/http/OpenTelemetryHttpRequestFilterTest.java b/servicetalk-opentelemetry-http/src/test/java/io/servicetalk/opentelemetry/http/OpenTelemetryHttpRequestFilterTest.java index fdc481e850..bb6612b0d6 100755 --- a/servicetalk-opentelemetry-http/src/test/java/io/servicetalk/opentelemetry/http/OpenTelemetryHttpRequestFilterTest.java +++ b/servicetalk-opentelemetry-http/src/test/java/io/servicetalk/opentelemetry/http/OpenTelemetryHttpRequestFilterTest.java @@ -16,11 +16,14 @@ package io.servicetalk.opentelemetry.http; +import io.servicetalk.client.api.TransportObserverConnectionFactoryFilter; import io.servicetalk.http.api.HttpClient; import io.servicetalk.http.api.HttpResponse; import io.servicetalk.http.api.HttpServerBuilder; import io.servicetalk.http.netty.HttpServers; import io.servicetalk.log4j2.mdc.utils.LoggerStringWriter; +import io.servicetalk.transport.api.ConnectionInfo; +import io.servicetalk.transport.api.ConnectionObserver; import io.servicetalk.transport.api.ServerContext; import io.opentelemetry.api.OpenTelemetry; @@ -32,6 +35,8 @@ import io.opentelemetry.context.propagation.ContextPropagators; import io.opentelemetry.sdk.testing.junit5.OpenTelemetryExtension; import io.opentelemetry.sdk.trace.data.SpanData; +import io.servicetalk.transport.api.TransportObserver; +import io.servicetalk.transport.netty.internal.NoopTransportObserver; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -41,7 +46,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.Nullable; import java.lang.invoke.MethodHandles; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicReference; import static io.opentelemetry.semconv.HttpAttributes.HTTP_REQUEST_METHOD; import static io.opentelemetry.semconv.NetworkAttributes.NETWORK_PROTOCOL_NAME; @@ -248,6 +257,110 @@ void testCaptureHeader() throws Exception { } } + @Test + void transportObserver() throws Exception { + final String requestUrl = "/"; + OpenTelemetry openTelemetry = otelTesting.getOpenTelemetry(); + BlockingQueue errors = new LinkedBlockingQueue<>(); + TransportObserver transportObserver = new TransportObserver() { + + final AtomicReference span = new AtomicReference<>(); + + private void checkSpan() { + Span current = Span.current(); + if (!current.equals(span.get())) { + errors.add(new AssertionError("Unexpected span: " + current + + " (expected " + span.get() + ").")); + } + } + + private void checkNoSpan() { + Span current = Span.current(); + if (!current.equals(Span.getInvalid())) { + errors.add(new AssertionError("Unexpected span: " + current + + " (expected " + span.get() + ").")); + } + } + + @Override + public ConnectionObserver onNewConnection(@Nullable Object localAddress, Object remoteAddress) { + if (!span.compareAndSet(null, Span.current())) { + errors.add(new AssertionError("onNewConnection called twice")); + } + return new ConnectionObserver() { + @Override + public void onDataRead(int size) { + checkNoSpan(); + } + + @Override + public void onDataWrite(int size) { + checkNoSpan(); + } + + @Override + public void onFlush() { + checkNoSpan(); + } + + @Override + public DataObserver connectionEstablished(ConnectionInfo info) { + // TODO: Getting null span info here. + checkSpan(); + return NoopTransportObserver.NoopDataObserver.INSTANCE; + } + + @Override + public MultiplexedObserver multiplexedConnectionEstablished(ConnectionInfo info) { + checkSpan(); + return NoopTransportObserver.NoopMultiplexedObserver.INSTANCE; + } + + @Override + public void connectionClosed(Throwable error) { + // TODO: We should have a test for when session establishment fails, we have the span. + checkNoSpan(); + } + + @Override + public void connectionClosed() { + // TODO: We should have a test for when session establishment fails, we have the span. + checkNoSpan(); + } + }; + } + }; + try (ServerContext context = buildServer(openTelemetry, false)) { + try (HttpClient client = forSingleAddress(serverHostAndPort(context)) + .appendClientFilter(new OpenTelemetryHttpRequestFilter(openTelemetry, "testClient")) + .appendClientFilter(new TestTracingClientLoggerFilter(TRACING_TEST_LOG_LINE_PREFIX)) + .appendConnectionFactoryFilter( + new TransportObserverConnectionFactoryFilter<>(transportObserver)).build()) { + HttpResponse response = client.request(client.get(requestUrl)).toFuture().get(); + TestSpanState serverSpanState = response.payloadBody(SPAN_STATE_SERIALIZER); + + verifyTraceIdPresentInLogs(loggerStringWriter.stableAccumulated(1000), requestUrl, + serverSpanState.getTraceId(), serverSpanState.getSpanId(), + TRACING_TEST_LOG_LINE_PREFIX); + assertThat(otelTesting.getSpans()).hasSize(1); + assertThat(otelTesting.getSpans()).extracting("traceId") + .containsExactly(serverSpanState.getTraceId()); + assertThat(otelTesting.getSpans()).extracting("spanId") + .containsAnyOf(serverSpanState.getSpanId()); + otelTesting.assertTraces() + .hasTracesSatisfyingExactly(ta -> ta.hasTraceId(serverSpanState.getTraceId())); + + otelTesting.assertTraces() + .hasTracesSatisfyingExactly(ta -> { + SpanData span = ta.getSpan(0); + }); + } + } + if (!errors.isEmpty()) { + throw errors.poll(); + } + } + private static ServerContext buildServer(OpenTelemetry givenOpentelemetry, boolean addFilter) throws Exception { HttpServerBuilder httpServerBuilder = HttpServers.forAddress(localAddress(0)); if (addFilter) { diff --git a/servicetalk-tcp-netty-internal/src/main/java/io/servicetalk/tcp/netty/internal/TcpConnector.java b/servicetalk-tcp-netty-internal/src/main/java/io/servicetalk/tcp/netty/internal/TcpConnector.java index 97776f84ba..63a53f1bc1 100644 --- a/servicetalk-tcp-netty-internal/src/main/java/io/servicetalk/tcp/netty/internal/TcpConnector.java +++ b/servicetalk-tcp-netty-internal/src/main/java/io/servicetalk/tcp/netty/internal/TcpConnector.java @@ -18,13 +18,19 @@ import io.servicetalk.client.api.RetryableConnectException; import io.servicetalk.concurrent.Cancellable; import io.servicetalk.concurrent.SingleSource.Subscriber; +import io.servicetalk.concurrent.api.AsyncContext; +import io.servicetalk.concurrent.api.CapturedContext; import io.servicetalk.concurrent.api.ListenableAsyncCloseable; +import io.servicetalk.concurrent.api.Scope; import io.servicetalk.concurrent.api.Single; import io.servicetalk.concurrent.api.internal.SubscribableSingle; import io.servicetalk.concurrent.internal.DelayedCancellable; +import io.servicetalk.transport.api.ConnectionInfo; import io.servicetalk.transport.api.ConnectionObserver; +import io.servicetalk.transport.api.DelegatingConnectionObserver; import io.servicetalk.transport.api.ExecutionContext; import io.servicetalk.transport.api.FileDescriptorSocketAddress; +import io.servicetalk.transport.api.SslConfig; import io.servicetalk.transport.api.TransportObserver; import io.servicetalk.transport.netty.internal.NettyConnection; @@ -45,6 +51,7 @@ import io.netty.util.concurrent.EventExecutor; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.ImmediateEventExecutor; +import io.servicetalk.transport.netty.internal.NoopTransportObserver; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -102,7 +109,7 @@ public static Single connect( @Override protected void handleSubscribe(final Subscriber subscriber) { final ConnectionObserver connectionObserver = - observer.onNewConnection(localAddress, resolvedRemoteAddress); + getConnectionObserver(localAddress, resolvedRemoteAddress, observer); final ConnectHandler connectHandler; try { connectHandler = new ConnectHandler<>(subscriber, connectionFactory, connectionObserver); @@ -122,6 +129,17 @@ protected void handleSubscribe(final Subscriber subscriber) { }; } + private static ConnectionObserver getConnectionObserver(@Nullable final SocketAddress localAddress, + final Object resolvedRemoteAddress, + TransportObserver observer) { + ConnectionObserver connectionObserver = + observer.onNewConnection(localAddress, resolvedRemoteAddress); + if (connectionObserver != NoopTransportObserver.NoopConnectionObserver.INSTANCE) { + connectionObserver = new ContextPreservingConnectionObserver(connectionObserver); + } + return connectionObserver; + } + private static Future connect0(@Nullable SocketAddress localAddress, Object resolvedRemoteAddress, ReadOnlyTcpClientConfig config, boolean autoRead, ExecutionContext executionContext, @@ -325,4 +343,95 @@ void unexpectedFailure(final Throwable cause) { } } } + + // + private static final class ContextPreservingConnectionObserver extends DelegatingConnectionObserver { + + private volatile boolean useContext = true; + private final CapturedContext capturedContext; + + ContextPreservingConnectionObserver(ConnectionObserver delegate) { + super(delegate); + this.capturedContext = AsyncContext.capturedContext(); + } + + @Override + public void connectionClosed(Throwable error) { + try (Scope unused = attachContext()) { + delegate().connectionClosed(error); + } finally { + unattachContext(); + } + } + + @Override + public void connectionClosed() { + try (Scope unused = attachContext()) { + delegate().connectionClosed(); + } finally { + unattachContext(); + } + } + + @Override + public SecurityHandshakeObserver onSecurityHandshake() { + try (Scope unused = attachContext()) { + return delegate().onSecurityHandshake(); + } + } + + @Override + public SecurityHandshakeObserver onSecurityHandshake(SslConfig sslConfig) { + try (Scope unused = attachContext()) { + return delegate().onSecurityHandshake(sslConfig); + } + } + + @Override + public ProxyConnectObserver onProxyConnect(Object connectMsg) { + try (Scope unused = attachContext()) { + return delegate().onProxyConnect(connectMsg); + } + } + + @Override + public void onTransportHandshakeComplete() { + try (Scope unused = attachContext()) { + delegate().onTransportHandshakeComplete(); + } + } + + @Override + public void onTransportHandshakeComplete(ConnectionInfo info) { + try (Scope unused = attachContext()) { + delegate().onTransportHandshakeComplete(info); + } + } + + @Override + public DataObserver connectionEstablished(ConnectionInfo info) { + try (Scope unused = attachContext()) { + return delegate().connectionEstablished(info); + } finally { + unattachContext(); + } + } + + @Override + public MultiplexedObserver multiplexedConnectionEstablished(ConnectionInfo info) { + try (Scope unused = attachContext()) { + return delegate().multiplexedConnectionEstablished(info); + } finally { + unattachContext(); + } + } + + private Scope attachContext() { + return useContext ? capturedContext.attachContext() : Scope.NOOP; + } + + private void unattachContext() { + useContext = false; + } + } } diff --git a/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/DelegatingConnectionObserver.java b/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/DelegatingConnectionObserver.java new file mode 100644 index 0000000000..a176d5c786 --- /dev/null +++ b/servicetalk-transport-api/src/main/java/io/servicetalk/transport/api/DelegatingConnectionObserver.java @@ -0,0 +1,81 @@ +package io.servicetalk.transport.api; + +import static java.util.Objects.requireNonNull; + +public class DelegatingConnectionObserver implements ConnectionObserver { + + public DelegatingConnectionObserver(ConnectionObserver delegate) { + this.delegate = requireNonNull(delegate, "delegate"); + } + + private final ConnectionObserver delegate; + + public ConnectionObserver delegate() { + return delegate; + } + + @Override + public void onDataRead(int size) { + delegate.onDataRead(size); + } + + @Override + public void onDataWrite(int size) { + delegate.onDataWrite(size); + } + + @Override + public void onFlush() { + delegate.onFlush(); + } + + @Override + public void onTransportHandshakeComplete() { + delegate.onTransportHandshakeComplete(); + } + + @Override + public void onTransportHandshakeComplete(ConnectionInfo info) { + delegate().onTransportHandshakeComplete(info); + } + + @Override + public ProxyConnectObserver onProxyConnect(Object connectMsg) { + return delegate().onProxyConnect(connectMsg); + } + + @Override + public SecurityHandshakeObserver onSecurityHandshake() { + return delegate().onSecurityHandshake(); + } + + @Override + public SecurityHandshakeObserver onSecurityHandshake(SslConfig sslConfig) { + return delegate().onSecurityHandshake(sslConfig); + } + + @Override + public DataObserver connectionEstablished(ConnectionInfo info) { + return delegate().connectionEstablished(info); + } + + @Override + public MultiplexedObserver multiplexedConnectionEstablished(ConnectionInfo info) { + return delegate().multiplexedConnectionEstablished(info); + } + + @Override + public void connectionWritabilityChanged(boolean isWritable) { + delegate().connectionWritabilityChanged(isWritable); + } + + @Override + public void connectionClosed(Throwable error) { + delegate().connectionClosed(error); + } + + @Override + public void connectionClosed() { + delegate().connectionClosed(); + } +}