diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/ThresholdCheck.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/ThresholdCheck.java index 2055c3259d..7611a03f3b 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/ThresholdCheck.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/common/util/ThresholdCheck.java @@ -26,4 +26,9 @@ public static boolean checkThresholdExceed(final Buffer currentBuffer, final int currentBuffer.getSize() > maxBytes.getBytes(); } } + + //add threshold check for only timout + public static boolean checkTimeoutExceeded(final Buffer currentBuffer, final Duration maxCollectionDuration) { + return currentBuffer.getDuration().compareTo(maxCollectionDuration) > 0; + } } diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java index 8417c912f3..f09f598421 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSink.java @@ -38,6 +38,7 @@ import software.amazon.awssdk.services.lambda.model.InvokeResponse; import java.net.HttpURLConnection; +import java.nio.channels.AsynchronousFileChannel; import java.time.Duration; import java.util.Collection; import java.util.ArrayList; @@ -46,6 +47,7 @@ import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.ReentrantLock; import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY; import static org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler.isSuccess; @@ -85,6 +87,7 @@ public class LambdaSink extends AbstractSink> { // The partial buffer that may not yet have reached threshold. // Access must be synchronized private Buffer statefulBuffer; + private ReentrantLock reentrantLock; @DataPrepperPluginConstructor public LambdaSink(final PluginSetting pluginSetting, @@ -127,6 +130,7 @@ public LambdaSink(final PluginSetting pluginSetting, if (lambdaSinkConfig.getDlqPluginSetting() != null) { this.dlqPushHandler = new DlqPushHandler(pluginFactory, pluginSetting, lambdaSinkConfig.getDlq(), lambdaSinkConfig.getAwsAuthenticationOptions()); } + reentrantLock = new ReentrantLock(); } @@ -177,32 +181,46 @@ public void doOutput(final Collection> records) { LOG.warn("LambdaSink doOutput called before initialization"); return; } - if (records.isEmpty()) { - return; - } - - // We'll collect any "full" buffers in a local list, flush them at the end - List fullBuffers = new ArrayList<>(); - - // Add to the persistent buffer, check threshold - for (Record record : records) { - //statefulBuffer is either empty or partially filled(from previous run) - statefulBuffer.addRecord(record); - - if (isThresholdExceeded(statefulBuffer)) { - // This buffer is full - fullBuffers.add(statefulBuffer); - // Create new partial buffer + reentrantLock.lock(); + try { + //check if old buffer needs to be flushed + if (statefulBuffer.getEventCount() > 0 + && ThresholdCheck.checkTimeoutExceeded(statefulBuffer, maxCollectTime)) { + LOG.debug("Flushing partial buffer due to timeout of {}", maxCollectTime); + final Buffer bufferToFlush = statefulBuffer; + // create a new partial buffer. statefulBuffer = new InMemoryBufferSynchronized( lambdaSinkConfig.getBatchOptions().getKeyName(), outputCodecContext ); + flushBuffers(Collections.singletonList(bufferToFlush)); } - } - // Flush any full buffers - if (!fullBuffers.isEmpty()) { - flushBuffers(fullBuffers); + // We'll collect any "full" buffers in a local list, flush them at the end + List fullBuffers = new ArrayList<>(); + + // Add to the persistent buffer, check threshold + for (Record record : records) { + //statefulBuffer is either empty or partially filled(from previous run) + statefulBuffer.addRecord(record); + + if (isThresholdExceeded(statefulBuffer)) { + // This buffer is full + fullBuffers.add(statefulBuffer); + // Create new partial buffer + statefulBuffer = new InMemoryBufferSynchronized( + lambdaSinkConfig.getBatchOptions().getKeyName(), + outputCodecContext + ); + } + } + + // Flush any full buffers + if (!fullBuffers.isEmpty()) { + flushBuffers(fullBuffers); + } + } finally { + reentrantLock.unlock(); } } @@ -263,7 +281,7 @@ private void releaseEventHandles(Collection> records, boolean succ } } - private void flushBuffers(final List buffersToFlush) { + void flushBuffers(final List buffersToFlush) { Map> bufferToFutureMap; diff --git a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkTest.java b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkTest.java index fc4fc29a7c..0701b8f710 100644 --- a/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkTest.java +++ b/data-prepper-plugins/aws-lambda/src/test/java/org/opensearch/dataprepper/plugins/lambda/sink/LambdaSinkTest.java @@ -15,11 +15,13 @@ import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.plugin.PluginFactory; import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.model.sink.OutputCodecContext; import org.opensearch.dataprepper.model.sink.SinkContext; import org.opensearch.dataprepper.model.types.ByteCount; import org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer; import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBuffer; +import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBufferSynchronized; import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions; import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions; @@ -34,22 +36,32 @@ import java.lang.reflect.Field; import java.time.Duration; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyDouble; import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -377,6 +389,80 @@ public void testHandleFailure_WithoutDlq() throws Exception { verify(dlqPushHandler, never()).perform(anyList()); } + @Test + void testFlushDueToTimeoutWhenSecondCallHasNoRecords() throws Exception { + // Set the maxCollectTime to a very short duration so that the buffer will be considered timed-out. + setPrivateField(lambdaSink, "maxCollectTime", Duration.ofMillis(1)); + + // Create a spy on lambdaSink to intercept flushBuffers() calls. + LambdaSink spySink = spy(lambdaSink); + + // Use an AtomicBoolean flag to record whether flushBuffers() is called. + AtomicBoolean flushCalled = new AtomicBoolean(false); + doAnswer(invocation -> { + flushCalled.set(true); + return null; + }).when(spySink).flushBuffers(anyList()); + + // First call: add one record so that the buffer is non-empty. + List> records = Collections.singletonList(new Record<>(mock(Event.class))); + spySink.doOutput(records); + + // Wait briefly to allow the buffer's duration to exceed maxCollectTime. + Thread.sleep(10); + + // Second call: pass an empty collection; this should trigger the timeout flush. + spySink.doOutput(Collections.emptyList()); + + // Verify that flushBuffers() was called due to the timeout. + assertTrue(flushCalled.get(), "Expected flushBuffers() to be called on the second call due to timeout."); + } + + @Test + void testConcurrentDoOutputWithMultipleThreads() throws Exception { + // Set maxCollectTime to a very short duration to force the timeout flush. + setPrivateField(lambdaSink, "maxCollectTime", Duration.ofMillis(1)); + + // Create a spy of the sink so we can intercept flushBuffers() calls. + LambdaSink spySink = spy(lambdaSink); + + // Use an AtomicInteger to count how many times flushBuffers() is called. + AtomicInteger flushCount = new AtomicInteger(0); + doAnswer(invocation -> { + flushCount.incrementAndGet(); + return null; + }).when(spySink).flushBuffers(anyList()); + + // Create a thread pool with multiple threads. + int threadCount = 20; + ExecutorService executor = Executors.newFixedThreadPool(threadCount); + int iterations = 1000; + List> futures = new ArrayList<>(); + + // Each task calls doOutput() with one record. + for (int i = 0; i < iterations; i++) { + futures.add(executor.submit(() -> { + spySink.doOutput(Collections.singletonList(new Record<>(mock(Event.class)))); + })); + } + + // Wait for all tasks to complete. + for (Future future : futures) { + future.get(); + } + + // Additionally, call doOutput() with an empty collection to trigger any pending timeout flush. + spySink.doOutput(Collections.emptyList()); + + executor.shutdown(); + + // Assert that flushBuffers() was called at least once (or more) under concurrent load. + // With a short timeout, even if each call adds a record, eventually the partial buffer becomes "old" + // and flushes. If the locking is working correctly, flushCount should be > 0. + assertTrue(flushCount.get() > 0, "Expected at least one flush due to timeout under concurrent calls, but flushCount is " + flushCount.get()); + } + + // Utility to set private fields private static void setPrivateField(Object target, String fieldName, Object value) {