Skip to content

Commit

Permalink
added additional unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Jeremy Michael <[email protected]>
  • Loading branch information
Jeremy Michael committed Dec 26, 2024
1 parent 29701d0 commit b103f7e
Show file tree
Hide file tree
Showing 9 changed files with 341 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,31 @@
import jakarta.validation.constraints.Size;
import software.amazon.awssdk.arns.Arn;
import software.amazon.awssdk.regions.Region;

import java.util.Map;
import java.util.Optional;

public class AwsAuthenticationOptions {
private static final String AWS_IAM_ROLE = "role";
private static final String AWS_IAM = "iam";

@JsonProperty("region")
@Size(min = 1, message = "Region cannot be empty string")
private String awsRegion;

@JsonProperty("sts_role_arn")
@Size(min = 20, max = 2048, message = "awsStsRoleArn length should be between 1 and 2048 characters")
private String awsStsRoleArn;
protected String awsStsRoleArn;

@JsonProperty("sts_external_id")
@Size(min = 2, max = 1224, message = "awsStsExternalId length should be between 2 and 1224 characters")
private String awsStsExternalId;

@JsonProperty("sts_header_overrides")
@Size(max = 5, message = "sts_header_overrides supports a maximum of 5 headers to override")
private Map<String, String> awsStsHeaderOverrides;
private void validateStsRoleArn() {

void validateStsRoleArn() {
final Arn arn = getArn();
if (!AWS_IAM.equals(arn.service())) {
throw new IllegalArgumentException("sts_role_arn must be an IAM Role");
Expand All @@ -43,27 +43,27 @@ private void validateStsRoleArn() {
throw new IllegalArgumentException("sts_role_arn must be an IAM Role");
}
}

private Arn getArn() {
try {
return Arn.fromString(awsStsRoleArn);
} catch (final Exception e) {
throw new IllegalArgumentException(String.format("Invalid ARN format for awsStsRoleArn. Check the format of %s", awsStsRoleArn));
}
}

public String getAwsStsRoleArn() {
return awsStsRoleArn;
}

public String getAwsStsExternalId() {
return awsStsExternalId;
}

public Region getAwsRegion() {
return awsRegion != null ? Region.of(awsRegion) : null;
}

public Map<String, String> getAwsStsHeaderOverrides() {
return awsStsHeaderOverrides;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public void handleMessage(final Message message,
}
}

private JsonNode parseMessageBody(String messageBody) {
JsonNode parseMessageBody(String messageBody) {
try {
return objectMapper.readTree(messageBody);
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import software.amazon.awssdk.arns.Arn;
import software.amazon.awssdk.regions.Region;

import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;

Expand All @@ -34,7 +38,7 @@ void getAwsRegion_returns_Region_of() throws NoSuchFieldException, IllegalAccess
final Region expectedRegionObject = mock(Region.class);
reflectivelySetField(awsAuthenticationOptions, "awsRegion", regionString);
final Region actualRegion;
try(final MockedStatic<Region> regionMockedStatic = mockStatic(Region.class)) {
try (final MockedStatic<Region> regionMockedStatic = mockStatic(Region.class)) {
regionMockedStatic.when(() -> Region.of(regionString)).thenReturn(expectedRegionObject);
actualRegion = awsAuthenticationOptions.getAwsRegion();
}
Expand All @@ -60,6 +64,66 @@ void getStsExternalId_Null() throws NoSuchFieldException, IllegalAccessException
assertThat(awsAuthenticationOptions.getAwsStsExternalId(), nullValue());
}

@Test
void getAwsStsRoleArn_returns_correct_arn() throws NoSuchFieldException, IllegalAccessException {
final String stsRoleArn = "arn:aws:iam::123456789012:role/SampleRole";
reflectivelySetField(awsAuthenticationOptions, "awsStsRoleArn", stsRoleArn);
assertThat(awsAuthenticationOptions.getAwsStsRoleArn(), equalTo(stsRoleArn));
}

@Test
void getAwsStsHeaderOverrides_returns_correct_map() throws NoSuchFieldException, IllegalAccessException {
Map<String, String> headerOverrides = new HashMap<>();
headerOverrides.put("header1", "value1");
headerOverrides.put("header2", "value2");
reflectivelySetField(awsAuthenticationOptions, "awsStsHeaderOverrides", headerOverrides);
assertThat(awsAuthenticationOptions.getAwsStsHeaderOverrides(), equalTo(headerOverrides));
}

@Test
void validateStsRoleArn_with_invalid_format_throws_exception() throws NoSuchFieldException, IllegalAccessException {
final String invalidFormatArn = "invalid-arn-format";
reflectivelySetField(awsAuthenticationOptions, "awsStsRoleArn", invalidFormatArn);

try (final MockedStatic<Arn> arnMockedStatic = mockStatic(Arn.class)) {
arnMockedStatic.when(() -> Arn.fromString(invalidFormatArn))
.thenThrow(new IllegalArgumentException("Invalid ARN format for awsStsRoleArn. Check the format of " + invalidFormatArn));

IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> {
awsAuthenticationOptions.validateStsRoleArn();
});
assertThat(exception.getMessage(), equalTo("Invalid ARN format for awsStsRoleArn. Check the format of " + invalidFormatArn));
}
}

@Test
void validateStsRoleArn_does_not_throw_for_valid_role_Arn() throws NoSuchFieldException, IllegalAccessException {
final String validRoleArn = "arn:aws:iam::123456789012:role/SampleRole";
reflectivelySetField(awsAuthenticationOptions, "awsStsRoleArn", validRoleArn);
try {
awsAuthenticationOptions.validateStsRoleArn();
} catch (Exception e) {
throw new AssertionError("Exception should not be thrown for a valid role ARN", e);
}
}

@Test
void validateStsRoleArn_throws_exception_for_non_role_resource() throws NoSuchFieldException, IllegalAccessException {
final String nonRoleResourceArn = "arn:aws:iam::123456789012:group/MyGroup";
reflectivelySetField(awsAuthenticationOptions, "awsStsRoleArn", nonRoleResourceArn);
Exception exception = assertThrows(IllegalArgumentException.class, () -> awsAuthenticationOptions.validateStsRoleArn());
assertThat(exception.getMessage(), equalTo("sts_role_arn must be an IAM Role"));
}

@Test
void validateStsRoleArn_throws_exception_when_service_is_not_iam() throws NoSuchFieldException, IllegalAccessException {
final String invalidServiceArn = "arn:aws:s3::123456789012:role/SampleRole";
reflectivelySetField(awsAuthenticationOptions, "awsStsRoleArn", invalidServiceArn);
Exception exception = assertThrows(IllegalArgumentException.class, () -> awsAuthenticationOptions.validateStsRoleArn());
assertThat(exception.getMessage(), equalTo("sts_role_arn must be an IAM Role"));
}


private void reflectivelySetField(final AwsAuthenticationOptions awsAuthenticationOptions, final String fieldName, final Object value) throws NoSuchFieldException, IllegalAccessException {
final Field field = AwsAuthenticationOptions.class.getDeclaredField(fieldName);
try {
Expand All @@ -69,4 +133,4 @@ private void reflectivelySetField(final AwsAuthenticationOptions awsAuthenticati
field.setAccessible(false);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package org.opensearch.dataprepper.plugins.source.sqs;

import org.junit.jupiter.api.Test;
import java.time.Duration;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;

public class QueueConfigTest {

@Test
void testDefaultValues() {
final QueueConfig queueConfig = new QueueConfig();

assertNull(queueConfig.getUrl(), "URL should be null by default");
assertEquals(1, queueConfig.getNumWorkers(), "Number of workers should default to 1");
assertNull(queueConfig.getMaximumMessages(), "Maximum messages should be null by default");
assertEquals(Duration.ofSeconds(0), queueConfig.getPollDelay(), "Poll delay should default to 0 seconds");
assertNull(queueConfig.getVisibilityTimeout(), "Visibility timeout should be null by default");
assertFalse(queueConfig.getVisibilityDuplicateProtection(), "Visibility duplicate protection should default to false");
assertEquals(Duration.ofHours(2), queueConfig.getVisibilityDuplicateProtectionTimeout(),
"Visibility duplicate protection timeout should default to 2 hours");
assertEquals(Duration.ofSeconds(20), queueConfig.getWaitTime(), "Wait time should default to 20 seconds");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package org.opensearch.dataprepper.plugins.source.sqs;

import com.fasterxml.jackson.databind.JsonNode;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.opensearch.dataprepper.buffer.common.BufferAccumulator;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
import software.amazon.awssdk.services.sqs.model.Message;

import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

class RawSqsMessageHandlerTest {

private final RawSqsMessageHandler rawSqsMessageHandler = new RawSqsMessageHandler();
private BufferAccumulator<Record<Event>> mockBufferAccumulator;

@BeforeEach
void setUp() {
mockBufferAccumulator = mock(BufferAccumulator.class);
}

@Test
void parseMessageBody_validJsonString_returnsJsonNode() {
String validJson = "{\"key\":\"value\"}";
JsonNode result = rawSqsMessageHandler.parseMessageBody(validJson);
assertTrue(result.isObject(), "Result should be a JSON object");
assertEquals("value", result.get("key").asText(), "The value of 'key' should be 'value'");
}

@Test
void handleMessage_callsBufferAccumulatorAddOnce() throws Exception {
Message message = Message.builder().body("{\"key\":\"value\"}").build();
String queueUrl = "https://sqs.us-east-1.amazonaws.com/123456789012/test-queue";
rawSqsMessageHandler.handleMessage(message, queueUrl, mockBufferAccumulator, null);
ArgumentCaptor<Record<Event>> argumentCaptor = ArgumentCaptor.forClass(Record.class);
verify(mockBufferAccumulator, times(1)).add(argumentCaptor.capture());
Record<Event> capturedRecord = argumentCaptor.getValue();
assertEquals("sqs-event", capturedRecord.getData().getMetadata().getEventType(), "Event type should be 'sqs-event'");
}

@Test
void handleMessage_handlesInvalidJsonBodyGracefully() throws Exception {
String invalidJsonBody = "Invalid JSON string";
Message message = Message.builder().body(invalidJsonBody).build();
String queueUrl = "https://sqs.us-east-1.amazonaws.com/123456789012/test-queue";
rawSqsMessageHandler.handleMessage(message, queueUrl, mockBufferAccumulator, null);
ArgumentCaptor<Record<Event>> argumentCaptor = ArgumentCaptor.forClass(Record.class);
verify(mockBufferAccumulator, times(1)).add(argumentCaptor.capture());
Record<Event> capturedRecord = argumentCaptor.getValue();
Map<String, Object> eventData = capturedRecord.getData().toMap();
assertEquals(invalidJsonBody, eventData.get("message"), "The message should be a text node containing the original string");
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package org.opensearch.dataprepper.plugins.source.sqs;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.opensearch.dataprepper.buffer.common.BufferAccumulator;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
import software.amazon.awssdk.services.sqs.model.Message;

import java.io.IOException;

import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

class SqsEventProcessorTest {

private SqsMessageHandler mockSqsMessageHandler;
private SqsEventProcessor sqsEventProcessor;
private BufferAccumulator<Record<Event>> mockBufferAccumulator;
private AcknowledgementSet mockAcknowledgementSet;

@BeforeEach
void setUp() {
mockSqsMessageHandler = Mockito.mock(SqsMessageHandler.class);
mockBufferAccumulator = Mockito.mock(BufferAccumulator.class);
mockAcknowledgementSet = Mockito.mock(AcknowledgementSet.class);
sqsEventProcessor = new SqsEventProcessor(mockSqsMessageHandler);
}

@Test
void addSqsObject_callsHandleMessageWithCorrectParameters() throws IOException {
Message message = Message.builder().body("Test Message Body").build();
String queueUrl = "https://sqs.us-east-1.amazonaws.com/123456789012/test-queue";
sqsEventProcessor.addSqsObject(message, queueUrl, mockBufferAccumulator, mockAcknowledgementSet);
verify(mockSqsMessageHandler, times(1)).handleMessage(message, queueUrl, mockBufferAccumulator, mockAcknowledgementSet);
}

@Test
void addSqsObject_propagatesIOExceptionThrownByHandleMessage() throws IOException {
Message message = Message.builder().body("Test Message Body").build();
String queueUrl = "https://sqs.us-east-1.amazonaws.com/123456789012/test-queue";
doThrow(new IOException("Handle message failed")).when(mockSqsMessageHandler).handleMessage(message, queueUrl, mockBufferAccumulator, mockAcknowledgementSet);
IOException thrownException = assertThrows(IOException.class, () ->
sqsEventProcessor.addSqsObject(message, queueUrl, mockBufferAccumulator, mockAcknowledgementSet)
);
assert(thrownException.getMessage().equals("Handle message failed"));
verify(mockSqsMessageHandler, times(1)).handleMessage(message, queueUrl, mockBufferAccumulator, mockAcknowledgementSet);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package org.opensearch.dataprepper.plugins.source.sqs;

import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;

public class SqsSourceConfigTest {

private final ObjectMapper objectMapper = new ObjectMapper();

@Test
void testDefaultValues() {
final SqsSourceConfig config = new SqsSourceConfig();
assertNull(config.getAwsAuthenticationOptions(), "AWS Authentication Options should be null by default");
assertFalse(config.getAcknowledgements(), "Acknowledgments should be false by default");
assertEquals(SqsSourceConfig.DEFAULT_BUFFER_TIMEOUT, config.getBufferTimeout(), "Buffer timeout should default to 10 seconds");
assertEquals(SqsSourceConfig.DEFAULT_NUMBER_OF_RECORDS_TO_ACCUMULATE, config.getNumberOfRecordsToAccumulate(),
"Number of records to accumulate should default to 100");
assertNull(config.getQueues(), "Queues should be null by default");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void setUp() {
acknowledgementSetManager = mock(AcknowledgementSetManager.class);
awsCredentialsSupplier = mock(AwsCredentialsSupplier.class);
sqsSource = new SqsSource(pluginMetrics, sqsSourceConfig, acknowledgementSetManager, awsCredentialsSupplier);
buffer = mock(Buffer.class);
}

@Test
Expand All @@ -48,7 +49,6 @@ void start_should_throw_IllegalStateException_when_buffer_is_null() {

@Test
void start_should_not_throw_when_buffer_is_not_null() {
Buffer<Record<Event>> buffer = mock(Buffer.class);
AwsAuthenticationOptions awsAuthenticationOptions = mock(AwsAuthenticationOptions.class);
when(awsAuthenticationOptions.getAwsStsRoleArn()).thenReturn("arn:aws:iam::123456789012:role/example-role");
when(sqsSourceConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions);
Expand Down
Loading

0 comments on commit b103f7e

Please sign in to comment.