Skip to content

Commit

Permalink
Merge branch 'main' into lambda-sink-stateful
Browse files Browse the repository at this point in the history
Signed-off-by: Srikanth Govindarajan <[email protected]>
  • Loading branch information
srikanthjg authored Jan 31, 2025
2 parents 3d9fb37 + 414a7c9 commit 0795cdc
Show file tree
Hide file tree
Showing 30 changed files with 785 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
Expand All @@ -42,24 +45,34 @@
import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodec;
import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodecConfig;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory;
import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType;
import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions;
import org.opensearch.dataprepper.plugins.lambda.common.util.CountingRetryCondition;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.lambda.LambdaAsyncClient;
import software.amazon.awssdk.services.lambda.model.InvokeResponse;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

@ExtendWith(MockitoExtension.class)
Expand Down Expand Up @@ -373,4 +386,143 @@ private List<Record<Event>> createRecords(int numRecords) {
}
return records;
}

@Test
void testRetryLogicWithThrottlingUsingMultipleThreads() throws Exception {
/*
* This test tries to create multiple parallel Lambda invocations
* while concurrency=1. The first invocation "occupies" the single concurrency slot
* The subsequent invocations should then get a 429 TooManyRequestsException,
* triggering our CountingRetryCondition.
*/

/* Lambda handler function looks like this:
def lambda_handler(event, context):
# Simulate a slow operation so that
# if concurrency = 1, multiple parallel invocations
# will result in TooManyRequestsException for the second+ invocation.
time.sleep(10)
# Return a simple success response
return {
"statusCode": 200,
"body": "Hello from concurrency-limited Lambda!"
}
*/

functionName = "lambdaExceptionSimulation";
// Create a CountingRetryCondition
CountingRetryCondition countingRetryCondition = new CountingRetryCondition();

// Configure a LambdaProcessorConfig

// We'll set invocation type to RequestResponse
InvocationType invocationType = mock(InvocationType.class);
when(invocationType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue());
when(lambdaProcessorConfig.getInvocationType()).thenReturn(invocationType);

when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName);
// If your code uses "responseEventsMatch", you can set it:
when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(true);

// Set up mock ClientOptions for concurrency + small retries
ClientOptions clientOptions = mock(ClientOptions.class);
when(clientOptions.getMaxConnectionRetries()).thenReturn(3); // up to 3 retries
when(clientOptions.getMaxConcurrency()).thenReturn(5);
when(clientOptions.getConnectionTimeout()).thenReturn(Duration.ofSeconds(5));
when(clientOptions.getApiCallTimeout()).thenReturn(Duration.ofSeconds(30));
when(lambdaProcessorConfig.getClientOptions()).thenReturn(clientOptions);

// AWS auth
AwsAuthenticationOptions awsAuthenticationOptions = mock(AwsAuthenticationOptions.class);
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of(lambdaRegion));
when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(role);
when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn(null);
when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(null);
when(lambdaProcessorConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions);

// Setup the mock for getProvider
when(awsCredentialsSupplier.getProvider(any())).thenReturn(awsCredentialsProvider);

// Mock the factory to inject our CountingRetryCondition into the LambdaAsyncClient
try (MockedStatic<LambdaClientFactory> mockedFactory = mockStatic(LambdaClientFactory.class)) {

LambdaAsyncClient clientWithCountingCondition = LambdaAsyncClient.builder()
.region(Region.of(lambdaRegion))
.credentialsProvider(awsCredentialsProvider)
.overrideConfiguration(ClientOverrideConfiguration.builder()
.retryPolicy(
RetryPolicy.builder()
.retryCondition(countingRetryCondition)
.numRetries(3)
.build()
)
.build())
// netty concurrency = 5 to allow parallel requests
.httpClient(NettyNioAsyncHttpClient.builder()
.maxConcurrency(5)
.build())
.build();

mockedFactory.when(() ->
LambdaClientFactory.createAsyncLambdaClient(
any(AwsAuthenticationOptions.class),
any(AwsCredentialsSupplier.class),
any(ClientOptions.class)))
.thenReturn(clientWithCountingCondition);

// 7) Instantiate the real LambdaProcessor
when(pluginSetting.getName()).thenReturn("lambda-processor");
when(pluginSetting.getPipelineName()).thenReturn("test-pipeline");
lambdaProcessor = new LambdaProcessor(
pluginFactory,
pluginSetting,
lambdaProcessorConfig,
awsCredentialsSupplier,
expressionEvaluator
);

// Create multiple parallel tasks to call doExecute(...)
// Each doExecute() invocation sends records to Lambda in an async manner.
int parallelInvocations = 5;
ExecutorService executor = Executors.newFixedThreadPool(parallelInvocations);

List<Future<Collection<Record<Event>>>> futures = new ArrayList<>();
for (int i = 0; i < parallelInvocations; i++) {
// Each subset of records calls the processor
List<Record<Event>> records = createRecords(2);
Future<Collection<Record<Event>>> future = executor.submit(() -> {
return lambdaProcessor.doExecute(records);
});
futures.add(future);
}

// Wait for all tasks to complete
executor.shutdown();
boolean finishedInTime = executor.awaitTermination(5, TimeUnit.MINUTES);
if (!finishedInTime) {
throw new RuntimeException("Test timed out waiting for executor tasks to complete.");
}

// Check results or handle exceptions
for (Future<Collection<Record<Event>>> f : futures) {
try {
Collection<Record<Event>> out = f.get();
} catch (ExecutionException ee) {
// A 429 from AWS will be thrown as TooManyRequestsException
// If all retries failed, we might see an exception here.
}
}

// Finally, check that we had at least one retry
// If concurrency=1 is truly enforced, at least some calls should have gotten a 429
// -> triggered CountingRetryCondition
int retryCount = countingRetryCondition.getRetryCount();
assertTrue(
retryCount > 0,
"Should have at least one retry due to concurrency-based throttling (429)."
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ private LambdaCommonHandler() {
}

public static boolean isSuccess(InvokeResponse response) {
if(response == null) {
return false;
}
int statusCode = response.statusCode();
return statusCode >= 200 && statusCode < 300;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ public void addRecord(Record<Event> record) {
eventCount++;
}

void completeCodec() {
if (eventCount > 0) {
try {
requestCodec.complete(this.byteArrayOutputStream);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}

@Override
public List<Record<Event>> getRecords() {
return records;
Expand Down Expand Up @@ -101,11 +111,7 @@ public InvokeRequest getRequestPayload(String functionName, String invocationTyp
return null;
}

try {
requestCodec.complete(this.byteArrayOutputStream);
} catch (IOException e) {
throw new RuntimeException(e);
}
completeCodec();

SdkBytes payload = getPayload();
payloadRequestSize = payload.asByteArray().length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions;
import org.opensearch.dataprepper.plugins.lambda.common.util.CustomLambdaRetryCondition;
import org.opensearch.dataprepper.plugins.metricpublisher.MicrometerMetricPublisher;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
Expand Down Expand Up @@ -48,13 +49,14 @@ private static ClientOverrideConfiguration createOverrideConfiguration(
.maxBackoffTime(clientOptions.getMaxBackoff())
.build();

final RetryPolicy retryPolicy = RetryPolicy.builder()
final RetryPolicy customRetryPolicy = RetryPolicy.builder()
.retryCondition(new CustomLambdaRetryCondition())
.numRetries(clientOptions.getMaxConnectionRetries())
.backoffStrategy(backoffStrategy)
.build();

return ClientOverrideConfiguration.builder()
.retryPolicy(retryPolicy)
.retryPolicy(customRetryPolicy)
.addMetricPublisher(new MicrometerMetricPublisher(awsSdkMetrics))
.apiCallTimeout(clientOptions.getApiCallTimeout())
.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.opensearch.dataprepper.plugins.lambda.common.util;

import software.amazon.awssdk.core.retry.RetryPolicyContext;

import java.util.concurrent.atomic.AtomicInteger;

//Used ONLY for tests
public class CountingRetryCondition extends CustomLambdaRetryCondition {
private final AtomicInteger retryCount = new AtomicInteger(0);

@Override
public boolean shouldRetry(RetryPolicyContext context) {
boolean shouldRetry = super.shouldRetry(context);
if (shouldRetry) {
retryCount.incrementAndGet();
}
return shouldRetry;
}

public int getRetryCount() {
return retryCount.get();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package org.opensearch.dataprepper.plugins.lambda.common.util;

import software.amazon.awssdk.core.retry.conditions.RetryCondition;
import software.amazon.awssdk.core.retry.RetryPolicyContext;

public class CustomLambdaRetryCondition implements RetryCondition {

@Override
public boolean shouldRetry(RetryPolicyContext context) {
Throwable exception = context.exception();
if (exception != null) {
return LambdaRetryStrategy.isRetryableException(exception);
}

return false;
}
}
Loading

0 comments on commit 0795cdc

Please sign in to comment.