diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/resync/CascadingActionDetector.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/resync/CascadingActionDetector.java index cb7f060976..4bab68b3f4 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/resync/CascadingActionDetector.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/resync/CascadingActionDetector.java @@ -83,9 +83,9 @@ public Map getParentTableMap(StreamPartition streamPartitio /** * Detects if a binlog event contains cascading updates and if detected, creates resync partitions - * @param event event + * @param event binlog event * @param parentTableMap parent table map - * @param tableMetadata table meta data + * @param tableMetadata table metadata */ public void detectCascadingUpdates(Event event, Map parentTableMap, TableMetadata tableMetadata) { final UpdateRowsEventData data = event.getData(); @@ -143,9 +143,9 @@ public void detectCascadingUpdates(Event event, Map parentT /** * Detects if a binlog event contains cascading deletes and if detected, creates resync partitions - * @param event event + * @param event binlog event * @param parentTableMap parent table map - * @param tableMetadata table meta data + * @param tableMetadata table metadata */ public void detectCascadingDeletes(Event event, Map parentTableMap, TableMetadata tableMetadata) { if (parentTableMap.containsKey(tableMetadata.getFullTableName())) { diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java index 2bc21ca786..1612e94ec3 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/BinlogEventListener.java @@ -130,7 +130,8 @@ public BinlogEventListener(final StreamPartition streamPartition, this.dbTableMetadata = dbTableMetadata; this.streamCheckpointManager = new StreamCheckpointManager( streamCheckpointer, sourceConfig.isAcknowledgmentsEnabled(), - acknowledgementSetManager, this::stopClient, sourceConfig.getStreamAcknowledgmentTimeout()); + acknowledgementSetManager, this::stopClient, sourceConfig.getStreamAcknowledgmentTimeout(), + sourceConfig.getEngine(), pluginMetrics); streamCheckpointManager.start(); this.cascadeActionDetector = cascadeActionDetector; @@ -200,7 +201,7 @@ void handleRotateEvent(com.github.shyiko.mysql.binlog.event.Event event) { // Trigger a checkpoint update for this rotate when there're no row mutation events being processed if (streamCheckpointManager.getChangeEventStatuses().isEmpty()) { - ChangeEventStatus changeEventStatus = streamCheckpointManager.saveChangeEventsStatus(currentBinlogCoordinate); + ChangeEventStatus changeEventStatus = streamCheckpointManager.saveChangeEventsStatus(currentBinlogCoordinate, 0); if (isAcknowledgmentsEnabled) { changeEventStatus.setAcknowledgmentStatus(ChangeEventStatus.AcknowledgmentStatus.POSITIVE_ACK); } @@ -347,9 +348,10 @@ void handleRowChangeEvent(com.github.shyiko.mysql.binlog.event.Event event, LOG.debug("Current binlog coordinate after receiving a row change event: " + currentBinlogCoordinate); } + final long recordCount = rows.size(); AcknowledgementSet acknowledgementSet = null; if (isAcknowledgmentsEnabled) { - acknowledgementSet = streamCheckpointManager.createAcknowledgmentSet(currentBinlogCoordinate); + acknowledgementSet = streamCheckpointManager.createAcknowledgmentSet(currentBinlogCoordinate, recordCount); } final long bytes = event.toString().getBytes().length; @@ -398,7 +400,7 @@ void handleRowChangeEvent(com.github.shyiko.mysql.binlog.event.Event event, if (isAcknowledgmentsEnabled) { acknowledgementSet.complete(); } else { - streamCheckpointManager.saveChangeEventsStatus(currentBinlogCoordinate); + streamCheckpointManager.saveChangeEventsStatus(currentBinlogCoordinate, recordCount); } } diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/ChangeEventStatus.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/ChangeEventStatus.java index f2b70cbe7b..af6ef02362 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/ChangeEventStatus.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/ChangeEventStatus.java @@ -6,11 +6,14 @@ package org.opensearch.dataprepper.plugins.source.rds.stream; import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; +import org.postgresql.replication.LogSequenceNumber; public class ChangeEventStatus { private final BinlogCoordinate binlogCoordinate; + private final LogSequenceNumber logSequenceNumber; private final long timestamp; + private final long recordCount; private volatile AcknowledgmentStatus acknowledgmentStatus; public enum AcknowledgmentStatus { @@ -19,9 +22,19 @@ public enum AcknowledgmentStatus { NO_ACK } - public ChangeEventStatus(final BinlogCoordinate binlogCoordinate, final long timestamp) { + public ChangeEventStatus(final BinlogCoordinate binlogCoordinate, final long timestamp, final long recordCount) { this.binlogCoordinate = binlogCoordinate; + this.logSequenceNumber = null; this.timestamp = timestamp; + this.recordCount = recordCount; + acknowledgmentStatus = AcknowledgmentStatus.NO_ACK; + } + + public ChangeEventStatus(final LogSequenceNumber logSequenceNumber, final long timestamp, final long recordCount) { + this.binlogCoordinate = null; + this.logSequenceNumber = logSequenceNumber; + this.timestamp = timestamp; + this.recordCount = recordCount; acknowledgmentStatus = AcknowledgmentStatus.NO_ACK; } @@ -45,7 +58,15 @@ public BinlogCoordinate getBinlogCoordinate() { return binlogCoordinate; } + public LogSequenceNumber getLogSequenceNumber() { + return logSequenceNumber; + } + public long getTimestamp() { return timestamp; } + + public long getRecordCount() { + return recordCount; + } } diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClient.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClient.java index 22935fc6e3..8eb3b9cde9 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClient.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClient.java @@ -47,6 +47,7 @@ public LogicalReplicationClient(final ConnectionManager connectionManager, @Override public void connect() { + LOG.debug("Start connecting logical replication stream. "); PGReplicationStream stream; try (Connection conn = connectionManager.getConnection()) { PGConnection pgConnection = conn.unwrap(PGConnection.class); @@ -62,6 +63,7 @@ public void connect() { logicalStreamBuilder.withStartPosition(startLsn); } stream = logicalStreamBuilder.start(); + LOG.debug("Logical replication stream started. "); if (eventProcessor != null) { while (!disconnectRequested) { @@ -88,7 +90,8 @@ public void connect() { } stream.close(); - LOG.info("Replication stream closed successfully."); + disconnectRequested = false; + LOG.debug("Replication stream closed successfully."); } catch (Exception e) { LOG.error("Exception while creating Postgres replication stream. ", e); } @@ -97,6 +100,7 @@ public void connect() { @Override public void disconnect() { disconnectRequested = true; + LOG.debug("Requested to disconnect logical replication stream."); } public void setEventProcessor(LogicalReplicationEventProcessor eventProcessor) { diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessor.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessor.java index 3d5c1a04b1..a2a9aa1017 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessor.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessor.java @@ -10,7 +10,13 @@ package org.opensearch.dataprepper.plugins.source.rds.stream; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.DistributionSummary; +import io.micrometer.core.instrument.Timer; import org.opensearch.dataprepper.buffer.common.BufferAccumulator; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.event.JacksonEvent; @@ -23,6 +29,7 @@ import org.opensearch.dataprepper.plugins.source.rds.datatype.postgres.ColumnType; import org.opensearch.dataprepper.plugins.source.rds.model.MessageType; import org.opensearch.dataprepper.plugins.source.rds.model.TableMetadata; +import org.postgresql.replication.LogSequenceNumber; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,6 +39,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Consumer; public class LogicalReplicationEventProcessor { enum TupleDataType { @@ -63,6 +71,14 @@ public static TupleDataType fromValue(char value) { static final Duration BUFFER_TIMEOUT = Duration.ofSeconds(60); static final int DEFAULT_BUFFER_BATCH_SIZE = 1_000; + static final int NUM_OF_RETRIES = 3; + static final int BACKOFF_IN_MILLIS = 500; + static final String CHANGE_EVENTS_PROCESSED_COUNT = "changeEventsProcessed"; + static final String CHANGE_EVENTS_PROCESSING_ERROR_COUNT = "changeEventsProcessingErrors"; + static final String BYTES_RECEIVED = "bytesReceived"; + static final String BYTES_PROCESSED = "bytesProcessed"; + static final String REPLICATION_LOG_EVENT_PROCESSING_TIME = "replicationLogEntryProcessingTime"; + static final String REPLICATION_LOG_PROCESSING_ERROR_COUNT = "replicationLogEntryProcessingErrors"; private final StreamPartition streamPartition; private final RdsSourceConfig sourceConfig; @@ -70,24 +86,57 @@ public static TupleDataType fromValue(char value) { private final Buffer> buffer; private final BufferAccumulator> bufferAccumulator; private final List pipelineEvents; + private final PluginMetrics pluginMetrics; + private final AcknowledgementSetManager acknowledgementSetManager; + private final LogicalReplicationClient logicalReplicationClient; + private final StreamCheckpointer streamCheckpointer; + private final StreamCheckpointManager streamCheckpointManager; + + private final Counter changeEventSuccessCounter; + private final Counter changeEventErrorCounter; + private final DistributionSummary bytesReceivedSummary; + private final DistributionSummary bytesProcessedSummary; + private final Timer eventProcessingTimer; + private final Counter eventProcessingErrorCounter; private long currentLsn; private long currentEventTimestamp; + private long bytesReceived; private Map tableMetadataMap; public LogicalReplicationEventProcessor(final StreamPartition streamPartition, final RdsSourceConfig sourceConfig, final Buffer> buffer, - final String s3Prefix) { + final String s3Prefix, + final PluginMetrics pluginMetrics, + final LogicalReplicationClient logicalReplicationClient, + final StreamCheckpointer streamCheckpointer, + final AcknowledgementSetManager acknowledgementSetManager) { this.streamPartition = streamPartition; this.sourceConfig = sourceConfig; recordConverter = new StreamRecordConverter(s3Prefix, sourceConfig.getPartitionCount()); this.buffer = buffer; bufferAccumulator = BufferAccumulator.create(buffer, DEFAULT_BUFFER_BATCH_SIZE, BUFFER_TIMEOUT); + this.pluginMetrics = pluginMetrics; + this.acknowledgementSetManager = acknowledgementSetManager; + this.logicalReplicationClient = logicalReplicationClient; + this.streamCheckpointer = streamCheckpointer; + streamCheckpointManager = new StreamCheckpointManager( + streamCheckpointer, sourceConfig.isAcknowledgmentsEnabled(), + acknowledgementSetManager, this::stopClient, sourceConfig.getStreamAcknowledgmentTimeout(), + sourceConfig.getEngine(), pluginMetrics); + streamCheckpointManager.start(); tableMetadataMap = new HashMap<>(); pipelineEvents = new ArrayList<>(); + + changeEventSuccessCounter = pluginMetrics.counter(CHANGE_EVENTS_PROCESSED_COUNT); + changeEventErrorCounter = pluginMetrics.counter(CHANGE_EVENTS_PROCESSING_ERROR_COUNT); + bytesReceivedSummary = pluginMetrics.summary(BYTES_RECEIVED); + bytesProcessedSummary = pluginMetrics.summary(BYTES_PROCESSED); + eventProcessingTimer = pluginMetrics.timer(REPLICATION_LOG_EVENT_PROCESSING_TIME); + eventProcessingErrorCounter = pluginMetrics.counter(REPLICATION_LOG_PROCESSING_ERROR_COUNT); } public void process(ByteBuffer msg) { @@ -97,20 +146,36 @@ public void process(ByteBuffer msg) { // If it's INSERT/UPDATE/DELETE, prepare events // If it's a COMMIT, convert all prepared events and send to buffer MessageType messageType = MessageType.from((char) msg.get()); - if (messageType == MessageType.BEGIN) { - processBeginMessage(msg); - } else if (messageType == MessageType.RELATION) { - processRelationMessage(msg); - } else if (messageType == MessageType.INSERT) { - processInsertMessage(msg); - } else if (messageType == MessageType.UPDATE) { - processUpdateMessage(msg); - } else if (messageType == MessageType.DELETE) { - processDeleteMessage(msg); - } else if (messageType == MessageType.COMMIT) { - processCommitMessage(msg); - } else { - throw new IllegalArgumentException("Replication message type [" + messageType + "] is not supported. "); + switch (messageType) { + case BEGIN: + handleMessageWithRetries(msg, this::processBeginMessage, messageType); + break; + case RELATION: + handleMessageWithRetries(msg, this::processRelationMessage, messageType); + break; + case INSERT: + handleMessageWithRetries(msg, this::processInsertMessage, messageType); + break; + case UPDATE: + handleMessageWithRetries(msg, this::processUpdateMessage, messageType); + break; + case DELETE: + handleMessageWithRetries(msg, this::processDeleteMessage, messageType); + break; + case COMMIT: + handleMessageWithRetries(msg, this::processCommitMessage, messageType); + break; + default: + throw new IllegalArgumentException("Replication message type [" + messageType + "] is not supported. "); + } + } + + public void stopClient() { + try { + logicalReplicationClient.disconnect(); + LOG.info("Binary log client disconnected."); + } catch (Exception e) { + LOG.error("Binary log client failed to disconnect.", e); } } @@ -169,15 +234,28 @@ void processCommitMessage(ByteBuffer msg) { throw new RuntimeException("Commit LSN does not match current LSN, skipping"); } - writeToBuffer(bufferAccumulator); + final long recordCount = pipelineEvents.size(); + AcknowledgementSet acknowledgementSet = null; + if (sourceConfig.isAcknowledgmentsEnabled()) { + acknowledgementSet = streamCheckpointManager.createAcknowledgmentSet(LogSequenceNumber.valueOf(currentLsn), recordCount); + } + + writeToBuffer(bufferAccumulator, acknowledgementSet); + bytesProcessedSummary.record(bytesReceived); LOG.debug("Processed a COMMIT message with Flag: {} CommitLsn: {} EndLsn: {} Timestamp: {}", flag, commitLsn, endLsn, epochMicro); + + if (sourceConfig.isAcknowledgmentsEnabled()) { + acknowledgementSet.complete(); + } else { + streamCheckpointManager.saveChangeEventsStatus(LogSequenceNumber.valueOf(currentLsn), recordCount); + } } void processInsertMessage(ByteBuffer msg) { int tableId = msg.getInt(); char n_char = (char) msg.get(); // Skip the 'N' character - final TableMetadata tableMetadata = tableMetadataMap.get((long)tableId); + final TableMetadata tableMetadata = tableMetadataMap.get((long) tableId); final List columnNames = tableMetadata.getColumnNames(); final List primaryKeys = tableMetadata.getPrimaryKeys(); final long eventTimestampMillis = currentEventTimestamp; @@ -189,7 +267,7 @@ void processInsertMessage(ByteBuffer msg) { void processUpdateMessage(ByteBuffer msg) { final int tableId = msg.getInt(); - final TableMetadata tableMetadata = tableMetadataMap.get((long)tableId); + final TableMetadata tableMetadata = tableMetadataMap.get((long) tableId); final List columnNames = tableMetadata.getColumnNames(); final List primaryKeys = tableMetadata.getPrimaryKeys(); final long eventTimestampMillis = currentEventTimestamp; @@ -231,7 +309,7 @@ void processDeleteMessage(ByteBuffer msg) { int tableId = msg.getInt(); char n_char = (char) msg.get(); // Skip the 'N' character - final TableMetadata tableMetadata = tableMetadataMap.get((long)tableId); + final TableMetadata tableMetadata = tableMetadataMap.get((long) tableId); final List columnNames = tableMetadata.getColumnNames(); final List primaryKeys = tableMetadata.getPrimaryKeys(); final long eventTimestampMillis = currentEventTimestamp; @@ -242,6 +320,8 @@ void processDeleteMessage(ByteBuffer msg) { private void doProcess(ByteBuffer msg, List columnNames, TableMetadata tableMetadata, List primaryKeys, long eventTimestampMillis, OpenSearchBulkActions bulkAction) { + bytesReceived = msg.capacity(); + bytesReceivedSummary.record(bytesReceived); Map rowDataMap = getRowDataMap(msg, columnNames); createPipelineEvent(rowDataMap, tableMetadata, primaryKeys, eventTimestampMillis, bulkAction); @@ -284,9 +364,12 @@ private void createPipelineEvent(Map rowDataMap, TableMetadata t pipelineEvents.add(pipelineEvent); } - private void writeToBuffer(BufferAccumulator> bufferAccumulator) { + private void writeToBuffer(BufferAccumulator> bufferAccumulator, AcknowledgementSet acknowledgementSet) { for (Event pipelineEvent : pipelineEvents) { addToBufferAccumulator(bufferAccumulator, new Record<>(pipelineEvent)); + if (acknowledgementSet != null) { + acknowledgementSet.add(pipelineEvent); + } } flushBufferAccumulator(bufferAccumulator, pipelineEvents.size()); @@ -304,10 +387,12 @@ private void addToBufferAccumulator(final BufferAccumulator> buffe private void flushBufferAccumulator(BufferAccumulator> bufferAccumulator, int eventCount) { try { bufferAccumulator.flush(); + changeEventSuccessCounter.increment(eventCount); } catch (Exception e) { // this will only happen if writing to buffer gets interrupted from shutdown, // otherwise bufferAccumulator will keep retrying with backoff LOG.error("Failed to flush buffer", e); + changeEventErrorCounter.increment(eventCount); } } @@ -333,4 +418,28 @@ private List getPrimaryKeys(String schemaName, String tableName) { return progressState.getPrimaryKeyMap().get(databaseName + "." + schemaName + "." + tableName); } + + private void handleMessageWithRetries(ByteBuffer message, Consumer function, MessageType messageType) { + int retry = 0; + while (retry <= NUM_OF_RETRIES) { + try { + eventProcessingTimer.record(() -> function.accept(message)); + return; + } catch (Exception e) { + LOG.warn("Error when processing change event of type {}, will retry", messageType, e); + applyBackoff(); + } + retry++; + } + LOG.error("Failed to process change event of type {} after {} retries", messageType, NUM_OF_RETRIES); + eventProcessingErrorCounter.increment(); + } + + private void applyBackoff() { + try { + Thread.sleep(BACKOFF_IN_MILLIS); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + } + } } diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointManager.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointManager.java index 3827f2b822..3880707e21 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointManager.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointManager.java @@ -5,9 +5,13 @@ package org.opensearch.dataprepper.plugins.source.rds.stream; +import io.micrometer.core.instrument.Counter; +import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSet; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType; import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; +import org.postgresql.replication.LogSequenceNumber; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -21,6 +25,11 @@ public class StreamCheckpointManager { private static final Logger LOG = LoggerFactory.getLogger(StreamCheckpointManager.class); static final int REGULAR_CHECKPOINT_INTERVAL_MILLIS = 60_000; static final int CHANGE_EVENT_COUNT_PER_CHECKPOINT_BATCH = 1000; + static final String POSITIVE_ACKNOWLEDGEMENT_SET_METRIC_NAME = "positiveAcknowledgementSets"; + static final String NEGATIVE_ACKNOWLEDGEMENT_SET_METRIC_NAME = "negativeAcknowledgementSets"; + static final String CHECKPOINT_COUNT = "checkpointCount"; + static final String NO_DATA_EXTEND_LEASE_COUNT = "noDataExtendLeaseCount"; + static final String GIVE_UP_PARTITION_COUNT = "giveupPartitionCount"; private final ConcurrentLinkedQueue changeEventStatuses = new ConcurrentLinkedQueue<>(); private final StreamCheckpointer streamCheckpointer; @@ -29,18 +38,35 @@ public class StreamCheckpointManager { private final boolean isAcknowledgmentEnabled; private final AcknowledgementSetManager acknowledgementSetManager; private final Duration acknowledgmentTimeout; + private final EngineType engineType; + private final PluginMetrics pluginMetrics; + private final Counter positiveAcknowledgementSets; + private final Counter negativeAcknowledgementSets; + private final Counter checkpointCount; + private final Counter noDataExtendLeaseCount; + private final Counter giveupPartitionCount; public StreamCheckpointManager(final StreamCheckpointer streamCheckpointer, final boolean isAcknowledgmentEnabled, final AcknowledgementSetManager acknowledgementSetManager, final Runnable stopStreamRunnable, - final Duration acknowledgmentTimeout) { + final Duration acknowledgmentTimeout, + final EngineType engineType, + final PluginMetrics pluginMetrics) { this.acknowledgementSetManager = acknowledgementSetManager; this.streamCheckpointer = streamCheckpointer; this.isAcknowledgmentEnabled = isAcknowledgmentEnabled; this.stopStreamRunnable = stopStreamRunnable; this.acknowledgmentTimeout = acknowledgmentTimeout; + this.engineType = engineType; + this.pluginMetrics = pluginMetrics; executorService = Executors.newSingleThreadExecutor(); + + this.positiveAcknowledgementSets = pluginMetrics.counter(POSITIVE_ACKNOWLEDGEMENT_SET_METRIC_NAME); + this.negativeAcknowledgementSets = pluginMetrics.counter(NEGATIVE_ACKNOWLEDGEMENT_SET_METRIC_NAME); + this.checkpointCount = pluginMetrics.counter(CHECKPOINT_COUNT); + this.noDataExtendLeaseCount = pluginMetrics.counter(NO_DATA_EXTEND_LEASE_COUNT); + this.giveupPartitionCount = pluginMetrics.counter(GIVE_UP_PARTITION_COUNT); } public void start() { @@ -54,6 +80,7 @@ void runCheckpointing() { try { if (changeEventStatuses.isEmpty()) { LOG.debug("No records processed. Extend the lease on stream partition."); + noDataExtendLeaseCount.increment(); streamCheckpointer.extendLease(); } else { if (isAcknowledgmentEnabled) { @@ -65,13 +92,14 @@ void runCheckpointing() { } if (lastChangeEventStatus != null) { - streamCheckpointer.checkpoint(lastChangeEventStatus.getBinlogCoordinate()); + checkpoint(engineType, lastChangeEventStatus); } // If negative ack is seen, give up partition and exit loop to stop processing stream if (currentChangeEventStatus != null && currentChangeEventStatus.isNegativeAcknowledgment()) { LOG.info("Received negative acknowledgement for change event at {}. Will restart from most recent checkpoint", currentChangeEventStatus.getBinlogCoordinate()); streamCheckpointer.giveUpPartition(); + giveupPartitionCount.increment(); break; } } else { @@ -81,10 +109,10 @@ void runCheckpointing() { changeEventCount++; // In case queue are populated faster than the poll, checkpoint when reaching certain count if (changeEventCount % CHANGE_EVENT_COUNT_PER_CHECKPOINT_BATCH == 0) { - streamCheckpointer.checkpoint(currentChangeEventStatus.getBinlogCoordinate()); + checkpoint(engineType, currentChangeEventStatus); } } while (!changeEventStatuses.isEmpty()); - streamCheckpointer.checkpoint(currentChangeEventStatus.getBinlogCoordinate()); + checkpoint(engineType, currentChangeEventStatus); } } } catch (Exception e) { @@ -107,25 +135,52 @@ public void stop() { executorService.shutdownNow(); } - public ChangeEventStatus saveChangeEventsStatus(BinlogCoordinate binlogCoordinate) { - final ChangeEventStatus changeEventStatus = new ChangeEventStatus(binlogCoordinate, Instant.now().toEpochMilli()); + public ChangeEventStatus saveChangeEventsStatus(BinlogCoordinate binlogCoordinate, long recordCount) { + final ChangeEventStatus changeEventStatus = new ChangeEventStatus(binlogCoordinate, Instant.now().toEpochMilli(), recordCount); changeEventStatuses.add(changeEventStatus); return changeEventStatus; } - public AcknowledgementSet createAcknowledgmentSet(BinlogCoordinate binlogCoordinate) { + public ChangeEventStatus saveChangeEventsStatus(LogSequenceNumber logSequenceNumber, long recordCount) { + final ChangeEventStatus changeEventStatus = new ChangeEventStatus(logSequenceNumber, Instant.now().toEpochMilli(), recordCount); + changeEventStatuses.add(changeEventStatus); + return changeEventStatus; + } + + public AcknowledgementSet createAcknowledgmentSet(BinlogCoordinate binlogCoordinate, long recordCount) { LOG.debug("Create acknowledgment set for events receive prior to {}", binlogCoordinate); - final ChangeEventStatus changeEventStatus = new ChangeEventStatus(binlogCoordinate, Instant.now().toEpochMilli()); + final ChangeEventStatus changeEventStatus = new ChangeEventStatus(binlogCoordinate, Instant.now().toEpochMilli(), recordCount); changeEventStatuses.add(changeEventStatus); + return getAcknowledgementSet(changeEventStatus); + } + + public AcknowledgementSet createAcknowledgmentSet(LogSequenceNumber logSequenceNumber, long recordCount) { + LOG.debug("Create acknowledgment set for events receive prior to {}", logSequenceNumber); + final ChangeEventStatus changeEventStatus = new ChangeEventStatus(logSequenceNumber, Instant.now().toEpochMilli(), recordCount); + changeEventStatuses.add(changeEventStatus); + return getAcknowledgementSet(changeEventStatus); + } + + private AcknowledgementSet getAcknowledgementSet(ChangeEventStatus changeEventStatus) { return acknowledgementSetManager.create((result) -> { - if (result) { - changeEventStatus.setAcknowledgmentStatus(ChangeEventStatus.AcknowledgmentStatus.POSITIVE_ACK); - } else { - changeEventStatus.setAcknowledgmentStatus(ChangeEventStatus.AcknowledgmentStatus.NEGATIVE_ACK); - } + if (result) { + positiveAcknowledgementSets.increment(); + changeEventStatus.setAcknowledgmentStatus(ChangeEventStatus.AcknowledgmentStatus.POSITIVE_ACK); + } else { + negativeAcknowledgementSets.increment(); + changeEventStatus.setAcknowledgmentStatus(ChangeEventStatus.AcknowledgmentStatus.NEGATIVE_ACK); + } }, acknowledgmentTimeout); } + private void checkpoint(final EngineType engineType, final ChangeEventStatus changeEventStatus) { + LOG.debug("Checkpoint at {} with record count {}. ", engineType == EngineType.MYSQL ? + changeEventStatus.getBinlogCoordinate() : changeEventStatus.getLogSequenceNumber(), + changeEventStatus.getRecordCount()); + streamCheckpointer.checkpoint(engineType, changeEventStatus); + checkpointCount.increment(); + } + //VisibleForTesting ConcurrentLinkedQueue getChangeEventStatuses() { return changeEventStatuses; diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointer.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointer.java index 1f60f9715f..2875bf5544 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointer.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointer.java @@ -8,9 +8,11 @@ import io.micrometer.core.instrument.Counter; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; +import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; import org.opensearch.dataprepper.plugins.source.rds.coordination.state.StreamProgressState; import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; +import org.postgresql.replication.LogSequenceNumber; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -38,7 +40,17 @@ public StreamCheckpointer(final EnhancedSourceCoordinator sourceCoordinator, checkpointCounter = pluginMetrics.counter(CHECKPOINT_COUNT); } - public void checkpoint(final BinlogCoordinate binlogCoordinate) { + public void checkpoint(final EngineType engineType, final ChangeEventStatus changeEventStatus) { + if (engineType == EngineType.MYSQL) { + checkpoint(changeEventStatus.getBinlogCoordinate()); + } else if (engineType == EngineType.POSTGRES) { + checkpoint(changeEventStatus.getLogSequenceNumber()); + } else { + throw new IllegalArgumentException("Unsupported engine type " + engineType); + } + } + + private void checkpoint(final BinlogCoordinate binlogCoordinate) { LOG.debug("Checkpointing stream partition {} with binlog coordinate {}", streamPartition.getPartitionKey(), binlogCoordinate); Optional progressState = streamPartition.getProgressState(); progressState.get().getMySqlStreamState().setCurrentPosition(binlogCoordinate); @@ -46,6 +58,14 @@ public void checkpoint(final BinlogCoordinate binlogCoordinate) { checkpointCounter.increment(); } + private void checkpoint(final LogSequenceNumber logSequenceNumber) { + LOG.debug("Checkpointing stream partition {} with log sequence number {}", streamPartition.getPartitionKey(), logSequenceNumber); + Optional progressState = streamPartition.getProgressState(); + progressState.get().getPostgresStreamState().setCurrentLsn(logSequenceNumber.asString()); + sourceCoordinator.saveProgressStateForPartition(streamPartition, CHECKPOINT_OWNERSHIP_TIMEOUT_INCREASE); + checkpointCounter.increment(); + } + public void extendLease() { LOG.debug("Extending lease of stream partition {}", streamPartition.getPartitionKey()); sourceCoordinator.saveProgressStateForPartition(streamPartition, CHECKPOINT_OWNERSHIP_TIMEOUT_INCREASE); diff --git a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresher.java b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresher.java index acd8d0535f..7d89855365 100644 --- a/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresher.java +++ b/data-prepper-plugins/rds-source/src/main/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamWorkerTaskRefresher.java @@ -130,8 +130,8 @@ private void refreshTask(RdsSourceConfig sourceConfig) { } else { final LogicalReplicationClient logicalReplicationClient = (LogicalReplicationClient) replicationLogClient; logicalReplicationClient.setEventProcessor(new LogicalReplicationEventProcessor( - streamPartition, sourceConfig, buffer, s3Prefix - )); + streamPartition, sourceConfig, buffer, s3Prefix, pluginMetrics, logicalReplicationClient, + streamCheckpointer, acknowledgementSetManager)); } final StreamWorker streamWorker = StreamWorker.create(sourceCoordinator, replicationLogClient, pluginMetrics); executorService.submit(() -> streamWorker.processStream(streamPartition)); @@ -150,4 +150,3 @@ private DbTableMetadata getDBTableMetadata(final StreamPartition streamPartition return DbTableMetadata.fromMap(globalState.getProgressState().get()); } } - diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/export/DataFileLoaderTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/export/DataFileLoaderTest.java index 6eeedfcd0f..efc831acfd 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/export/DataFileLoaderTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/export/DataFileLoaderTest.java @@ -140,7 +140,7 @@ void test_run_success() throws Exception { ParquetReader parquetReader = mock(ParquetReader.class); BufferAccumulator> bufferAccumulator = mock(BufferAccumulator.class); when(builder.build()).thenReturn(parquetReader); - when(parquetReader.read()).thenReturn(mock(GenericRecord.class, RETURNS_DEEP_STUBS), null); + when(parquetReader.read()).thenReturn(mock(GenericRecord.class, RETURNS_DEEP_STUBS), (GenericRecord) null); try (MockedStatic readerMockedStatic = mockStatic(AvroParquetReader.class); MockedStatic bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class)) { @@ -191,7 +191,7 @@ void test_flush_failure_then_error_metric_updated() throws Exception { BufferAccumulator> bufferAccumulator = mock(BufferAccumulator.class); doThrow(new RuntimeException("testing")).when(bufferAccumulator).flush(); when(builder.build()).thenReturn(parquetReader); - when(parquetReader.read()).thenReturn(mock(GenericRecord.class, RETURNS_DEEP_STUBS), null); + when(parquetReader.read()).thenReturn(mock(GenericRecord.class, RETURNS_DEEP_STUBS), (GenericRecord) null); try (MockedStatic readerMockedStatic = mockStatic(AvroParquetReader.class); MockedStatic bufferAccumulatorMockedStatic = mockStatic(BufferAccumulator.class)) { readerMockedStatic.when(() -> AvroParquetReader.builder(any(InputFile.class), any())).thenReturn(builder); diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClientTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClientTest.java index 9cd410ee44..45897335b5 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClientTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationClientTest.java @@ -33,7 +33,9 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.RETURNS_DEEP_STUBS; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) @@ -87,6 +89,92 @@ void test_connect() throws SQLException, InterruptedException { verify(stream).setFlushedLSN(lsn); } + @Test + void test_disconnect() throws SQLException, InterruptedException { + final Connection connection = mock(Connection.class); + final PGConnection pgConnection = mock(PGConnection.class, RETURNS_DEEP_STUBS); + final ChainedLogicalStreamBuilder logicalStreamBuilder = mock(ChainedLogicalStreamBuilder.class); + final PGReplicationStream stream = mock(PGReplicationStream.class); + final ByteBuffer message = ByteBuffer.allocate(0); + final LogSequenceNumber lsn = mock(LogSequenceNumber.class); + + when(connectionManager.getConnection()).thenReturn(connection); + when(connection.unwrap(PGConnection.class)).thenReturn(pgConnection); + when(pgConnection.getReplicationAPI().replicationStream().logical()).thenReturn(logicalStreamBuilder); + when(logicalStreamBuilder.withSlotName(anyString())).thenReturn(logicalStreamBuilder); + when(logicalStreamBuilder.withSlotOption(anyString(), anyString())).thenReturn(logicalStreamBuilder); + when(logicalStreamBuilder.start()).thenReturn(stream); + when(stream.readPending()).thenReturn(message).thenReturn(null); + when(stream.getLastReceiveLSN()).thenReturn(lsn); + + final ExecutorService executorService = Executors.newSingleThreadExecutor(); + executorService.submit(() -> logicalReplicationClient.connect()); + + await().atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> verify(eventProcessor).process(message)); + Thread.sleep(20); + verify(stream).setAppliedLSN(lsn); + verify(stream).setFlushedLSN(lsn); + + logicalReplicationClient.disconnect(); + Thread.sleep(20); + verify(stream).close(); + verifyNoMoreInteractions(stream, eventProcessor); + + executorService.shutdownNow(); + } + + @Test + void test_connect_disconnect_cycles() throws SQLException, InterruptedException { + final Connection connection = mock(Connection.class); + final PGConnection pgConnection = mock(PGConnection.class, RETURNS_DEEP_STUBS); + final ChainedLogicalStreamBuilder logicalStreamBuilder = mock(ChainedLogicalStreamBuilder.class); + final PGReplicationStream stream = mock(PGReplicationStream.class); + final ByteBuffer message = ByteBuffer.allocate(0); + final LogSequenceNumber lsn = mock(LogSequenceNumber.class); + + when(connectionManager.getConnection()).thenReturn(connection); + when(connection.unwrap(PGConnection.class)).thenReturn(pgConnection); + when(pgConnection.getReplicationAPI().replicationStream().logical()).thenReturn(logicalStreamBuilder); + when(logicalStreamBuilder.withSlotName(anyString())).thenReturn(logicalStreamBuilder); + when(logicalStreamBuilder.withSlotOption(anyString(), anyString())).thenReturn(logicalStreamBuilder); + when(logicalStreamBuilder.start()).thenReturn(stream); + when(stream.readPending()).thenReturn(message).thenReturn(null); + when(stream.getLastReceiveLSN()).thenReturn(lsn); + + // First connect + final ExecutorService executorService = Executors.newSingleThreadExecutor(); + executorService.submit(() -> logicalReplicationClient.connect()); + await().atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> verify(eventProcessor, times(1)).process(message)); + Thread.sleep(20); + verify(stream).setAppliedLSN(lsn); + verify(stream).setFlushedLSN(lsn); + + // First disconnect + logicalReplicationClient.disconnect(); + Thread.sleep(20); + verify(stream).close(); + verifyNoMoreInteractions(stream, eventProcessor); + + // Second connect + when(stream.readPending()).thenReturn(message).thenReturn(null); + executorService.submit(() -> logicalReplicationClient.connect()); + await().atMost(Duration.ofSeconds(1)) + .untilAsserted(() -> verify(eventProcessor, times(2)).process(message)); + Thread.sleep(20); + verify(stream, times(2)).setAppliedLSN(lsn); + verify(stream, times(2)).setFlushedLSN(lsn); + + // Second disconnect + logicalReplicationClient.disconnect(); + Thread.sleep(20); + verify(stream, times(2)).close(); + verifyNoMoreInteractions(stream, eventProcessor); + + executorService.shutdownNow(); + } + private LogicalReplicationClient createObjectUnderTest() { return new LogicalReplicationClient(connectionManager, replicationSlotName, publicationName); } diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessorTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessorTest.java index 31ec9618a2..90e8149319 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessorTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/LogicalReplicationEventProcessorTest.java @@ -10,12 +10,15 @@ package org.opensearch.dataprepper.plugins.source.rds.stream; +import io.micrometer.core.instrument.Metrics; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Answers; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.metrics.PluginMetrics; +import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; import org.opensearch.dataprepper.model.buffer.Buffer; import org.opensearch.dataprepper.model.event.Event; import org.opensearch.dataprepper.model.record.Record; @@ -28,9 +31,11 @@ import java.util.UUID; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) class LogicalReplicationEventProcessorTest { @@ -44,6 +49,18 @@ class LogicalReplicationEventProcessorTest { @Mock private Buffer> buffer; + @Mock + private PluginMetrics pluginMetrics; + + @Mock + private LogicalReplicationClient logicalReplicationClient; + + @Mock + private StreamCheckpointer streamCheckpointer; + + @Mock + private AcknowledgementSetManager acknowledgementSetManager; + private ByteBuffer message; private String s3Prefix; @@ -56,6 +73,8 @@ class LogicalReplicationEventProcessorTest { void setUp() { s3Prefix = UUID.randomUUID().toString(); random = new Random(); + when(pluginMetrics.timer(anyString())).thenReturn(Metrics.timer("test-timer")); + when(pluginMetrics.counter(anyString())).thenReturn(Metrics.counter("test-counter")); objectUnderTest = spy(createObjectUnderTest()); } @@ -129,8 +148,15 @@ void test_unsupported_message_type_throws_exception() { assertThrows(IllegalArgumentException.class, () -> objectUnderTest.process(message)); } + @Test + void test_stopClient() { + objectUnderTest.stopClient(); + verify(logicalReplicationClient).disconnect(); + } + private LogicalReplicationEventProcessor createObjectUnderTest() { - return new LogicalReplicationEventProcessor(streamPartition, sourceConfig, buffer, s3Prefix); + return new LogicalReplicationEventProcessor(streamPartition, sourceConfig, buffer, s3Prefix, pluginMetrics, + logicalReplicationClient, streamCheckpointer, acknowledgementSetManager); } private void setMessageType(MessageType messageType) { diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointManagerTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointManagerTest.java index deddb45e32..1b32639daf 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointManagerTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointManagerTest.java @@ -11,10 +11,14 @@ import org.mockito.Mock; import org.mockito.MockedStatic; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.acknowledgements.AcknowledgementSetManager; +import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType; import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; +import org.postgresql.replication.LogSequenceNumber; import java.time.Duration; +import java.util.Random; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.function.Consumer; @@ -42,11 +46,16 @@ class StreamCheckpointManagerTest { @Mock private Runnable stopStreamRunnable; + @Mock + private PluginMetrics pluginMetrics; + private boolean isAcknowledgmentEnabled = false; + private EngineType engineType = EngineType.MYSQL; + private Random random; @BeforeEach void setUp() { - + random = new Random(); } @Test @@ -76,29 +85,65 @@ void test_shutdown() { } @Test - void test_saveChangeEventsStatus() { + void test_saveChangeEventsStatus_mysql() { final BinlogCoordinate binlogCoordinate = mock(BinlogCoordinate.class); + final long recordCount = random.nextLong(); + final StreamCheckpointManager streamCheckpointManager = createObjectUnderTest(); + + streamCheckpointManager.saveChangeEventsStatus(binlogCoordinate, recordCount); + + assertThat(streamCheckpointManager.getChangeEventStatuses().size(), is(1)); + final ChangeEventStatus changeEventStatus = streamCheckpointManager.getChangeEventStatuses().peek(); + assertThat(changeEventStatus.getBinlogCoordinate(), is(binlogCoordinate)); + assertThat(changeEventStatus.getRecordCount(), is(recordCount)); + } + + @Test + void test_saveChangeEventsStatus_postgres() { + final LogSequenceNumber logSequenceNumber = mock(LogSequenceNumber.class); + engineType = EngineType.POSTGRES; + final long recordCount = random.nextLong(); final StreamCheckpointManager streamCheckpointManager = createObjectUnderTest(); - streamCheckpointManager.saveChangeEventsStatus(binlogCoordinate); + + streamCheckpointManager.saveChangeEventsStatus(logSequenceNumber, recordCount); assertThat(streamCheckpointManager.getChangeEventStatuses().size(), is(1)); - assertThat(streamCheckpointManager.getChangeEventStatuses().peek().getBinlogCoordinate(), is(binlogCoordinate)); + final ChangeEventStatus changeEventStatus = streamCheckpointManager.getChangeEventStatuses().peek(); + assertThat(changeEventStatus.getLogSequenceNumber(), is(logSequenceNumber)); + assertThat(changeEventStatus.getRecordCount(), is(recordCount)); } @Test - void test_createAcknowledgmentSet() { + void test_createAcknowledgmentSet_mysql() { final BinlogCoordinate binlogCoordinate = mock(BinlogCoordinate.class); + final long recordCount = random.nextLong(); final StreamCheckpointManager streamCheckpointManager = createObjectUnderTest(); - streamCheckpointManager.createAcknowledgmentSet(binlogCoordinate); + streamCheckpointManager.createAcknowledgmentSet(binlogCoordinate, recordCount); assertThat(streamCheckpointManager.getChangeEventStatuses().size(), is(1)); ChangeEventStatus changeEventStatus = streamCheckpointManager.getChangeEventStatuses().peek(); assertThat(changeEventStatus.getBinlogCoordinate(), is(binlogCoordinate)); + assertThat(changeEventStatus.getRecordCount(), is(recordCount)); + verify(acknowledgementSetManager).create(any(Consumer.class), eq(ACK_TIMEOUT)); + } + + @Test + void test_createAcknowledgmentSet_postgres() { + final LogSequenceNumber logSequenceNumber = mock(LogSequenceNumber.class); + engineType = EngineType.POSTGRES; + final long recordCount = random.nextLong(); + final StreamCheckpointManager streamCheckpointManager = createObjectUnderTest(); + streamCheckpointManager.createAcknowledgmentSet(logSequenceNumber, recordCount); + + assertThat(streamCheckpointManager.getChangeEventStatuses().size(), is(1)); + ChangeEventStatus changeEventStatus = streamCheckpointManager.getChangeEventStatuses().peek(); + assertThat(changeEventStatus.getLogSequenceNumber(), is(logSequenceNumber)); + assertThat(changeEventStatus.getRecordCount(), is(recordCount)); verify(acknowledgementSetManager).create(any(Consumer.class), eq(ACK_TIMEOUT)); } private StreamCheckpointManager createObjectUnderTest() { return new StreamCheckpointManager( - streamCheckpointer, isAcknowledgmentEnabled, acknowledgementSetManager, stopStreamRunnable, ACK_TIMEOUT); + streamCheckpointer, isAcknowledgmentEnabled, acknowledgementSetManager, stopStreamRunnable, ACK_TIMEOUT, engineType, pluginMetrics); } -} \ No newline at end of file +} diff --git a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointerTest.java b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointerTest.java index 3327e847f5..75f16ac9fc 100644 --- a/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointerTest.java +++ b/data-prepper-plugins/rds-source/src/test/java/org/opensearch/dataprepper/plugins/source/rds/stream/StreamCheckpointerTest.java @@ -13,10 +13,13 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.dataprepper.metrics.PluginMetrics; import org.opensearch.dataprepper.model.source.coordinator.enhanced.EnhancedSourceCoordinator; +import org.opensearch.dataprepper.plugins.source.rds.configuration.EngineType; import org.opensearch.dataprepper.plugins.source.rds.coordination.partition.StreamPartition; import org.opensearch.dataprepper.plugins.source.rds.coordination.state.MySqlStreamState; +import org.opensearch.dataprepper.plugins.source.rds.coordination.state.PostgresStreamState; import org.opensearch.dataprepper.plugins.source.rds.coordination.state.StreamProgressState; import org.opensearch.dataprepper.plugins.source.rds.model.BinlogCoordinate; +import org.postgresql.replication.LogSequenceNumber; import java.util.Optional; @@ -39,12 +42,18 @@ class StreamCheckpointerTest { @Mock private MySqlStreamState mySqlStreamState; + @Mock + private PostgresStreamState postgresStreamState; + @Mock private PluginMetrics pluginMetrics; @Mock private Counter checkpointCounter; + @Mock + private ChangeEventStatus changeEventStatus; + private StreamCheckpointer streamCheckpointer; @@ -55,19 +64,35 @@ void setUp() { } @Test - void test_checkpoint() { + void test_checkpoint_mysql() { final BinlogCoordinate binlogCoordinate = mock(BinlogCoordinate.class); final StreamProgressState streamProgressState = mock(StreamProgressState.class); when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState)); when(streamProgressState.getMySqlStreamState()).thenReturn(mySqlStreamState); + when(changeEventStatus.getBinlogCoordinate()).thenReturn(binlogCoordinate); - streamCheckpointer.checkpoint(binlogCoordinate); + streamCheckpointer.checkpoint(EngineType.MYSQL, changeEventStatus); verify(mySqlStreamState).setCurrentPosition(binlogCoordinate); verify(sourceCoordinator).saveProgressStateForPartition(streamPartition, CHECKPOINT_OWNERSHIP_TIMEOUT_INCREASE); verify(checkpointCounter).increment(); } + @Test + void test_checkpoint_postgres() { + final LogSequenceNumber logSequenceNumber = mock(LogSequenceNumber.class); + final StreamProgressState streamProgressState = mock(StreamProgressState.class); + when(streamPartition.getProgressState()).thenReturn(Optional.of(streamProgressState)); + when(streamProgressState.getPostgresStreamState()).thenReturn(postgresStreamState); + when(changeEventStatus.getLogSequenceNumber()).thenReturn(logSequenceNumber); + + streamCheckpointer.checkpoint(EngineType.POSTGRES, changeEventStatus); + + verify(postgresStreamState).setCurrentLsn(logSequenceNumber.asString()); + verify(sourceCoordinator).saveProgressStateForPartition(streamPartition, CHECKPOINT_OWNERSHIP_TIMEOUT_INCREASE); + verify(checkpointCounter).increment(); + } + @Test void test_extendLease() { streamCheckpointer.extendLease();