diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 85b7e816cf..0bc07143a7 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -14,7 +14,6 @@ This document contains a list of maintainers in this repo. See [opensearch-proje | Taylor Gray | [graytaylor0](https://github.com/graytaylor0) | Amazon | | Dinu John | [dinujoh](https://github.com/dinujoh) | Amazon | | Krishna Kondaka | [kkondaka](https://github.com/kkondaka) | Amazon | -| Asif Sohail Mohammed | [asifsmohammed](https://github.com/asifsmohammed) | Amazon | | Karsten Schnitter | [KarstenSchnitter](https://github.com/KarstenSchnitter) | SAP | | David Venable | [dlvenable](https://github.com/dlvenable) | Amazon | | Hai Yan | [oeyh](https://github.com/oeyh) | Amazon | @@ -22,10 +21,11 @@ This document contains a list of maintainers in this repo. See [opensearch-proje ## Emeritus -| Maintainer | GitHub ID | Affiliation | -| -------------------- | ----------------------------------------------------- | ----------- | -| Steven Bayer | [sbayer55](https://github.com/sbayer55) | Amazon | -| Christopher Manning | [cmanning09](https://github.com/cmanning09) | Amazon | -| David Powers | [dapowers87](https://github.com/dapowers87) | Amazon | -| Shivani Shukla | [sshivanii](https://github.com/sshivanii) | Amazon | -| Phill Treddenick | [treddeni-amazon](https://github.com/treddeni-amazon) | Amazon | +| Maintainer | GitHub ID | Affiliation | +| ---------------------- | ----------------------------------------------------- | ----------- | +| Steven Bayer | [sbayer55](https://github.com/sbayer55) | Amazon | +| Christopher Manning | [cmanning09](https://github.com/cmanning09) | Amazon | +| Asif Sohail Mohammed | [asifsmohammed](https://github.com/asifsmohammed) | Amazon | +| David Powers | [dapowers87](https://github.com/dapowers87) | Amazon | +| Shivani Shukla | [sshivanii](https://github.com/sshivanii) | Amazon | +| Phill Treddenick | [treddeni-amazon](https://github.com/treddeni-amazon) | Amazon | diff --git a/data-prepper-plugins/aggregate-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/aggregate/AggregateProcessorConfig.java b/data-prepper-plugins/aggregate-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/aggregate/AggregateProcessorConfig.java index e1522b23ed..3f7fedb539 100644 --- a/data-prepper-plugins/aggregate-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/aggregate/AggregateProcessorConfig.java +++ b/data-prepper-plugins/aggregate-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/aggregate/AggregateProcessorConfig.java @@ -31,7 +31,7 @@ public class AggregateProcessorConfig { @JsonPropertyDescription("An unordered list by which to group events. Events with the same values as these keys are put into the same group. " + "If an event does not contain one of the identification_keys, then the value of that key is considered to be equal to null. " + - "At least one identification_key is required. An example configuration is [\"sourceIp\", \"destinationIp\", \"port\"].") + "At least one identification_key is required.") @JsonProperty("identification_keys") @NotEmpty @ExampleValues({ diff --git a/data-prepper-plugins/anomaly-detector-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/anomalydetector/AnomalyDetectorProcessorConfig.java b/data-prepper-plugins/anomaly-detector-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/anomalydetector/AnomalyDetectorProcessorConfig.java index 578e22a351..a8c45316ea 100644 --- a/data-prepper-plugins/anomaly-detector-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/anomalydetector/AnomalyDetectorProcessorConfig.java +++ b/data-prepper-plugins/anomaly-detector-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/anomalydetector/AnomalyDetectorProcessorConfig.java @@ -34,7 +34,7 @@ public class AnomalyDetectorProcessorConfig { @UsesDataPrepperPlugin(pluginType = AnomalyDetectorMode.class) private PluginModel detectorMode; - @JsonPropertyDescription("If provided, anomalies will be detected within each unique instance of these keys. For example, if you provide the ip field, anomalies will be detected separately for each unique IP address.") + @JsonPropertyDescription("If provided, anomalies will be detected within each unique instance of these keys. For example, if you provide the IP field, anomalies will be detected separately for each unique IP address.") @JsonProperty("identification_keys") @ExampleValues({ @Example(value = "ip_address", description = "Anomalies will be detected separately for each unique IP address from the existing ip_address key of the Event.") diff --git a/data-prepper-plugins/anomaly-detector-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/anomalydetector/modes/RandomCutForestModeConfig.java b/data-prepper-plugins/anomaly-detector-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/anomalydetector/modes/RandomCutForestModeConfig.java index 74c90fca50..46657946d4 100644 --- a/data-prepper-plugins/anomaly-detector-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/anomalydetector/modes/RandomCutForestModeConfig.java +++ b/data-prepper-plugins/anomaly-detector-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/anomalydetector/modes/RandomCutForestModeConfig.java @@ -51,7 +51,7 @@ public class RandomCutForestModeConfig { @JsonProperty(value = "time_decay", defaultValue = "" + DEFAULT_TIME_DECAY) private double timeDecay = DEFAULT_TIME_DECAY; - @JsonPropertyDescription("Output after indicates the number of events to consume before outputting anomalies. Default is 32.") + @JsonPropertyDescription("The number of events to train the model before generating an anomaly event. Default is 32.") @JsonProperty(value = "output_after", defaultValue = "" + DEFAULT_OUTPUT_AFTER) private int outputAfter = DEFAULT_OUTPUT_AFTER; diff --git a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfig.java b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfig.java index c9675ed931..49b35ebf08 100644 --- a/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfig.java +++ b/data-prepper-plugins/aws-lambda/src/main/java/org/opensearch/dataprepper/plugins/lambda/processor/LambdaProcessorConfig.java @@ -17,7 +17,7 @@ import org.opensearch.dataprepper.plugins.lambda.common.config.LambdaCommonConfig; @JsonPropertyOrder -@JsonClassDescription("The aws_lambda processor enables invocation of an AWS Lambda function within your Data Prepper pipeline in order to process events." + +@JsonClassDescription("The aws_lambda processor enables invocation of an AWS Lambda function within your pipeline in order to process events. " + "It supports both synchronous and asynchronous invocations based on your use case.") public class LambdaProcessorConfig extends LambdaCommonConfig { static final String DEFAULT_INVOCATION_TYPE = "request-response"; @@ -40,7 +40,7 @@ public class LambdaProcessorConfig extends LambdaCommonConfig { @JsonPropertyDescription("Defines a condition for event to use this processor.") @ExampleValues({ - @Example(value = "event['status'] == 'process'", description = "The processor will only run on this condition.") + @Example(value = "/some_key == null", description = "The processor will only run on events where this condition evaluates to true.") }) @JsonProperty("lambda_when") private String whenCondition; diff --git a/data-prepper-plugins/cloudwatch-logs/build.gradle b/data-prepper-plugins/cloudwatch-logs/build.gradle index 3bbb24f443..348275f298 100644 --- a/data-prepper-plugins/cloudwatch-logs/build.gradle +++ b/data-prepper-plugins/cloudwatch-logs/build.gradle @@ -1,8 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + plugins { id 'java' id 'java-library' } +sourceSets { + integrationTest { + java { + compileClasspath += main.output + test.output + runtimeClasspath += main.output + test.output + srcDir file('src/integrationTest/java') + } + } +} + +configurations { + integrationTestImplementation.extendsFrom testImplementation + integrationTestRuntime.extendsFrom testRuntime +} + dependencies { implementation project(':data-prepper-plugins:aws-plugin-api') implementation project(path: ':data-prepper-plugins:common') @@ -33,4 +53,20 @@ jacocoTestCoverageVerification { test { useJUnitPlatform() -} \ No newline at end of file +} + +task integrationTest(type: Test) { + group = 'verification' + testClassesDirs = sourceSets.integrationTest.output.classesDirs + + useJUnitPlatform() + + classpath = sourceSets.integrationTest.runtimeClasspath + systemProperty 'tests.cloudwatch.log_group', System.getProperty('tests.cloudwatch.log_group') + systemProperty 'tests.cloudwatch.log_stream', System.getProperty('tests.cloudwatch.log_stream') + systemProperty 'tests.aws.region', System.getProperty('tests.aws.region') + systemProperty 'tests.aws.role', System.getProperty('tests.aws.role') + filter { + includeTestsMatching '*IT' + } +} diff --git a/data-prepper-plugins/cloudwatch-logs/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/cloudwatch_logs/CouldWatchLogsIT.java b/data-prepper-plugins/cloudwatch-logs/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/cloudwatch_logs/CouldWatchLogsIT.java new file mode 100644 index 0000000000..5158efb18e --- /dev/null +++ b/data-prepper-plugins/cloudwatch-logs/src/integrationTest/java/org/opensearch/dataprepper/plugins/sink/cloudwatch_logs/CouldWatchLogsIT.java @@ -0,0 +1,292 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.dataprepper.plugins.sink.cloudwatch_logs; + +import io.micrometer.core.instrument.Counter; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.configuration.PluginSetting; +import org.opensearch.dataprepper.plugins.sink.cloudwatch_logs.config.ThresholdConfig; +import org.opensearch.dataprepper.plugins.sink.cloudwatch_logs.config.AwsConfig; +import org.opensearch.dataprepper.plugins.sink.cloudwatch_logs.config.CloudWatchLogsSinkConfig; +import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier; +import org.opensearch.dataprepper.plugins.sink.cloudwatch_logs.client.CloudWatchLogsClientFactory; +import software.amazon.awssdk.services.cloudwatchlogs.model.GetLogEventsResponse; +import software.amazon.awssdk.services.cloudwatchlogs.model.GetLogEventsRequest; +import software.amazon.awssdk.services.cloudwatchlogs.model.OutputLogEvent; +import software.amazon.awssdk.services.cloudwatchlogs.model.CreateLogStreamRequest; +import software.amazon.awssdk.services.cloudwatchlogs.model.CreateLogStreamResponse; +import software.amazon.awssdk.services.cloudwatchlogs.model.DeleteLogStreamRequest; +import software.amazon.awssdk.services.cloudwatchlogs.CloudWatchLogsClient; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.log.JacksonLog; +import org.opensearch.dataprepper.model.record.Record; +import software.amazon.awssdk.regions.Region; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import static org.awaitility.Awaitility.await; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.lenient; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.List; +import java.util.Collection; +import java.util.Collections; +import java.util.concurrent.atomic.AtomicInteger; + +@ExtendWith(MockitoExtension.class) +public class CouldWatchLogsIT { + static final int NUM_RECORDS = 2; + @Mock + private PluginSetting pluginSetting; + + @Mock + private PluginMetrics pluginMetrics; + + @Mock + private AwsCredentialsSupplier awsCredentialsSupplier; + + @Mock + private AwsConfig awsConfig; + + @Mock + private ThresholdConfig thresholdConfig; + + @Mock + private CloudWatchLogsSinkConfig cloudWatchLogsSinkConfig; + + @Mock + private Counter counter; + + private String awsRegion; + private String awsRole; + private String logGroupName; + private String logStreamName; + private CloudWatchLogsSink sink; + private AtomicInteger count; + private CloudWatchLogsClient cloudWatchLogsClient; + private ObjectMapper objectMapper; + + @BeforeEach + void setUp() { + count = new AtomicInteger(0); + objectMapper = new ObjectMapper(); + pluginSetting = mock(PluginSetting.class); + when(pluginSetting.getPipelineName()).thenReturn("pipeline"); + when(pluginSetting.getName()).thenReturn("name"); + awsRegion = System.getProperty("tests.aws.region"); + awsRole = System.getProperty("tests.aws.role"); + awsConfig = mock(AwsConfig.class); + when(awsConfig.getAwsRegion()).thenReturn(Region.of(awsRegion)); + when(awsConfig.getAwsStsRoleArn()).thenReturn(awsRole); + when(awsConfig.getAwsStsExternalId()).thenReturn(null); + when(awsConfig.getAwsStsHeaderOverrides()).thenReturn(null); + when(awsCredentialsSupplier.getProvider(any())).thenAnswer(options -> DefaultCredentialsProvider.create()); + cloudWatchLogsClient = CloudWatchLogsClientFactory.createCwlClient(awsConfig, awsCredentialsSupplier); + logGroupName = System.getProperty("tests.cloudwatch.log_group"); + logStreamName = createLogStream(logGroupName); + pluginMetrics = mock(PluginMetrics.class); + counter = mock(Counter.class); + lenient().doAnswer((a)-> { + int v = (int)(double)(a.getArgument(0)); + count.addAndGet(v); + return null; + }).when(counter).increment(any(Double.class)); + lenient().doAnswer((a)-> { + count.addAndGet(1); + return null; + }).when(counter).increment(); + when(pluginMetrics.counter(anyString())).thenReturn(counter); + cloudWatchLogsSinkConfig = mock(CloudWatchLogsSinkConfig.class); + when(cloudWatchLogsSinkConfig.getLogGroup()).thenReturn(logGroupName); + when(cloudWatchLogsSinkConfig.getLogStream()).thenReturn(logStreamName); + when(cloudWatchLogsSinkConfig.getAwsConfig()).thenReturn(awsConfig); + when(cloudWatchLogsSinkConfig.getBufferType()).thenReturn(CloudWatchLogsSinkConfig.DEFAULT_BUFFER_TYPE); + + thresholdConfig = mock(ThresholdConfig.class); + when(thresholdConfig.getBackOffTime()).thenReturn(500L); + when(thresholdConfig.getLogSendInterval()).thenReturn(60L); + when(thresholdConfig.getRetryCount()).thenReturn(10); + when(thresholdConfig.getMaxEventSizeBytes()).thenReturn(1000L); + when(cloudWatchLogsSinkConfig.getThresholdConfig()).thenReturn(thresholdConfig); + } + + @AfterEach + void tearDown() { + DeleteLogStreamRequest deleteRequest = DeleteLogStreamRequest + .builder() + .logGroupName(logGroupName) + .logStreamName(logStreamName) + .build(); + cloudWatchLogsClient.deleteLogStream(deleteRequest); + } + + private CloudWatchLogsSink createObjectUnderTest() { + return new CloudWatchLogsSink(pluginSetting, pluginMetrics, cloudWatchLogsSinkConfig, awsCredentialsSupplier); + } + + private String createLogStream(final String logGroupName) { + final String newLogStreamName = "CouldWatchLogsIT_"+RandomStringUtils.randomAlphabetic(6); + CreateLogStreamRequest createRequest = CreateLogStreamRequest + .builder() + .logGroupName(logGroupName) + .logStreamName(newLogStreamName) + .build(); + CreateLogStreamResponse response = cloudWatchLogsClient.createLogStream(createRequest); + return newLogStreamName; + + } + + @Test + void TestSinkOperationWithLogSendInterval() throws Exception { + long startTime = Instant.now().toEpochMilli(); + when(thresholdConfig.getBatchSize()).thenReturn(10); + when(thresholdConfig.getLogSendInterval()).thenReturn(10L); + when(thresholdConfig.getMaxRequestSizeBytes()).thenReturn(1000L); + + sink = createObjectUnderTest(); + Collection> records = getRecordList(NUM_RECORDS); + sink.doOutput(records); + await().atMost(Duration.ofSeconds(30)) + .untilAsserted(() -> { + sink.doOutput(Collections.emptyList()); + long endTime = Instant.now().toEpochMilli(); + GetLogEventsRequest getRequest = GetLogEventsRequest + .builder() + .logGroupName(logGroupName) + .logStreamName(logStreamName) + .startTime(startTime) + .endTime(endTime) + .build(); + GetLogEventsResponse response = cloudWatchLogsClient.getLogEvents(getRequest); + List events = response.events(); + assertThat(events.size(), equalTo(NUM_RECORDS)); + for (int i = 0; i < events.size(); i++) { + String message = events.get(i).message(); + Map event = objectMapper.readValue(message, Map.class); + assertThat(event.get("name"), equalTo("Person"+i)); + assertThat(event.get("age"), equalTo(Integer.toString(i))); + } + }); + // NUM_RECORDS success + // 1 request success + assertThat(count.get(), equalTo(NUM_RECORDS+1)); + + } + + @Test + void TestSinkOperationWithBatchSize() throws Exception { + long startTime = Instant.now().toEpochMilli(); + when(thresholdConfig.getBatchSize()).thenReturn(1); + when(thresholdConfig.getMaxEventSizeBytes()).thenReturn(1000L); + when(thresholdConfig.getMaxRequestSizeBytes()).thenReturn(1000L); + + sink = createObjectUnderTest(); + Collection> records = getRecordList(NUM_RECORDS); + sink.doOutput(records); + await().atMost(Duration.ofSeconds(30)) + .untilAsserted(() -> { + long endTime = Instant.now().toEpochMilli(); + GetLogEventsRequest getRequest = GetLogEventsRequest + .builder() + .logGroupName(logGroupName) + .logStreamName(logStreamName) + .startTime(startTime) + .endTime(endTime) + .build(); + GetLogEventsResponse response = cloudWatchLogsClient.getLogEvents(getRequest); + List events = response.events(); + assertThat(events.size(), equalTo(NUM_RECORDS)); + for (int i = 0; i < events.size(); i++) { + String message = events.get(i).message(); + Map event = objectMapper.readValue(message, Map.class); + assertThat(event.get("name"), equalTo("Person"+i)); + assertThat(event.get("age"), equalTo(Integer.toString(i))); + } + }); + // NUM_RECORDS success + // NUM_RECORDS request success + assertThat(count.get(), equalTo(NUM_RECORDS*2)); + + } + + @Test + void TestSinkOperationWithMaxRequestSize() throws Exception { + long startTime = Instant.now().toEpochMilli(); + when(thresholdConfig.getBatchSize()).thenReturn(20); + when(thresholdConfig.getMaxRequestSizeBytes()).thenReturn(108L); + + sink = createObjectUnderTest(); + Collection> records = getRecordList(NUM_RECORDS); + sink.doOutput(records); + await().atMost(Duration.ofSeconds(30)) + .untilAsserted(() -> { + long endTime = Instant.now().toEpochMilli(); + GetLogEventsRequest getRequest = GetLogEventsRequest + .builder() + .logGroupName(logGroupName) + .logStreamName(logStreamName) + .startTime(startTime) + .endTime(endTime) + .build(); + GetLogEventsResponse response = cloudWatchLogsClient.getLogEvents(getRequest); + List events = response.events(); + assertThat(events.size(), equalTo(NUM_RECORDS)); + for (int i = 0; i < events.size(); i++) { + String message = events.get(i).message(); + Map event = objectMapper.readValue(message, Map.class); + assertThat(event.get("name"), equalTo("Person"+i)); + assertThat(event.get("age"), equalTo(Integer.toString(i))); + } + }); + // NUM_RECORDS success + // 1 request success + assertThat(count.get(), equalTo(NUM_RECORDS+1)); + + } + + private Collection> getRecordList(int numberOfRecords) { + final Collection> recordList = new ArrayList<>(); + List records = generateRecords(numberOfRecords); + for (int i = 0; i < numberOfRecords; i++) { + final Event event = JacksonLog.builder().withData(records.get(i)).build(); + recordList.add(new Record<>(event)); + } + return recordList; + } + + private static List generateRecords(int numberOfRecords) { + + List recordList = new ArrayList<>(); + + for (int rows = 0; rows < numberOfRecords; rows++) { + + HashMap eventData = new HashMap<>(); + + eventData.put("name", "Person" + rows); + eventData.put("age", Integer.toString(rows)); + recordList.add((eventData)); + + } + return recordList; + } +} diff --git a/data-prepper-plugins/cloudwatch-logs/src/main/java/org/opensearch/dataprepper/plugins/sink/cloudwatch_logs/CloudWatchLogsSink.java b/data-prepper-plugins/cloudwatch-logs/src/main/java/org/opensearch/dataprepper/plugins/sink/cloudwatch_logs/CloudWatchLogsSink.java index f17d79c7af..ddc5654d27 100644 --- a/data-prepper-plugins/cloudwatch-logs/src/main/java/org/opensearch/dataprepper/plugins/sink/cloudwatch_logs/CloudWatchLogsSink.java +++ b/data-prepper-plugins/cloudwatch-logs/src/main/java/org/opensearch/dataprepper/plugins/sink/cloudwatch_logs/CloudWatchLogsSink.java @@ -87,10 +87,6 @@ public void doInitialize() { @Override public void doOutput(Collection> records) { - if (records.isEmpty()) { - return; - } - cloudWatchLogsService.processLogEvents(records); } @@ -98,4 +94,4 @@ public void doOutput(Collection> records) { public boolean isReady() { return isInitialized; } -} \ No newline at end of file +} diff --git a/data-prepper-plugins/cloudwatch-logs/src/main/java/org/opensearch/dataprepper/plugins/sink/cloudwatch_logs/client/CloudWatchLogsService.java b/data-prepper-plugins/cloudwatch-logs/src/main/java/org/opensearch/dataprepper/plugins/sink/cloudwatch_logs/client/CloudWatchLogsService.java index fc9963ab46..af54d19267 100644 --- a/data-prepper-plugins/cloudwatch-logs/src/main/java/org/opensearch/dataprepper/plugins/sink/cloudwatch_logs/client/CloudWatchLogsService.java +++ b/data-prepper-plugins/cloudwatch-logs/src/main/java/org/opensearch/dataprepper/plugins/sink/cloudwatch_logs/client/CloudWatchLogsService.java @@ -60,6 +60,19 @@ public CloudWatchLogsService(final Buffer buffer, */ public void processLogEvents(final Collection> logs) { sinkStopWatch.startIfNotRunning(); + if (logs.isEmpty() && buffer.getEventCount() > 0) { + processLock.lock(); + try { + if (cloudWatchLogsLimits.isGreaterThanLimitReached(sinkStopWatch.getElapsedTimeInSeconds(), + buffer.getBufferSize(), buffer.getEventCount())) { + stageLogEvents(); + } + } finally { + processLock.unlock(); + } + return; + } + for (Record log : logs) { String logString = log.getData().toJsonString(); int logLength = logString.length(); diff --git a/data-prepper-plugins/date-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/date/DateProcessorConfig.java b/data-prepper-plugins/date-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/date/DateProcessorConfig.java index 2e3114d979..05deb5aa07 100644 --- a/data-prepper-plugins/date-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/date/DateProcessorConfig.java +++ b/data-prepper-plugins/date-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/date/DateProcessorConfig.java @@ -10,6 +10,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; import com.fasterxml.jackson.annotation.JsonPropertyOrder; +import com.fasterxml.jackson.annotation.JsonAlias; import jakarta.validation.constraints.AssertTrue; import org.opensearch.dataprepper.model.annotations.AlsoRequired; import org.opensearch.dataprepper.model.annotations.ConditionalRequired; @@ -133,8 +134,7 @@ public static boolean isValidPattern(final String pattern) { private Boolean fromTimeReceived = DEFAULT_FROM_TIME_RECEIVED; @JsonProperty("match") - @JsonPropertyDescription("The date match configuration. " + - "This option cannot be defined at the same time as from_time_received. " + + @JsonPropertyDescription("This option cannot be defined at the same time as from_time_received. " + "The date processor will use the first pattern that matches each event's timestamp field. " + "You must provide at least one pattern unless you have from_time_received.") @AlsoRequired(values = { @@ -155,9 +155,11 @@ public static boolean isValidPattern(final String pattern) { }) private String outputFormat = DEFAULT_OUTPUT_FORMAT; - @JsonProperty("to_origination_metadata") - @JsonPropertyDescription("When true, the matched time is also added to the event's metadata as an instance of " + - "Instant. Default is false.") + @JsonProperty("origination_timestamp_to_metadata") + @JsonAlias("to_origination_metadata") + @JsonPropertyDescription("Include the origination timestamp in the metadata. " + + "Enabling this option will use this timestamp to report the EndToEndLatency metric " + + "when events reach the sink. Default is false.") private Boolean toOriginationMetadata = DEFAULT_TO_ORIGINATION_METADATA; @JsonProperty("source_timezone") @@ -187,7 +189,7 @@ public static boolean isValidPattern(final String pattern) { "or a string representation of the " + "locale object, such as en_US. " + "A full list of locale fields, including language, country, and variant, can be found " + - "here." + + "here. " + "Default is Locale.ROOT.") @ExampleValues({ @Example("en-US"), @@ -196,10 +198,10 @@ public static boolean isValidPattern(final String pattern) { private String locale; @JsonProperty("date_when") - @JsonPropertyDescription("Specifies under what condition the date processor should perform matching. " + + @JsonPropertyDescription("Specifies under what condition the date processor should run. " + "Default is no condition.") @ExampleValues({ - @Example(value = "/some_key == null", description = "Only runs the date processor on the Event if some_key is null or doesn't exist.") + @Example(value = "/some_key == null", description = "The processor will only run on events where this condition evaluates to true.") }) private String dateWhen; diff --git a/data-prepper-plugins/drop-events-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/drop/DropEventProcessorConfig.java b/data-prepper-plugins/drop-events-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/drop/DropEventProcessorConfig.java index dbf1e9b63d..29e5de133d 100644 --- a/data-prepper-plugins/drop-events-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/drop/DropEventProcessorConfig.java +++ b/data-prepper-plugins/drop-events-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/drop/DropEventProcessorConfig.java @@ -27,7 +27,7 @@ public class DropEventProcessorConfig { }) private String dropWhen; - @JsonPropertyDescription("Specifies how exceptions are handled when an exception occurs while evaluating an event. Default value is skip, which drops the event so that it is not sent to further processors or sinks.") + @JsonPropertyDescription("Specifies how exceptions are handled when an exception occurs while evaluating an event. Default value is skip, which sends the events to further processors or sinks.") @JsonProperty(value = "handle_failed_events", defaultValue = "skip") private HandleFailedEventsOption handleFailedEventsOption = HandleFailedEventsOption.SKIP; diff --git a/data-prepper-plugins/flatten-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/flatten/FlattenProcessorConfig.java b/data-prepper-plugins/flatten-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/flatten/FlattenProcessorConfig.java index e5f4cb5c1b..747eeb4ab5 100644 --- a/data-prepper-plugins/flatten-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/flatten/FlattenProcessorConfig.java +++ b/data-prepper-plugins/flatten-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/flatten/FlattenProcessorConfig.java @@ -7,9 +7,9 @@ import com.fasterxml.jackson.annotation.JsonClassDescription; import com.fasterxml.jackson.annotation.JsonProperty; - import com.fasterxml.jackson.annotation.JsonPropertyDescription; import com.fasterxml.jackson.annotation.JsonPropertyOrder; +import com.fasterxml.jackson.annotation.JsonAlias; import jakarta.validation.constraints.AssertTrue; import jakarta.validation.constraints.NotNull; import org.opensearch.dataprepper.model.annotations.ExampleValues; @@ -20,7 +20,7 @@ import java.util.List; @JsonPropertyOrder -@JsonClassDescription("The flatten processor transforms nested objects inside of events into flattened structures.") +@JsonClassDescription("The flatten processor transforms nested objects inside of events into flattened structures.") public class FlattenProcessorConfig { static final String REMOVE_LIST_INDICES_KEY = "remove_list_indices"; @@ -44,12 +44,14 @@ public class FlattenProcessorConfig { }) private String target; - @JsonProperty("remove_processed_fields") + @JsonProperty("remove_source_keys") + @JsonAlias("remove_processed_fields") @JsonPropertyDescription("When true, the processor removes all processed fields from the source. " + "The default is false which leaves the source fields.") private boolean removeProcessedFields = false; - @JsonProperty(REMOVE_LIST_INDICES_KEY) + @JsonProperty("remove_list_elements") + @JsonAlias(REMOVE_LIST_INDICES_KEY) @JsonPropertyDescription("When true, the processor converts the fields from the source map into lists and " + "puts the lists into the target field. Default is false.") private boolean removeListIndices = false; diff --git a/data-prepper-plugins/key-value-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/keyvalue/KeyValueProcessorConfig.java b/data-prepper-plugins/key-value-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/keyvalue/KeyValueProcessorConfig.java index 501dbb9d54..cfc1804a30 100644 --- a/data-prepper-plugins/key-value-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/keyvalue/KeyValueProcessorConfig.java +++ b/data-prepper-plugins/key-value-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/keyvalue/KeyValueProcessorConfig.java @@ -195,8 +195,7 @@ public class KeyValueProcessorConfig { "If this flag is enabled, then the content between the delimiters is considered to be one entity and " + "they are not parsed as key-value pairs. The following characters are used a group delimiters: " + "{...}, [...], <...>, (...), \"...\", '...', http://... (space), and https:// (space). " + - "Default is false. For example, if value_grouping is true, then " + - "{\"key1=[a=b,c=d]&key2=value2\"} parses to {\"key1\": \"[a=b,c=d]\", \"key2\": \"value2\"}.") + "Default is false.") @AlsoRequired(values = { @AlsoRequired.Required(name = FIELD_DELIMITER_REGEX_KEY, allowedValues = {"null"}) }) diff --git a/data-prepper-plugins/mutate-event-processors/src/main/java/org/opensearch/dataprepper/plugins/processor/mutateevent/AddEntryProcessorConfig.java b/data-prepper-plugins/mutate-event-processors/src/main/java/org/opensearch/dataprepper/plugins/processor/mutateevent/AddEntryProcessorConfig.java index 5529428ecb..cf46fc13ae 100644 --- a/data-prepper-plugins/mutate-event-processors/src/main/java/org/opensearch/dataprepper/plugins/processor/mutateevent/AddEntryProcessorConfig.java +++ b/data-prepper-plugins/mutate-event-processors/src/main/java/org/opensearch/dataprepper/plugins/processor/mutateevent/AddEntryProcessorConfig.java @@ -66,7 +66,7 @@ public class AddEntryProcessorConfig { public static class Entry { @JsonPropertyDescription("The key of the new entry to be added. Some examples of keys include my_key, " + "myKey, and object/sub_Key. The key can also be a format expression, for example, ${/key1} to " + - "use the value of field key1 as the key. Either one of key or metadata_key is required.") + "use the value of field key1 as the key. Exactly one of key or metadata_key is required.") @AlsoRequired(values = { @AlsoRequired.Required(name=METADATA_KEY_KEY, allowedValues = {"null"}) }) @@ -79,7 +79,8 @@ public static class Entry { @JsonProperty(METADATA_KEY_KEY) @JsonPropertyDescription("The key for the new metadata attribute. The argument must be a literal string key " + - "and not a JSON Pointer. Either one of key or metadata_key is required.") + "and not a JSON Pointer. Adds an attribute to the Events that will not be sent to the sinks, but can be used for condition expressions and routing with the getMetadata function. " + + "Exactly one of key or metadata_key is required.") @AlsoRequired(values = { @AlsoRequired.Required(name="key", allowedValues = {"null"}) }) @@ -103,7 +104,7 @@ public static class Entry { private Object value; @JsonPropertyDescription("A format string to use as the value of the new entry, for example, " + - "${key1}-${key2}, where key1 and key2 are existing keys in the event. Required if neither" + + "${key1}-${key2}, where key1 and key2 are existing keys in the event. Required if neither " + "value nor value_expression is specified.") @AlsoRequired(values = { @AlsoRequired.Required(name="value", allowedValues = {"null"}), @@ -132,7 +133,7 @@ public static class Entry { @JsonProperty(OVERWRITE_IF_KEY_EXISTS_KEY) @JsonPropertyDescription("When set to true, the existing value is overwritten if key already exists " + - "in the event. The default value is false.") + "in the event. Only one of overwrite_if_key_exists or append_if_key_exists can be true. The default value is false.") @AlsoRequired(values = { @AlsoRequired.Required(name=APPEND_IF_KEY_EXISTS_KEY, allowedValues = {"false"}) }) diff --git a/data-prepper-plugins/mutate-event-processors/src/main/java/org/opensearch/dataprepper/plugins/processor/mutateevent/ConvertEntryTypeProcessorConfig.java b/data-prepper-plugins/mutate-event-processors/src/main/java/org/opensearch/dataprepper/plugins/processor/mutateevent/ConvertEntryTypeProcessorConfig.java index f184e02896..0d938792c1 100644 --- a/data-prepper-plugins/mutate-event-processors/src/main/java/org/opensearch/dataprepper/plugins/processor/mutateevent/ConvertEntryTypeProcessorConfig.java +++ b/data-prepper-plugins/mutate-event-processors/src/main/java/org/opensearch/dataprepper/plugins/processor/mutateevent/ConvertEntryTypeProcessorConfig.java @@ -9,6 +9,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyOrder; import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import com.fasterxml.jackson.annotation.JsonAlias; import org.opensearch.dataprepper.model.annotations.AlsoRequired; import org.opensearch.dataprepper.model.annotations.ConditionalRequired; import org.opensearch.dataprepper.model.annotations.ConditionalRequired.IfThenElse; @@ -55,7 +56,8 @@ public class ConvertEntryTypeProcessorConfig implements ConverterArguments { @JsonPropertyDescription("Target type for the values. Default value is integer.") private TargetType type = TargetType.INTEGER; - @JsonProperty("null_values") + @JsonProperty("null_conversion_values") + @JsonAlias("null_values") @JsonPropertyDescription("String representation of what constitutes a null value. If the field value equals one of these strings, then the value is considered null and is converted to null.") private List nullValues; diff --git a/data-prepper-plugins/mutate-event-processors/src/main/java/org/opensearch/dataprepper/plugins/processor/mutateevent/CopyValueProcessorConfig.java b/data-prepper-plugins/mutate-event-processors/src/main/java/org/opensearch/dataprepper/plugins/processor/mutateevent/CopyValueProcessorConfig.java index e44c1da74e..d0f991fc33 100644 --- a/data-prepper-plugins/mutate-event-processors/src/main/java/org/opensearch/dataprepper/plugins/processor/mutateevent/CopyValueProcessorConfig.java +++ b/data-prepper-plugins/mutate-event-processors/src/main/java/org/opensearch/dataprepper/plugins/processor/mutateevent/CopyValueProcessorConfig.java @@ -20,7 +20,7 @@ import java.util.List; @JsonPropertyOrder -@JsonClassDescription("The copy_values processor copies values within an event to other fields within the event.") +@JsonClassDescription("The copy_values processor copies values from an event to another key in an event.") public class CopyValueProcessorConfig { static final String FROM_LIST_KEY = "from_list"; static final String TO_LIST_KEY = "to_list"; @@ -30,15 +30,13 @@ public static class Entry { @NotEmpty @NotNull @JsonProperty("from_key") - @JsonPropertyDescription("The key of the entry to be copied. Either from_key and " + - "to_key or from_list and to_list must be defined.") + @JsonPropertyDescription("The key of the entry to be copied. This must be configured.") private String fromKey; @NotEmpty @NotNull @JsonProperty("to_key") - @JsonPropertyDescription("The key of the new entry to be added. Either from_key and " + - "to_key or from_list and to_list must be defined.") + @JsonPropertyDescription("The key of the new entry to be added. This must be configured.") private String toKey; @JsonProperty("overwrite_if_to_key_exists") @@ -88,16 +86,16 @@ public Entry() { private List entries; @JsonProperty(FROM_LIST_KEY) - @JsonPropertyDescription("The key of the list of objects to be copied. Either from_key and " + - "to_key or from_list and to_list must be defined.") + @JsonPropertyDescription("The key of the list of objects to be copied. " + + "Both from_key and to_key must be configured and will be applied on the corresponding list.") @AlsoRequired(values = { @AlsoRequired.Required(name = TO_LIST_KEY) }) private String fromList; @JsonProperty(TO_LIST_KEY) - @JsonPropertyDescription("The key of the new list to be added. Either from_key and " + - "to_key or from_list and to_list must be defined.") + @JsonPropertyDescription("The key of the new list to be added. " + + "Both from_key and to_key must be configured and will be applied on the corresponding list.") @AlsoRequired(values = { @AlsoRequired.Required(name = FROM_LIST_KEY) }) diff --git a/data-prepper-plugins/obfuscate-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/obfuscation/ObfuscationProcessorConfig.java b/data-prepper-plugins/obfuscate-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/obfuscation/ObfuscationProcessorConfig.java index fa0ddf355c..935d99c219 100644 --- a/data-prepper-plugins/obfuscate-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/obfuscation/ObfuscationProcessorConfig.java +++ b/data-prepper-plugins/obfuscate-processor/src/main/java/org/opensearch/dataprepper/plugins/processor/obfuscation/ObfuscationProcessorConfig.java @@ -27,7 +27,7 @@ public class ObfuscationProcessorConfig { @JsonProperty("source") - @JsonPropertyDescription("The source field to obfuscate.") + @JsonPropertyDescription("The source key to obfuscate. Default action is to mask with *.") @NotEmpty @NotNull @ExampleValues({ diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/resync/CascadingActionDetector.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/resync/CascadingActionDetector.java index cb7f060976..4bab68b3f4 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/resync/CascadingActionDetector.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/resync/CascadingActionDetector.java @@ -83,9 +83,9 @@ public Map getParentTableMap(StreamPartition streamPartitio /** * Detects if a binlog event contains cascading updates and if detected, creates resync partitions - * @param event event + * @param event binlog event * @param parentTableMap parent table map - * @param tableMetadata table meta data + * @param tableMetadata table metadata */ public void detectCascadingUpdates(Event event, Map parentTableMap, TableMetadata tableMetadata) { final UpdateRowsEventData data = event.getData(); @@ -143,9 +143,9 @@ public void detectCascadingUpdates(Event event, Map parentT /** * Detects if a binlog event contains cascading deletes and if detected, creates resync partitions - * @param event event + * @param event binlog event * @param parentTableMap parent table map - * @param tableMetadata table meta data + * @param tableMetadata table metadata */ public void detectCascadingDeletes(Event event, Map parentTableMap, TableMetadata tableMetadata) { if (parentTableMap.containsKey(tableMetadata.getFullTableName())) { diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java index 2bc21ca786..1612e94ec3 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java @@ -130,7 +130,8 @@ public BinlogEventListener(final StreamPartition streamPartition, this.dbTableMetadata = dbTableMetadata; this.streamCheckpointManager = new StreamCheckpointManager( streamCheckpointer, sourceConfig.isAcknowledgmentsEnabled(), - acknowledgementSetManager, this::stopClient, sourceConfig.getStreamAcknowledgmentTimeout()); + acknowledgementSetManager, this::stopClient, sourceConfig.getStreamAcknowledgmentTimeout(), + sourceConfig.getEngine(), pluginMetrics); streamCheckpointManager.start(); this.cascadeActionDetector = cascadeActionDetector; @@ -200,7 +201,7 @@ void handleRotateEvent(com.github.shyiko.mysql.binlog.event.Event event) { // Trigger a checkpoint update for this rotate when there're no row mutation events being processed if (streamCheckpointManager.getChangeEventStatuses().isEmpty()) { - ChangeEventStatus changeEventStatus = streamCheckpointManager.saveChangeEventsStatus(currentBinlogCoordinate); + ChangeEventStatus changeEventStatus = streamCheckpointManager.saveChangeEventsStatus(currentBinlogCoordinate, 0); if (isAcknowledgmentsEnabled) { changeEventStatus.setAcknowledgmentStatus(ChangeEventStatus.AcknowledgmentStatus.POSITIVE_ACK); } @@ -347,9 +348,10 @@ void handleRowChangeEvent(com.github.shyiko.mysql.binlog.event.Event event, LOG.debug("Current binlog coordinate after receiving a row change event: " + currentBinlogCoordinate); } + final long recordCount = rows.size(); AcknowledgementSet acknowledgementSet = null; if (isAcknowledgmentsEnabled) { - acknowledgementSet = streamCheckpointManager.createAcknowledgmentSet(currentBinlogCoordinate); + acknowledgementSet = streamCheckpointManager.createAcknowledgmentSet(currentBinlogCoordinate, recordCount); } final long bytes = event.toString().getBytes().length; @@ -398,7 +400,7 @@ void handleRowChangeEvent(com.github.shyiko.mysql.binlog.event.Event event, if (isAcknowledgmentsEnabled) { acknowledgementSet.complete(); } else { - streamCheckpointManager.saveChangeEventsStatus(currentBinlogCoordinate); + streamCheckpointManager.saveChangeEventsStatus(currentBinlogCoordinate, recordCount); } } diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/ChangeEventStatus.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/ChangeEventStatus.java index f2b70cbe7b..af6ef02362 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/ChangeEventStatus.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/ChangeEventStatus.java @@ -6,11 +6,14 @@ package org.opensearch.dataprepper.plugins.source.rds.stream; import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; +import org.postgresql.replication.LogSequenceNumber; public class ChangeEventStatus { private final BinlogCoordinate binlogCoordinate; + private final LogSequenceNumber logSequenceNumber; private final long timestamp; + private final long recordCount; private volatile AcknowledgmentStatus acknowledgmentStatus; public enum AcknowledgmentStatus { @@ -19,9 +22,19 @@ public enum AcknowledgmentStatus { NO_ACK } - public ChangeEventStatus(final BinlogCoordinate binlogCoordinate, final long timestamp) { + public ChangeEventStatus(final BinlogCoordinate binlogCoordinate, final long timestamp, final long recordCount) { this.binlogCoordinate = binlogCoordinate; + this.logSequenceNumber = null; this.timestamp = timestamp; + this.recordCount = recordCount; + acknowledgmentStatus = AcknowledgmentStatus.NO_ACK; + } + + public ChangeEventStatus(final LogSequenceNumber logSequenceNumber, final long timestamp, final long recordCount) { + this.binlogCoordinate = null; + this.logSequenceNumber = logSequenceNumber; + this.timestamp = timestamp; + this.recordCount = recordCount; acknowledgmentStatus = AcknowledgmentStatus.NO_ACK; } @@ -45,7 +58,15 @@ public BinlogCoordinate getBinlogCoordinate() { return binlogCoordinate; } + public LogSequenceNumber getLogSequenceNumber() { + return logSequenceNumber; + } + public long getTimestamp() { return timestamp; } + + public long getRecordCount() { + return recordCount; + } } diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClient.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClient.java index 22935fc6e3..8eb3b9cde9 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClient.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClient.java @@ -47,6 +47,7 @@ public LogicalReplicationClient(final ConnectionManager connectionManager, @Override public void connect() { + LOG.debug("Start connecting logical replication stream. "); PGReplicationStream stream; try (Connection conn = connectionManager.getConnection()) { PGConnection pgConnection = conn.unwrap(PGConnection.class); @@ -62,6 +63,7 @@ public void connect() { logicalStreamBuilder.withStartPosition(startLsn); } stream = logicalStreamBuilder.start(); + LOG.debug("Logical replication stream started. "); if (eventProcessor != null) { while (!disconnectRequested) { @@ -88,7 +90,8 @@ public void connect() { } stream.close(); - LOG.info("Replication stream closed successfully."); + disconnectRequested = false; + LOG.debug("Replication stream closed successfully."); } catch (Exception e) { LOG.error("Exception while creating Postgres replication stream. ", e); } @@ -97,6 +100,7 @@ public void connect() { @Override public void disconnect() { disconnectRequested = true; + LOG.debug("Requested to disconnect logical replication stream."); } public void setEventProcessor(LogicalReplicationEventProcessor eventProcessor) { diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessor.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessor.java index 3d5c1a04b1..a2a9aa1017 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessor.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessor.java @@ -10,7 +10,13 @@ package org.opensearch.dataprepper.plugins.source.rds.stream; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.DistributionSummary; +import io.micrometer.core.instrument.Timer; import org.opensearch.dataprepper.buffer.common.BufferAccumulator; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.event.JacksonEvent; @@ -23,6 +29,7 @@ import org.opensearch.dataprepper.plugins.source.rds.datatype.postgres.ColumnType; import org.opensearch.dataprepper.plugins.source.rds.model.MessageType; import org.opensearch.dataprepper.plugins.source.rds.model.TableMetadata; +import org.postgresql.replication.LogSequenceNumber; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,6 +39,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Consumer; public class LogicalReplicationEventProcessor { enum TupleDataType { @@ -63,6 +71,14 @@ public static TupleDataType fromValue(char value) { static final Duration BUFFER_TIMEOUT = Duration.ofSeconds(60); static final int DEFAULT_BUFFER_BATCH_SIZE = 1_000; + static final int NUM_OF_RETRIES = 3; + static final int BACKOFF_IN_MILLIS = 500; + static final String CHANGE_EVENTS_PROCESSED_COUNT = "changeEventsProcessed"; + static final String CHANGE_EVENTS_PROCESSING_ERROR_COUNT = "changeEventsProcessingErrors"; + static final String BYTES_RECEIVED = "bytesReceived"; + static final String BYTES_PROCESSED = "bytesProcessed"; + static final String REPLICATION_LOG_EVENT_PROCESSING_TIME = "replicationLogEntryProcessingTime"; + static final String REPLICATION_LOG_PROCESSING_ERROR_COUNT = "replicationLogEntryProcessingErrors"; private final StreamPartition streamPartition; private final RdsSourceConfig sourceConfig; @@ -70,24 +86,57 @@ public static TupleDataType fromValue(char value) { private final Buffer> buffer; private final BufferAccumulator> bufferAccumulator; private final List pipelineEvents; + private final PluginMetrics pluginMetrics; + private final AcknowledgementSetManager acknowledgementSetManager; + private final LogicalReplicationClient logicalReplicationClient; + private final StreamCheckpointer streamCheckpointer; + private final StreamCheckpointManager streamCheckpointManager; + + private final Counter changeEventSuccessCounter; + private final Counter changeEventErrorCounter; + private final DistributionSummary bytesReceivedSummary; + private final DistributionSummary bytesProcessedSummary; + private final Timer eventProcessingTimer; + private final Counter eventProcessingErrorCounter; private long currentLsn; private long currentEventTimestamp; + private long bytesReceived; private Map tableMetadataMap; public LogicalReplicationEventProcessor(final StreamPartition streamPartition, final RdsSourceConfig sourceConfig, final Buffer> buffer, - final String s3Prefix) { + final String s3Prefix, + final PluginMetrics pluginMetrics, + final LogicalReplicationClient logicalReplicationClient, + final StreamCheckpointer streamCheckpointer, + final AcknowledgementSetManager acknowledgementSetManager) { this.streamPartition = streamPartition; this.sourceConfig = sourceConfig; recordConverter = new StreamRecordConverter(s3Prefix, sourceConfig.getPartitionCount()); this.buffer = buffer; bufferAccumulator = BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT); + this.pluginMetrics = pluginMetrics; + this.acknowledgementSetManager = acknowledgementSetManager; + this.logicalReplicationClient = logicalReplicationClient; + this.streamCheckpointer = streamCheckpointer; + streamCheckpointManager = new StreamCheckpointManager( + streamCheckpointer, sourceConfig.isAcknowledgmentsEnabled(), + acknowledgementSetManager, this::stopClient, sourceConfig.getStreamAcknowledgmentTimeout(), + sourceConfig.getEngine(), pluginMetrics); + streamCheckpointManager.start(); tableMetadataMap = new HashMap<>(); pipelineEvents = new ArrayList<>(); + + changeEventSuccessCounter = pluginMetrics.counter(CHANGE_EVENTS_PROCESSED_COUNT); + changeEventErrorCounter = pluginMetrics.counter(CHANGE_EVENTS_PROCESSING_ERROR_COUNT); + bytesReceivedSummary = pluginMetrics.summary(BYTES_RECEIVED); + bytesProcessedSummary = pluginMetrics.summary(BYTES_PROCESSED); + eventProcessingTimer = pluginMetrics.timer(REPLICATION_LOG_EVENT_PROCESSING_TIME); + eventProcessingErrorCounter = pluginMetrics.counter(REPLICATION_LOG_PROCESSING_ERROR_COUNT); } public void process(ByteBuffer msg) { @@ -97,20 +146,36 @@ public void process(ByteBuffer msg) { // If it's INSERT/UPDATE/DELETE, prepare events // If it's a COMMIT, convert all prepared events and send to buffer MessageType messageType = MessageType.from((char) msg.get()); - if (messageType == MessageType.BEGIN) { - processBeginMessage(msg); - } else if (messageType == MessageType.RELATION) { - processRelationMessage(msg); - } else if (messageType == MessageType.INSERT) { - processInsertMessage(msg); - } else if (messageType == MessageType.UPDATE) { - processUpdateMessage(msg); - } else if (messageType == MessageType.DELETE) { - processDeleteMessage(msg); - } else if (messageType == MessageType.COMMIT) { - processCommitMessage(msg); - } else { - throw new IllegalArgumentException("Replication message type [" + messageType + "] is not supported. "); + switch (messageType) { + case BEGIN: + handleMessageWithRetries(msg, this::processBeginMessage, messageType); + break; + case RELATION: + handleMessageWithRetries(msg, this::processRelationMessage, messageType); + break; + case INSERT: + handleMessageWithRetries(msg, this::processInsertMessage, messageType); + break; + case UPDATE: + handleMessageWithRetries(msg, this::processUpdateMessage, messageType); + break; + case DELETE: + handleMessageWithRetries(msg, this::processDeleteMessage, messageType); + break; + case COMMIT: + handleMessageWithRetries(msg, this::processCommitMessage, messageType); + break; + default: + throw new IllegalArgumentException("Replication message type [" + messageType + "] is not supported. "); + } + } + + public void stopClient() { + try { + logicalReplicationClient.disconnect(); + LOG.info("Binary log client disconnected."); + } catch (Exception e) { + LOG.error("Binary log client failed to disconnect.", e); } } @@ -169,15 +234,28 @@ void processCommitMessage(ByteBuffer msg) { throw new RuntimeException("Commit LSN does not match current LSN, skipping"); } - writeToBuffer(bufferAccumulator); + final long recordCount = pipelineEvents.size(); + AcknowledgementSet acknowledgementSet = null; + if (sourceConfig.isAcknowledgmentsEnabled()) { + acknowledgementSet = streamCheckpointManager.createAcknowledgmentSet(LogSequenceNumber.valueOf(currentLsn), recordCount); + } + + writeToBuffer(bufferAccumulator, acknowledgementSet); + bytesProcessedSummary.record(bytesReceived); LOG.debug("Processed a COMMIT message with Flag: {} CommitLsn: {} EndLsn: {} Timestamp: {}", flag, commitLsn, endLsn, epochMicro); + + if (sourceConfig.isAcknowledgmentsEnabled()) { + acknowledgementSet.complete(); + } else { + streamCheckpointManager.saveChangeEventsStatus(LogSequenceNumber.valueOf(currentLsn), recordCount); + } } void processInsertMessage(ByteBuffer msg) { int tableId = msg.getInt(); char n_char = (char) msg.get(); // Skip the 'N' character - final TableMetadata tableMetadata = tableMetadataMap.get((long)tableId); + final TableMetadata tableMetadata = tableMetadataMap.get((long) tableId); final List columnNames = tableMetadata.getColumnNames(); final List primaryKeys = tableMetadata.getPrimaryKeys(); final long eventTimestampMillis = currentEventTimestamp; @@ -189,7 +267,7 @@ void processInsertMessage(ByteBuffer msg) { void processUpdateMessage(ByteBuffer msg) { final int tableId = msg.getInt(); - final TableMetadata tableMetadata = tableMetadataMap.get((long)tableId); + final TableMetadata tableMetadata = tableMetadataMap.get((long) tableId); final List columnNames = tableMetadata.getColumnNames(); final List primaryKeys = tableMetadata.getPrimaryKeys(); final long eventTimestampMillis = currentEventTimestamp; @@ -231,7 +309,7 @@ void processDeleteMessage(ByteBuffer msg) { int tableId = msg.getInt(); char n_char = (char) msg.get(); // Skip the 'N' character - final TableMetadata tableMetadata = tableMetadataMap.get((long)tableId); + final TableMetadata tableMetadata = tableMetadataMap.get((long) tableId); final List columnNames = tableMetadata.getColumnNames(); final List primaryKeys = tableMetadata.getPrimaryKeys(); final long eventTimestampMillis = currentEventTimestamp; @@ -242,6 +320,8 @@ void processDeleteMessage(ByteBuffer msg) { private void doProcess(ByteBuffer msg, List columnNames, TableMetadata tableMetadata, List primaryKeys, long eventTimestampMillis, OpenSearchBulkActions bulkAction) { + bytesReceived = msg.capacity(); + bytesReceivedSummary.record(bytesReceived); Map rowDataMap = getRowDataMap(msg, columnNames); createPipelineEvent(rowDataMap, tableMetadata, primaryKeys, eventTimestampMillis, bulkAction); @@ -284,9 +364,12 @@ private void createPipelineEvent(Map rowDataMap, TableMetadata t pipelineEvents.add(pipelineEvent); } - private void writeToBuffer(BufferAccumulator> bufferAccumulator) { + private void writeToBuffer(BufferAccumulator> bufferAccumulator, AcknowledgementSet acknowledgementSet) { for (Event pipelineEvent : pipelineEvents) { addToBufferAccumulator(bufferAccumulator, new Record<>(pipelineEvent)); + if (acknowledgementSet != null) { + acknowledgementSet.add(pipelineEvent); + } } flushBufferAccumulator(bufferAccumulator, pipelineEvents.size()); @@ -304,10 +387,12 @@ private void addToBufferAccumulator(final BufferAccumulator> buffe private void flushBufferAccumulator(BufferAccumulator> bufferAccumulator, int eventCount) { try { bufferAccumulator.flush(); + changeEventSuccessCounter.increment(eventCount); } catch (Exception e) { // this will only happen if writing to buffer gets interrupted from shutdown, // otherwise bufferAccumulator will keep retrying with backoff LOG.error("Failed to flush buffer", e); + changeEventErrorCounter.increment(eventCount); } } @@ -333,4 +418,28 @@ private List getPrimaryKeys(String schemaName, String tableName) { return progressState.getPrimaryKeyMap().get(databaseName + "." + schemaName + "." + tableName); } + + private void handleMessageWithRetries(ByteBuffer message, Consumer function, MessageType messageType) { + int retry = 0; + while (retry <= NUM_OF_RETRIES) { + try { + eventProcessingTimer.record(() -> function.accept(message)); + return; + } catch (Exception e) { + LOG.warn("Error when processing change event of type {}, will retry", messageType, e); + applyBackoff(); + } + retry++; + } + LOG.error("Failed to process change event of type {} after {} retries", messageType, NUM_OF_RETRIES); + eventProcessingErrorCounter.increment(); + } + + private void applyBackoff() { + try { + Thread.sleep(BACKOFF_IN_MILLIS); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + } + } } diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointManager.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointManager.java index 3827f2b822..3880707e21 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointManager.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointManager.java @@ -5,9 +5,13 @@ package org.opensearch.dataprepper.plugins.source.rds.stream; +import io.micrometer.core.instrument.Counter; +import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType; import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; +import org.postgresql.replication.LogSequenceNumber; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -21,6 +25,11 @@ public class StreamCheckpointManager { private static final Logger LOG = LoggerFactory.getLogger(StreamCheckpointManager.class); static final int REGULAR_CHECKPOINT_INTERVAL_MILLIS = 60_000; static final int CHANGE_EVENT_COUNT_PER_CHECKPOINT_BATCH = 1000; + static final String POSITIVE_ACKNOWLEDGEMENT_SET_METRIC_NAME = "positiveAcknowledgementSets"; + static final String NEGATIVE_ACKNOWLEDGEMENT_SET_METRIC_NAME = "negativeAcknowledgementSets"; + static final String CHECKPOINT_COUNT = "checkpointCount"; + static final String NO_DATA_EXTEND_LEASE_COUNT = "noDataExtendLeaseCount"; + static final String GIVE_UP_PARTITION_COUNT = "giveupPartitionCount"; private final ConcurrentLinkedQueue changeEventStatuses = new ConcurrentLinkedQueue<>(); private final StreamCheckpointer streamCheckpointer; @@ -29,18 +38,35 @@ public class StreamCheckpointManager { private final boolean isAcknowledgmentEnabled; private final AcknowledgementSetManager acknowledgementSetManager; private final Duration acknowledgmentTimeout; + private final EngineType engineType; + private final PluginMetrics pluginMetrics; + private final Counter positiveAcknowledgementSets; + private final Counter negativeAcknowledgementSets; + private final Counter checkpointCount; + private final Counter noDataExtendLeaseCount; + private final Counter giveupPartitionCount; public StreamCheckpointManager(final StreamCheckpointer streamCheckpointer, final boolean isAcknowledgmentEnabled, final AcknowledgementSetManager acknowledgementSetManager, final Runnable stopStreamRunnable, - final Duration acknowledgmentTimeout) { + final Duration acknowledgmentTimeout, + final EngineType engineType, + final PluginMetrics pluginMetrics) { this.acknowledgementSetManager = acknowledgementSetManager; this.streamCheckpointer = streamCheckpointer; this.isAcknowledgmentEnabled = isAcknowledgmentEnabled; this.stopStreamRunnable = stopStreamRunnable; this.acknowledgmentTimeout = acknowledgmentTimeout; + this.engineType = engineType; + this.pluginMetrics = pluginMetrics; executorService = Executors.newSingleThreadExecutor(); + + this.positiveAcknowledgementSets = pluginMetrics.counter(POSITIVE_ACKNOWLEDGEMENT_SET_METRIC_NAME); + this.negativeAcknowledgementSets = pluginMetrics.counter(NEGATIVE_ACKNOWLEDGEMENT_SET_METRIC_NAME); + this.checkpointCount = pluginMetrics.counter(CHECKPOINT_COUNT); + this.noDataExtendLeaseCount = pluginMetrics.counter(NO_DATA_EXTEND_LEASE_COUNT); + this.giveupPartitionCount = pluginMetrics.counter(GIVE_UP_PARTITION_COUNT); } public void start() { @@ -54,6 +80,7 @@ void runCheckpointing() { try { if (changeEventStatuses.isEmpty()) { LOG.debug("No records processed. Extend the lease on stream partition."); + noDataExtendLeaseCount.increment(); streamCheckpointer.extendLease(); } else { if (isAcknowledgmentEnabled) { @@ -65,13 +92,14 @@ void runCheckpointing() { } if (lastChangeEventStatus != null) { - streamCheckpointer.checkpoint(lastChangeEventStatus.getBinlogCoordinate()); + checkpoint(engineType, lastChangeEventStatus); } // If negative ack is seen, give up partition and exit loop to stop processing stream if (currentChangeEventStatus != null && currentChangeEventStatus.isNegativeAcknowledgment()) { LOG.info("Received negative acknowledgement for change event at {}. Will restart from most recent checkpoint", currentChangeEventStatus.getBinlogCoordinate()); streamCheckpointer.giveUpPartition(); + giveupPartitionCount.increment(); break; } } else { @@ -81,10 +109,10 @@ void runCheckpointing() { changeEventCount++; // In case queue are populated faster than the poll, checkpoint when reaching certain count if (changeEventCount % CHANGE_EVENT_COUNT_PER_CHECKPOINT_BATCH == 0) { - streamCheckpointer.checkpoint(currentChangeEventStatus.getBinlogCoordinate()); + checkpoint(engineType, currentChangeEventStatus); } } while (!changeEventStatuses.isEmpty()); - streamCheckpointer.checkpoint(currentChangeEventStatus.getBinlogCoordinate()); + checkpoint(engineType, currentChangeEventStatus); } } } catch (Exception e) { @@ -107,25 +135,52 @@ public void stop() { executorService.shutdownNow(); } - public ChangeEventStatus saveChangeEventsStatus(BinlogCoordinate binlogCoordinate) { - final ChangeEventStatus changeEventStatus = new ChangeEventStatus(binlogCoordinate, Instant.now().toEpochMilli()); + public ChangeEventStatus saveChangeEventsStatus(BinlogCoordinate binlogCoordinate, long recordCount) { + final ChangeEventStatus changeEventStatus = new ChangeEventStatus(binlogCoordinate, Instant.now().toEpochMilli(), recordCount); changeEventStatuses.add(changeEventStatus); return changeEventStatus; } - public AcknowledgementSet createAcknowledgmentSet(BinlogCoordinate binlogCoordinate) { + public ChangeEventStatus saveChangeEventsStatus(LogSequenceNumber logSequenceNumber, long recordCount) { + final ChangeEventStatus changeEventStatus = new ChangeEventStatus(logSequenceNumber, Instant.now().toEpochMilli(), recordCount); + changeEventStatuses.add(changeEventStatus); + return changeEventStatus; + } + + public AcknowledgementSet createAcknowledgmentSet(BinlogCoordinate binlogCoordinate, long recordCount) { LOG.debug("Create acknowledgment set for events receive prior to {}", binlogCoordinate); - final ChangeEventStatus changeEventStatus = new ChangeEventStatus(binlogCoordinate, Instant.now().toEpochMilli()); + final ChangeEventStatus changeEventStatus = new ChangeEventStatus(binlogCoordinate, Instant.now().toEpochMilli(), recordCount); changeEventStatuses.add(changeEventStatus); + return getAcknowledgementSet(changeEventStatus); + } + + public AcknowledgementSet createAcknowledgmentSet(LogSequenceNumber logSequenceNumber, long recordCount) { + LOG.debug("Create acknowledgment set for events receive prior to {}", logSequenceNumber); + final ChangeEventStatus changeEventStatus = new ChangeEventStatus(logSequenceNumber, Instant.now().toEpochMilli(), recordCount); + changeEventStatuses.add(changeEventStatus); + return getAcknowledgementSet(changeEventStatus); + } + + private AcknowledgementSet getAcknowledgementSet(ChangeEventStatus changeEventStatus) { return acknowledgementSetManager.create((result) -> { - if (result) { - changeEventStatus.setAcknowledgmentStatus(ChangeEventStatus.AcknowledgmentStatus.POSITIVE_ACK); - } else { - changeEventStatus.setAcknowledgmentStatus(ChangeEventStatus.AcknowledgmentStatus.NEGATIVE_ACK); - } + if (result) { + positiveAcknowledgementSets.increment(); + changeEventStatus.setAcknowledgmentStatus(ChangeEventStatus.AcknowledgmentStatus.POSITIVE_ACK); + } else { + negativeAcknowledgementSets.increment(); + changeEventStatus.setAcknowledgmentStatus(ChangeEventStatus.AcknowledgmentStatus.NEGATIVE_ACK); + } }, acknowledgmentTimeout); } + private void checkpoint(final EngineType engineType, final ChangeEventStatus changeEventStatus) { + LOG.debug("Checkpoint at {} with record count {}. ", engineType == EngineType.MYSQL ? + changeEventStatus.getBinlogCoordinate() : changeEventStatus.getLogSequenceNumber(), + changeEventStatus.getRecordCount()); + streamCheckpointer.checkpoint(engineType, changeEventStatus); + checkpointCount.increment(); + } + //VisibleForTesting ConcurrentLinkedQueue getChangeEventStatuses() { return changeEventStatuses; diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointer.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointer.java index 1f60f9715f..2875bf5544 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointer.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointer.java @@ -8,9 +8,11 @@ import io.micrometer.core.instrument.Counter; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; +import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; import org.opensearch.dataprepper.plugins.source.rds.coordination.state.StreamProgressState; import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; +import org.postgresql.replication.LogSequenceNumber; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -38,7 +40,17 @@ public StreamCheckpointer(final EnhancedSourceCoordinator sourceCoordinator, checkpointCounter = pluginMetrics.counter(CHECKPOINT_COUNT); } - public void checkpoint(final BinlogCoordinate binlogCoordinate) { + public void checkpoint(final EngineType engineType, final ChangeEventStatus changeEventStatus) { + if (engineType == EngineType.MYSQL) { + checkpoint(changeEventStatus.getBinlogCoordinate()); + } else if (engineType == EngineType.POSTGRES) { + checkpoint(changeEventStatus.getLogSequenceNumber()); + } else { + throw new IllegalArgumentException("Unsupported engine type " + engineType); + } + } + + private void checkpoint(final BinlogCoordinate binlogCoordinate) { LOG.debug("Checkpointing stream partition {} with binlog coordinate {}", streamPartition.getPartitionKey(), binlogCoordinate); Optional progressState = streamPartition.getProgressState(); progressState.get().getMySqlStreamState().setCurrentPosition(binlogCoordinate); @@ -46,6 +58,14 @@ public void checkpoint(final BinlogCoordinate binlogCoordinate) { checkpointCounter.increment(); } + private void checkpoint(final LogSequenceNumber logSequenceNumber) { + LOG.debug("Checkpointing stream partition {} with log sequence number {}", streamPartition.getPartitionKey(), logSequenceNumber); + Optional progressState = streamPartition.getProgressState(); + progressState.get().getPostgresStreamState().setCurrentLsn(logSequenceNumber.asString()); + sourceCoordinator.saveProgressStateForPartition(streamPartition, CHECKPOINT_OWNERSHIP_TIMEOUT_INCREASE); + checkpointCounter.increment(); + } + public void extendLease() { LOG.debug("Extending lease of stream partition {}", streamPartition.getPartitionKey()); sourceCoordinator.saveProgressStateForPartition(streamPartition, CHECKPOINT_OWNERSHIP_TIMEOUT_INCREASE); diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresher.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresher.java index acd8d0535f..7d89855365 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresher.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresher.java @@ -130,8 +130,8 @@ private void refreshTask(RdsSourceConfig sourceConfig) { } else { final LogicalReplicationClient logicalReplicationClient = (LogicalReplicationClient) replicationLogClient; logicalReplicationClient.setEventProcessor(new LogicalReplicationEventProcessor( - streamPartition, sourceConfig, buffer, s3Prefix - )); + streamPartition, sourceConfig, buffer, s3Prefix, pluginMetrics, logicalReplicationClient, + streamCheckpointer, acknowledgementSetManager)); } final StreamWorker streamWorker = StreamWorker.create(sourceCoordinator, replicationLogClient, pluginMetrics); executorService.submit(() -> streamWorker.processStream(streamPartition)); @@ -150,4 +150,3 @@ private DbTableMetadata getDBTableMetadata(final StreamPartition streamPartition return DbTableMetadata.fromMap(globalState.getProgressState().get()); } } - diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/export/DataFileLoaderTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/export/DataFileLoaderTest.java index 6eeedfcd0f..efc831acfd 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/export/DataFileLoaderTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/export/DataFileLoaderTest.java @@ -140,7 +140,7 @@ void test_run_success() throws Exception { ParquetReader parquetReader = mock(ParquetReader.class); BufferAccumulator> bufferAccumulator = mock(BufferAccumulator.class); when(builder.build()).thenReturn(parquetReader); - when(parquetReader.read()).thenReturn(mock(GenericRecord.class, RETURNS_DEEP_STUBS), null); + when(parquetReader.read()).thenReturn(mock(GenericRecord.class, RETURNS_DEEP_STUBS), (GenericRecord) null); try (MockedStatic readerMockedStatic = mockStatic(AvroParquetReader.class); MockedStatic bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class)) { @@ -191,7 +191,7 @@ void test_flush_failure_then_error_metric_updated() throws Exception { BufferAccumulator> bufferAccumulator = mock(BufferAccumulator.class); doThrow(new RuntimeException("testing")).when(bufferAccumulator).flush(); when(builder.build()).thenReturn(parquetReader); - when(parquetReader.read()).thenReturn(mock(GenericRecord.class, RETURNS_DEEP_STUBS), null); + when(parquetReader.read()).thenReturn(mock(GenericRecord.class, RETURNS_DEEP_STUBS), (GenericRecord) null); try (MockedStatic readerMockedStatic = mockStatic(AvroParquetReader.class); MockedStatic bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class)) { readerMockedStatic.when(() -> AvroParquetReader.builder(any(InputFile.class), any())).thenReturn(builder); diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClientTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClientTest.java index 9cd410ee44..45897335b5 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClientTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClientTest.java @@ -33,7 +33,9 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) @@ -87,6 +89,92 @@ void test_connect() throws SQLException, InterruptedException { verify(stream).setFlushedLSN(lsn); } + @Test + void test_disconnect() throws SQLException, InterruptedException { + final Connection connection = mock(Connection.class); + final PGConnection pgConnection = mock(PGConnection.class, RETURNS_DEEP_STUBS); + final ChainedLogicalStreamBuilder logicalStreamBuilder = mock(ChainedLogicalStreamBuilder.class); + final PGReplicationStream stream = mock(PGReplicationStream.class); + final ByteBuffer message = ByteBuffer.allocate(0); + final LogSequenceNumber lsn = mock(LogSequenceNumber.class); + + when(connectionManager.getConnection()).thenReturn(connection); + when(connection.unwrap(PGConnection.class)).thenReturn(pgConnection); + when(pgConnection.getReplicationAPI().replicationStream().logical()).thenReturn(logicalStreamBuilder); + when(logicalStreamBuilder.withSlotName(anyString())).thenReturn(logicalStreamBuilder); + when(logicalStreamBuilder.withSlotOption(anyString(), anyString())).thenReturn(logicalStreamBuilder); + when(logicalStreamBuilder.start()).thenReturn(stream); + when(stream.readPending()).thenReturn(message).thenReturn(null); + when(stream.getLastReceiveLSN()).thenReturn(lsn); + + final ExecutorService executorService = Executors.newSingleThreadExecutor(); + executorService.submit(() -> logicalReplicationClient.connect()); + + await().atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> verify(eventProcessor).process(message)); + Thread.sleep(20); + verify(stream).setAppliedLSN(lsn); + verify(stream).setFlushedLSN(lsn); + + logicalReplicationClient.disconnect(); + Thread.sleep(20); + verify(stream).close(); + verifyNoMoreInteractions(stream, eventProcessor); + + executorService.shutdownNow(); + } + + @Test + void test_connect_disconnect_cycles() throws SQLException, InterruptedException { + final Connection connection = mock(Connection.class); + final PGConnection pgConnection = mock(PGConnection.class, RETURNS_DEEP_STUBS); + final ChainedLogicalStreamBuilder logicalStreamBuilder = mock(ChainedLogicalStreamBuilder.class); + final PGReplicationStream stream = mock(PGReplicationStream.class); + final ByteBuffer message = ByteBuffer.allocate(0); + final LogSequenceNumber lsn = mock(LogSequenceNumber.class); + + when(connectionManager.getConnection()).thenReturn(connection); + when(connection.unwrap(PGConnection.class)).thenReturn(pgConnection); + when(pgConnection.getReplicationAPI().replicationStream().logical()).thenReturn(logicalStreamBuilder); + when(logicalStreamBuilder.withSlotName(anyString())).thenReturn(logicalStreamBuilder); + when(logicalStreamBuilder.withSlotOption(anyString(), anyString())).thenReturn(logicalStreamBuilder); + when(logicalStreamBuilder.start()).thenReturn(stream); + when(stream.readPending()).thenReturn(message).thenReturn(null); + when(stream.getLastReceiveLSN()).thenReturn(lsn); + + // First connect + final ExecutorService executorService = Executors.newSingleThreadExecutor(); + executorService.submit(() -> logicalReplicationClient.connect()); + await().atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> verify(eventProcessor, times(1)).process(message)); + Thread.sleep(20); + verify(stream).setAppliedLSN(lsn); + verify(stream).setFlushedLSN(lsn); + + // First disconnect + logicalReplicationClient.disconnect(); + Thread.sleep(20); + verify(stream).close(); + verifyNoMoreInteractions(stream, eventProcessor); + + // Second connect + when(stream.readPending()).thenReturn(message).thenReturn(null); + executorService.submit(() -> logicalReplicationClient.connect()); + await().atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> verify(eventProcessor, times(2)).process(message)); + Thread.sleep(20); + verify(stream, times(2)).setAppliedLSN(lsn); + verify(stream, times(2)).setFlushedLSN(lsn); + + // Second disconnect + logicalReplicationClient.disconnect(); + Thread.sleep(20); + verify(stream, times(2)).close(); + verifyNoMoreInteractions(stream, eventProcessor); + + executorService.shutdownNow(); + } + private LogicalReplicationClient createObjectUnderTest() { return new LogicalReplicationClient(connectionManager, replicationSlotName, publicationName); } diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessorTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessorTest.java index 31ec9618a2..90e8149319 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessorTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessorTest.java @@ -10,12 +10,15 @@ package org.opensearch.dataprepper.plugins.source.rds.stream; +import io.micrometer.core.instrument.Metrics; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Answers; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; @@ -28,9 +31,11 @@ import java.util.UUID; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) class LogicalReplicationEventProcessorTest { @@ -44,6 +49,18 @@ class LogicalReplicationEventProcessorTest { @Mock private Buffer> buffer; + @Mock + private PluginMetrics pluginMetrics; + + @Mock + private LogicalReplicationClient logicalReplicationClient; + + @Mock + private StreamCheckpointer streamCheckpointer; + + @Mock + private AcknowledgementSetManager acknowledgementSetManager; + private ByteBuffer message; private String s3Prefix; @@ -56,6 +73,8 @@ class LogicalReplicationEventProcessorTest { void setUp() { s3Prefix = UUID.randomUUID().toString(); random = new Random(); + when(pluginMetrics.timer(anyString())).thenReturn(Metrics.timer("test-timer")); + when(pluginMetrics.counter(anyString())).thenReturn(Metrics.counter("test-counter")); objectUnderTest = spy(createObjectUnderTest()); } @@ -129,8 +148,15 @@ void test_unsupported_message_type_throws_exception() { assertThrows(IllegalArgumentException.class, () -> objectUnderTest.process(message)); } + @Test + void test_stopClient() { + objectUnderTest.stopClient(); + verify(logicalReplicationClient).disconnect(); + } + private LogicalReplicationEventProcessor createObjectUnderTest() { - return new LogicalReplicationEventProcessor(streamPartition, sourceConfig, buffer, s3Prefix); + return new LogicalReplicationEventProcessor(streamPartition, sourceConfig, buffer, s3Prefix, pluginMetrics, + logicalReplicationClient, streamCheckpointer, acknowledgementSetManager); } private void setMessageType(MessageType messageType) { diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointManagerTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointManagerTest.java index deddb45e32..1b32639daf 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointManagerTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointManagerTest.java @@ -11,10 +11,14 @@ import org.mockito.Mock; import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType; import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; +import org.postgresql.replication.LogSequenceNumber; import java.time.Duration; +import java.util.Random; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.function.Consumer; @@ -42,11 +46,16 @@ class StreamCheckpointManagerTest { @Mock private Runnable stopStreamRunnable; + @Mock + private PluginMetrics pluginMetrics; + private boolean isAcknowledgmentEnabled = false; + private EngineType engineType = EngineType.MYSQL; + private Random random; @BeforeEach void setUp() { - + random = new Random(); } @Test @@ -76,29 +85,65 @@ void test_shutdown() { } @Test - void test_saveChangeEventsStatus() { + void test_saveChangeEventsStatus_mysql() { final BinlogCoordinate binlogCoordinate = mock(BinlogCoordinate.class); + final long recordCount = random.nextLong(); + final StreamCheckpointManager streamCheckpointManager = createObjectUnderTest(); + + streamCheckpointManager.saveChangeEventsStatus(binlogCoordinate, recordCount); + + assertThat(streamCheckpointManager.getChangeEventStatuses().size(), is(1)); + final ChangeEventStatus changeEventStatus = streamCheckpointManager.getChangeEventStatuses().peek(); + assertThat(changeEventStatus.getBinlogCoordinate(), is(binlogCoordinate)); + assertThat(changeEventStatus.getRecordCount(), is(recordCount)); + } + + @Test + void test_saveChangeEventsStatus_postgres() { + final LogSequenceNumber logSequenceNumber = mock(LogSequenceNumber.class); + engineType = EngineType.POSTGRES; + final long recordCount = random.nextLong(); final StreamCheckpointManager streamCheckpointManager = createObjectUnderTest(); - streamCheckpointManager.saveChangeEventsStatus(binlogCoordinate); + + streamCheckpointManager.saveChangeEventsStatus(logSequenceNumber, recordCount); assertThat(streamCheckpointManager.getChangeEventStatuses().size(), is(1)); - assertThat(streamCheckpointManager.getChangeEventStatuses().peek().getBinlogCoordinate(), is(binlogCoordinate)); + final ChangeEventStatus changeEventStatus = streamCheckpointManager.getChangeEventStatuses().peek(); + assertThat(changeEventStatus.getLogSequenceNumber(), is(logSequenceNumber)); + assertThat(changeEventStatus.getRecordCount(), is(recordCount)); } @Test - void test_createAcknowledgmentSet() { + void test_createAcknowledgmentSet_mysql() { final BinlogCoordinate binlogCoordinate = mock(BinlogCoordinate.class); + final long recordCount = random.nextLong(); final StreamCheckpointManager streamCheckpointManager = createObjectUnderTest(); - streamCheckpointManager.createAcknowledgmentSet(binlogCoordinate); + streamCheckpointManager.createAcknowledgmentSet(binlogCoordinate, recordCount); assertThat(streamCheckpointManager.getChangeEventStatuses().size(), is(1)); ChangeEventStatus changeEventStatus = streamCheckpointManager.getChangeEventStatuses().peek(); assertThat(changeEventStatus.getBinlogCoordinate(), is(binlogCoordinate)); + assertThat(changeEventStatus.getRecordCount(), is(recordCount)); + verify(acknowledgementSetManager).create(any(Consumer.class), eq(ACK_TIMEOUT)); + } + + @Test + void test_createAcknowledgmentSet_postgres() { + final LogSequenceNumber logSequenceNumber = mock(LogSequenceNumber.class); + engineType = EngineType.POSTGRES; + final long recordCount = random.nextLong(); + final StreamCheckpointManager streamCheckpointManager = createObjectUnderTest(); + streamCheckpointManager.createAcknowledgmentSet(logSequenceNumber, recordCount); + + assertThat(streamCheckpointManager.getChangeEventStatuses().size(), is(1)); + ChangeEventStatus changeEventStatus = streamCheckpointManager.getChangeEventStatuses().peek(); + assertThat(changeEventStatus.getLogSequenceNumber(), is(logSequenceNumber)); + assertThat(changeEventStatus.getRecordCount(), is(recordCount)); verify(acknowledgementSetManager).create(any(Consumer.class), eq(ACK_TIMEOUT)); } private StreamCheckpointManager createObjectUnderTest() { return new StreamCheckpointManager( - streamCheckpointer, isAcknowledgmentEnabled, acknowledgementSetManager, stopStreamRunnable, ACK_TIMEOUT); + streamCheckpointer, isAcknowledgmentEnabled, acknowledgementSetManager, stopStreamRunnable, ACK_TIMEOUT, engineType, pluginMetrics); } -} \ No newline at end of file +} diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointerTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointerTest.java index 3327e847f5..75f16ac9fc 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointerTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointerTest.java @@ -13,10 +13,13 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; +import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; import org.opensearch.dataprepper.plugins.source.rds.coordination.state.MySqlStreamState; +import org.opensearch.dataprepper.plugins.source.rds.coordination.state.PostgresStreamState; import org.opensearch.dataprepper.plugins.source.rds.coordination.state.StreamProgressState; import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; +import org.postgresql.replication.LogSequenceNumber; import java.util.Optional; @@ -39,12 +42,18 @@ class StreamCheckpointerTest { @Mock private MySqlStreamState mySqlStreamState; + @Mock + private PostgresStreamState postgresStreamState; + @Mock private PluginMetrics pluginMetrics; @Mock private Counter checkpointCounter; + @Mock + private ChangeEventStatus changeEventStatus; + private StreamCheckpointer streamCheckpointer; @@ -55,19 +64,35 @@ void setUp() { } @Test - void test_checkpoint() { + void test_checkpoint_mysql() { final BinlogCoordinate binlogCoordinate = mock(BinlogCoordinate.class); final StreamProgressState streamProgressState = mock(StreamProgressState.class); when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState)); when(streamProgressState.getMySqlStreamState()).thenReturn(mySqlStreamState); + when(changeEventStatus.getBinlogCoordinate()).thenReturn(binlogCoordinate); - streamCheckpointer.checkpoint(binlogCoordinate); + streamCheckpointer.checkpoint(EngineType.MYSQL, changeEventStatus); verify(mySqlStreamState).setCurrentPosition(binlogCoordinate); verify(sourceCoordinator).saveProgressStateForPartition(streamPartition, CHECKPOINT_OWNERSHIP_TIMEOUT_INCREASE); verify(checkpointCounter).increment(); } + @Test + void test_checkpoint_postgres() { + final LogSequenceNumber logSequenceNumber = mock(LogSequenceNumber.class); + final StreamProgressState streamProgressState = mock(StreamProgressState.class); + when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState)); + when(streamProgressState.getPostgresStreamState()).thenReturn(postgresStreamState); + when(changeEventStatus.getLogSequenceNumber()).thenReturn(logSequenceNumber); + + streamCheckpointer.checkpoint(EngineType.POSTGRES, changeEventStatus); + + verify(postgresStreamState).setCurrentLsn(logSequenceNumber.asString()); + verify(sourceCoordinator).saveProgressStateForPartition(streamPartition, CHECKPOINT_OWNERSHIP_TIMEOUT_INCREASE); + verify(checkpointCounter).increment(); + } + @Test void test_extendLease() { streamCheckpointer.extendLease(); diff --git a/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/OnErrorOption.java b/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/OnErrorOption.java new file mode 100644 index 0000000000..f3244f9014 --- /dev/null +++ b/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/OnErrorOption.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.sqs.common; + +import com.fasterxml.jackson.annotation.JsonCreator; + +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; + +public enum OnErrorOption { + DELETE_MESSAGES("delete_messages"), + RETAIN_MESSAGES("retain_messages"); + + private static final Map OPTIONS_MAP = Arrays.stream(OnErrorOption.values()) + .collect(Collectors.toMap( + value -> value.option, + value -> value + )); + + private final String option; + + OnErrorOption(final String option) { + this.option = option; + } + + @JsonCreator + static OnErrorOption fromOptionValue(final String option) { + return OPTIONS_MAP.get(option); + } +} diff --git a/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsWorkerCommon.java b/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsWorkerCommon.java index 9301574237..bda8318928 100644 --- a/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsWorkerCommon.java +++ b/data-prepper-plugins/sqs-common/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/common/SqsWorkerCommon.java @@ -42,7 +42,6 @@ public class SqsWorkerCommon { public static final String SQS_VISIBILITY_TIMEOUT_CHANGED_COUNT_METRIC_NAME = "sqsVisibilityTimeoutChangedCount"; public static final String SQS_VISIBILITY_TIMEOUT_CHANGE_FAILED_COUNT_METRIC_NAME = "sqsVisibilityTimeoutChangeFailedCount"; - private final SqsClient sqsClient; private final Backoff standardBackoff; private final PluginMetrics pluginMetrics; private final AcknowledgementSetManager acknowledgementSetManager; @@ -56,18 +55,17 @@ public class SqsWorkerCommon { private final Counter sqsVisibilityTimeoutChangedCount; private final Counter sqsVisibilityTimeoutChangeFailedCount; - public SqsWorkerCommon(final SqsClient sqsClient, - final Backoff standardBackoff, + public SqsWorkerCommon(final Backoff standardBackoff, final PluginMetrics pluginMetrics, final AcknowledgementSetManager acknowledgementSetManager) { - this.sqsClient = sqsClient; this.standardBackoff = standardBackoff; this.pluginMetrics = pluginMetrics; this.acknowledgementSetManager = acknowledgementSetManager; this.isStopped = false; this.failedAttemptCount = 0; + sqsMessagesReceivedCounter = pluginMetrics.counter(SQS_MESSAGES_RECEIVED_METRIC_NAME); sqsMessagesDeletedCounter = pluginMetrics.counter(SQS_MESSAGES_DELETED_METRIC_NAME); sqsMessagesFailedCounter = pluginMetrics.counter(SQS_MESSAGES_FAILED_METRIC_NAME); @@ -78,6 +76,7 @@ public SqsWorkerCommon(final SqsClient sqsClient, } public List pollSqsMessages(final String queueUrl, + final SqsClient sqsClient, final Integer maxNumberOfMessages, final Duration waitTime, final Duration visibilityTimeout) { @@ -134,7 +133,7 @@ public void applyBackoff() { } } - public void deleteSqsMessages(final String queueUrl, final List entries) { + public void deleteSqsMessages(final String queueUrl, final SqsClient sqsClient, final List entries) { if (entries == null || entries.isEmpty() || isStopped) { return; } @@ -146,7 +145,6 @@ public void deleteSqsMessages(final String queueUrl, final List messages = Collections.singletonList(message); + ReceiveMessageResponse response = ReceiveMessageResponse.builder().messages(messages).build(); + when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(response); + List result = sqsWorkerCommon.pollSqsMessages(queueUrl, sqsClient, 10, + Duration.ofSeconds(5), Duration.ofSeconds(30)); + + verify(sqsMessagesReceivedCounter).increment(messages.size()); + assertThat(result, equalTo(messages)); } @Test - void testPollSqsMessages_handlesEmptyList() { + void testPollSqsMessages_exception() { when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))) - .thenReturn(ReceiveMessageResponse.builder() - .messages(Collections.emptyList()) - .build()); - var messages = sqsWorkerCommon.pollSqsMessages( - "testQueueUrl", - 10, - Duration.ofSeconds(5), - Duration.ofSeconds(30) - ); - - assertNotNull(messages); - assertTrue(messages.isEmpty()); - Mockito.verify(sqsClient).receiveMessage(any(ReceiveMessageRequest.class)); - Mockito.verify(backoff, Mockito.never()).nextDelayMillis(Mockito.anyInt()); + .thenThrow(SqsException.builder().message("Error").build()); + when(backoff.nextDelayMillis(anyInt())).thenReturn(100L); + List result = sqsWorkerCommon.pollSqsMessages(queueUrl, sqsClient, 10, + Duration.ofSeconds(5), Duration.ofSeconds(30)); + assertThat(result, is(empty())); + verify(sqsMessagesReceivedCounter, never()).increment(anyDouble()); } @Test - void testDeleteSqsMessages_callsClientWhenNotStopped() { - var entries = Collections.singletonList( - DeleteMessageBatchRequestEntry.builder() - .id("msg-id") - .receiptHandle("receipt-handle") - .build() - ); + void testApplyBackoff_negativeDelayThrowsException() { + when(backoff.nextDelayMillis(anyInt())).thenReturn(-1L); + assertThrows(SqsRetriesExhaustedException.class, () -> sqsWorkerCommon.applyBackoff()); + } - when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))) - .thenReturn(DeleteMessageBatchResponse.builder().build()); - - sqsWorkerCommon.deleteSqsMessages("testQueueUrl", entries); - ArgumentCaptor captor = - ArgumentCaptor.forClass(DeleteMessageBatchRequest.class); - Mockito.verify(sqsClient).deleteMessageBatch(captor.capture()); - assertEquals("testQueueUrl", captor.getValue().queueUrl()); - assertEquals(1, captor.getValue().entries().size()); + @Test + void testApplyBackoff_positiveDelay() { + when(backoff.nextDelayMillis(anyInt())).thenReturn(50L); + assertDoesNotThrow(() -> sqsWorkerCommon.applyBackoff()); + verify(backoff).nextDelayMillis(anyInt()); } @Test - void testStop_skipsFurtherOperations() { + void testDeleteSqsMessages_noEntries() { + sqsWorkerCommon.deleteSqsMessages(queueUrl, sqsClient, null); + verify(sqsClient, never()).deleteMessageBatch(any(DeleteMessageBatchRequest.class)); + + sqsWorkerCommon.deleteSqsMessages(queueUrl, sqsClient, Collections.emptyList()); + verify(sqsClient, never()).deleteMessageBatch(any(DeleteMessageBatchRequest.class)); + sqsWorkerCommon.stop(); - sqsWorkerCommon.deleteSqsMessages("testQueueUrl", Collections.singletonList( - DeleteMessageBatchRequestEntry.builder() - .id("msg-id") - .receiptHandle("receipt-handle") - .build() - )); - Mockito.verify(sqsClient, Mockito.never()).deleteMessageBatch((DeleteMessageBatchRequest) any()); + DeleteMessageBatchRequestEntry entry = DeleteMessageBatchRequestEntry.builder() + .id("id").receiptHandle("rh").build(); + sqsWorkerCommon.deleteSqsMessages(queueUrl, sqsClient, Collections.singletonList(entry)); + verify(sqsClient, never()).deleteMessageBatch(any(DeleteMessageBatchRequest.class)); + } + + @Test + void testDeleteSqsMessages_successfulDeletion_withConsumerBuilder() { + DeleteMessageBatchRequestEntry entry = DeleteMessageBatchRequestEntry.builder() + .id("id") + .receiptHandle("rh") + .build(); + List entries = Collections.singletonList(entry); + DeleteMessageBatchResponse response = DeleteMessageBatchResponse.builder() + .successful(builder -> builder.id("id")) + .failed(Collections.emptyList()) + .build(); + when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))).thenReturn(response); + sqsWorkerCommon.deleteSqsMessages(queueUrl, sqsClient, entries); + verify(sqsMessagesDeletedCounter).increment(1.0); + } + + + @Test + void testDeleteSqsMessages_failedDeletion() { + DeleteMessageBatchRequestEntry entry = DeleteMessageBatchRequestEntry.builder() + .id("id").receiptHandle("rh").build(); + List entries = Collections.singletonList(entry); + + DeleteMessageBatchResponse response = DeleteMessageBatchResponse.builder() + .successful(Collections.emptyList()) + .failed(Collections.singletonList( + BatchResultErrorEntry.builder().id("id").message("Failure").build())) + .build(); + + when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))).thenReturn(response); + sqsWorkerCommon.deleteSqsMessages(queueUrl, sqsClient, entries); + verify(sqsMessagesDeleteFailedCounter).increment(1.0); + } + + @Test + void testDeleteSqsMessages_sdkException() { + DeleteMessageBatchRequestEntry entry1 = DeleteMessageBatchRequestEntry.builder() + .id("id-1").receiptHandle("rh-1").build(); + DeleteMessageBatchRequestEntry entry2 = DeleteMessageBatchRequestEntry.builder() + .id("id-2").receiptHandle("rh-2").build(); + List entries = Arrays.asList(entry1, entry2); + + when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))) + .thenThrow(SdkException.class); + + sqsWorkerCommon.deleteSqsMessages(queueUrl, sqsClient, entries); + verify(sqsMessagesDeleteFailedCounter).increment(entries.size()); + } + + @Test + void testIncreaseVisibilityTimeout_successful() { + String receiptHandle = "rh"; + int newVisibilityTimeout = 45; + String messageId = "msg"; + when(sqsClient.changeMessageVisibility(any(ChangeMessageVisibilityRequest.class))) + .thenReturn(ChangeMessageVisibilityResponse.builder().build()); + + sqsWorkerCommon.increaseVisibilityTimeout(queueUrl, sqsClient, receiptHandle, + newVisibilityTimeout, messageId); + + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(ChangeMessageVisibilityRequest.class); + verify(sqsClient).changeMessageVisibility(requestCaptor.capture()); + ChangeMessageVisibilityRequest request = requestCaptor.getValue(); + assertThat(request.queueUrl(), equalTo(queueUrl)); + assertThat(request.receiptHandle(), equalTo(receiptHandle)); + assertThat(request.visibilityTimeout(), equalTo(newVisibilityTimeout)); + verify(sqsVisibilityTimeoutChangedCount).increment(); + } + + @Test + void testIncreaseVisibilityTimeout_whenException() { + String receiptHandle = "rh"; + int newVisibilityTimeout = 45; + String messageId = "msg"; + doThrow(new RuntimeException("failure")) + .when(sqsClient).changeMessageVisibility(any(ChangeMessageVisibilityRequest.class)); + + sqsWorkerCommon.increaseVisibilityTimeout(queueUrl, sqsClient, receiptHandle, + newVisibilityTimeout, messageId); + + verify(sqsVisibilityTimeoutChangeFailedCount).increment(); } } diff --git a/data-prepper-plugins/sqs-source/README.md b/data-prepper-plugins/sqs-source/README.md index ff4313605f..c3898b396d 100644 --- a/data-prepper-plugins/sqs-source/README.md +++ b/data-prepper-plugins/sqs-source/README.md @@ -2,7 +2,7 @@ This source allows Data Prepper to use SQS as a source. It reads messages from specified SQS queues and processes them into events. -## Example Configuration +## Minimal Configuration ```yaml sqs-pipeline: @@ -10,13 +10,59 @@ sqs-pipeline: sqs: queues: - url: - batch_size: 10 - workers: 1 + aws: + region: + sts_role_arn: + ``` +## Full Configuration + +```yaml +sqs-pipeline: + source: + sqs: + queues: + - url: + workers: 2 + maximum_messages: 10 + poll_delay: 0s + wait_time: 20s + visibility_timeout: 30s + visibility_duplication_protection: true + visibility_duplicate_protection_timeout: "PT1H" + on_error: "retain_messages" + codec: + json: + key_name: "events" - url: - batch_size: 10 - workers: 1 + # This queue will use the defaults for optional properties. + acknowledgments: true aws: region: sts_role_arn: - sink: - - stdout: +``` +## Key Features + +- **Multi-Queue Support:** + Process messages from multiple SQS queues simultaneously. + + +- **Configurable Polling:** + Customize batch size, poll delay, wait time, and visibility timeout per queue. + + +- **Error Handling:** + Use an `on_error` option to control behavior on errors (e.g., delete or retain messages) + + +- **Codec Support:** + Configure codecs (e.g., JSON, CSV, newline-delimited) to parse incoming messages. + + +## IAM Permissions + +Ensure that the SQS queues have the following AWS permissions +- `sqs:ReceiveMessage` + +- `sqs:DeleteMessageBatch` + +- `sqs:ChangeMessageVisibility` diff --git a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/QueueConfig.java b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/QueueConfig.java index 47e417bc27..d0715ee8c1 100644 --- a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/QueueConfig.java +++ b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/QueueConfig.java @@ -17,7 +17,7 @@ import jakarta.validation.constraints.Min; import jakarta.validation.constraints.NotNull; import java.time.Duration; - +import org.opensearch.dataprepper.plugins.source.sqs.common.OnErrorOption; import org.hibernate.validator.constraints.time.DurationMax; import org.hibernate.validator.constraints.time.DurationMin; import org.opensearch.dataprepper.model.configuration.PluginModel; @@ -71,6 +71,9 @@ public class QueueConfig { @JsonProperty("codec") private PluginModel codec = null; + @JsonProperty("on_error") + private OnErrorOption onErrorOption = OnErrorOption.RETAIN_MESSAGES; + public String getUrl() { return url; } @@ -83,9 +86,7 @@ public int getNumWorkers() { return numWorkers; } - public Duration getVisibilityTimeout() { - return visibilityTimeout; - } + public Duration getVisibilityTimeout() {return visibilityTimeout; } public boolean getVisibilityDuplicateProtection() { return visibilityDuplicateProtection; @@ -107,5 +108,8 @@ public PluginModel getCodec() { return codec; } + public OnErrorOption getOnErrorOption() { + return onErrorOption; + } } diff --git a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsService.java b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsService.java index 5c3e3ba2d2..3c48b76f4e 100644 --- a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsService.java +++ b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsService.java @@ -20,16 +20,20 @@ import org.opensearch.dataprepper.model.plugin.PluginFactory; import org.opensearch.dataprepper.plugins.source.sqs.common.SqsBackoff; import org.opensearch.dataprepper.plugins.source.sqs.common.SqsClientFactory; + import org.opensearch.dataprepper.plugins.source.sqs.common.SqsWorkerCommon; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; + import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.sqs.SqsClient; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; import java.util.ArrayList; + import java.util.HashMap; import java.util.List; + import java.util.Map; import java.util.concurrent.TimeUnit; import java.util.concurrent.Executors; import java.util.concurrent.ExecutorService; @@ -40,14 +44,15 @@ public class SqsService { private static final Logger LOG = LoggerFactory.getLogger(SqsService.class); static final long SHUTDOWN_TIMEOUT = 30L; private final SqsSourceConfig sqsSourceConfig; - private final SqsClient sqsClient; private final PluginMetrics pluginMetrics; private final PluginFactory pluginFactory; private final AcknowledgementSetManager acknowledgementSetManager; private final List allSqsUrlExecutorServices; private final List sqsWorkers; private final Buffer> buffer; - private final Backoff backoff; + private final Map sqsClientMap = new HashMap<>(); + private final AwsCredentialsProvider credentialsProvider; + public SqsService(final Buffer> buffer, final AcknowledgementSetManager acknowledgementSetManager, @@ -59,19 +64,23 @@ public SqsService(final Buffer> buffer, this.sqsSourceConfig = sqsSourceConfig; this.pluginMetrics = pluginMetrics; this.pluginFactory = pluginFactory; + this.credentialsProvider = credentialsProvider; this.acknowledgementSetManager = acknowledgementSetManager; this.allSqsUrlExecutorServices = new ArrayList<>(); this.sqsWorkers = new ArrayList<>(); - this.sqsClient = SqsClientFactory.createSqsClient(sqsSourceConfig.getAwsAuthenticationOptions().getAwsRegion(), credentialsProvider); this.buffer = buffer; - backoff = SqsBackoff.createExponentialBackoff(); } public void start() { LOG.info("Starting SqsService"); sqsSourceConfig.getQueues().forEach(queueConfig -> { String queueUrl = queueConfig.getUrl(); + String region = extractRegionFromQueueUrl(queueUrl); + SqsClient sqsClient = sqsClientMap.computeIfAbsent(region, + r -> SqsClientFactory.createSqsClient(Region.of(r), credentialsProvider)); String queueName = queueUrl.substring(queueUrl.lastIndexOf('/') + 1); + Backoff backoff = SqsBackoff.createExponentialBackoff(); + SqsWorkerCommon sqsWorkerCommon = new SqsWorkerCommon(backoff, pluginMetrics, acknowledgementSetManager); int numWorkers = queueConfig.getNumWorkers(); SqsEventProcessor sqsEventProcessor; MessageFieldStrategy strategy; @@ -93,11 +102,11 @@ public void start() { buffer, acknowledgementSetManager, sqsClient, + sqsWorkerCommon, sqsSourceConfig, queueConfig, pluginMetrics, - sqsEventProcessor, - backoff)) + sqsEventProcessor)) .collect(Collectors.toList()); sqsWorkers.addAll(workers); @@ -122,10 +131,15 @@ public void stop() { Thread.currentThread().interrupt(); } }); - - sqsClient.close(); + + sqsClientMap.values().forEach(SqsClient::close); LOG.info("SqsService shutdown completed."); } + + private String extractRegionFromQueueUrl(final String queueUrl) { + String[] split = queueUrl.split("\\."); + return split[1]; + } } \ No newline at end of file diff --git a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorker.java b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorker.java index cb4c168345..ee82e8279d 100644 --- a/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorker.java +++ b/data-prepper-plugins/sqs-source/src/main/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorker.java @@ -10,22 +10,26 @@ package org.opensearch.dataprepper.plugins.source.sqs; -import com.linecorp.armeria.client.retry.Backoff; import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Timer; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; import org.opensearch.dataprepper.model.buffer.Buffer; +import org.opensearch.dataprepper.plugins.source.sqs.common.OnErrorOption; import org.opensearch.dataprepper.plugins.source.sqs.common.SqsWorkerCommon; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import software.amazon.awssdk.services.sqs.SqsClient; import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry; import software.amazon.awssdk.services.sqs.model.Message; +import software.amazon.awssdk.services.sqs.model.MessageSystemAttributeName; +import software.amazon.awssdk.services.sqs.model.SqsException; import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -34,9 +38,12 @@ public class SqsWorker implements Runnable { private static final Logger LOG = LoggerFactory.getLogger(SqsWorker.class); + static final String SQS_MESSAGE_DELAY_METRIC_NAME = "sqsMessageDelay"; + private final Timer sqsMessageDelayTimer; static final String ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME = "acknowledgementSetCallbackCounter"; private final SqsWorkerCommon sqsWorkerCommon; private final SqsEventProcessor sqsEventProcessor; + private final SqsClient sqsClient; private final QueueConfig queueConfig; private final boolean endToEndAcknowledgementsEnabled; private final Buffer> buffer; @@ -50,21 +57,25 @@ public class SqsWorker implements Runnable { public SqsWorker(final Buffer> buffer, final AcknowledgementSetManager acknowledgementSetManager, final SqsClient sqsClient, + final SqsWorkerCommon sqsWorkerCommon, final SqsSourceConfig sqsSourceConfig, final QueueConfig queueConfig, final PluginMetrics pluginMetrics, - final SqsEventProcessor sqsEventProcessor, - final Backoff backoff) { - this.sqsWorkerCommon = new SqsWorkerCommon(sqsClient, backoff, pluginMetrics, acknowledgementSetManager); + final SqsEventProcessor sqsEventProcessor) { + + this.sqsWorkerCommon = sqsWorkerCommon; this.queueConfig = queueConfig; this.acknowledgementSetManager = acknowledgementSetManager; this.sqsEventProcessor = sqsEventProcessor; + this.sqsClient = sqsClient; this.buffer = buffer; this.bufferTimeoutMillis = (int) sqsSourceConfig.getBufferTimeout().toMillis(); this.endToEndAcknowledgementsEnabled = sqsSourceConfig.getAcknowledgements(); this.messageVisibilityTimesMap = new HashMap<>(); this.failedAttemptCount = 0; - this.acknowledgementSetCallbackCounter = pluginMetrics.counter(ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME); + acknowledgementSetCallbackCounter = pluginMetrics.counter(ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME); + sqsMessageDelayTimer = pluginMetrics.timer(SQS_MESSAGE_DELAY_METRIC_NAME); + } @Override @@ -90,17 +101,27 @@ public void run() { } int processSqsMessages() { - List messages = sqsWorkerCommon.pollSqsMessages(queueConfig.getUrl(), - queueConfig.getMaximumMessages(), - queueConfig.getWaitTime(), - queueConfig.getVisibilityTimeout()); - if (!messages.isEmpty()) { - final List deleteMessageBatchRequestEntries = processSqsEvents(messages); - if (!deleteMessageBatchRequestEntries.isEmpty()) { - sqsWorkerCommon.deleteSqsMessages(queueConfig.getUrl(), deleteMessageBatchRequestEntries); + try { + List messages = sqsWorkerCommon.pollSqsMessages( + queueConfig.getUrl(), + sqsClient, + queueConfig.getMaximumMessages(), + queueConfig.getWaitTime(), + queueConfig.getVisibilityTimeout()); + + if (!messages.isEmpty()) { + final List deleteEntries = processSqsEvents(messages); + if (!deleteEntries.isEmpty()) { + sqsWorkerCommon.deleteSqsMessages(queueConfig.getUrl(), sqsClient, deleteEntries); + } + } else { + sqsMessageDelayTimer.record(Duration.ZERO); } + return messages.size(); + } catch (SqsException e) { + sqsWorkerCommon.applyBackoff(); + return 0; } - return messages.size(); } private List processSqsEvents(final List messages) { @@ -128,7 +149,7 @@ private List processSqsEvents(final List processSqsEvents(final List processSqsEvents(final List waitingForAcknowledgements = messageWaitingForAcknowledgementsMap.get(message); final Optional deleteEntry = processSqsObject(message, acknowledgementSet); @@ -174,7 +203,11 @@ private Optional processSqsObject(final Message sqsWorkerCommon.getSqsMessagesFailedCounter().increment(); LOG.error("Error processing from SQS: {}. Retrying with exponential backoff.", e.getMessage()); sqsWorkerCommon.applyBackoff(); - return Optional.empty(); + if (queueConfig.getOnErrorOption().equals(OnErrorOption.DELETE_MESSAGES)) { + return Optional.of(sqsWorkerCommon.buildDeleteMessageBatchRequestEntry(message.messageId(), message.receiptHandle())); + } else { + return Optional.empty(); + } } } diff --git a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/AttributeHandlerTest.java b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/AttributeHandlerTest.java new file mode 100644 index 0000000000..be7c772620 --- /dev/null +++ b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/AttributeHandlerTest.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + */ + +package org.opensearch.dataprepper.plugins.source.sqs; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.opensearch.dataprepper.model.event.Event; +import org.opensearch.dataprepper.model.event.EventMetadata; +import software.amazon.awssdk.services.sqs.model.Message; +import software.amazon.awssdk.services.sqs.model.MessageAttributeValue; +import software.amazon.awssdk.services.sqs.model.MessageSystemAttributeName; + +import org.mockito.Mockito; + +public class AttributeHandlerTest { + + @Test + void testCollectMetadataAttributes() { + final Map systemAttributes = new HashMap<>(); + systemAttributes.put(MessageSystemAttributeName.SENT_TIMESTAMP, "1234567890"); + systemAttributes.put(MessageSystemAttributeName.APPROXIMATE_RECEIVE_COUNT, "5"); + final Map messageAttributes = new HashMap<>(); + + messageAttributes.put("CustomKey", MessageAttributeValue.builder() + .stringValue("customValue") + .dataType("String") + .build()); + + final Message message = Message.builder() + .messageId("id-1") + .receiptHandle("rh-1") + .body("Test message") + .attributes(systemAttributes) + .messageAttributes(messageAttributes) + .build(); + + final String queueUrl = "https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"; + final Map metadata = AttributeHandler.collectMetadataAttributes(message, queueUrl); + assertThat(metadata.get("queueUrl"), equalTo(queueUrl)); + assertThat(metadata.get("sentTimestamp"), equalTo("1234567890")); + assertThat(metadata.get("approximateReceiveCount"), equalTo("5")); + assertThat(metadata.get("customKey"), equalTo("customValue")); + } + + @Test + void testApplyMetadataAttributes() { + final Event event = Mockito.mock(Event.class); + final EventMetadata metadata = Mockito.mock(EventMetadata.class); + when(event.getMetadata()).thenReturn(metadata); + final Map attributes = new HashMap<>(); + attributes.put("key1", "value1"); + attributes.put("key2", "value2"); + AttributeHandler.applyMetadataAttributes(event, attributes); + verify(metadata).setAttribute("key1", "value1"); + verify(metadata).setAttribute("key2", "value2"); + } +} diff --git a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsServiceTest.java b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsServiceTest.java index 83a12e5940..4b4777d434 100644 --- a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsServiceTest.java +++ b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsServiceTest.java @@ -20,18 +20,15 @@ import org.opensearch.dataprepper.model.record.Record; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.sqs.SqsClient; -import java.util.List; +import java.util.List; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.when; -import static org.mockito.Mockito.withSettings; class SqsServiceTest { private SqsSourceConfig sqsSourceConfig; - private SqsClient sqsClient; private PluginMetrics pluginMetrics; private PluginFactory pluginFactory; private AcknowledgementSetManager acknowledgementSetManager; @@ -41,13 +38,11 @@ class SqsServiceTest { @BeforeEach void setUp() { sqsSourceConfig = mock(SqsSourceConfig.class); - sqsClient = mock(SqsClient.class, withSettings()); pluginMetrics = mock(PluginMetrics.class); pluginFactory = mock(PluginFactory.class); acknowledgementSetManager = mock(AcknowledgementSetManager.class); buffer = mock(Buffer.class); credentialsProvider = mock(AwsCredentialsProvider.class); - AwsAuthenticationOptions awsAuthenticationOptions = mock(AwsAuthenticationOptions.class); when(awsAuthenticationOptions.getAwsRegion()).thenReturn(Region.US_EAST_1); when(sqsSourceConfig.getAwsAuthenticationOptions()).thenReturn(awsAuthenticationOptions); @@ -66,10 +61,9 @@ void start_with_single_queue_starts_workers() { @Test void stop_should_shutdown_executors_and_workers() throws InterruptedException { QueueConfig queueConfig = mock(QueueConfig.class); - when(queueConfig.getUrl()).thenReturn("MyQueue"); + when(queueConfig.getUrl()).thenReturn("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"); when(queueConfig.getNumWorkers()).thenReturn(1); when(sqsSourceConfig.getQueues()).thenReturn(List.of(queueConfig)); - SqsClient sqsClient = mock(SqsClient.class); SqsService sqsService = new SqsService(buffer, acknowledgementSetManager, sqsSourceConfig, pluginMetrics, pluginFactory, credentialsProvider) {}; sqsService.start(); sqsService.stop(); // again assuming that if no exception is thrown here, then workers and client have been stopped diff --git a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorkerTest.java b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorkerTest.java index e7339543c2..9f7a7f1f30 100644 --- a/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorkerTest.java +++ b/data-prepper-plugins/sqs-source/src/test/java/org/opensearch/dataprepper/plugins/source/sqs/SqsWorkerTest.java @@ -10,8 +10,8 @@ package org.opensearch.dataprepper.plugins.source.sqs; -import com.linecorp.armeria.client.retry.Backoff; import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Timer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -24,47 +24,42 @@ import org.opensearch.dataprepper.model.acknowledgements.ProgressCheck; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; -import org.opensearch.dataprepper.model.plugin.PluginFactory; import org.opensearch.dataprepper.model.record.Record; +import org.opensearch.dataprepper.plugins.source.sqs.common.OnErrorOption; import org.opensearch.dataprepper.plugins.source.sqs.common.SqsWorkerCommon; -import org.opensearch.dataprepper.plugins.source.sqs.common.SqsRetriesExhaustedException; import software.amazon.awssdk.services.sqs.SqsClient; -import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityRequest; -import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest; -import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse; -import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResultEntry; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry; import software.amazon.awssdk.services.sqs.model.Message; -import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest; +import software.amazon.awssdk.services.sqs.model.MessageSystemAttributeName; import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse; import software.amazon.awssdk.services.sqs.model.SqsException; -import software.amazon.awssdk.services.sqs.model.BatchResultErrorEntry; import java.io.IOException; import java.time.Duration; +import java.time.Instant; import java.util.Collections; -import java.util.UUID; +import java.util.Map; import java.util.function.Consumer; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyDouble; import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) class SqsWorkerTest { - @Mock private Buffer> buffer; @Mock @@ -72,6 +67,8 @@ class SqsWorkerTest { @Mock private SqsClient sqsClient; @Mock + private SqsWorkerCommon sqsWorkerCommon; + @Mock private SqsEventProcessor sqsEventProcessor; @Mock private SqsSourceConfig sqsSourceConfig; @@ -80,125 +77,119 @@ class SqsWorkerTest { @Mock private PluginMetrics pluginMetrics; @Mock - private PluginFactory pluginFactory; - @Mock - private Backoff backoff; - @Mock - private Counter sqsMessagesReceivedCounter; - @Mock - private Counter sqsMessagesDeletedCounter; + private Timer sqsMessageDelayTimer; @Mock private Counter sqsMessagesFailedCounter; - @Mock - private Counter sqsMessagesDeleteFailedCounter; - @Mock - private Counter acknowledgementSetCallbackCounter; - @Mock - private Counter sqsVisibilityTimeoutChangedCount; - @Mock - private Counter sqsVisibilityTimeoutChangeFailedCount; - private final int mockBufferTimeoutMillis = 10000; + private final int bufferTimeoutMillis = 10000; + private SqsWorker sqsWorker; private SqsWorker createObjectUnderTest() { return new SqsWorker( buffer, acknowledgementSetManager, sqsClient, + sqsWorkerCommon, sqsSourceConfig, queueConfig, pluginMetrics, - sqsEventProcessor, - backoff); + sqsEventProcessor + ); } @BeforeEach void setUp() { - when(pluginMetrics.counter(SqsWorkerCommon.SQS_MESSAGES_RECEIVED_METRIC_NAME)) - .thenReturn(sqsMessagesReceivedCounter); - when(pluginMetrics.counter(SqsWorkerCommon.SQS_MESSAGES_DELETED_METRIC_NAME)) - .thenReturn(sqsMessagesDeletedCounter); - when(pluginMetrics.counter(SqsWorkerCommon.SQS_MESSAGES_FAILED_METRIC_NAME)) - .thenReturn(sqsMessagesFailedCounter); - when(pluginMetrics.counter(SqsWorkerCommon.SQS_MESSAGES_DELETE_FAILED_METRIC_NAME)) - .thenReturn(sqsMessagesDeleteFailedCounter); - when(pluginMetrics.counter(SqsWorkerCommon.ACKNOWLEDGEMENT_SET_CALLACK_METRIC_NAME)) - .thenReturn(acknowledgementSetCallbackCounter); - when(pluginMetrics.counter(SqsWorkerCommon.SQS_VISIBILITY_TIMEOUT_CHANGED_COUNT_METRIC_NAME)) - .thenReturn(sqsVisibilityTimeoutChangedCount); - when(pluginMetrics.counter(SqsWorkerCommon.SQS_VISIBILITY_TIMEOUT_CHANGE_FAILED_COUNT_METRIC_NAME)) - .thenReturn(sqsVisibilityTimeoutChangeFailedCount); when(sqsSourceConfig.getAcknowledgements()).thenReturn(false); when(sqsSourceConfig.getBufferTimeout()).thenReturn(Duration.ofSeconds(10)); - when(queueConfig.getUrl()).thenReturn("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"); - when(queueConfig.getWaitTime()).thenReturn(Duration.ofSeconds(1)); - } - - @Test - void processSqsMessages_should_return_number_of_messages_processed_and_increment_counters() throws IOException { + lenient().when(queueConfig.getUrl()).thenReturn("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"); + lenient().when(queueConfig.getWaitTime()).thenReturn(Duration.ofSeconds(10)); + lenient().when(queueConfig.getMaximumMessages()).thenReturn(10); + lenient().when(queueConfig.getVisibilityTimeout()).thenReturn(Duration.ofSeconds(30)); + when(pluginMetrics.timer(SqsWorker.SQS_MESSAGE_DELAY_METRIC_NAME)).thenReturn(sqsMessageDelayTimer); + lenient().doNothing().when(sqsMessageDelayTimer).record(any(Duration.class)); + sqsWorker = new SqsWorker( + buffer, + acknowledgementSetManager, + sqsClient, + sqsWorkerCommon, + sqsSourceConfig, + queueConfig, + pluginMetrics, + sqsEventProcessor + ); final Message message = Message.builder() - .messageId(UUID.randomUUID().toString()) - .receiptHandle(UUID.randomUUID().toString()) + .messageId("msg-1") + .receiptHandle("rh-1") .body("{\"Records\":[{\"eventSource\":\"custom\",\"message\":\"Hello World\"}]}") + .attributes(Map.of( + MessageSystemAttributeName.SENT_TIMESTAMP, "1234567890", + MessageSystemAttributeName.APPROXIMATE_RECEIVE_COUNT, "0" + )) .build(); - final ReceiveMessageResponse response = ReceiveMessageResponse.builder().messages(message).build(); - when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(response); - final DeleteMessageBatchResultEntry successfulDelete = DeleteMessageBatchResultEntry.builder().id(message.messageId()).build(); - final DeleteMessageBatchResponse deleteResponse = DeleteMessageBatchResponse.builder().successful(successfulDelete).build(); - when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))).thenReturn(deleteResponse); - - int messagesProcessed = createObjectUnderTest().processSqsMessages(); - assertThat(messagesProcessed, equalTo(1)); - - verify(sqsMessagesReceivedCounter).increment(1); - verify(sqsMessagesDeletedCounter).increment(1); - verify(sqsMessagesDeleteFailedCounter, never()).increment(anyDouble()); + lenient().when(sqsWorkerCommon.pollSqsMessages( + anyString(), + eq(sqsClient), + any(), + any(), + any() + )).thenReturn(Collections.singletonList(message)); + + lenient().when(sqsWorkerCommon.buildDeleteMessageBatchRequestEntry(anyString(), anyString())) + .thenAnswer(invocation -> { + String messageId = invocation.getArgument(0); + String receiptHandle = invocation.getArgument(1); + return DeleteMessageBatchRequestEntry.builder() + .id(messageId) + .receiptHandle(receiptHandle) + .build(); + }); } @Test - void processSqsMessages_should_invoke_processSqsEvent_and_deleteSqsMessages_when_entries_non_empty() throws IOException { - final Message message = Message.builder() - .messageId(UUID.randomUUID().toString()) - .receiptHandle(UUID.randomUUID().toString()) - .body("{\"Records\":[{\"eventSource\":\"custom\",\"message\":\"Hello World\"}]}") - .build(); - - final ReceiveMessageResponse response = ReceiveMessageResponse.builder() - .messages(message) - .build(); - when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(response); - - final DeleteMessageBatchResultEntry successfulDelete = DeleteMessageBatchResultEntry.builder() - .id(message.messageId()) - .build(); - final DeleteMessageBatchResponse deleteResponse = DeleteMessageBatchResponse.builder() - .successful(successfulDelete) - .build(); - when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))).thenReturn(deleteResponse); - - SqsWorker sqsWorker = createObjectUnderTest(); + void processSqsMessages_should_call_addSqsObject_and_deleteSqsMessages_for_valid_message() throws IOException { int messagesProcessed = sqsWorker.processSqsMessages(); - assertThat(messagesProcessed, equalTo(1)); - verify(sqsEventProcessor, times(1)).addSqsObject(eq(message), eq("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"), eq(buffer), eq(mockBufferTimeoutMillis), isNull()); - verify(sqsClient, times(1)).deleteMessageBatch(any(DeleteMessageBatchRequest.class)); - verify(sqsMessagesReceivedCounter).increment(1); - verify(sqsMessagesDeletedCounter).increment(1); - verify(sqsMessagesDeleteFailedCounter, never()).increment(anyDouble()); + verify(sqsEventProcessor, times(1)).addSqsObject( + any(), + eq("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"), + eq(buffer), + anyInt(), + isNull()); + verify(sqsWorkerCommon, atLeastOnce()).deleteSqsMessages( + eq("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"), + eq(sqsClient), + anyList() + ); + verify(sqsMessageDelayTimer, times(1)).record(any(Duration.class)); } + @Test void processSqsMessages_should_not_invoke_processSqsEvent_and_deleteSqsMessages_when_entries_are_empty() throws IOException { - when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))) - .thenReturn(ReceiveMessageResponse.builder().messages(Collections.emptyList()).build()); - SqsWorker sqsWorker = createObjectUnderTest(); + when(sqsWorkerCommon.pollSqsMessages( + anyString(), + eq(sqsClient), + any(), + any(), + any() + )).thenReturn(ReceiveMessageResponse.builder().messages(Collections.emptyList()).build().messages()); + int messagesProcessed = sqsWorker.processSqsMessages(); + assertThat(messagesProcessed, equalTo(0)); - verify(sqsEventProcessor, never()).addSqsObject(any(), anyString(), any(), anyInt(), any()); - verify(sqsClient, never()).deleteMessageBatch(any(DeleteMessageBatchRequest.class)); - verify(sqsMessagesReceivedCounter, never()).increment(anyDouble()); - verify(sqsMessagesDeletedCounter, never()).increment(anyDouble()); + verify(sqsEventProcessor, times(0)).addSqsObject( + any(), + eq("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"), + eq(buffer), + anyInt(), + isNull()); + verify(sqsWorkerCommon, times(0)).deleteSqsMessages( + eq("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"), + eq(sqsClient), + anyList() + ); + verify(sqsMessageDelayTimer, times(1)).record(any(Duration.class)); } @@ -208,188 +199,176 @@ void processSqsMessages_should_not_delete_messages_if_acknowledgements_enabled_u AcknowledgementSet acknowledgementSet = mock(AcknowledgementSet.class); when(acknowledgementSetManager.create(any(), any())).thenReturn(acknowledgementSet); when(queueConfig.getUrl()).thenReturn("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"); - - final Message message = Message.builder() - .messageId("msg-1") - .receiptHandle("rh-1") - .body("{\"Records\":[{\"eventSource\":\"custom\",\"message\":\"Hello World\"}]}") - .build(); - - final ReceiveMessageResponse response = ReceiveMessageResponse.builder().messages(message).build(); - when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(response); int messagesProcessed = createObjectUnderTest().processSqsMessages(); assertThat(messagesProcessed, equalTo(1)); - verify(sqsEventProcessor, times(1)).addSqsObject(eq(message), + verify(sqsEventProcessor, times(1)).addSqsObject(any(), eq("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"), eq(buffer), - eq(mockBufferTimeoutMillis), + eq(bufferTimeoutMillis), eq(acknowledgementSet)); - verify(sqsMessagesReceivedCounter).increment(1); - verifyNoInteractions(sqsMessagesDeletedCounter); + verify(sqsWorkerCommon, times(0)).deleteSqsMessages( + eq("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"), + eq(sqsClient), + anyList() + ); + verify(sqsMessageDelayTimer, times(1)).record(any(Duration.class)); } @Test void acknowledgementsEnabled_and_visibilityDuplicateProtectionEnabled_should_create_ack_sets_and_progress_check() { when(sqsSourceConfig.getAcknowledgements()).thenReturn(true); when(queueConfig.getVisibilityDuplicateProtection()).thenReturn(true); - - SqsWorker worker = new SqsWorker(buffer, acknowledgementSetManager, sqsClient, sqsSourceConfig, queueConfig, pluginMetrics, sqsEventProcessor, backoff); - Message message = Message.builder().messageId("msg-dup").receiptHandle("handle-dup").build(); - ReceiveMessageResponse response = ReceiveMessageResponse.builder().messages(message).build(); - when(sqsClient.receiveMessage((ReceiveMessageRequest) any())).thenReturn(response); - AcknowledgementSet acknowledgementSet = mock(AcknowledgementSet.class); when(acknowledgementSetManager.create(any(), any())).thenReturn(acknowledgementSet); - - int processed = worker.processSqsMessages(); - assertThat(processed, equalTo(1)); - + createObjectUnderTest().processSqsMessages(); verify(acknowledgementSetManager).create(any(), any()); verify(acknowledgementSet).addProgressCheck(any(), any()); } @Test void processSqsMessages_should_return_zero_messages_with_backoff_when_a_SqsException_is_thrown() { - when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenThrow(SqsException.class); + when(sqsWorkerCommon.pollSqsMessages( + anyString(), + eq(sqsClient), + any(), + any(), + any() + )).thenThrow(SqsException.class); final int messagesProcessed = createObjectUnderTest().processSqsMessages(); - verify(backoff).nextDelayMillis(1); + verify(sqsWorkerCommon, times(1)).applyBackoff(); assertThat(messagesProcessed, equalTo(0)); } - @Test - void processSqsMessages_should_throw_when_a_SqsException_is_thrown_with_max_retries() { - when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenThrow(SqsException.class); - when(backoff.nextDelayMillis(anyInt())).thenReturn((long) -1); - SqsWorker objectUnderTest = createObjectUnderTest(); - assertThrows(SqsRetriesExhaustedException.class, objectUnderTest::processSqsMessages); - } @Test void processSqsMessages_should_update_visibility_timeout_when_progress_changes() throws IOException { - AcknowledgementSet acknowledgementSet = mock(AcknowledgementSet.class); - when(queueConfig.getVisibilityDuplicateProtection()).thenReturn(true); - when(queueConfig.getVisibilityTimeout()).thenReturn(Duration.ofMillis(1)); - when(acknowledgementSetManager.create(any(), any(Duration.class))).thenReturn(acknowledgementSet); when(sqsSourceConfig.getAcknowledgements()).thenReturn(true); - final Message message = mock(Message.class); - final String testReceiptHandle = UUID.randomUUID().toString(); - when(message.messageId()).thenReturn(testReceiptHandle); - when(message.receiptHandle()).thenReturn(testReceiptHandle); - - final ReceiveMessageResponse receiveMessageResponse = mock(ReceiveMessageResponse.class); - when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveMessageResponse); - when(receiveMessageResponse.messages()).thenReturn(Collections.singletonList(message)); - - final int messagesProcessed = createObjectUnderTest().processSqsMessages(); + when(queueConfig.getVisibilityDuplicateProtection()).thenReturn(true); + when(queueConfig.getVisibilityDuplicateProtectionTimeout()).thenReturn(Duration.ofSeconds(60)); + when(queueConfig.getVisibilityTimeout()).thenReturn(Duration.ofSeconds(30)); + when(queueConfig.getUrl()).thenReturn("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"); + AcknowledgementSet acknowledgementSet = mock(AcknowledgementSet.class); + when(acknowledgementSetManager.create(any(), any(Duration.class))) + .thenReturn(acknowledgementSet); + final String testMessageId = "msg-1"; + final String testReceiptHandle = "rh-1"; + SqsWorker sqsWorker = createObjectUnderTest(); // your builder method + final int messagesProcessed = sqsWorker.processSqsMessages(); assertThat(messagesProcessed, equalTo(1)); - verify(sqsEventProcessor).addSqsObject(any(), anyString(), any(), anyInt(), any()); + + verify(sqsEventProcessor).addSqsObject( + any(), + eq("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"), + eq(buffer), + eq(bufferTimeoutMillis), + eq(acknowledgementSet) + ); verify(acknowledgementSetManager).create(any(), any(Duration.class)); - ArgumentCaptor> progressConsumerArgumentCaptor = ArgumentCaptor.forClass(Consumer.class); - verify(acknowledgementSet).addProgressCheck(progressConsumerArgumentCaptor.capture(), any(Duration.class)); - final Consumer actualConsumer = progressConsumerArgumentCaptor.getValue(); - final ProgressCheck progressCheck = mock(ProgressCheck.class); + @SuppressWarnings("unchecked") + ArgumentCaptor> progressConsumerCaptor = ArgumentCaptor.forClass(Consumer.class); + verify(acknowledgementSet).addProgressCheck(progressConsumerCaptor.capture(), any(Duration.class)); + final Consumer actualConsumer = progressConsumerCaptor.getValue(); + ProgressCheck progressCheck = mock(ProgressCheck.class); actualConsumer.accept(progressCheck); - - ArgumentCaptor changeMessageVisibilityRequestArgumentCaptor = ArgumentCaptor.forClass(ChangeMessageVisibilityRequest.class); - verify(sqsClient).changeMessageVisibility(changeMessageVisibilityRequestArgumentCaptor.capture()); - ChangeMessageVisibilityRequest actualChangeVisibilityRequest = changeMessageVisibilityRequestArgumentCaptor.getValue(); - assertThat(actualChangeVisibilityRequest.queueUrl(), equalTo(queueConfig.getUrl())); - assertThat(actualChangeVisibilityRequest.receiptHandle(), equalTo(testReceiptHandle)); - verify(sqsMessagesReceivedCounter).increment(1); + verify(sqsWorkerCommon).increaseVisibilityTimeout( + eq("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"), + eq(sqsClient), + eq(testReceiptHandle), + eq(30), + eq(testMessageId) + ); } @Test - void increaseVisibilityTimeout_doesNothing_whenIsStopped() throws IOException { - when(sqsSourceConfig.getAcknowledgements()).thenReturn(true); - when(queueConfig.getVisibilityDuplicateProtection()).thenReturn(false); - when(queueConfig.getVisibilityTimeout()).thenReturn(Duration.ofSeconds(30)); - AcknowledgementSet mockAcknowledgementSet = mock(AcknowledgementSet.class); - when(acknowledgementSetManager.create(any(), any())).thenReturn(mockAcknowledgementSet); - Message message = Message.builder() - .messageId(UUID.randomUUID().toString()) - .receiptHandle(UUID.randomUUID().toString()) - .body("{\"Records\":[{\"eventSource\":\"custom\",\"message\":\"Hello World\"}]}") - .build(); - ReceiveMessageResponse response = ReceiveMessageResponse.builder() - .messages(message) - .build(); - when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(response); - SqsWorker sqsWorker = createObjectUnderTest(); - sqsWorker.stop(); - int messagesProcessed = sqsWorker.processSqsMessages(); + void processSqsMessages_should_return_delete_message_entry_when_exception_thrown_and_onErrorOption_is_DELETE_MESSAGES() throws IOException { + when(queueConfig.getOnErrorOption()).thenReturn(OnErrorOption.DELETE_MESSAGES); + doThrow(new RuntimeException("Processing error")) + .when(sqsEventProcessor).addSqsObject(any(), + anyString(), + eq(buffer), + anyInt(), + isNull()); + + when(sqsWorkerCommon.getSqsMessagesFailedCounter()).thenReturn(sqsMessagesFailedCounter); + SqsWorker worker = createObjectUnderTest(); + int messagesProcessed = worker.processSqsMessages(); assertThat(messagesProcessed, equalTo(1)); - verify(sqsEventProcessor, times(1)).addSqsObject(eq(message), + verify(sqsMessagesFailedCounter, times(1)).increment(); + verify(sqsWorkerCommon, atLeastOnce()).applyBackoff(); + verify(sqsWorkerCommon, atLeastOnce()).deleteSqsMessages( eq("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"), - eq(buffer), - eq(mockBufferTimeoutMillis), - eq(mockAcknowledgementSet)); - verify(sqsClient, never()).changeMessageVisibility(any(ChangeMessageVisibilityRequest.class)); - verify(sqsVisibilityTimeoutChangeFailedCount, never()).increment(); + eq(sqsClient), + anyList()); } @Test - void deleteSqsMessages_incrementsFailedCounter_whenDeleteResponseHasFailedDeletes() throws IOException { - final Message message1 = Message.builder() - .messageId(UUID.randomUUID().toString()) - .receiptHandle(UUID.randomUUID().toString()) - .body("{\"Records\":[{\"eventSource\":\"custom\",\"message\":\"Hello World 1\"}]}") - .build(); - final Message message2 = Message.builder() - .messageId(UUID.randomUUID().toString()) - .receiptHandle(UUID.randomUUID().toString()) - .body("{\"Records\":[{\"eventSource\":\"custom\",\"message\":\"Hello World 2\"}]}") - .build(); - - final ReceiveMessageResponse receiveResponse = ReceiveMessageResponse.builder() - .messages(message1, message2) - .build(); - when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn(receiveResponse); - - DeleteMessageBatchResultEntry successfulDelete = DeleteMessageBatchResultEntry.builder() - .id(message1.messageId()) - .build(); + void processSqsMessages_should_not_delete_message_entry_when_exception_thrown_and_onErrorOption_is_RETAIN_MESSAGES() throws IOException { + when(queueConfig.getOnErrorOption()).thenReturn(OnErrorOption.RETAIN_MESSAGES); + doThrow(new RuntimeException("Processing error")) + .when(sqsEventProcessor).addSqsObject(any(), + anyString(), + eq(buffer), + anyInt(), + isNull()); + + when(sqsWorkerCommon.getSqsMessagesFailedCounter()).thenReturn(sqsMessagesFailedCounter); + SqsWorker worker = createObjectUnderTest(); + int messagesProcessed = worker.processSqsMessages(); + assertThat(messagesProcessed, equalTo(1)); + verify(sqsMessagesFailedCounter, times(1)).increment(); + verify(sqsWorkerCommon, atLeastOnce()).applyBackoff(); + verify(sqsWorkerCommon, times(0)).deleteSqsMessages( + eq("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"), + eq(sqsClient), + anyList()); + } - BatchResultErrorEntry failedDelete = BatchResultErrorEntry.builder() - .id(message2.messageId()) - .code("ReceiptHandleIsInvalid") - .senderFault(true) - .message("Failed to delete message due to invalid receipt handle.") - .build(); + @Test + void stop_should_set_isStopped_and_call_stop_on_sqsWorkerCommon() { + SqsWorker worker = createObjectUnderTest(); + worker.stop(); + verify(sqsWorkerCommon, times(1)).stop(); + } - DeleteMessageBatchResponse deleteResponse = DeleteMessageBatchResponse.builder() - .successful(successfulDelete) - .failed(failedDelete) + @Test + void processSqsMessages_should_record_sqsMessageDelayTimer_when_approximateReceiveCount_less_than_or_equal_to_one() throws IOException { + final long sentTimestampMillis = Instant.now().minusSeconds(5).toEpochMilli(); + final Message message = Message.builder() + .messageId("msg-1") + .receiptHandle("rh-1") + .body("{\"Records\":[{\"eventSource\":\"custom\",\"message\":\"Hello World\"}]}") + .attributes(Map.of( + MessageSystemAttributeName.SENT_TIMESTAMP, String.valueOf(sentTimestampMillis), + MessageSystemAttributeName.APPROXIMATE_RECEIVE_COUNT, "1" + )) .build(); + when(sqsWorkerCommon.pollSqsMessages(anyString(), eq(sqsClient), anyInt(), any(), any())) + .thenReturn(Collections.singletonList(message)); - when(sqsClient.deleteMessageBatch(any(DeleteMessageBatchRequest.class))).thenReturn(deleteResponse); - SqsWorker sqsWorker = createObjectUnderTest(); - int messagesProcessed = sqsWorker.processSqsMessages(); - assertThat(messagesProcessed, equalTo(2)); - verify(sqsMessagesReceivedCounter).increment(2); - verify(sqsMessagesDeletedCounter).increment(1); - verify(sqsMessagesDeleteFailedCounter).increment(1); + SqsWorker worker = createObjectUnderTest(); + worker.processSqsMessages(); + ArgumentCaptor durationCaptor = ArgumentCaptor.forClass(Duration.class); + verify(sqsMessageDelayTimer).record(durationCaptor.capture()); } + @Test - void processSqsMessages_handlesException_correctly_when_addSqsObject_throwsException() throws IOException { + void processSqsMessages_should_not_record_sqsMessageDelayTimer_when_approximateReceiveCount_greater_than_one() throws IOException { + final long sentTimestampMillis = Instant.now().minusSeconds(5).toEpochMilli(); final Message message = Message.builder() - .messageId(UUID.randomUUID().toString()) - .receiptHandle(UUID.randomUUID().toString()) + .messageId("msg-1") + .receiptHandle("rh-1") .body("{\"Records\":[{\"eventSource\":\"custom\",\"message\":\"Hello World\"}]}") + .attributes(Map.of( + MessageSystemAttributeName.SENT_TIMESTAMP, String.valueOf(sentTimestampMillis), + MessageSystemAttributeName.APPROXIMATE_RECEIVE_COUNT, "2" + )) .build(); - when(sqsClient.receiveMessage(any(ReceiveMessageRequest.class))).thenReturn( - ReceiveMessageResponse.builder().messages(message).build() - ); - doThrow(new RuntimeException("Processing failed")).when(sqsEventProcessor) - .addSqsObject(eq(message), eq("https://sqs.us-east-1.amazonaws.com/123456789012/MyQueue"), - any(), anyInt(), any()); - SqsWorker sqsWorker = createObjectUnderTest(); - int messagesProcessed = sqsWorker.processSqsMessages(); - assertThat(messagesProcessed, equalTo(1)); - verify(sqsMessagesReceivedCounter).increment(1); - verify(sqsMessagesFailedCounter).increment(); - verify(backoff).nextDelayMillis(anyInt()); - verify(sqsClient, never()).deleteMessageBatch(any(DeleteMessageBatchRequest.class)); - verify(sqsMessagesDeletedCounter, never()).increment(anyInt()); + when(sqsWorkerCommon.pollSqsMessages(anyString(), eq(sqsClient), anyInt(), any(), any())) + .thenReturn(Collections.singletonList(message)); + + SqsWorker worker = createObjectUnderTest(); + worker.processSqsMessages(); + verify(sqsMessageDelayTimer, never()).record(any(Duration.class)); } } \ No newline at end of file diff --git a/release/docker/Dockerfile b/release/docker/Dockerfile index 4052d59909..04d040e353 100644 --- a/release/docker/Dockerfile +++ b/release/docker/Dockerfile @@ -5,7 +5,7 @@ ARG CONFIG_FILEPATH ARG ARCHIVE_FILE ARG ARCHIVE_FILE_UNPACKED -ENV DATA_PREPPER_PATH /usr/share/data-prepper +ENV DATA_PREPPER_PATH=/usr/share/data-prepper ENV ENV_CONFIG_FILEPATH=$CONFIG_FILEPATH ENV ENV_PIPELINE_FILEPATH=$PIPELINE_FILEPATH diff --git a/settings.gradle b/settings.gradle index d2aa09b52c..37a125aaa5 100644 --- a/settings.gradle +++ b/settings.gradle @@ -168,7 +168,7 @@ include 'data-prepper-plugins:aws-sqs-common' include 'data-prepper-plugins:buffer-common' include 'data-prepper-plugins:sqs-source' include 'data-prepper-plugins:sqs-common' -//include 'data-prepper-plugins:cloudwatch-logs' +include 'data-prepper-plugins:cloudwatch-logs' //include 'data-prepper-plugins:http-sink' //include 'data-prepper-plugins:sns-sink' //include 'data-prepper-plugins:prometheus-sink'