Skip to content

Commit

Permalink
Address comments and add UT and IT
Browse files Browse the repository at this point in the history
Signed-off-by: Srikanth Govindarajan <[email protected]>
  • Loading branch information
srikanthjg committed Jan 28, 2025
1 parent 4c391d9 commit 1f4c5fa
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
Expand All @@ -42,35 +45,34 @@
import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodec;
import org.opensearch.dataprepper.plugins.codec.json.JsonInputCodecConfig;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.opensearch.dataprepper.plugins.lambda.common.client.LambdaClientFactory;
import org.opensearch.dataprepper.plugins.lambda.common.config.AwsAuthenticationOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.BatchOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.ClientOptions;
import org.opensearch.dataprepper.plugins.lambda.common.config.InvocationType;
import org.opensearch.dataprepper.plugins.lambda.common.config.ThresholdOptions;
import org.opensearch.dataprepper.plugins.lambda.common.util.CustomLambdaRetryCondition;
import org.opensearch.dataprepper.plugins.lambda.utils.CountingHttpClient;
import org.opensearch.dataprepper.plugins.lambda.common.util.CountingRetryCondition;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.retry.RetryPolicy;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.lambda.LambdaAsyncClient;
import software.amazon.awssdk.services.lambda.model.InvokeRequest;
import software.amazon.awssdk.services.lambda.model.InvokeResponse;
import software.amazon.awssdk.services.lambda.model.TooManyRequestsException;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

@ExtendWith(MockitoExtension.class)
Expand Down Expand Up @@ -108,6 +110,10 @@ public void setup() {
lambdaRegion = System.getProperty("tests.lambda.processor.region");
functionName = System.getProperty("tests.lambda.processor.functionName");
role = System.getProperty("tests.lambda.processor.sts_role_arn");
lambdaRegion = "us-west-2";
functionName = "lambdaNoReturn";
role = "arn:aws:iam::176893235612:role/osis-s3-opensearch-role";


pluginMetrics = mock(PluginMetrics.class);
pluginSetting = mock(PluginSetting.class);
Expand Down Expand Up @@ -385,79 +391,142 @@ private List<Record<Event>> createRecords(int numRecords) {
return records;
}

/*
* For this test, set concurrency limit to 1
*/
@Test
void testTooManyRequestsExceptionWithCustomRetryCondition() {
//Note lambda function for this test looks like this:
/*def lambda_handler(event, context):
# Simulate a slow operation so that
# if concurrency = 1, multiple parallel invocations
# will result in TooManyRequestsException for the second+ invocation.
time.sleep(10)
# Return a simple success response
return {
"statusCode": 200,
"body": "Hello from concurrency-limited Lambda!"
}
*/
void testRetryLogicWithThrottlingUsingMultipleThreads() throws Exception {
/*
* This test tries to create multiple parallel Lambda invocations
* while concurrency=1. The first invocation "occupies" the single concurrency slot
* The subsequent invocations should then get a 429 TooManyRequestsException,
* triggering our CountingRetryCondition.
*/

/* Lambda handler function looks like this:
def lambda_handler(event, context):
# Simulate a slow operation so that
# if concurrency = 1, multiple parallel invocations
# will result in TooManyRequestsException for the second+ invocation.
time.sleep(10)
# Return a simple success response
return {
"statusCode": 200,
"body": "Hello from concurrency-limited Lambda!"
}
// Wrap the default HTTP client to count requests
CountingHttpClient countingHttpClient = new CountingHttpClient(
NettyNioAsyncHttpClient.builder().build()
);
*/

// Configure a custom retry policy with 3 retries and your custom condition
RetryPolicy retryPolicy = RetryPolicy.builder()
.numRetries(3)
.retryCondition(new CustomLambdaRetryCondition())
.build();
functionName = "lambdaExceptionSimulation";
// Create a CountingRetryCondition
CountingRetryCondition countingRetryCondition = new CountingRetryCondition();

// Build the real Lambda client
LambdaAsyncClient client = LambdaAsyncClient.builder()
.overrideConfiguration(
ClientOverrideConfiguration.builder()
.retryPolicy(retryPolicy)
.build()
)
.region(Region.of(lambdaRegion))
.httpClient(countingHttpClient)
.build();
// Configure a LambdaProcessorConfig

// We'll set invocation type to RequestResponse
InvocationType invocationType = mock(InvocationType.class);
when(invocationType.getAwsLambdaValue()).thenReturn(InvocationType.REQUEST_RESPONSE.getAwsLambdaValue());
when(lambdaProcessorConfig.getInvocationType()).thenReturn(invocationType);

when(lambdaProcessorConfig.getFunctionName()).thenReturn(functionName);
// If your code uses "responseEventsMatch", you can set it:
when(lambdaProcessorConfig.getResponseEventsMatch()).thenReturn(true);

// Parallel invocations to force concurrency=1 to throw TooManyRequestsException
int parallelInvocations = 10;
CompletableFuture<?>[] futures = new CompletableFuture[parallelInvocations];
for (int i = 0; i < parallelInvocations; i++) {
InvokeRequest request = InvokeRequest.builder()
.functionName(functionName)
// Set up mock ClientOptions for concurrency + small retries
ClientOptions clientOptions = mock(ClientOptions.class);
when(clientOptions.getMaxConnectionRetries()).thenReturn(3); // up to 3 retries
when(clientOptions.getMaxConcurrency()).thenReturn(5);
when(clientOptions.getConnectionTimeout()).thenReturn(Duration.ofSeconds(5));
when(clientOptions.getApiCallTimeout()).thenReturn(Duration.ofSeconds(30));
when(lambdaProcessorConfig.getClientOptions()).thenReturn(clientOptions);

// AWS auth
AwsAuthenticationOptions awsAuthenticationOptions = mock(AwsAuthenticationOptions.class);
when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.of(lambdaRegion));
when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn(role);
when(awsAuthenticationOptions.getAwsStsExternalId()).thenReturn(null);
when(awsAuthenticationOptions.getAwsStsHeaderOverrides()).thenReturn(null);
when(lambdaProcessorConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions);

// Setup the mock for getProvider
when(awsCredentialsSupplier.getProvider(any())).thenReturn(awsCredentialsProvider);

// Mock the factory to inject our CountingRetryCondition into the LambdaAsyncClient
try (MockedStatic<LambdaClientFactory> mockedFactory = mockStatic(LambdaClientFactory.class)) {

LambdaAsyncClient clientWithCountingCondition = LambdaAsyncClient.builder()
.region(Region.of(lambdaRegion))
.credentialsProvider(awsCredentialsProvider)
.overrideConfiguration(ClientOverrideConfiguration.builder()
.retryPolicy(
RetryPolicy.builder()
.retryCondition(countingRetryCondition)
.numRetries(3)
.build()
)
.build())
// netty concurrency = 5 to allow parallel requests
.httpClient(NettyNioAsyncHttpClient.builder()
.maxConcurrency(5)
.build())
.build();

futures[i] = client.invoke(request);
}
mockedFactory.when(() ->
LambdaClientFactory.createAsyncLambdaClient(
any(AwsAuthenticationOptions.class),
any(AwsCredentialsSupplier.class),
any(ClientOptions.class)))
.thenReturn(clientWithCountingCondition);

// 7) Instantiate the real LambdaProcessor
when(pluginSetting.getName()).thenReturn("lambda-processor");
when(pluginSetting.getPipelineName()).thenReturn("test-pipeline");
lambdaProcessor = new LambdaProcessor(
pluginFactory,
pluginSetting,
lambdaProcessorConfig,
awsCredentialsSupplier,
expressionEvaluator
);

// Create multiple parallel tasks to call doExecute(...)
// Each doExecute() invocation sends records to Lambda in an async manner.
int parallelInvocations = 5;
ExecutorService executor = Executors.newFixedThreadPool(parallelInvocations);

List<Future<Collection<Record<Event>>>> futures = new ArrayList<>();
for (int i = 0; i < parallelInvocations; i++) {
// Each subset of records calls the processor
List<Record<Event>> records = createRecords(2);
Future<Collection<Record<Event>>> future = executor.submit(() -> {
return lambdaProcessor.doExecute(records);
});
futures.add(future);
}

// 5) Wait for all to complete
CompletableFuture.allOf(futures).join();

// 6) Check how many had TooManyRequestsException
long tooManyRequestsCount = Arrays.stream(futures)
.filter(f -> {
try {
f.join();
return false; // no error => no TMR
} catch (CompletionException e) {
return e.getCause() instanceof TooManyRequestsException;
}
})
.count();

// 7) Observe how many total network requests occurred (including SDK retries)
int totalRequests = countingHttpClient.getRequestCount();

// Optionally: If you want to confirm the EXACT number,
// this might vary depending on how many parallel calls and how your TMR throttles them.
// For example, if all 5 calls are blocked, you might see 5*(numRetries + 1) in worst case.
assertTrue(totalRequests >= parallelInvocations,
"Should be at least one request per initial invocation, plus retries.");
// Wait for all tasks to complete
executor.shutdown();
boolean finishedInTime = executor.awaitTermination(5, TimeUnit.MINUTES);
if (!finishedInTime) {
throw new RuntimeException("Test timed out waiting for executor tasks to complete.");
}

// Check results or handle exceptions
for (Future<Collection<Record<Event>>> f : futures) {
try {
Collection<Record<Event>> out = f.get();
} catch (ExecutionException ee) {
// A 429 from AWS will be thrown as TooManyRequestsException
// If all retries failed, we might see an exception here.
}
}

// 11) Finally, check that we had at least one retry
// If concurrency=1 is truly enforced, at least some calls should have gotten a 429
// -> triggered CountingRetryCondition
int retryCount = countingRetryCondition.getRetryCount();
assertTrue(
retryCount > 0,
"Should have at least one retry due to concurrency-based throttling (429)."
);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.opensearch.dataprepper.plugins.lambda.common.util;

import software.amazon.awssdk.core.retry.RetryPolicyContext;

import java.util.concurrent.atomic.AtomicInteger;

//Used ONLY for tests
public class CountingRetryCondition extends CustomLambdaRetryCondition {
private final AtomicInteger retryCount = new AtomicInteger(0);

@Override
public boolean shouldRetry(RetryPolicyContext context) {
boolean shouldRetry = super.shouldRetry(context);
if (shouldRetry) {
retryCount.incrementAndGet();
}
return shouldRetry;
}

public int getRetryCount() {
return retryCount.get();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import software.amazon.awssdk.services.lambda.model.TooManyRequestsException;
import software.amazon.awssdk.services.lambda.model.ServiceException;

import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;


Expand All @@ -23,49 +21,41 @@ private LambdaRetryStrategy() {
/**
* Possibly a set of “bad request” style errors which might fall
*/
private static final Set<Integer> BAD_REQUEST_ERRORS = new HashSet<>(
Arrays.asList(
private static final Set<Integer> BAD_REQUEST_ERRORS = Set.of(
400, // Bad Request
422, // Unprocessable Entity
417, // Expectation Failed
406 // Not Acceptable
)
);

/**
* Status codes which may indicate a security or policy problem, so we don't retry.
*/
private static final Set<Integer> NOT_ALLOWED_ERRORS = new HashSet<>(
Arrays.asList(
private static final Set<Integer> NOT_ALLOWED_ERRORS = Set.of(
401, // Unauthorized
403, // Forbidden
405 // Method Not Allowed
)
);
);

/**
* Examples of input or payload errors that are likely not retryable
* unless the pipeline itself corrects them.
*/
private static final Set<Integer> INVALID_INPUT_ERRORS = new HashSet<>(
Arrays.asList(
private static final Set<Integer> INVALID_INPUT_ERRORS = Set.of(
413, // Payload Too Large
414, // URI Too Long
416 // Range Not Satisfiable
)
);
);

/**
* Example of a “timeout” scenario. Lambda can return 429 for "Too Many Requests" or
* 408 (if applicable) for timeouts in some contexts.
* This can be considered retryable if you want to handle the throttling scenario.
*/
private static final Set<Integer> TIMEOUT_ERRORS = new HashSet<>(
Arrays.asList(
private static final Set<Integer> TIMEOUT_ERRORS = Set.of(
408, // Request Timeout
429 // Too Many Requests (often used as "throttling" for Lambda)
)
);
);

public static boolean isRetryableStatusCode(final int statusCode) {
return TIMEOUT_ERRORS.contains(statusCode) || (statusCode >= 500 && statusCode < 600);
Expand Down
Loading

0 comments on commit 1f4c5fa

Please sign in to comment.