Skip to content
This repository was archived by the owner on Sep 23, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Full list of options in `config.json`:
| use_secondary | Boolean | No | Use a database replica for `INCREMENTAL` and `FULL_TABLE` replication (Default : False) |
| secondary_host | String | No | PostgreSQL Replica host (required if `use_secondary` is `True`) |
| secondary_port | Integer | No | PostgreSQL Replica port (required if `use_secondary` is `True`) |
| wal2json_message_format | Integer | No | Which `wal2json` message format to use (Default: 1) |


### Run the tap in Discovery Mode
Expand Down
1 change: 1 addition & 0 deletions tap_postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def main_impl():
'break_at_end_lsn': args.config.get('break_at_end_lsn', True),
'logical_poll_total_seconds': float(args.config.get('logical_poll_total_seconds', 0)),
'use_secondary': args.config.get('use_secondary', False),
'wal2json_message_format': args.config.get('wal2json_message_format', 1)
}

if conn_config['use_secondary']:
Expand Down
163 changes: 135 additions & 28 deletions tap_postgres/sync_strategies/logical_replication.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,18 +377,24 @@ def row_to_singer_message(stream, row, version, columns, time_extracted, md_map,
time_extracted=time_extracted)


# pylint: disable=unused-argument,too-many-locals
def consume_message(streams, state, msg, time_extracted, conn_info):
# Strip leading comma generated by write-in-chunks and parse valid JSON
try:
payload = json.loads(msg.payload.lstrip(','))
except Exception:
return state
def check_for_new_columns(columns, target_stream, conn_info):
diff = set(columns).difference(target_stream['schema']['properties'].keys())

lsn = msg.data_start
if diff:
LOGGER.info('Detected new columns "%s", refreshing schema of stream %s', diff, target_stream['stream'])
# encountered a column that is not in the schema
# refresh the stream schema and metadata by running discovery
refresh_streams_schema(conn_info, [target_stream])

streams_lookup = {s['tap_stream_id']: s for s in streams}
# add the automatic properties back to the stream
add_automatic_properties(target_stream, conn_info.get('debug_lsn', False))

# publish new schema
sync_common.send_schema_message(target_stream, ['lsn'])


# pylint: disable=too-many-locals
def consume_message_format_1(payload, conn_info, streams_lookup, state, time_extracted, lsn):
tap_stream_id = post_db.compute_tap_stream_id(payload['schema'], payload['table'])
if streams_lookup.get(tap_stream_id) is None:
return state
Expand All @@ -400,22 +406,8 @@ def consume_message(streams, state, msg, time_extracted, conn_info):

# Get the additional fields in payload that are not in schema properties:
# only inserts and updates have the list of columns that can be used to detect any different in columns
diff = set()
if payload['kind'] in {'insert', 'update'}:
diff = set(payload['columnnames']).difference(target_stream['schema']['properties'].keys())

# if there is new columns in the payload that are not in the schema properties then refresh the stream schema
if diff:
LOGGER.info('Detected new columns "%s", refreshing schema of stream %s', diff, target_stream['stream'])
# encountered a column that is not in the schema
# refresh the stream schema and metadata by running discovery
refresh_streams_schema(conn_info, [target_stream])

# add the automatic properties back to the stream
add_automatic_properties(target_stream, conn_info.get('debug_lsn', False))

# publish new schema
sync_common.send_schema_message(target_stream, ['lsn'])
check_for_new_columns(payload['columnnames'], target_stream, conn_info)

stream_version = get_stream_version(target_stream['tap_stream_id'], state)
stream_md_map = metadata.to_map(target_stream['metadata'])
Expand Down Expand Up @@ -476,6 +468,109 @@ def consume_message(streams, state, msg, time_extracted, conn_info):
return state


def consume_message_format_2(payload, conn_info, streams_lookup, state, time_extracted, lsn):
## Action Types:
# I = Insert
# U = Update
# D = Delete
# B = Begin Transaction
# C = Commit Transaction
# M = Message
# T = Truncate
action = payload['action']
if action not in {'U', 'I', 'D'}:
raise UnsupportedPayloadKindError(f"unrecognized replication operation: {action}")

tap_stream_id = post_db.compute_tap_stream_id(payload['schema'], payload['table'])
if streams_lookup.get(tap_stream_id) is not None:
target_stream = streams_lookup[tap_stream_id]

# Get the additional fields in payload that are not in schema properties:
# only inserts and updates have the list of columns that can be used to detect any different in columns
if payload['action'] in {'I', 'U'}:
check_for_new_columns({column['name'] for column in payload['columns']}, target_stream, conn_info)

stream_version = get_stream_version(target_stream['tap_stream_id'], state)
stream_md_map = metadata.to_map(target_stream['metadata'])

desired_columns = {c for c in target_stream['schema']['properties'].keys() if sync_common.should_sync_column(
stream_md_map, c)}

stream_version = get_stream_version(target_stream['tap_stream_id'], state)
stream_md_map = metadata.to_map(target_stream['metadata'])

desired_columns = [
col for col in target_stream['schema']['properties'].keys()
if sync_common.should_sync_column(stream_md_map, col)
]

col_names = []
col_vals = []
if payload['action'] in ['I', 'U']:
for column in payload['columns']:
if column['name'] in set(desired_columns):
col_names.append(column['name'])
col_vals.append(column['value'])

col_names = col_names + ['_sdc_deleted_at']
col_vals = col_vals + [None]

if conn_info.get('debug_lsn'):
col_names = col_names + ['_sdc_lsn']
col_vals = col_vals + [str(lsn)]

elif payload['action'] == 'D':
for column in payload['identity']:
if column['name'] in set(desired_columns):
col_names.append(column['name'])
col_vals.append(column['value'])

col_names = col_names + ['_sdc_deleted_at']
col_vals = col_vals + [singer.utils.strftime(singer.utils.strptime_to_utc(payload['timestamp']))]

if conn_info.get('debug_lsn'):
col_vals = col_vals + [str(lsn)]
col_names = col_names + ['_sdc_lsn']

# Write 1 record to match the API of V1
record_message = row_to_singer_message(
target_stream,
col_vals,
stream_version,
col_names,
time_extracted,
stream_md_map,
conn_info,
)

singer.write_message(record_message)
state = singer.write_bookmark(state, target_stream['tap_stream_id'], 'lsn', lsn)

return state


def consume_message(streams, state, msg, time_extracted, conn_info):
# Strip leading comma generated by write-in-chunks and parse valid JSON
try:
payload = json.loads(msg.payload.lstrip(','))
except Exception:
return state

lsn = msg.data_start

streams_lookup = {s['tap_stream_id']: s for s in streams}

message_format = conn_info['wal2json_message_format']
if message_format == 1:
state = consume_message_format_1(payload, conn_info, streams_lookup, state, time_extracted, lsn)
elif message_format == 2:
state = consume_message_format_2(payload, conn_info, streams_lookup, state, time_extracted, lsn)
else:
raise Exception(f"Unknown wal2json message format version: {message_format}")

return state


def generate_replication_slot_name(dbname, tap_id=None, prefix='pipelinewise'):
"""Generate replication slot name with

Expand Down Expand Up @@ -591,14 +686,26 @@ def sync_tables(conn_info, logical_streams, state, end_lsn, state_file):
int_to_lsn(end_lsn),
slot)
# psycopg2 2.8.4 will send a keep-alive message to postgres every status_interval
options = {
'add-tables': streams_to_wal2json_tables(logical_streams),
'include-timestamp': True,
'include-types': False,
}
if conn_info['wal2json_message_format'] == 1:
options.update({'write-in-chunks': 1})
else:
options.update(
{
'format-version': conn_info['wal2json_message_format'],
'include-transaction': False,
'actions': 'insert,update,delete',
}
)
cur.start_replication(slot_name=slot,
decode=True,
start_lsn=start_lsn,
status_interval=poll_interval,
options={
'write-in-chunks': 1,
'add-tables': streams_to_wal2json_tables(logical_streams)
})
options=options)

except psycopg2.ProgrammingError as ex:
raise Exception(f"Unable to start replication with logical replication (slot {ex})") from ex
Expand Down
10 changes: 5 additions & 5 deletions tests/test_full_table_interruption.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def do_not_dump_catalog(catalog):
tap_postgres.dump_catalog = do_not_dump_catalog
full_table.UPDATE_BOOKMARK_PERIOD = 1

@pytest.mark.parametrize('use_secondary', [False, True])
@pytest.mark.parametrize('use_secondary,message_format', [(False, 1), (True, 2)])
@unittest.mock.patch('psycopg2.connect', wraps=psycopg2.connect)
class TestLogicalInterruption:
maxDiff = None
Expand All @@ -67,11 +67,11 @@ def setup_method(self):
global CAUGHT_MESSAGES
CAUGHT_MESSAGES.clear()

def test_catalog(self, mock_connect, use_secondary):
def test_catalog(self, mock_connect, use_secondary, message_format):
singer.write_message = singer_write_message_no_cow
pg_common.write_schema_message = singer_write_message_ok

conn_config = get_test_connection_config(use_secondary=use_secondary)
conn_config = get_test_connection_config(use_secondary=use_secondary, message_format=message_format)
streams = tap_postgres.do_discovery(conn_config)

# Assert that we connected to the correct database
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_catalog(self, mock_connect, use_secondary):
#the initial phase of cows logical replication will be a full table.
#it will sync the first record and then blow up on the 2nd record
try:
tap_postgres.do_sync(get_test_connection_config(use_secondary=use_secondary), {'streams' : streams}, None, state)
tap_postgres.do_sync(conn_config, {'streams' : streams}, None, state)
except Exception:
blew_up_on_cow = True

Expand Down Expand Up @@ -171,7 +171,7 @@ def test_catalog(self, mock_connect, use_secondary):
global COW_RECORD_COUNT
COW_RECORD_COUNT = 0
CAUGHT_MESSAGES.clear()
tap_postgres.do_sync(get_test_connection_config(use_secondary=use_secondary), {'streams' : streams}, None, old_state)
tap_postgres.do_sync(conn_config, {'streams' : streams}, None, old_state)

mock_connect.assert_called_with(**expected_connection)
mock_connect.reset_mock()
Expand Down
Loading