Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQS source plugin implementation #5274

Merged
merged 15 commits into from
Jan 9, 2025
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ private Arn getArn() {
try {
return Arn.fromString(awsStsRoleArn);
} catch (final Exception e) {
throw new IllegalArgumentException(String.format("Invalid ARN format for awsStsRoleArn. Check the format of %s", awsStsRoleArn));
}
throw new IllegalArgumentException(String.format("The value provided for sts_role_arn is not a valid AWS ARN. Provided value: %s", awsStsRoleArn)); }
}

public String getAwsStsRoleArn() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.source.sqs;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is missing the copyright header. Please check other files as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, i've checked all of them. Do we need to add the header to tests as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the tests need these headers as well.


import com.fasterxml.jackson.annotation.JsonProperty;
Expand All @@ -14,10 +19,10 @@
public class QueueConfig {

private static final Integer DEFAULT_MAXIMUM_MESSAGES = null;
private static final Boolean DEFAULT_VISIBILITY_DUPLICATE_PROTECTION = false;
private static final boolean DEFAULT_VISIBILITY_DUPLICATE_PROTECTION = false;
private static final Duration DEFAULT_VISIBILITY_TIMEOUT_SECONDS = null;
private static final Duration DEFAULT_VISIBILITY_DUPLICATE_PROTECTION_TIMEOUT = Duration.ofHours(2);
private static final Duration DEFAULT_WAIT_TIME_SECONDS = Duration.ofSeconds(20);
private static final Duration DEFAULT_WAIT_TIME_SECONDS = null;
private static final Duration DEFAULT_POLL_DELAY_SECONDS = Duration.ofSeconds(0);
static final int DEFAULT_NUMBER_OF_WORKERS = 1;

Expand Down Expand Up @@ -45,7 +50,7 @@ public class QueueConfig {

@JsonProperty("visibility_duplication_protection")
@NotNull
private Boolean visibilityDuplicateProtection = DEFAULT_VISIBILITY_DUPLICATE_PROTECTION;
private boolean visibilityDuplicateProtection = DEFAULT_VISIBILITY_DUPLICATE_PROTECTION;

@JsonProperty("visibility_duplicate_protection_timeout")
@DurationMin(seconds = 30)
Expand Down Expand Up @@ -73,7 +78,7 @@ public Duration getVisibilityTimeout() {
return visibilityTimeout;
}

public Boolean getVisibilityDuplicateProtection() {
public boolean getVisibilityDuplicateProtection() {
return visibilityDuplicateProtection;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,76 +1,51 @@
package org.opensearch.dataprepper.plugins.source.sqs;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file appears to be missing the copyright header.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch. Thought I already looked over every file, my bad


/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.opensearch.dataprepper.buffer.common.BufferAccumulator;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.EventMetadata;
import org.opensearch.dataprepper.model.event.JacksonEvent;
import org.opensearch.dataprepper.model.record.Record;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.sqs.model.Message;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.time.Instant;
import java.util.Objects;
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
import software.amazon.awssdk.services.sqs.model.MessageSystemAttributeName;

import java.util.Collections;
import java.util.Map;

/**
* Implements the SqsMessageHandler to read and parse SQS messages generically and push to buffer.
*/
public class RawSqsMessageHandler implements SqsMessageHandler {

private static final Logger LOG = LoggerFactory.getLogger(RawSqsMessageHandler.class);
private static final ObjectMapper objectMapper = new ObjectMapper();

/**
* Processes the SQS message, attempting to parse it as JSON, and adds it to the buffer.
*
* @param message - the SQS message for processing
* @param url - the SQS queue url
* @param bufferAccumulator - the buffer accumulator
* @param acknowledgementSet - the acknowledgement set for end-to-end acknowledgements
*/
@Override
public void handleMessage(final Message message,
final String url,
final BufferAccumulator<Record<Event>> bufferAccumulator,
final String url,
final Buffer<Record<Event>> buffer,
final int bufferTimeoutMillis,
final AcknowledgementSet acknowledgementSet) {
try {
ObjectNode dataNode = objectMapper.createObjectNode();
dataNode.set("message", parseMessageBody(message.body()));
dataNode.put("queueUrl", url);

Instant now = Instant.now();
int unixTimestamp = (int) now.getEpochSecond();
dataNode.put("sentTimestamp", unixTimestamp);

final Record<Event> event = new Record<Event>(JacksonEvent.builder()
.withEventType("sqs-event")
.withData(dataNode)
.build());

if (Objects.nonNull(acknowledgementSet)) {
acknowledgementSet.add(event.getData());
final Map<MessageSystemAttributeName, String> systemAttributes = message.attributes();
final Map<String, MessageAttributeValue> customAttributes = message.messageAttributes();
final Event event = JacksonEvent.builder()
.withEventType("DOCUMENT")
.withData(Collections.singletonMap("message", message.body()))
.build();
final EventMetadata eventMetadata = event.getMetadata();
eventMetadata.setAttribute("url", url);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use metadata keys that match the terminology from SQS. So keep this as queueUrl as you had previously.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea

final String sentTimestamp = systemAttributes.get(MessageSystemAttributeName.SENT_TIMESTAMP);
eventMetadata.setAttribute("SentTimestamp", sentTimestamp);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency with other Data Prepper metadata, we should probably use lowerCamelCase here - sentTimestamp.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missed that, thanks

for (Map.Entry<String, MessageAttributeValue> entry : customAttributes.entrySet()) {
eventMetadata.setAttribute(entry.getKey(), entry.getValue().stringValue());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use this for SentTimestamp as well?

I think it may sense to lowercase the first letter of the key to have consistency with other Data Prepper conventions. So we'd have keys like senderId and messageDeduplicationId.

@kkondaka , Any thoughts on this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dlvenable I agree. Let's make it consistent with lowercase "s"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the lower the first letter of all custom keys? l think this may confuse the customer since they specify their own custom keys when sending the sqs message. They'll probably expect the exact same key name when referencing it later in the pipeline.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is what we mean. I think this may help with consistency.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, changed

}

bufferAccumulator.add(new Record<>(event.getData()));

if (acknowledgementSet != null) {
acknowledgementSet.add(event);
}
buffer.write(new Record<>(event), bufferTimeoutMillis);
} catch (Exception e) {
LOG.error("Error processing SQS message: {}", e.getMessage(), e);
throw new RuntimeException(e);
}
}

JsonNode parseMessageBody(String messageBody) {
try {
return objectMapper.readTree(messageBody);
} catch (Exception e) {
return objectMapper.getNodeFactory().textNode(messageBody);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@

package org.opensearch.dataprepper.plugins.source.sqs;

import org.opensearch.dataprepper.buffer.common.BufferAccumulator;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;

import software.amazon.awssdk.services.sqs.model.Message;

import java.io.IOException;

public class SqsEventProcessor {
private final SqsMessageHandler sqsMessageHandler;
SqsEventProcessor(final SqsMessageHandler sqsMessageHandler) {
Expand All @@ -22,9 +20,10 @@ public class SqsEventProcessor {

void addSqsObject(final Message message,
final String url,
final BufferAccumulator<Record<Event>> bufferAccumulator,
final Buffer<Record<Event>> buffer,
final int bufferTimeoutmillis,
final AcknowledgementSet acknowledgementSet) throws IOException {
sqsMessageHandler.handleMessage(message, url, bufferAccumulator, acknowledgementSet);
sqsMessageHandler.handleMessage(message, url, buffer, bufferTimeoutmillis, acknowledgementSet);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
*/
package org.opensearch.dataprepper.plugins.source.sqs;

import org.opensearch.dataprepper.buffer.common.BufferAccumulator;
import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet;
import org.opensearch.dataprepper.model.buffer.Buffer;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.record.Record;
import software.amazon.awssdk.services.sqs.model.Message;
Expand All @@ -14,6 +14,7 @@
public interface SqsMessageHandler {
void handleMessage(final Message message,
final String url,
final BufferAccumulator<Record<Event>> bufferAccumulator,
final Buffer<Record<Event>> buffer,
final int bufferTimeoutMillis,
final AcknowledgementSet acknowledgementSet) throws IOException ;
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.source.sqs;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,11 @@ public void start() {

sqsSourceConfig.getQueues().forEach(queueConfig -> {
String queueUrl = queueConfig.getUrl();
String queueName = queueUrl.substring(queueUrl.lastIndexOf('/') + 1);

int numWorkers = queueConfig.getNumWorkers();
ExecutorService executorService = Executors.newFixedThreadPool(
numWorkers, BackgroundThreadFactory.defaultExecutorThreadFactory("sqs-source-new-" + queueUrl));
numWorkers, BackgroundThreadFactory.defaultExecutorThreadFactory("sqs-source" + queueName));
allSqsUrlExecutorServices.add(executorService);
List<SqsWorker> workers = IntStream.range(0, numWorkers)
.mapToObj(i -> new SqsWorker(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.source.sqs;

import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.dataprepper.plugins.source.sqs;

import com.fasterxml.jackson.annotation.JsonProperty;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.exception.SdkException;
import software.amazon.awssdk.services.sqs.SqsClient;
import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityRequest;
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest;
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry;
import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityRequest;
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse;
import software.amazon.awssdk.services.sqs.model.Message;
import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest;
import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse;
import software.amazon.awssdk.services.sqs.model.SqsException;
import software.amazon.awssdk.services.sts.model.StsException;
import org.opensearch.dataprepper.buffer.common.BufferAccumulator;
import org.opensearch.dataprepper.model.buffer.Buffer;

import java.time.Duration;
Expand Down Expand Up @@ -61,7 +61,8 @@ public class SqsWorker implements Runnable {
private final boolean endToEndAcknowledgementsEnabled;
private final AcknowledgementSetManager acknowledgementSetManager;
private volatile boolean isStopped = false;
private final BufferAccumulator<Record<Event>> bufferAccumulator;
private final Buffer<Record<Event>> buffer;
private final int bufferTimeoutMillis;
private Map<Message, Integer> messageVisibilityTimesMap;

public SqsWorker(final Buffer<Record<Event>> buffer,
Expand All @@ -79,12 +80,11 @@ public SqsWorker(final Buffer<Record<Event>> buffer,
this.acknowledgementSetManager = acknowledgementSetManager;
this.standardBackoff = backoff;
this.endToEndAcknowledgementsEnabled = sqsSourceConfig.getAcknowledgements();
this.bufferAccumulator = BufferAccumulator.create(buffer, sqsSourceConfig.getNumberOfRecordsToAccumulate(), sqsSourceConfig.getBufferTimeout());
this.buffer = buffer;
this.bufferTimeoutMillis = (int) sqsSourceConfig.getBufferTimeout().toMillis();

messageVisibilityTimesMap = new HashMap<>();

failedAttemptCount = 0;

sqsMessagesReceivedCounter = pluginMetrics.counter(SQS_MESSAGES_RECEIVED_METRIC_NAME);
sqsMessagesDeletedCounter = pluginMetrics.counter(SQS_MESSAGES_DELETED_METRIC_NAME);
sqsMessagesFailedCounter = pluginMetrics.counter(SQS_MESSAGES_FAILED_METRIC_NAME);
Expand Down Expand Up @@ -131,9 +131,11 @@ int processSqsMessages() {
private List<Message> getMessagesFromSqs() {
try {
final ReceiveMessageRequest request = createReceiveMessageRequest();
final List<Message> messages = sqsClient.receiveMessage(request).messages();
final ReceiveMessageResponse response = sqsClient.receiveMessage(request);
List<Message> messages = response.messages();
failedAttemptCount = 0;
return messages;

} catch (final SqsException | StsException e) {
LOG.error("Error reading from SQS: {}. Retrying with exponential backoff.", e.getMessage());
applyBackoff();
Expand All @@ -160,16 +162,18 @@ private void applyBackoff() {
private ReceiveMessageRequest createReceiveMessageRequest() {
ReceiveMessageRequest.Builder requestBuilder = ReceiveMessageRequest.builder()
.queueUrl(queueConfig.getUrl())
.waitTimeSeconds((int) queueConfig.getWaitTime().getSeconds());
.attributeNamesWithStrings("All")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make this configurable because it probably affects performance. But, I think that would be appropriate for a follow-on PR. Using All as the default makes sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can have the customer specify the metadata that they want to intake in the config. Definitely has performance and cost implications

.messageAttributeNames("All");

if (queueConfig.getWaitTime() != null) {
requestBuilder.waitTimeSeconds((int) queueConfig.getWaitTime().getSeconds());
}
if (queueConfig.getMaximumMessages() != null) {
requestBuilder.maxNumberOfMessages(queueConfig.getMaximumMessages());
}

if (queueConfig.getVisibilityTimeout() != null) {
requestBuilder.visibilityTimeout((int) queueConfig.getVisibilityTimeout().getSeconds());
}

return requestBuilder.build();
}

Expand Down Expand Up @@ -244,14 +248,6 @@ private List<DeleteMessageBatchRequestEntry> processSqsEvents(final List<Message
}
}

if (!messages.isEmpty()) {
try {
bufferAccumulator.flush();
} catch (final Exception e) {
throw new RuntimeException(e);
}
}

return deleteMessageBatchRequestEntryCollection;
}

Expand All @@ -260,7 +256,7 @@ private Optional<DeleteMessageBatchRequestEntry> processSqsObject(
final Message message,
final AcknowledgementSet acknowledgementSet) {
try {
sqsEventProcessor.addSqsObject(message, queueConfig.getUrl(), bufferAccumulator, acknowledgementSet);
sqsEventProcessor.addSqsObject(message, queueConfig.getUrl(), buffer, bufferTimeoutMillis, acknowledgementSet);
return Optional.of(buildDeleteMessageBatchRequestEntry(message));
} catch (final Exception e) {
sqsMessagesFailedCounter.increment();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ void validateStsRoleArn_with_invalid_format_throws_exception() throws NoSuchFiel

try (final MockedStatic<Arn> arnMockedStatic = mockStatic(Arn.class)) {
arnMockedStatic.when(() -> Arn.fromString(invalidFormatArn))
.thenThrow(new IllegalArgumentException("Invalid ARN format for awsStsRoleArn. Check the format of " + invalidFormatArn));
.thenThrow(new IllegalArgumentException("The value provided for sts_role_arn is not a valid AWS ARN. Provided value: " + invalidFormatArn));

IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> {
awsAuthenticationOptions.validateStsRoleArn();
});
assertThat(exception.getMessage(), equalTo("Invalid ARN format for awsStsRoleArn. Check the format of " + invalidFormatArn));
assertThat(exception.getMessage(), equalTo("The value provided for sts_role_arn is not a valid AWS ARN. Provided value: " + invalidFormatArn));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ void testDefaultValues() {
assertFalse(queueConfig.getVisibilityDuplicateProtection(), "Visibility duplicate protection should default to false");
assertEquals(Duration.ofHours(2), queueConfig.getVisibilityDuplicateProtectionTimeout(),
"Visibility duplicate protection timeout should default to 2 hours");
assertEquals(Duration.ofSeconds(20), queueConfig.getWaitTime(), "Wait time should default to 20 seconds");
assertNull(queueConfig.getWaitTime(), "Wait time should default to null");
}
}
Loading
Loading