Skip to content

Commit 71f1fab

Browse files
GaganjunejaGagan Junejareta
authored
Add support for conditional Transient header propagation (opensearch-project#11490)
* Clear transient header from system context Signed-off-by: Gagan Juneja <[email protected]> * Clear transient header from system context Signed-off-by: Gagan Juneja <[email protected]> * Adds changelog Signed-off-by: Gagan Juneja <[email protected]> * Update CHANGELOG.md Co-authored-by: Andriy Redko <[email protected]> Signed-off-by: Gagan Juneja <[email protected]> * Adds unit tests Signed-off-by: Gagan Juneja <[email protected]> * Refactor code Signed-off-by: Gagan Juneja <[email protected]> * Refactor code Signed-off-by: Gagan Juneja <[email protected]> * Refactor code Signed-off-by: Gagan Juneja <[email protected]> * Supress warning Signed-off-by: Gagan Juneja <[email protected]> * Refactor code Signed-off-by: Gagan Juneja <[email protected]> --------- Signed-off-by: Gagan Juneja <[email protected]> Signed-off-by: Gagan Juneja <[email protected]> Co-authored-by: Gagan Juneja <[email protected]> Co-authored-by: Andriy Redko <[email protected]>
1 parent 3f5ec64 commit 71f1fab

File tree

9 files changed

+193
-28
lines changed

9 files changed

+193
-28
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
214214
- Fix template setting override for replication type ([#11417](https://github.com/opensearch-project/OpenSearch/pull/11417))
215215
- Fix Automatic addition of protocol broken in #11512 ([#11609](https://github.com/opensearch-project/OpenSearch/pull/11609))
216216
- Fix issue when calling Delete PIT endpoint and no PITs exist ([#11711](https://github.com/opensearch-project/OpenSearch/pull/11711))
217+
- Fix tracing context propagation for local transport instrumentation ([#11490](https://github.com/opensearch-project/OpenSearch/pull/11490))
217218

218219
### Security
219220

server/src/main/java/org/opensearch/common/util/concurrent/ThreadContext.java

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ public StoredContext stashContext() {
161161
);
162162
}
163163

164-
final Map<String, Object> transientHeaders = propagateTransients(context.transientHeaders);
164+
final Map<String, Object> transientHeaders = propagateTransients(context.transientHeaders, context.isSystemContext);
165165
if (!transientHeaders.isEmpty()) {
166166
threadContextStruct = threadContextStruct.putTransient(transientHeaders);
167167
}
@@ -182,7 +182,7 @@ public StoredContext stashContext() {
182182
public Writeable captureAsWriteable() {
183183
final ThreadContextStruct context = threadLocal.get();
184184
return out -> {
185-
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders);
185+
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders, context.isSystemContext);
186186
context.writeTo(out, defaultHeader, propagatedHeaders);
187187
};
188188
}
@@ -245,7 +245,7 @@ public StoredContext newStoredContext(boolean preserveResponseHeaders, Collectio
245245
final Map<String, Object> newTransientHeaders = new HashMap<>(originalContext.transientHeaders);
246246

247247
boolean transientHeadersModified = false;
248-
final Map<String, Object> transientHeaders = propagateTransients(originalContext.transientHeaders);
248+
final Map<String, Object> transientHeaders = propagateTransients(originalContext.transientHeaders, originalContext.isSystemContext);
249249
if (!transientHeaders.isEmpty()) {
250250
newTransientHeaders.putAll(transientHeaders);
251251
transientHeadersModified = true;
@@ -322,7 +322,7 @@ public Supplier<StoredContext> wrapRestorable(StoredContext storedContext) {
322322
@Override
323323
public void writeTo(StreamOutput out) throws IOException {
324324
final ThreadContextStruct context = threadLocal.get();
325-
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders);
325+
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders, context.isSystemContext);
326326
context.writeTo(out, defaultHeader, propagatedHeaders);
327327
}
328328

@@ -534,7 +534,7 @@ boolean isDefaultContext() {
534534
* by the system itself rather than by a user action.
535535
*/
536536
public void markAsSystemContext() {
537-
threadLocal.set(threadLocal.get().setSystemContext());
537+
threadLocal.set(threadLocal.get().setSystemContext(propagators));
538538
}
539539

540540
/**
@@ -573,15 +573,15 @@ public static Map<String, String> buildDefaultHeaders(Settings settings) {
573573
}
574574
}
575575

576-
private Map<String, Object> propagateTransients(Map<String, Object> source) {
576+
private Map<String, Object> propagateTransients(Map<String, Object> source, boolean isSystemContext) {
577577
final Map<String, Object> transients = new HashMap<>();
578-
propagators.forEach(p -> transients.putAll(p.transients(source)));
578+
propagators.forEach(p -> transients.putAll(p.transients(source, isSystemContext)));
579579
return transients;
580580
}
581581

582-
private Map<String, String> propagateHeaders(Map<String, Object> source) {
582+
private Map<String, String> propagateHeaders(Map<String, Object> source, boolean isSystemContext) {
583583
final Map<String, String> headers = new HashMap<>();
584-
propagators.forEach(p -> headers.putAll(p.headers(source)));
584+
propagators.forEach(p -> headers.putAll(p.headers(source, isSystemContext)));
585585
return headers;
586586
}
587587

@@ -603,11 +603,13 @@ private static final class ThreadContextStruct {
603603
// saving current warning headers' size not to recalculate the size with every new warning header
604604
private final long warningHeadersSize;
605605

606-
private ThreadContextStruct setSystemContext() {
606+
private ThreadContextStruct setSystemContext(final List<ThreadContextStatePropagator> propagators) {
607607
if (isSystemContext) {
608608
return this;
609609
}
610-
return new ThreadContextStruct(requestHeaders, responseHeaders, transientHeaders, persistentHeaders, true);
610+
final Map<String, Object> transients = new HashMap<>();
611+
propagators.forEach(p -> transients.putAll(p.transients(transientHeaders, true)));
612+
return new ThreadContextStruct(requestHeaders, responseHeaders, transients, persistentHeaders, true);
611613
}
612614

613615
private ThreadContextStruct(

server/src/main/java/org/opensearch/common/util/concurrent/ThreadContextStatePropagator.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,41 @@
2222
public interface ThreadContextStatePropagator {
2323
/**
2424
* Returns the list of transient headers that needs to be propagated from current context to new thread context.
25-
* @param source current context transient headers
25+
*
26+
* @param source current context transient headers
2627
* @return the list of transient headers that needs to be propagated from current context to new thread context
2728
*/
29+
@Deprecated(since = "2.12.0", forRemoval = true)
2830
Map<String, Object> transients(Map<String, Object> source);
2931

32+
/**
33+
* Returns the list of transient headers that needs to be propagated from current context to new thread context.
34+
*
35+
* @param source current context transient headers
36+
* @param isSystemContext if the propagation is for system context.
37+
* @return the list of transient headers that needs to be propagated from current context to new thread context
38+
*/
39+
default Map<String, Object> transients(Map<String, Object> source, boolean isSystemContext) {
40+
return transients(source);
41+
};
42+
3043
/**
3144
* Returns the list of request headers that needs to be propagated from current context to request.
32-
* @param source current context headers
45+
*
46+
* @param source current context headers
3347
* @return the list of request headers that needs to be propagated from current context to request
3448
*/
49+
@Deprecated(since = "2.12.0", forRemoval = true)
3550
Map<String, String> headers(Map<String, Object> source);
51+
52+
/**
53+
* Returns the list of request headers that needs to be propagated from current context to request.
54+
*
55+
* @param source current context headers
56+
* @param isSystemContext if the propagation is for system context.
57+
* @return the list of request headers that needs to be propagated from current context to request
58+
*/
59+
default Map<String, String> headers(Map<String, Object> source, boolean isSystemContext) {
60+
return headers(source);
61+
}
3662
}

server/src/main/java/org/opensearch/tasks/TaskThreadContextStatePropagator.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
* Propagates TASK_ID across thread contexts
2121
*/
2222
public class TaskThreadContextStatePropagator implements ThreadContextStatePropagator {
23+
2324
@Override
25+
@SuppressWarnings("removal")
2426
public Map<String, Object> transients(Map<String, Object> source) {
2527
final Map<String, Object> transients = new HashMap<>();
2628

@@ -32,7 +34,18 @@ public Map<String, Object> transients(Map<String, Object> source) {
3234
}
3335

3436
@Override
37+
public Map<String, Object> transients(Map<String, Object> source, boolean isSystemContext) {
38+
return transients(source);
39+
}
40+
41+
@Override
42+
@SuppressWarnings("removal")
3543
public Map<String, String> headers(Map<String, Object> source) {
3644
return Collections.emptyMap();
3745
}
46+
47+
@Override
48+
public Map<String, String> headers(Map<String, Object> source, boolean isSystemContext) {
49+
return headers(source);
50+
}
3851
}

server/src/main/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorage.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.opensearch.common.util.concurrent.ThreadContext;
1313
import org.opensearch.common.util.concurrent.ThreadContextStatePropagator;
1414

15+
import java.util.Collections;
1516
import java.util.HashMap;
1617
import java.util.Map;
1718
import java.util.Objects;
@@ -50,20 +51,29 @@ public void put(String key, Span span) {
5051
}
5152

5253
@Override
54+
@SuppressWarnings("removal")
5355
public Map<String, Object> transients(Map<String, Object> source) {
5456
final Map<String, Object> transients = new HashMap<>();
55-
5657
if (source.containsKey(CURRENT_SPAN)) {
5758
final SpanReference current = (SpanReference) source.get(CURRENT_SPAN);
5859
if (current != null) {
5960
transients.put(CURRENT_SPAN, new SpanReference(current.getSpan()));
6061
}
6162
}
62-
6363
return transients;
6464
}
6565

6666
@Override
67+
public Map<String, Object> transients(Map<String, Object> source, boolean isSystemContext) {
68+
if (isSystemContext == true) {
69+
return Collections.emptyMap();
70+
} else {
71+
return transients(source);
72+
}
73+
}
74+
75+
@Override
76+
@SuppressWarnings("removal")
6777
public Map<String, String> headers(Map<String, Object> source) {
6878
final Map<String, String> headers = new HashMap<>();
6979

@@ -77,6 +87,11 @@ public Map<String, String> headers(Map<String, Object> source) {
7787
return headers;
7888
}
7989

90+
@Override
91+
public Map<String, String> headers(Map<String, Object> source, boolean isSystemContext) {
92+
return headers(source);
93+
}
94+
8095
Span getCurrentSpan(String key) {
8196
SpanReference currentSpanRef = threadContext.getTransient(key);
8297
return (currentSpanRef == null) ? null : currentSpanRef.getSpan();

server/src/main/java/org/opensearch/transport/TransportService.java

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -868,19 +868,10 @@ public final <T extends TransportResponse> void sendRequest(
868868
final TransportRequestOptions options,
869869
final TransportResponseHandler<T> handler
870870
) {
871-
if (connection == localNodeConnection) {
872-
// See please https://github.com/opensearch-project/OpenSearch/issues/10291
873-
sendRequestAsync(connection, action, request, options, handler);
874-
} else {
875-
final Span span = tracer.startSpan(SpanBuilder.from(action, connection));
876-
try (SpanScope spanScope = tracer.withSpanInScope(span)) {
877-
TransportResponseHandler<T> traceableTransportResponseHandler = TraceableTransportResponseHandler.create(
878-
handler,
879-
span,
880-
tracer
881-
);
882-
sendRequestAsync(connection, action, request, options, traceableTransportResponseHandler);
883-
}
871+
final Span span = tracer.startSpan(SpanBuilder.from(action, connection));
872+
try (SpanScope spanScope = tracer.withSpanInScope(span)) {
873+
TransportResponseHandler<T> traceableTransportResponseHandler = TraceableTransportResponseHandler.create(handler, span, tracer);
874+
sendRequestAsync(connection, action, request, options, traceableTransportResponseHandler);
884875
}
885876
}
886877

server/src/test/java/org/opensearch/common/util/concurrent/ThreadContextTests.java

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
import java.util.Map;
4545
import java.util.function.Supplier;
4646

47+
import org.mockito.Mockito;
48+
4749
import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;
4850
import static org.hamcrest.Matchers.equalTo;
4951
import static org.hamcrest.Matchers.hasItem;
@@ -740,6 +742,71 @@ public void testMarkAsSystemContext() throws IOException {
740742
assertFalse(threadContext.isSystemContext());
741743
}
742744

745+
public void testSystemContextWithPropagator() {
746+
Settings build = Settings.builder().put("request.headers.default", "1").build();
747+
Map<String, Object> transientHeaderMap = Collections.singletonMap("test_transient_propagation_key", "test");
748+
Map<String, Object> transientHeaderTransformedMap = Collections.singletonMap("test_transient_propagation_key", "test");
749+
Map<String, Object> headerMap = Collections.singletonMap("test_transient_propagation_key", "test");
750+
Map<String, String> headerTransformedMap = Collections.singletonMap("test_transient_propagation_key", "test");
751+
ThreadContext threadContext = new ThreadContext(build);
752+
ThreadContextStatePropagator mockPropagator = Mockito.mock(ThreadContextStatePropagator.class);
753+
Mockito.when(mockPropagator.transients(transientHeaderMap, true)).thenReturn(Collections.emptyMap());
754+
Mockito.when(mockPropagator.transients(transientHeaderMap, false)).thenReturn(transientHeaderTransformedMap);
755+
756+
Mockito.when(mockPropagator.headers(headerMap, true)).thenReturn(headerTransformedMap);
757+
Mockito.when(mockPropagator.headers(headerMap, false)).thenReturn(headerTransformedMap);
758+
threadContext.registerThreadContextStatePropagator(mockPropagator);
759+
threadContext.putHeader("foo", "bar");
760+
threadContext.putTransient("test_transient_propagation_key", 1);
761+
assertEquals(Integer.valueOf(1), threadContext.getTransient("test_transient_propagation_key"));
762+
assertEquals("bar", threadContext.getHeader("foo"));
763+
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
764+
threadContext.markAsSystemContext();
765+
assertNull(threadContext.getHeader("foo"));
766+
assertNull(threadContext.getTransient("test_transient_propagation_key"));
767+
assertEquals("1", threadContext.getHeader("default"));
768+
}
769+
770+
assertEquals("bar", threadContext.getHeader("foo"));
771+
assertEquals(Integer.valueOf(1), threadContext.getTransient("test_transient_propagation_key"));
772+
assertEquals("1", threadContext.getHeader("default"));
773+
}
774+
775+
public void testSerializeSystemContext() throws IOException {
776+
Settings build = Settings.builder().put("request.headers.default", "1").build();
777+
Map<String, Object> transientHeaderMap = Collections.singletonMap("test_transient_propagation_key", "test");
778+
Map<String, Object> transientHeaderTransformedMap = Collections.singletonMap("test_transient_propagation_key", "test");
779+
Map<String, Object> headerMap = Collections.singletonMap("test_transient_propagation_key", "test");
780+
Map<String, String> headerTransformedMap = Collections.singletonMap("test_transient_propagation_key", "test");
781+
ThreadContext threadContext = new ThreadContext(build);
782+
ThreadContextStatePropagator mockPropagator = Mockito.mock(ThreadContextStatePropagator.class);
783+
Mockito.when(mockPropagator.transients(transientHeaderMap, true)).thenReturn(Collections.emptyMap());
784+
Mockito.when(mockPropagator.transients(transientHeaderMap, false)).thenReturn(transientHeaderTransformedMap);
785+
786+
Mockito.when(mockPropagator.headers(headerMap, true)).thenReturn(headerTransformedMap);
787+
Mockito.when(mockPropagator.headers(headerMap, false)).thenReturn(headerTransformedMap);
788+
threadContext.registerThreadContextStatePropagator(mockPropagator);
789+
threadContext.putHeader("foo", "bar");
790+
threadContext.putTransient("test_transient_propagation_key", "test");
791+
BytesStreamOutput out = new BytesStreamOutput();
792+
BytesStreamOutput outFromSystemContext = new BytesStreamOutput();
793+
threadContext.writeTo(out);
794+
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
795+
assertEquals("test", threadContext.getTransient("test_transient_propagation_key"));
796+
threadContext.markAsSystemContext();
797+
threadContext.writeTo(outFromSystemContext);
798+
assertNull(threadContext.getHeader("foo"));
799+
assertNull(threadContext.getTransient("test_transient_propagation_key"));
800+
threadContext.readHeaders(outFromSystemContext.bytes().streamInput());
801+
assertNull(threadContext.getHeader("test_transient_propagation_key"));
802+
}
803+
assertEquals("test", threadContext.getTransient("test_transient_propagation_key"));
804+
threadContext.readHeaders(out.bytes().streamInput());
805+
assertEquals("bar", threadContext.getHeader("foo"));
806+
assertEquals("test", threadContext.getHeader("test_transient_propagation_key"));
807+
assertEquals("1", threadContext.getHeader("default"));
808+
}
809+
743810
public void testPutHeaders() {
744811
Settings build = Settings.builder().put("request.headers.default", "1").build();
745812
ThreadContext threadContext = new ThreadContext(build);
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.tasks;
10+
11+
import org.opensearch.test.OpenSearchTestCase;
12+
13+
import java.util.HashMap;
14+
import java.util.Map;
15+
16+
import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;
17+
18+
public class TaskThreadContextStatePropagatorTests extends OpenSearchTestCase {
19+
private final TaskThreadContextStatePropagator taskThreadContextStatePropagator = new TaskThreadContextStatePropagator();
20+
21+
public void testTransient() {
22+
Map<String, Object> transientHeader = new HashMap<>();
23+
transientHeader.put(TASK_ID, "t_1");
24+
Map<String, Object> transientPropagatedHeader = taskThreadContextStatePropagator.transients(transientHeader, false);
25+
assertEquals("t_1", transientPropagatedHeader.get(TASK_ID));
26+
}
27+
28+
public void testTransientForSystemContext() {
29+
Map<String, Object> transientHeader = new HashMap<>();
30+
transientHeader.put(TASK_ID, "t_1");
31+
Map<String, Object> transientPropagatedHeader = taskThreadContextStatePropagator.transients(transientHeader, true);
32+
assertEquals("t_1", transientPropagatedHeader.get(TASK_ID));
33+
}
34+
}

server/src/test/java/org/opensearch/telemetry/tracing/ThreadContextBasedTracerContextStorageTests.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,4 +252,20 @@ public void run() {
252252
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
253253
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
254254
}
255+
256+
public void testSpanNotPropagatedToChildSystemThreadContext() {
257+
final Span span = tracer.startSpan(SpanCreationContext.internal().name("test"));
258+
259+
try (SpanScope scope = tracer.withSpanInScope(span)) {
260+
try (StoredContext ignored = threadContext.stashContext()) {
261+
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
262+
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(span));
263+
threadContext.markAsSystemContext();
264+
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
265+
}
266+
}
267+
268+
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
269+
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
270+
}
255271
}

0 commit comments

Comments
 (0)