Skip to content

Commit

Permalink
Address timeout threshold comment
Browse files Browse the repository at this point in the history
Signed-off-by: Srikanth Govindarajan <[email protected]>
  • Loading branch information
srikanthjg committed Feb 1, 2025
1 parent 3d9fb37 commit debad91
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,9 @@ public static boolean checkThresholdExceed(final Buffer currentBuffer, final int
currentBuffer.getSize() > maxBytes.getBytes();
}
}

//add threshold check for only timout
public static boolean checkTimeoutExceeded(final Buffer currentBuffer, final Duration maxCollectionDuration) {
return currentBuffer.getDuration().compareTo(maxCollectionDuration) > 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import software.amazon.awssdk.services.lambda.model.InvokeResponse;

import java.net.HttpURLConnection;
import java.nio.channels.AsynchronousFileChannel;
import java.time.Duration;
import java.util.Collection;
import java.util.ArrayList;
Expand All @@ -46,6 +47,7 @@
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantLock;

import static org.opensearch.dataprepper.logging.DataPrepperMarkers.NOISY;
import static org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler.isSuccess;
Expand Down Expand Up @@ -85,6 +87,7 @@ public class LambdaSink extends AbstractSink<Record<Event>> {
// The partial buffer that may not yet have reached threshold.
// Access must be synchronized
private Buffer statefulBuffer;
private ReentrantLock reentrantLock;

@DataPrepperPluginConstructor
public LambdaSink(final PluginSetting pluginSetting,
Expand Down Expand Up @@ -127,6 +130,7 @@ public LambdaSink(final PluginSetting pluginSetting,
if (lambdaSinkConfig.getDlqPluginSetting() != null) {
this.dlqPushHandler = new DlqPushHandler(pluginFactory, pluginSetting, lambdaSinkConfig.getDlq(), lambdaSinkConfig.getAwsAuthenticationOptions());
}
reentrantLock = new ReentrantLock();

}

Expand Down Expand Up @@ -177,32 +181,46 @@ public void doOutput(final Collection<Record<Event>> records) {
LOG.warn("LambdaSink doOutput called before initialization");
return;
}
if (records.isEmpty()) {
return;
}

// We'll collect any "full" buffers in a local list, flush them at the end
List<Buffer> fullBuffers = new ArrayList<>();

// Add to the persistent buffer, check threshold
for (Record<Event> record : records) {
//statefulBuffer is either empty or partially filled(from previous run)
statefulBuffer.addRecord(record);

if (isThresholdExceeded(statefulBuffer)) {
// This buffer is full
fullBuffers.add(statefulBuffer);
// Create new partial buffer
reentrantLock.lock();
try {
//check if old buffer needs to be flushed
if (statefulBuffer.getEventCount() > 0
&& ThresholdCheck.checkTimeoutExceeded(statefulBuffer, maxCollectTime)) {
LOG.debug("Flushing partial buffer due to timeout of {}", maxCollectTime);
final Buffer bufferToFlush = statefulBuffer;
// create a new partial buffer.
statefulBuffer = new InMemoryBufferSynchronized(
lambdaSinkConfig.getBatchOptions().getKeyName(),
outputCodecContext
);
flushBuffers(Collections.singletonList(bufferToFlush));
}
}

// Flush any full buffers
if (!fullBuffers.isEmpty()) {
flushBuffers(fullBuffers);
// We'll collect any "full" buffers in a local list, flush them at the end
List<Buffer> fullBuffers = new ArrayList<>();

// Add to the persistent buffer, check threshold
for (Record<Event> record : records) {
//statefulBuffer is either empty or partially filled(from previous run)
statefulBuffer.addRecord(record);

if (isThresholdExceeded(statefulBuffer)) {
// This buffer is full
fullBuffers.add(statefulBuffer);
// Create new partial buffer
statefulBuffer = new InMemoryBufferSynchronized(
lambdaSinkConfig.getBatchOptions().getKeyName(),
outputCodecContext
);
}
}

// Flush any full buffers
if (!fullBuffers.isEmpty()) {
flushBuffers(fullBuffers);
}
} finally {
reentrantLock.unlock();
}
}

Expand Down Expand Up @@ -263,7 +281,7 @@ private void releaseEventHandles(Collection<Record<Event>> records, boolean succ
}
}

private void flushBuffers(final List<Buffer> buffersToFlush) {
void flushBuffers(final List<Buffer> buffersToFlush) {


Map<Buffer, CompletableFuture<InvokeResponse>> bufferToFutureMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.plugin.PluginFactory;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.model.sink.OutputCodecContext;
import org.opensearch.dataprepper.model.sink.SinkContext;
import org.opensearch.dataprepper.model.types.ByteCount;
import org.opensearch.dataprepper.plugins.lambda.common.LambdaCommonHandler;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBuffer;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.InMemoryBufferSynchronized;
import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions;
Expand All @@ -34,22 +36,32 @@

import java.lang.reflect.Field;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;


import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyDouble;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -377,6 +389,80 @@ public void testHandleFailure_WithoutDlq() throws Exception {
verify(dlqPushHandler, never()).perform(anyList());
}

@Test
void testFlushDueToTimeoutWhenSecondCallHasNoRecords() throws Exception {
// Set the maxCollectTime to a very short duration so that the buffer will be considered timed-out.
setPrivateField(lambdaSink, "maxCollectTime", Duration.ofMillis(1));

// Create a spy on lambdaSink to intercept flushBuffers() calls.
LambdaSink spySink = spy(lambdaSink);

// Use an AtomicBoolean flag to record whether flushBuffers() is called.
AtomicBoolean flushCalled = new AtomicBoolean(false);
doAnswer(invocation -> {
flushCalled.set(true);
return null;
}).when(spySink).flushBuffers(anyList());

// First call: add one record so that the buffer is non-empty.
List<Record<Event>> records = Collections.singletonList(new Record<>(mock(Event.class)));
spySink.doOutput(records);

// Wait briefly to allow the buffer's duration to exceed maxCollectTime.
Thread.sleep(10);

// Second call: pass an empty collection; this should trigger the timeout flush.
spySink.doOutput(Collections.emptyList());

// Verify that flushBuffers() was called due to the timeout.
assertTrue(flushCalled.get(), "Expected flushBuffers() to be called on the second call due to timeout.");
}

@Test
void testConcurrentDoOutputWithMultipleThreads() throws Exception {
// Set maxCollectTime to a very short duration to force the timeout flush.
setPrivateField(lambdaSink, "maxCollectTime", Duration.ofMillis(1));

// Create a spy of the sink so we can intercept flushBuffers() calls.
LambdaSink spySink = spy(lambdaSink);

// Use an AtomicInteger to count how many times flushBuffers() is called.
AtomicInteger flushCount = new AtomicInteger(0);
doAnswer(invocation -> {
flushCount.incrementAndGet();
return null;
}).when(spySink).flushBuffers(anyList());

// Create a thread pool with multiple threads.
int threadCount = 20;
ExecutorService executor = Executors.newFixedThreadPool(threadCount);
int iterations = 1000;
List<Future<?>> futures = new ArrayList<>();

// Each task calls doOutput() with one record.
for (int i = 0; i < iterations; i++) {
futures.add(executor.submit(() -> {
spySink.doOutput(Collections.singletonList(new Record<>(mock(Event.class))));
}));
}

// Wait for all tasks to complete.
for (Future<?> future : futures) {
future.get();
}

// Additionally, call doOutput() with an empty collection to trigger any pending timeout flush.
spySink.doOutput(Collections.emptyList());

executor.shutdown();

// Assert that flushBuffers() was called at least once (or more) under concurrent load.
// With a short timeout, even if each call adds a record, eventually the partial buffer becomes "old"
// and flushes. If the locking is working correctly, flushCount should be > 0.
assertTrue(flushCount.get() > 0, "Expected at least one flush due to timeout under concurrent calls, but flushCount is " + flushCount.get());
}



// Utility to set private fields
private static void setPrivateField(Object target, String fieldName, Object value) {
Expand Down

0 comments on commit debad91

Please sign in to comment.