Skip to content

Commit

Permalink
Add stateful buffer for lambda sink (#5354)
Browse files Browse the repository at this point in the history
* Fix merge conflict

Signed-off-by: Srikanth Govindarajan <[email protected]>

* Address concurrency/synchronization comment

Signed-off-by: Srikanth Govindarajan <[email protected]>

* Fix InMemoryBufferSynchronized and Add IT

Signed-off-by: Srikanth Govindarajan <[email protected]>

* Address timeout threshold comment

Signed-off-by: Srikanth Govindarajan <[email protected]>

* Add IT for timeout threshold

Signed-off-by: Srikanth Govindarajan <[email protected]>

* Fix checkstyle

Signed-off-by: Srikanth Govindarajan <[email protected]>

---------

Signed-off-by: Srikanth Govindarajan <[email protected]>
  • Loading branch information
srikanthjg authored Feb 11, 2025
1 parent b3b6c65 commit c2a0972
Show file tree
Hide file tree
Showing 9 changed files with 1,006 additions and 215 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
package org.opensearch.dataprepper.plugins.lambda.sink;

import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.Timer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.MockedStatic;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
import org.opensearch.dataprepper.aws.api.AwsCredentialsOptions;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.expression.ExpressionEvaluator;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.configuration.PluginSetting;
import org.opensearch.dataprepper.model.event.DefaultEventMetadata;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.EventMetadata;
import org.opensearch.dataprepper.model.event.JacksonEvent;
import org.opensearch.dataprepper.model.plugin.PluginFactory;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.model.types.ByteCount;
import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory;
import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType;
import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions;
import org.opensearch.dataprepper.plugins.lambda.sink.dlq.DlqPushHandler;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.lambda.LambdaAsyncClient;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import static org.mockito.ArgumentMatchers.anyDouble;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory.convertToCredentialsOptions;

/**
* Demonstrates testing threshold-based partial flush logic in LambdaSink
*/
@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
public class LambdaSinkIT {
private String lambdaRegion;
private String functionName;
private String roleArn;
private ClientOptions clientOptions;
private MockedStatic<PluginMetrics> pluginMetricsMock;
private MockedStatic<LambdaClientFactory> factoryMock;

// Mocks for config objects
@Mock
private LambdaSinkConfig lambdaSinkConfig;
@Mock
private BatchOptions batchOptions;
@Mock
private ThresholdOptions thresholdOptions;
@Mock
private AwsAuthenticationOptions awsAuthenticationOptions;
@Mock
private AwsCredentialsSupplier awsCredentialsSupplier;
@Mock
private PluginSetting pluginSetting;
@Mock
private PluginFactory pluginFactory;
@Mock
private ExpressionEvaluator expressionEvaluator;
@Mock
private DlqPushHandler dlqPushHandler;

// Mock pluginMetrics and the counters/timers
@Mock
private PluginMetrics pluginMetrics;
@Mock
private Counter numberOfRecordsSuccessCounter;
@Mock
private Counter numberOfRecordsFailedCounter;
@Mock
private Counter numberOfRequestsSuccessCounter;
@Mock
private Counter numberOfRequestsFailedCounter;

// The sink under test
private LambdaSink lambdaSink;

@BeforeEach
void setUp() {
MockitoAnnotations.openMocks(this);

lambdaRegion = System.getProperty("tests.lambda.sink.region", "us-east-1");
functionName = System.getProperty("tests.lambda.sink.functionName", "testFunctionName");
roleArn = System.getProperty("tests.lambda.sink.sts_role_arn", "someRole");

// Mock pluginSetting
when(pluginSetting.getName()).thenReturn("aws_lambda");
when(pluginSetting.getPipelineName()).thenReturn("lambdaSinkITMultiBatch");

// Configure pluginMetrics so that each named counter/timer returns a mock
when(pluginMetrics.counter(LambdaSink.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_SUCCESS))
.thenReturn(numberOfRecordsSuccessCounter);
when(pluginMetrics.counter(LambdaSink.NUMBER_OF_RECORDS_FLUSHED_TO_LAMBDA_FAILED))
.thenReturn(numberOfRecordsFailedCounter);
when(pluginMetrics.counter(LambdaSink.NUMBER_OF_SUCCESSFUL_REQUESTS_TO_LAMBDA))
.thenReturn(numberOfRequestsSuccessCounter);
when(pluginMetrics.counter(LambdaSink.NUMBER_OF_FAILED_REQUESTS_TO_LAMBDA))
.thenReturn(numberOfRequestsFailedCounter);
Timer genericTimer = mock(Timer.class);
DistributionSummary genericSummary = mock(DistributionSummary.class);
when(pluginMetrics.timer(anyString())).thenReturn(genericTimer);
when(pluginMetrics.summary(anyString())).thenReturn(genericSummary);

// Threshold config
when(thresholdOptions.getEventCount()).thenReturn(5);
when(thresholdOptions.getMaximumSize()).thenReturn(ByteCount.parse("1mb"));
when(thresholdOptions.getEventCollectTimeOut()).thenReturn(Duration.ofSeconds(120));

// BatchOptions
when(batchOptions.getKeyName()).thenReturn("lambdaSinkITKey");
when(batchOptions.getThresholdOptions()).thenReturn(thresholdOptions);

// AWS Auth
Region regionObj = Region.of(lambdaRegion);
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(regionObj);
when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(roleArn);

// LambdaSinkConfig
when(lambdaSinkConfig.getFunctionName()).thenReturn(functionName);
when(lambdaSinkConfig.getInvocationType()).thenReturn(InvocationType.REQUEST_RESPONSE);
when(lambdaSinkConfig.getBatchOptions()).thenReturn(batchOptions);
when(lambdaSinkConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions);
when(lambdaSinkConfig.getDlqPluginSetting()).thenReturn(null);
clientOptions = new ClientOptions();
when(lambdaSinkConfig.getClientOptions()).thenReturn(clientOptions);

// Now mock the static method PluginMetrics.fromPluginSetting(...) from AbstractSink
pluginMetricsMock = mockStatic(PluginMetrics.class);
pluginMetricsMock.when(() -> PluginMetrics.fromPluginSetting(pluginSetting))
.thenReturn(pluginMetrics);

factoryMock = mockStatic(LambdaClientFactory.class);
createLambdaClient(factoryMock, clientOptions);

lambdaSink = objectUnderTest();
lambdaSink.doInitialize();
}

@AfterEach
void tearDown() {
pluginMetricsMock.close();
factoryMock.close();
}

private LambdaSink objectUnderTest() {
// Build the sink
lambdaSink = new LambdaSink(
pluginSetting,
lambdaSinkConfig,
pluginFactory,
null, // real or mock SinkContext
awsCredentialsSupplier,
expressionEvaluator
);
return lambdaSink;
}

private void createLambdaClient(MockedStatic<LambdaClientFactory> factoryMock, ClientOptions clientOptions) {

// Tell the mock to skip attaching MicrometerMetricPublisher:
factoryMock.when(() ->
LambdaClientFactory.createAsyncLambdaClient(eq(awsAuthenticationOptions),
eq(awsCredentialsSupplier), eq(clientOptions))
).thenAnswer(inv -> {
// Build a normal client but omit .addMetricPublisher(...)
// or just return a fully mock client.

// a) If you want a near-real client
NettyNioAsyncHttpClient httpClient = (NettyNioAsyncHttpClient) NettyNioAsyncHttpClient.builder()
.maxConcurrency(clientOptions.getMaxConcurrency())
.connectionTimeout(clientOptions.getConnectionTimeout())
.build();

ClientOverrideConfiguration overrideConfig = ClientOverrideConfiguration.builder()
// .addMetricPublisher(...) => skip
.apiCallTimeout(clientOptions.getApiCallTimeout())
.build();

AwsCredentialsOptions awsCredentialsOptions = convertToCredentialsOptions(awsAuthenticationOptions);
AwsCredentialsProvider awsCredentialsProvider = awsCredentialsSupplier.getProvider(awsCredentialsOptions);

return LambdaAsyncClient.builder()
.region(awsAuthenticationOptions.getAwsRegion())
.credentialsProvider(awsCredentialsProvider)
.overrideConfiguration(overrideConfig)
.httpClient(httpClient)
.build();
});
}

@Test
void testMultiBatchPartialAndFullFlushes() {
// threshold=5
// doOutput(3) => partial => no success
List<Record<Event>> part1 = createEvents(3, "Batch1");
lambdaSink.doOutput(part1);

verify(numberOfRecordsSuccessCounter, never()).increment(anyDouble());
verify(numberOfRequestsSuccessCounter, never()).increment();

// doOutput(3) => total=6 => flush=5 => leftover=1 => success=5
List<Record<Event>> part2 = createEvents(3, "Batch2");
lambdaSink.doOutput(part2);

verify(numberOfRecordsSuccessCounter).increment(5.0);
verify(numberOfRequestsSuccessCounter).increment();

// leftover=1

// doOutput(4) => leftover(1)+4=5 => flush => success=5 => total=10
List<Record<Event>> part3 = createEvents(4, "Batch3");
lambdaSink.doOutput(part3);

verify(numberOfRecordsSuccessCounter, times(2)).increment(5.0);
verify(numberOfRequestsSuccessCounter, times(2)).increment();
}

@Test
void testFinalPartialFlushOnShutdown() {

// doOutput(3) => partial => success=0
List<Record<Event>> smallList = createEvents(3, "PartialShutdown");
lambdaSink.doOutput(smallList);

verify(numberOfRecordsSuccessCounter, never()).increment(anyDouble());

// shutdown => leftover=3 => flush => success=3
lambdaSink.shutdown();
verify(numberOfRecordsSuccessCounter).increment(3.0);
verify(numberOfRequestsSuccessCounter).increment();
}

@Test
void testSingleBatchFlushExceedThreshold() {
// pass 6 => threshold=5 => flush=5 => leftover=1 => success=5
List<Record<Event>> events = createEvents(6, "SingleBatch");
lambdaSink.doOutput(events);

verify(numberOfRecordsSuccessCounter).increment(5.0);
verify(numberOfRequestsSuccessCounter).increment();

lambdaSink.shutdown();
// leftover=1 => flush => success=1 => total=6
verify(numberOfRecordsSuccessCounter).increment(1.0);
verify(numberOfRequestsSuccessCounter, times(2)).increment();
}

@Test
void testTimeBasedThresholdFlush() throws InterruptedException {
// Send 3 events (below the event count threshold)
List<Record<Event>> events = createEvents(3, "TimeBatch1");
lambdaSink.doOutput(events);

// Wait for slightly less than the timeout
Thread.sleep(400);

// Send 2 more events
events = createEvents(2, "TimeBatch2");
lambdaSink.doOutput(events);

// Wait for the timeout to be exceeded
Thread.sleep(200);

// Send an empty batch to trigger the time-based flush
lambdaSink.doOutput(Collections.emptyList());

// Verify that 5 events were flushed due to time-based threshold
verify(numberOfRecordsSuccessCounter).increment(5.0);
verify(numberOfRequestsSuccessCounter).increment();

// Send 1 more event
events = createEvents(1, "TimeBatch3");
lambdaSink.doOutput(events);

// Shutdown to flush any remaining events
lambdaSink.shutdown();

// Verify that the final event was flushed
verify(numberOfRecordsSuccessCounter).increment(1.0);
verify(numberOfRequestsSuccessCounter, times(2)).increment();
}

private List<Record<Event>> createEvents(int count, String prefix) {
List<Record<Event>> records = new ArrayList<>(count);
for (int i = 0; i < count; i++) {
Map<String, Object> data = Map.of("id", i, "msg", prefix + i);
EventMetadata metadata = DefaultEventMetadata.builder()
.withEventType("ITTest")
.build();
Event event = JacksonEvent.builder()
.withData(data)
.withEventMetadata(metadata)
.build();
records.add(new Record<>(event));
}
return records;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ public static Map<Buffer, CompletableFuture<InvokeResponse>> sendRecords(
List<Buffer> batchedBuffers = createBufferBatches(records, config.getBatchOptions(),
outputCodecContext);

Map<Buffer, CompletableFuture<InvokeResponse>> bufferToFutureMap = invokeLambdaAndGetFutureMap(config, lambdaAsyncClient, batchedBuffers);
return bufferToFutureMap;
}

public static Map<Buffer, CompletableFuture<InvokeResponse>> invokeLambdaAndGetFutureMap(LambdaCommonConfig config, LambdaAsyncClient lambdaAsyncClient, List<Buffer> batchedBuffers) {
Map<Buffer, CompletableFuture<InvokeResponse>> bufferToFutureMap = new HashMap<>();
LOG.debug("Batch Chunks created after threshold check: {}", batchedBuffers.size());
for (Buffer buffer : batchedBuffers) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@

/**
* A buffer can hold in memory data and flushing it.
* NOT Thread-safe by default
*/
public class InMemoryBuffer implements Buffer {

private final ByteArrayOutputStream byteArrayOutputStream;

private final List<Record<Event>> records;
protected List<Record<Event>> records;
private final StopWatch bufferWatch;
private final StopWatch lambdaLatencyWatch;
private final OutputCodec requestCodec;
Expand Down Expand Up @@ -58,6 +59,7 @@ public InMemoryBuffer(String batchOptionKeyName, OutputCodecContext outputCodecC
this.outputCodecContext = outputCodecContext;
}

@Override
public void addRecord(Record<Event> record) {
records.add(record);
Event event = record.getData();
Expand All @@ -72,16 +74,7 @@ public void addRecord(Record<Event> record) {
eventCount++;
}

void completeCodec() {
if (eventCount > 0) {
try {
requestCodec.complete(this.byteArrayOutputStream);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}

@Override
public List<Record<Event>> getRecords() {
return records;
}
Expand All @@ -108,7 +101,11 @@ public InvokeRequest getRequestPayload(String functionName, String invocationTyp
return null;
}

completeCodec();
try {
requestCodec.complete(this.byteArrayOutputStream);
} catch (IOException e) {
throw new RuntimeException(e);
}

SdkBytes payload = getPayload();
payloadRequestSize = payload.asByteArray().length;
Expand Down
Loading

0 comments on commit c2a0972

Please sign in to comment.