From 56d001b2823a6de0cd55f68f50a0479e96ef469c Mon Sep 17 00:00:00 2001
From: Santhosh Gandhe <1909520+san81@users.noreply.github.com>
Date: Sun, 26 Jan 2025 22:39:36 -0800
Subject: [PATCH] otal flaky test fix

---
 .../otel-trace-raw-processor/build.gradle     |   1 +
 .../oteltrace/OTelTraceRawProcessorTest.java  | 110 ++++++++++--------
 2 files changed, 62 insertions(+), 49 deletions(-)

diff --git a/data-prepper-plugins/otel-trace-raw-processor/build.gradle b/data-prepper-plugins/otel-trace-raw-processor/build.gradle
index 2df90630d8..447b9e82b5 100644
--- a/data-prepper-plugins/otel-trace-raw-processor/build.gradle
+++ b/data-prepper-plugins/otel-trace-raw-processor/build.gradle
@@ -20,6 +20,7 @@ dependencies {
     implementation 'com.fasterxml.jackson.dataformat:jackson-dataformat-yaml'
     implementation libs.caffeine
     testImplementation 'org.assertj:assertj-core:3.25.3'
+    testImplementation project(':data-prepper-test-common')
 }
 
 jacocoTestCoverageVerification {
diff --git a/data-prepper-plugins/otel-trace-raw-processor/src/test/java/org/opensearch/dataprepper/plugins/processor/oteltrace/OTelTraceRawProcessorTest.java b/data-prepper-plugins/otel-trace-raw-processor/src/test/java/org/opensearch/dataprepper/plugins/processor/oteltrace/OTelTraceRawProcessorTest.java
index 623a2c9cae..47cc86925c 100644
--- a/data-prepper-plugins/otel-trace-raw-processor/src/test/java/org/opensearch/dataprepper/plugins/processor/oteltrace/OTelTraceRawProcessorTest.java
+++ b/data-prepper-plugins/otel-trace-raw-processor/src/test/java/org/opensearch/dataprepper/plugins/processor/oteltrace/OTelTraceRawProcessorTest.java
@@ -7,6 +7,8 @@
 
 import com.fasterxml.jackson.core.type.TypeReference;
 import com.fasterxml.jackson.databind.ObjectMapper;
+import com.github.benmanes.caffeine.cache.Cache;
+import com.github.benmanes.caffeine.cache.Caffeine;
 import org.assertj.core.api.Assertions;
 import org.hamcrest.MatcherAssert;
 import org.junit.jupiter.api.AfterEach;
@@ -22,6 +24,8 @@
 import org.opensearch.dataprepper.model.trace.JacksonSpan;
 import org.opensearch.dataprepper.model.trace.Span;
 import org.opensearch.dataprepper.model.trace.TraceGroupFields;
+import org.opensearch.dataprepper.plugins.processor.oteltrace.model.TraceGroup;
+import org.opensearch.dataprepper.test.helper.ReflectivelySetField;
 
 import java.io.IOException;
 import java.io.InputStream;
@@ -65,27 +69,67 @@ class OTelTraceRawProcessorTest {
     private static final String TEST_TRACE_GROUP_2_ROOT_SPAN_JSON_FILE = "trace-group-2-root-span.json";
     private static final String TEST_TRACE_GROUP_2_CHILD_SPAN_1_JSON_FILE = "trace-group-2-child-span-1.json";
     private static final String TEST_TRACE_GROUP_2_CHILD_SPAN_2_JSON_FILE = "trace-group-2-child-span-2.json";
-
+    public OTelTraceRawProcessor oTelTraceRawProcessor;
+    public ExecutorService executorService;
     private Span TEST_TRACE_GROUP_1_ROOT_SPAN;
     private Span TEST_TRACE_GROUP_1_CHILD_SPAN_1;
     private Span TEST_TRACE_GROUP_1_CHILD_SPAN_2;
     private Span TEST_TRACE_GROUP_2_ROOT_SPAN;
     private Span TEST_TRACE_GROUP_2_CHILD_SPAN_1;
     private Span TEST_TRACE_GROUP_2_CHILD_SPAN_2;
-
     private List<Record<Span>> TEST_ONE_FULL_TRACE_GROUP_RECORDS;
     private List<Record<Span>> TEST_ONE_TRACE_GROUP_MISSING_ROOT_RECORDS;
     private List<Record<Span>> TEST_TWO_FULL_TRACE_GROUP_RECORDS;
     private List<Record<Span>> TEST_TWO_TRACE_GROUP_INTERLEAVED_PART_1_RECORDS;
     private List<Record<Span>> TEST_TWO_TRACE_GROUP_INTERLEAVED_PART_2_RECORDS;
     private List<Record<Span>> TEST_TWO_TRACE_GROUP_MISSING_ROOT_RECORDS;
-
     private OtelTraceRawProcessorConfig config;
     private PluginMetrics pluginMetrics;
-    public OTelTraceRawProcessor oTelTraceRawProcessor;
-    public ExecutorService executorService;
     private PipelineDescription pipelineDescription;
 
+    private static Span buildSpanFromJsonFile(final String jsonFileName) {
+        JacksonSpan.Builder spanBuilder = JacksonSpan.builder();
+        try (final InputStream inputStream = Objects.requireNonNull(
+                OTelTraceRawProcessorTest.class.getClassLoader().getResourceAsStream(jsonFileName))) {
+            final Map<String, Object> spanMap = OBJECT_MAPPER.readValue(inputStream, new TypeReference<Map<String, Object>>() {
+            });
+            final String traceId = (String) spanMap.get("traceId");
+            final String spanId = (String) spanMap.get("spanId");
+            final String parentSpanId = (String) spanMap.get("parentSpanId");
+            final String traceState = (String) spanMap.get("traceState");
+            final String name = (String) spanMap.get("name");
+            final String kind = (String) spanMap.get("kind");
+            final Long durationInNanos = ((Number) spanMap.get("durationInNanos")).longValue();
+            final String startTime = (String) spanMap.get("startTime");
+            final String endTime = (String) spanMap.get("endTime");
+            spanBuilder = spanBuilder
+                    .withTraceId(traceId)
+                    .withSpanId(spanId)
+                    .withParentSpanId(parentSpanId)
+                    .withTraceState(traceState)
+                    .withName(name)
+                    .withKind(kind)
+                    .withDurationInNanos(durationInNanos)
+                    .withStartTime(startTime)
+                    .withEndTime(endTime)
+                    .withTraceGroup(null);
+            DefaultTraceGroupFields.Builder traceGroupFieldsBuilder = DefaultTraceGroupFields.builder();
+            if (parentSpanId.isEmpty()) {
+                final Integer statusCode = (Integer) ((Map<String, Object>) spanMap.get("traceGroupFields")).get("statusCode");
+                traceGroupFieldsBuilder = traceGroupFieldsBuilder
+                        .withStatusCode(statusCode)
+                        .withEndTime(endTime)
+                        .withDurationInNanos(durationInNanos);
+                final String traceGroup = (String) spanMap.get("traceGroup");
+                spanBuilder = spanBuilder.withTraceGroup(traceGroup);
+            }
+            spanBuilder.withTraceGroupFields(traceGroupFieldsBuilder.build());
+        } catch (IOException e) {
+            throw new RuntimeException(e);
+        }
+        return spanBuilder.build();
+    }
+
     @BeforeEach
     void setup() {
         TEST_TRACE_GROUP_1_ROOT_SPAN = buildSpanFromJsonFile(TEST_TRACE_GROUP_1_ROOT_SPAN_JSON_FILE);
@@ -255,13 +299,23 @@ void testMetricsOnSpanSet() {
             "1, 4",
             "2, 6"
     })
-    void traceGroupCacheMaxSize_provides_an_upper_bound(final long cacheMaxSize, final int expectedProcessedRecords) {
+    void traceGroupCacheMaxSize_provides_an_upper_bound(final long cacheMaxSize, final int expectedProcessedRecords)
+            throws NoSuchFieldException, IllegalAccessException {
         reset(config);
         when(config.getTraceFlushIntervalSeconds()).thenReturn(TEST_TRACE_FLUSH_INTERVAL);
         when(config.getTraceGroupCacheMaxSize()).thenReturn(cacheMaxSize);
         when(config.getTraceGroupCacheTimeToLive()).thenReturn(OtelTraceRawProcessorConfig.DEFAULT_TRACE_ID_TTL);
 
         oTelTraceRawProcessor = new OTelTraceRawProcessor(config, pipelineDescription, pluginMetrics);
+        // By default Caffeine evictions are async so not guaranteed to happen within the test execution
+        // by adding executor(Runnable::run) makes the operations synchronous
+        final Cache<String, TraceGroup> traceIdTraceGroupCache = Caffeine.newBuilder()
+                .maximumSize(cacheMaxSize)
+                .executor(Runnable::run)
+                .expireAfterWrite(OtelTraceRawProcessorConfig.DEFAULT_TRACE_ID_TTL.toMillis(), TimeUnit.MILLISECONDS)
+                .build();
+        ReflectivelySetField.setField(OTelTraceRawProcessor.class, oTelTraceRawProcessor,
+                "traceIdTraceGroupCache", traceIdTraceGroupCache);
 
         final Collection<Record<Span>> processedRecords = new ArrayList<>();
         processedRecords.addAll(oTelTraceRawProcessor.doExecute(TEST_TWO_TRACE_GROUP_INTERLEAVED_PART_1_RECORDS));
@@ -293,48 +347,6 @@ void traceGroupCacheTimeToLive_causes_trace_group_expiry(final long traceIdTtlMi
         MatcherAssert.assertThat(getMissingTraceGroupFieldsSpanCount(processedRecords), equalTo(0));
     }
 
-    private static Span buildSpanFromJsonFile(final String jsonFileName) {
-        JacksonSpan.Builder spanBuilder = JacksonSpan.builder();
-        try (final InputStream inputStream = Objects.requireNonNull(
-                OTelTraceRawProcessorTest.class.getClassLoader().getResourceAsStream(jsonFileName))){
-            final Map<String, Object> spanMap = OBJECT_MAPPER.readValue(inputStream, new TypeReference<Map<String, Object>>() {});
-            final String traceId = (String) spanMap.get("traceId");
-            final String spanId = (String) spanMap.get("spanId");
-            final String parentSpanId = (String) spanMap.get("parentSpanId");
-            final String traceState = (String) spanMap.get("traceState");
-            final String name = (String) spanMap.get("name");
-            final String kind = (String) spanMap.get("kind");
-            final Long durationInNanos = ((Number) spanMap.get("durationInNanos")).longValue();
-            final String startTime = (String) spanMap.get("startTime");
-            final String endTime = (String) spanMap.get("endTime");
-            spanBuilder = spanBuilder
-                    .withTraceId(traceId)
-                    .withSpanId(spanId)
-                    .withParentSpanId(parentSpanId)
-                    .withTraceState(traceState)
-                    .withName(name)
-                    .withKind(kind)
-                    .withDurationInNanos(durationInNanos)
-                    .withStartTime(startTime)
-                    .withEndTime(endTime)
-                    .withTraceGroup(null);
-            DefaultTraceGroupFields.Builder traceGroupFieldsBuilder = DefaultTraceGroupFields.builder();
-            if (parentSpanId.isEmpty()) {
-                final Integer statusCode = (Integer) ((Map<String, Object>) spanMap.get("traceGroupFields")).get("statusCode");
-                traceGroupFieldsBuilder = traceGroupFieldsBuilder
-                        .withStatusCode(statusCode)
-                        .withEndTime(endTime)
-                        .withDurationInNanos(durationInNanos);
-                final String traceGroup = (String) spanMap.get("traceGroup");
-                spanBuilder = spanBuilder.withTraceGroup(traceGroup);
-            }
-            spanBuilder.withTraceGroupFields(traceGroupFieldsBuilder.build());
-        } catch (IOException e) {
-            throw new RuntimeException(e);
-        }
-        return spanBuilder.build();
-    }
-
     private List<Future<Collection<Record<Span>>>> submitRecords(Collection<Record<Span>> records) {
         final List<Future<Collection<Record<Span>>>> futures = new ArrayList<>();
         futures.add(executorService.submit(() -> oTelTraceRawProcessor.doExecute(records)));
@@ -343,7 +355,7 @@ private List<Future<Collection<Record<Span>>>> submitRecords(Collection<Record<S
 
     private int getMissingTraceGroupFieldsSpanCount(final Collection<Record<Span>> records) {
         int count = 0;
-        for (Record<Span> record: records) {
+        for (Record<Span> record : records) {
             final Span span = record.getData();
             final String traceGroup = span.getTraceGroup();
             final TraceGroupFields traceGroupFields = span.getTraceGroupFields();