Skip to content

Commit

Permalink
lambda processor should retry for certain class of exceptions (#5320)
Browse files Browse the repository at this point in the history
* lambda processor should retry for certain class of exceptions

Signed-off-by: Srikanth Govindarajan <[email protected]>

* Address Comment on complete codec

Signed-off-by: Srikanth Govindarajan <[email protected]>

* Add retryCondidition to lambda Client

Signed-off-by: Srikanth Govindarajan <[email protected]>

* Address comments

Signed-off-by: Srikanth Govindarajan <[email protected]>

* Address comments and add UT and IT

Signed-off-by: Srikanth Govindarajan <[email protected]>

* Address comment on completeCodec

Signed-off-by: Srikanth Govindarajan <[email protected]>

---------

Signed-off-by: Srikanth Govindarajan <[email protected]>
  • Loading branch information
srikanthjg authored Jan 29, 2025
1 parent 8152c57 commit 70e8c8b
Show file tree
Hide file tree
Showing 13 changed files with 718 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
Expand All @@ -42,24 +45,34 @@
import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodec;
import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodecConfig;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory;
import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType;
import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions;
import org.opensearch.dataprepper.plugins.lambda.common.util.CountingRetryCondition;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.lambda.LambdaAsyncClient;
import software.amazon.awssdk.services.lambda.model.InvokeResponse;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

@ExtendWith(MockitoExtension.class)
Expand Down Expand Up @@ -373,4 +386,143 @@ private List<Record<Event>> createRecords(int numRecords) {
}
return records;
}

@Test
void testRetryLogicWithThrottlingUsingMultipleThreads() throws Exception {
/*
* This test tries to create multiple parallel Lambda invocations
* while concurrency=1. The first invocation "occupies" the single concurrency slot
* The subsequent invocations should then get a 429 TooManyRequestsException,
* triggering our CountingRetryCondition.
*/

/* Lambda handler function looks like this:
def lambda_handler(event, context):
# Simulate a slow operation so that
# if concurrency = 1, multiple parallel invocations
# will result in TooManyRequestsException for the second+ invocation.
time.sleep(10)
# Return a simple success response
return {
"statusCode": 200,
"body": "Hello from concurrency-limited Lambda!"
}
*/

functionName = "lambdaExceptionSimulation";
// Create a CountingRetryCondition
CountingRetryCondition countingRetryCondition = new CountingRetryCondition();

// Configure a LambdaProcessorConfig

// We'll set invocation type to RequestResponse
InvocationType invocationType = mock(InvocationType.class);
when(invocationType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue());
when(lambdaProcessorConfig.getInvocationType()).thenReturn(invocationType);

when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName);
// If your code uses "responseEventsMatch", you can set it:
when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(true);

// Set up mock ClientOptions for concurrency + small retries
ClientOptions clientOptions = mock(ClientOptions.class);
when(clientOptions.getMaxConnectionRetries()).thenReturn(3); // up to 3 retries
when(clientOptions.getMaxConcurrency()).thenReturn(5);
when(clientOptions.getConnectionTimeout()).thenReturn(Duration.ofSeconds(5));
when(clientOptions.getApiCallTimeout()).thenReturn(Duration.ofSeconds(30));
when(lambdaProcessorConfig.getClientOptions()).thenReturn(clientOptions);

// AWS auth
AwsAuthenticationOptions awsAuthenticationOptions = mock(AwsAuthenticationOptions.class);
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of(lambdaRegion));
when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(role);
when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn(null);
when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(null);
when(lambdaProcessorConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions);

// Setup the mock for getProvider
when(awsCredentialsSupplier.getProvider(any())).thenReturn(awsCredentialsProvider);

// Mock the factory to inject our CountingRetryCondition into the LambdaAsyncClient
try (MockedStatic<LambdaClientFactory> mockedFactory = mockStatic(LambdaClientFactory.class)) {

LambdaAsyncClient clientWithCountingCondition = LambdaAsyncClient.builder()
.region(Region.of(lambdaRegion))
.credentialsProvider(awsCredentialsProvider)
.overrideConfiguration(ClientOverrideConfiguration.builder()
.retryPolicy(
RetryPolicy.builder()
.retryCondition(countingRetryCondition)
.numRetries(3)
.build()
)
.build())
// netty concurrency = 5 to allow parallel requests
.httpClient(NettyNioAsyncHttpClient.builder()
.maxConcurrency(5)
.build())
.build();

mockedFactory.when(() ->
LambdaClientFactory.createAsyncLambdaClient(
any(AwsAuthenticationOptions.class),
any(AwsCredentialsSupplier.class),
any(ClientOptions.class)))
.thenReturn(clientWithCountingCondition);

// 7) Instantiate the real LambdaProcessor
when(pluginSetting.getName()).thenReturn("lambda-processor");
when(pluginSetting.getPipelineName()).thenReturn("test-pipeline");
lambdaProcessor = new LambdaProcessor(
pluginFactory,
pluginSetting,
lambdaProcessorConfig,
awsCredentialsSupplier,
expressionEvaluator
);

// Create multiple parallel tasks to call doExecute(...)
// Each doExecute() invocation sends records to Lambda in an async manner.
int parallelInvocations = 5;
ExecutorService executor = Executors.newFixedThreadPool(parallelInvocations);

List<Future<Collection<Record<Event>>>> futures = new ArrayList<>();
for (int i = 0; i < parallelInvocations; i++) {
// Each subset of records calls the processor
List<Record<Event>> records = createRecords(2);
Future<Collection<Record<Event>>> future = executor.submit(() -> {
return lambdaProcessor.doExecute(records);
});
futures.add(future);
}

// Wait for all tasks to complete
executor.shutdown();
boolean finishedInTime = executor.awaitTermination(5, TimeUnit.MINUTES);
if (!finishedInTime) {
throw new RuntimeException("Test timed out waiting for executor tasks to complete.");
}

// Check results or handle exceptions
for (Future<Collection<Record<Event>>> f : futures) {
try {
Collection<Record<Event>> out = f.get();
} catch (ExecutionException ee) {
// A 429 from AWS will be thrown as TooManyRequestsException
// If all retries failed, we might see an exception here.
}
}

// Finally, check that we had at least one retry
// If concurrency=1 is truly enforced, at least some calls should have gotten a 429
// -> triggered CountingRetryCondition
int retryCount = countingRetryCondition.getRetryCount();
assertTrue(
retryCount > 0,
"Should have at least one retry due to concurrency-based throttling (429)."
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ private LambdaCommonHandler() {
}

public static boolean isSuccess(InvokeResponse response) {
if(response == null) {
return false;
}
int statusCode = response.statusCode();
return statusCode >= 200 && statusCode < 300;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ public void addRecord(Record<Event> record) {
eventCount++;
}

void completeCodec() {
if (eventCount > 0) {
try {
requestCodec.complete(this.byteArrayOutputStream);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}

public List<Record<Event>> getRecords() {
return records;
}
Expand All @@ -98,11 +108,7 @@ public InvokeRequest getRequestPayload(String functionName, String invocationTyp
return null;
}

try {
requestCodec.complete(this.byteArrayOutputStream);
} catch (IOException e) {
throw new RuntimeException(e);
}
completeCodec();

SdkBytes payload = getPayload();
payloadRequestSize = payload.asByteArray().length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions;
import org.opensearch.dataprepper.plugins.lambda.common.util.CustomLambdaRetryCondition;
import org.opensearch.dataprepper.plugins.metricpublisher.MicrometerMetricPublisher;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
Expand Down Expand Up @@ -48,13 +49,14 @@ private static ClientOverrideConfiguration createOverrideConfiguration(
.maxBackoffTime(clientOptions.getMaxBackoff())
.build();

final RetryPolicy retryPolicy = RetryPolicy.builder()
final RetryPolicy customRetryPolicy = RetryPolicy.builder()
.retryCondition(new CustomLambdaRetryCondition())
.numRetries(clientOptions.getMaxConnectionRetries())
.backoffStrategy(backoffStrategy)
.build();

return ClientOverrideConfiguration.builder()
.retryPolicy(retryPolicy)
.retryPolicy(customRetryPolicy)
.addMetricPublisher(new MicrometerMetricPublisher(awsSdkMetrics))
.apiCallTimeout(clientOptions.getApiCallTimeout())
.build();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.opensearch.dataprepper.plugins.lambda.common.util;

import software.amazon.awssdk.core.retry.RetryPolicyContext;

import java.util.concurrent.atomic.AtomicInteger;

//Used ONLY for tests
public class CountingRetryCondition extends CustomLambdaRetryCondition {
private final AtomicInteger retryCount = new AtomicInteger(0);

@Override
public boolean shouldRetry(RetryPolicyContext context) {
boolean shouldRetry = super.shouldRetry(context);
if (shouldRetry) {
retryCount.incrementAndGet();
}
return shouldRetry;
}

public int getRetryCount() {
return retryCount.get();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package org.opensearch.dataprepper.plugins.lambda.common.util;

import software.amazon.awssdk.core.retry.conditions.RetryCondition;
import software.amazon.awssdk.core.retry.RetryPolicyContext;

public class CustomLambdaRetryCondition implements RetryCondition {

@Override
public boolean shouldRetry(RetryPolicyContext context) {
Throwable exception = context.exception();
if (exception != null) {
return LambdaRetryStrategy.isRetryableException(exception);
}

return false;
}
}
Loading

0 comments on commit 70e8c8b

Please sign in to comment.