diff --git a/python/migrate.py b/python/migrate.py index d44990ca6..580ff4d18 100644 --- a/python/migrate.py +++ b/python/migrate.py @@ -1,11 +1,11 @@ import base64 import csv import datetime +import hashlib import io import json import os import re -import threading from decimal import Decimal from typing import Any, Dict, List @@ -41,6 +41,32 @@ class Constants: USERNAME = "username" +def _get_query_hash( + query: str, config: mgp.Map, params: mgp.Nullable[mgp.Any] = None +) -> str: + """ + Create a hash from query, config, and params to use as a cache key. + + :param query: The query string (or table name, endpoint, file path, etc.) + :param config: Configuration map + :param params: Optional query parameters + """ + config_dict = dict(config) + config_str = json.dumps(config_dict, sort_keys=True, default=str) + + params_str = "" + if params is not None: + if isinstance(params, dict): + params_str = json.dumps(params, sort_keys=True, default=str) + elif isinstance(params, (list, tuple)): + params_str = json.dumps(list(params), sort_keys=False, default=str) + else: + params_str = str(params) + + hash_input = f"{query}|{config_str}|{params_str}" + return hashlib.sha256(hash_input.encode("utf-8")).hexdigest() + + # MYSQL mysql_dict = {} @@ -62,23 +88,25 @@ def init_migrate_mysql( if _query_is_table(table_or_sql): table_or_sql = f"SELECT * FROM {table_or_sql};" - thread_id = threading.get_native_id() - if thread_id not in mysql_dict: - mysql_dict[thread_id] = {} + query_hash = _get_query_hash(table_or_sql, config, params) + + # check if query is already running + if query_hash in mysql_dict: + raise RuntimeError( + f"Migrate module with these parameters is already running. Please wait for it to finish before starting a new one." + ) - if Constants.CURSOR not in mysql_dict[thread_id]: - mysql_dict[thread_id][Constants.CURSOR] = None + mysql_dict[query_hash] = {} - if mysql_dict[thread_id][Constants.CURSOR] is None: - connection = mysql_connector.connect(**config) - cursor = connection.cursor() - cursor.execute(table_or_sql, params=params) + connection = mysql_connector.connect(**config) + cursor = connection.cursor() + cursor.execute(table_or_sql, params=params) - mysql_dict[thread_id][Constants.CONNECTION] = connection - mysql_dict[thread_id][Constants.CURSOR] = cursor - mysql_dict[thread_id][Constants.COLUMN_NAMES] = [ - column[Constants.I_COLUMN_NAME] for column in cursor.description - ] + mysql_dict[query_hash][Constants.CONNECTION] = connection + mysql_dict[query_hash][Constants.CURSOR] = cursor + mysql_dict[query_hash][Constants.COLUMN_NAMES] = [ + column[Constants.I_COLUMN_NAME] for column in cursor.description + ] def mysql( @@ -105,24 +133,43 @@ def mysql( """ global mysql_dict - thread_id = threading.get_native_id() - cursor = mysql_dict[thread_id][Constants.CURSOR] - column_names = mysql_dict[thread_id][Constants.COLUMN_NAMES] + if len(config_path) > 0: + config = _combine_config(config=config, config_path=config_path) + + if _query_is_table(table_or_sql): + table_or_sql = f"SELECT * FROM {table_or_sql};" + + query_hash = _get_query_hash(table_or_sql, config, params) + cursor = mysql_dict[query_hash][Constants.CURSOR] + column_names = mysql_dict[query_hash][Constants.COLUMN_NAMES] rows = cursor.fetchmany(Constants.BATCH_SIZE) - return [mgp.Record(row=_name_row_cells_mysql(row, column_names)) for row in rows] + result = [mgp.Record(row=_name_row_cells_mysql(row, column_names)) for row in rows] + # if results are empty, cleanup the query since cleanup doesn't accept any parameters + if not result: + _cleanup_mysql_by_hash(query_hash) -def cleanup_migrate_mysql(): + return result + + +def _cleanup_mysql_by_hash(query_hash: str): + """Internal cleanup function that takes a query hash.""" global mysql_dict - thread_id = threading.get_native_id() - mysql_dict[thread_id][Constants.CURSOR] = None - mysql_dict[thread_id][Constants.CONNECTION].commit() - mysql_dict[thread_id][Constants.CONNECTION].close() - mysql_dict[thread_id][Constants.CONNECTION] = None - mysql_dict[thread_id][Constants.COLUMN_NAMES] = None + if query_hash in mysql_dict: + mysql_dict[query_hash][Constants.CURSOR] = None + mysql_dict[query_hash][Constants.CONNECTION].commit() + mysql_dict[query_hash][Constants.CONNECTION].close() + mysql_dict[query_hash][Constants.CONNECTION] = None + mysql_dict[query_hash][Constants.COLUMN_NAMES] = None + mysql_dict.pop(query_hash, None) + + +def cleanup_migrate_mysql(): + """Cleanup function called by mgp framework (no parameters).""" + pass mgp.add_batch_read_proc(mysql, init_migrate_mysql, cleanup_migrate_mysql) @@ -151,23 +198,25 @@ def init_migrate_sql_server( if _query_is_table(table_or_sql): table_or_sql = f"SELECT * FROM {table_or_sql};" - thread_id = threading.get_native_id() - if thread_id not in sql_server_dict: - sql_server_dict[thread_id] = {} + query_hash = _get_query_hash(table_or_sql, config, params) - if Constants.CURSOR not in sql_server_dict[thread_id]: - sql_server_dict[thread_id][Constants.CURSOR] = None + # check if query is already running + if query_hash in sql_server_dict: + raise RuntimeError( + f"Migrate module with these parameters is already running. Please wait for it to finish before starting a new one." + ) - if sql_server_dict[thread_id][Constants.CURSOR] is None: - connection = pyodbc.connect(**config) - cursor = connection.cursor() - cursor.execute(table_or_sql, *params) + sql_server_dict[query_hash] = {} - sql_server_dict[thread_id][Constants.CONNECTION] = connection - sql_server_dict[thread_id][Constants.CURSOR] = cursor - sql_server_dict[thread_id][Constants.COLUMN_NAMES] = [ - column[Constants.I_COLUMN_NAME] for column in cursor.description - ] + connection = pyodbc.connect(**config) + cursor = connection.cursor() + cursor.execute(table_or_sql, *params) + + sql_server_dict[query_hash][Constants.CONNECTION] = connection + sql_server_dict[query_hash][Constants.CURSOR] = cursor + sql_server_dict[query_hash][Constants.COLUMN_NAMES] = [ + column[Constants.I_COLUMN_NAME] for column in cursor.description + ] def sql_server( @@ -193,23 +242,45 @@ def sql_server( """ global sql_server_dict - thread_id = threading.get_native_id() - cursor = sql_server_dict[thread_id][Constants.CURSOR] - column_names = sql_server_dict[thread_id][Constants.COLUMN_NAMES] + if not params: + params = [] + + if len(config_path) > 0: + config = _combine_config(config=config, config_path=config_path) + + if _query_is_table(table_or_sql): + table_or_sql = f"SELECT * FROM {table_or_sql};" + + query_hash = _get_query_hash(table_or_sql, config, params) + cursor = sql_server_dict[query_hash][Constants.CURSOR] + column_names = sql_server_dict[query_hash][Constants.COLUMN_NAMES] rows = cursor.fetchmany(Constants.BATCH_SIZE) - return [mgp.Record(row=_name_row_cells(row, column_names)) for row in rows] + result = [mgp.Record(row=_name_row_cells(row, column_names)) for row in rows] + # if results are empty, cleanup the query since cleanup doesn't accept any parameters + if not result: + _cleanup_sql_server_by_hash(query_hash) -def cleanup_migrate_sql_server(): + return result + + +def _cleanup_sql_server_by_hash(query_hash: str): + """Internal cleanup function that takes a query hash.""" global sql_server_dict - thread_id = threading.get_native_id() - sql_server_dict[thread_id][Constants.CURSOR] = None - sql_server_dict[thread_id][Constants.CONNECTION].commit() - sql_server_dict[thread_id][Constants.CONNECTION].close() - sql_server_dict[thread_id][Constants.CONNECTION] = None - sql_server_dict[thread_id][Constants.COLUMN_NAMES] = None + if query_hash in sql_server_dict: + sql_server_dict[query_hash][Constants.CURSOR] = None + sql_server_dict[query_hash][Constants.CONNECTION].commit() + sql_server_dict[query_hash][Constants.CONNECTION].close() + sql_server_dict[query_hash][Constants.CONNECTION] = None + sql_server_dict[query_hash][Constants.COLUMN_NAMES] = None + sql_server_dict.pop(query_hash, None) + + +def cleanup_migrate_sql_server(): + """Cleanup function called by mgp framework (no parameters).""" + pass mgp.add_batch_read_proc(sql_server, init_migrate_sql_server, cleanup_migrate_sql_server) @@ -242,29 +313,31 @@ def init_migrate_oracle_db( # To prevent query execution from hanging config["disable_oob"] = True - thread_id = threading.get_native_id() - if thread_id not in oracle_db_dict: - oracle_db_dict[thread_id] = {} + query_hash = _get_query_hash(table_or_sql, config, params) - if Constants.CURSOR not in oracle_db_dict[thread_id]: - oracle_db_dict[thread_id][Constants.CURSOR] = None + # check if query is already running + if query_hash in oracle_db_dict: + raise RuntimeError( + f"Migrate module with these parameters is already running. Please wait for it to finish before starting a new one." + ) - if oracle_db_dict[thread_id][Constants.CURSOR] is None: - connection = oracledb.connect(**config) - cursor = connection.cursor() + oracle_db_dict[query_hash] = {} - if not params: - cursor.execute(table_or_sql) - elif isinstance(params, (list, tuple)): - cursor.execute(table_or_sql, params) - else: - cursor.execute(table_or_sql, **params) + connection = oracledb.connect(**config) + cursor = connection.cursor() - oracle_db_dict[thread_id][Constants.CONNECTION] = connection - oracle_db_dict[thread_id][Constants.CURSOR] = cursor - oracle_db_dict[thread_id][Constants.COLUMN_NAMES] = [ - column[Constants.I_COLUMN_NAME] for column in cursor.description - ] + if not params: + cursor.execute(table_or_sql) + elif isinstance(params, (list, tuple)): + cursor.execute(table_or_sql, params) + else: + cursor.execute(table_or_sql, **params) + + oracle_db_dict[query_hash][Constants.CONNECTION] = connection + oracle_db_dict[query_hash][Constants.CURSOR] = cursor + oracle_db_dict[query_hash][Constants.COLUMN_NAMES] = [ + column[Constants.I_COLUMN_NAME] for column in cursor.description + ] def oracle_db( @@ -291,23 +364,46 @@ def oracle_db( global oracle_db_dict - thread_id = threading.get_native_id() - cursor = oracle_db_dict[thread_id][Constants.CURSOR] - column_names = oracle_db_dict[thread_id][Constants.COLUMN_NAMES] + if len(config_path) > 0: + config = _combine_config(config=config, config_path=config_path) + + if _query_is_table(table_or_sql): + table_or_sql = f"SELECT * FROM {table_or_sql}" + + if not config: + config = {} + config["disable_oob"] = True + + query_hash = _get_query_hash(table_or_sql, config, params) + cursor = oracle_db_dict[query_hash][Constants.CURSOR] + column_names = oracle_db_dict[query_hash][Constants.COLUMN_NAMES] rows = cursor.fetchmany(Constants.BATCH_SIZE) - return [mgp.Record(row=_name_row_cells(row, column_names)) for row in rows] + result = [mgp.Record(row=_name_row_cells(row, column_names)) for row in rows] + # if results are empty, cleanup the query since cleanup doesn't accept any parameters + if not result: + _cleanup_oracle_db_by_hash(query_hash) -def cleanup_migrate_oracle_db(): + return result + + +def _cleanup_oracle_db_by_hash(query_hash: str): + """Internal cleanup function that takes a query hash.""" global oracle_db_dict - thread_id = threading.get_native_id() - oracle_db_dict[thread_id][Constants.CURSOR] = None - oracle_db_dict[thread_id][Constants.CONNECTION].commit() - oracle_db_dict[thread_id][Constants.CONNECTION].close() - oracle_db_dict[thread_id][Constants.CONNECTION] = None - oracle_db_dict[thread_id][Constants.COLUMN_NAMES] = None + if query_hash in oracle_db_dict: + oracle_db_dict[query_hash][Constants.CURSOR] = None + oracle_db_dict[query_hash][Constants.CONNECTION].commit() + oracle_db_dict[query_hash][Constants.CONNECTION].close() + oracle_db_dict[query_hash][Constants.CONNECTION] = None + oracle_db_dict[query_hash][Constants.COLUMN_NAMES] = None + oracle_db_dict.pop(query_hash, None) + + +def cleanup_migrate_oracle_db(): + """Cleanup function called by mgp framework (no parameters).""" + pass mgp.add_batch_read_proc(oracle_db, init_migrate_oracle_db, cleanup_migrate_oracle_db) @@ -336,23 +432,25 @@ def init_migrate_postgresql( if _query_is_table(table_or_sql): table_or_sql = f"SELECT * FROM {table_or_sql};" - thread_id = threading.get_native_id() - if thread_id not in postgres_dict: - postgres_dict[thread_id] = {} + query_hash = _get_query_hash(table_or_sql, config, params) - if Constants.CURSOR not in postgres_dict[thread_id]: - postgres_dict[thread_id][Constants.CURSOR] = None + # check if query is already running + if query_hash in postgres_dict: + raise RuntimeError( + f"Migrate module with these parameters is already running. Please wait for it to finish before starting a new one." + ) - if postgres_dict[thread_id][Constants.CURSOR] is None: - connection = psycopg2.connect(**config) - cursor = connection.cursor() - cursor.execute(table_or_sql, params) + postgres_dict[query_hash] = {} - postgres_dict[thread_id][Constants.CONNECTION] = connection - postgres_dict[thread_id][Constants.CURSOR] = cursor - postgres_dict[thread_id][Constants.COLUMN_NAMES] = [ - column.name for column in cursor.description - ] + connection = psycopg2.connect(**config) + cursor = connection.cursor() + cursor.execute(table_or_sql, params) + + postgres_dict[query_hash][Constants.CONNECTION] = connection + postgres_dict[query_hash][Constants.CURSOR] = cursor + postgres_dict[query_hash][Constants.COLUMN_NAMES] = [ + column.name for column in cursor.description + ] def postgresql( @@ -378,24 +476,46 @@ def postgresql( """ global postgres_dict - thread_id = threading.get_native_id() - cursor = postgres_dict[thread_id][Constants.CURSOR] - column_names = postgres_dict[thread_id][Constants.COLUMN_NAMES] + if not params: + params = [] + + if len(config_path) > 0: + config = _combine_config(config=config, config_path=config_path) + + if _query_is_table(table_or_sql): + table_or_sql = f"SELECT * FROM {table_or_sql};" + + query_hash = _get_query_hash(table_or_sql, config, params) + cursor = postgres_dict[query_hash][Constants.CURSOR] + column_names = postgres_dict[query_hash][Constants.COLUMN_NAMES] rows = cursor.fetchmany(Constants.BATCH_SIZE) - return [mgp.Record(row=_name_row_cells(row, column_names)) for row in rows] + result = [mgp.Record(row=_name_row_cells(row, column_names)) for row in rows] + # if results are empty, cleanup the query since cleanup doesn't accept any parameters + if not result: + _cleanup_postgresql_by_hash(query_hash) -def cleanup_migrate_postgresql(): + return result + + +def _cleanup_postgresql_by_hash(query_hash: str): + """Internal cleanup function that takes a query hash.""" global postgres_dict - thread_id = threading.get_native_id() - postgres_dict[thread_id][Constants.CURSOR] = None - postgres_dict[thread_id][Constants.CONNECTION].commit() - postgres_dict[thread_id][Constants.CONNECTION].close() - postgres_dict[thread_id][Constants.CONNECTION] = None - postgres_dict[thread_id][Constants.COLUMN_NAMES] = None + if query_hash in postgres_dict: + postgres_dict[query_hash][Constants.CURSOR] = None + postgres_dict[query_hash][Constants.CONNECTION].commit() + postgres_dict[query_hash][Constants.CONNECTION].close() + postgres_dict[query_hash][Constants.CONNECTION] = None + postgres_dict[query_hash][Constants.COLUMN_NAMES] = None + postgres_dict.pop(query_hash, None) + + +def cleanup_migrate_postgresql(): + """Cleanup function called by mgp framework (no parameters).""" + pass mgp.add_batch_read_proc(postgresql, init_migrate_postgresql, cleanup_migrate_postgresql) @@ -432,6 +552,14 @@ def init_migrate_s3( bucket_name, *key_parts = file_path_no_protocol.split("/") s3_key = "/".join(key_parts) + query_hash = _get_query_hash(file_path, config) + + # check if query is already running + if query_hash in s3_dict: + raise RuntimeError( + f"Migrate module with these parameters is already running. Please wait for it to finish before starting a new one." + ) + # Initialize S3 client s3_client = boto3.client( "s3", @@ -456,12 +584,9 @@ def init_migrate_s3( csv_reader = csv.reader(text_stream) column_names = next(csv_reader) # First row contains column names - thread_id = threading.get_native_id() - if thread_id not in s3_dict: - s3_dict[thread_id] = {} - - s3_dict[thread_id][Constants.CURSOR] = csv_reader - s3_dict[thread_id][Constants.COLUMN_NAMES] = column_names + s3_dict[query_hash] = {} + s3_dict[query_hash][Constants.CURSOR] = csv_reader + s3_dict[query_hash][Constants.COLUMN_NAMES] = column_names def s3( @@ -480,9 +605,12 @@ def s3( """ global s3_dict - thread_id = threading.get_native_id() - csv_reader = s3_dict[thread_id][Constants.CURSOR] - column_names = s3_dict[thread_id][Constants.COLUMN_NAMES] + if len(config_path) > 0: + config = _combine_config(config=config, config_path=config_path) + + query_hash = _get_query_hash(file_path, config) + csv_reader = s3_dict[query_hash][Constants.CURSOR] + column_names = s3_dict[query_hash][Constants.COLUMN_NAMES] batch_rows = [] for _ in range(Constants.BATCH_SIZE): @@ -492,17 +620,24 @@ def s3( except StopIteration: break + # if results are empty, cleanup the query since cleanup doesn't accept any parameters + if not batch_rows: + _cleanup_s3_by_hash(query_hash) + return batch_rows -def cleanup_migrate_s3(): - """ - Clean up S3 dictionary references per-thread. - """ +def _cleanup_s3_by_hash(query_hash: str): + """Internal cleanup function that takes a query hash.""" global s3_dict - thread_id = threading.get_native_id() - s3_dict.pop(thread_id, None) + if query_hash in s3_dict: + s3_dict.pop(query_hash, None) + + +def cleanup_migrate_s3(): + """Cleanup function called by mgp framework (no parameters).""" + pass mgp.add_batch_read_proc(s3, init_migrate_s3, cleanup_migrate_s3) @@ -519,13 +654,18 @@ def init_migrate_neo4j( ): global neo4j_dict - thread_id = threading.get_native_id() - if thread_id not in neo4j_dict: - neo4j_dict[thread_id] = {} - if len(config_path) > 0: config = _combine_config(config=config, config_path=config_path) + query = _formulate_cypher_query(label_or_rel_or_query) + query_hash = _get_query_hash(query, config, params) + + # check if query is already running + if query_hash in neo4j_dict: + raise RuntimeError( + f"Migrate module with these parameters is already running. Please wait for it to finish before starting a new one." + ) + uri = _build_neo4j_uri(config) username = config.get(Constants.USERNAME, "neo4j") password = config.get(Constants.PASSWORD, "password") @@ -539,14 +679,14 @@ def init_migrate_neo4j( else: session = driver.session() - query = _formulate_cypher_query(label_or_rel_or_query) # Neo4j expects params to be a dict or None cypher_params = params if params is not None else {} result = session.run(query, parameters=cypher_params) - neo4j_dict[thread_id][Constants.DRIVER] = driver - neo4j_dict[thread_id][Constants.SESSION] = session - neo4j_dict[thread_id][Constants.RESULT] = result + neo4j_dict[query_hash] = {} + neo4j_dict[query_hash][Constants.DRIVER] = driver + neo4j_dict[query_hash][Constants.SESSION] = session + neo4j_dict[query_hash][Constants.RESULT] = result def neo4j( @@ -566,8 +706,12 @@ def neo4j( """ global neo4j_dict - thread_id = threading.get_native_id() - result = neo4j_dict[thread_id][Constants.RESULT] + if len(config_path) > 0: + config = _combine_config(config=config, config_path=config_path) + + query = _formulate_cypher_query(label_or_rel_or_query) + query_hash = _get_query_hash(query, config, params) + result = neo4j_dict[query_hash][Constants.RESULT] # Fetch up to BATCH_SIZE records batch = [] @@ -579,20 +723,30 @@ def neo4j( if len(batch) >= Constants.BATCH_SIZE: break + # if results are empty, cleanup the query since cleanup doesn't accept any parameters + if not batch: + _cleanup_neo4j_by_hash(query_hash) + return batch -def cleanup_migrate_neo4j(): +def _cleanup_neo4j_by_hash(query_hash: str): + """Internal cleanup function that takes a query hash.""" global neo4j_dict - thread_id = threading.get_native_id() - session = neo4j_dict[thread_id].get(Constants.SESSION) - driver = neo4j_dict[thread_id].get(Constants.DRIVER) - if session: - session.close() - if driver: - driver.close() - neo4j_dict.pop(thread_id, None) + if query_hash in neo4j_dict: + session = neo4j_dict[query_hash].get(Constants.SESSION) + driver = neo4j_dict[query_hash].get(Constants.DRIVER) + if session: + session.close() + if driver: + driver.close() + neo4j_dict.pop(query_hash, None) + + +def cleanup_migrate_neo4j(): + """Cleanup function called by mgp framework (no parameters).""" + pass mgp.add_batch_read_proc(neo4j, init_migrate_neo4j, cleanup_migrate_neo4j) @@ -612,6 +766,14 @@ def init_migrate_arrow_flight( if len(config_path) > 0: config = _combine_config(config=config, config_path=config_path) + query_hash = _get_query_hash(query, config) + + # check if query is already running + if query_hash in flight_dict: + raise RuntimeError( + f"Migrate module with these parameters is already running. Please wait for it to finish before starting a new one." + ) + host = config.get(Constants.HOST, None) port = config.get(Constants.PORT, None) username = config.get(Constants.USERNAME, "") @@ -633,13 +795,9 @@ def init_migrate_arrow_flight( flight.FlightDescriptor.for_command(query), options ) - # Store connection per thread - thread_id = threading.get_native_id() - if thread_id not in flight_dict: - flight_dict[thread_id] = {} - - flight_dict[thread_id][Constants.CONNECTION] = client - flight_dict[thread_id][Constants.CURSOR] = iter( + flight_dict[query_hash] = {} + flight_dict[query_hash][Constants.CONNECTION] = client + flight_dict[query_hash][Constants.CURSOR] = iter( _fetch_flight_data(client, flight_info, options) ) @@ -671,8 +829,11 @@ def arrow_flight( """ global flight_dict - thread_id = threading.get_native_id() - cursor = flight_dict[thread_id][Constants.CURSOR] + if len(config_path) > 0: + config = _combine_config(config=config, config_path=config_path) + + query_hash = _get_query_hash(query, config) + cursor = flight_dict[query_hash][Constants.CURSOR] batch = [] for _ in range(Constants.BATCH_SIZE): try: @@ -681,18 +842,24 @@ def arrow_flight( except StopIteration: break + # if results are empty, cleanup the query since cleanup doesn't accept any parameters + if not batch: + _cleanup_arrow_flight_by_hash(query_hash) + return batch -def cleanup_migrate_arrow_flight(): - """ - Close the Flight connection per-thread. - """ +def _cleanup_arrow_flight_by_hash(query_hash: str): + """Internal cleanup function that takes a query hash.""" global flight_dict - thread_id = threading.get_native_id() - if thread_id in flight_dict: - flight_dict.pop(thread_id, None) + if query_hash in flight_dict: + flight_dict.pop(query_hash, None) + + +def cleanup_migrate_arrow_flight(): + """Cleanup function called by mgp framework (no parameters).""" + pass mgp.add_batch_read_proc( @@ -709,16 +876,25 @@ def init_migrate_duckdb(query: str, setup_queries: mgp.Nullable[List[str]] = Non Initialize an in-memory DuckDB connection and execute the query. :param query: SQL query to execute - :param config: Unused but kept for consistency with other migration functions - :param config_path: Unused but kept for consistency with other migration functions + :param setup_queries: Optional list of setup queries to execute before the main query """ global duckdb_dict - thread_id = threading.get_native_id() - if thread_id not in duckdb_dict: - duckdb_dict[thread_id] = {} + # Create hash from query and setup_queries + setup_queries_str = ( + json.dumps(setup_queries, sort_keys=False) if setup_queries else "" + ) + query_hash = hashlib.sha256( + f"{query}|{setup_queries_str}".encode("utf-8") + ).hexdigest() + + # check if query is already running + if query_hash in duckdb_dict: + raise RuntimeError( + f"Migrate module with these parameters is already running. Please wait for it to finish before starting a new one." + ) - # Ensure a fresh in-memory DuckDB instance for each thread + # Ensure a fresh in-memory DuckDB instance for each query connection = duckDB.connect() cursor = connection.cursor() if setup_queries is not None: @@ -727,9 +903,10 @@ def init_migrate_duckdb(query: str, setup_queries: mgp.Nullable[List[str]] = Non cursor.execute(query) - duckdb_dict[thread_id][Constants.CONNECTION] = connection - duckdb_dict[thread_id][Constants.CURSOR] = cursor - duckdb_dict[thread_id][Constants.COLUMN_NAMES] = [ + duckdb_dict[query_hash] = {} + duckdb_dict[query_hash][Constants.CONNECTION] = connection + duckdb_dict[query_hash][Constants.CURSOR] = cursor + duckdb_dict[query_hash][Constants.COLUMN_NAMES] = [ desc[0] for desc in cursor.description ] @@ -741,31 +918,43 @@ def duckdb( Fetch rows from DuckDB in batches. :param query: SQL query to execute - :param config: Unused but kept for consistency with other migration functions - :param config_path: Unused but kept for consistency with other migration functions + :param setup_queries: Optional list of setup queries to execute before the main query :return: The result table as a stream of rows """ global duckdb_dict - thread_id = threading.get_native_id() - cursor = duckdb_dict[thread_id][Constants.CURSOR] - column_names = duckdb_dict[thread_id][Constants.COLUMN_NAMES] + setup_queries_str = ( + json.dumps(setup_queries, sort_keys=False) if setup_queries else "" + ) + query_hash = hashlib.sha256( + f"{query}|{setup_queries_str}".encode("utf-8") + ).hexdigest() + cursor = duckdb_dict[query_hash][Constants.CURSOR] + column_names = duckdb_dict[query_hash][Constants.COLUMN_NAMES] rows = cursor.fetchmany(Constants.BATCH_SIZE) - return [mgp.Record(row=_name_row_cells(row, column_names)) for row in rows] + result = [mgp.Record(row=_name_row_cells(row, column_names)) for row in rows] + # if results are empty, cleanup the query since cleanup doesn't accept any parameters + if not result: + _cleanup_duckdb_by_hash(query_hash) -def cleanup_migrate_duckdb(): - """ - Clean up DuckDB dictionary references per-thread. - """ + return result + + +def _cleanup_duckdb_by_hash(query_hash: str): + """Internal cleanup function that takes a query hash.""" global duckdb_dict - thread_id = threading.get_native_id() - if thread_id in duckdb_dict: - if Constants.CONNECTION in duckdb_dict[thread_id]: - duckdb_dict[thread_id][Constants.CONNECTION].close() - duckdb_dict.pop(thread_id, None) + if query_hash in duckdb_dict: + if Constants.CONNECTION in duckdb_dict[query_hash]: + duckdb_dict[query_hash][Constants.CONNECTION].close() + duckdb_dict.pop(query_hash, None) + + +def cleanup_migrate_duckdb(): + """Cleanup function called by mgp framework (no parameters).""" + pass mgp.add_batch_read_proc(duckdb, init_migrate_duckdb, cleanup_migrate_duckdb) @@ -782,19 +971,24 @@ def init_migrate_memgraph( ): global memgraph_dict - thread_id = threading.get_native_id() - if thread_id not in memgraph_dict: - memgraph_dict[thread_id] = {} - if len(config_path) > 0: config = _combine_config(config=config, config_path=config_path) - memgraph_db = Memgraph(**config) query = _formulate_cypher_query(label_or_rel_or_query) + query_hash = _get_query_hash(query, config, params) + + # check if query is already running + if query_hash in memgraph_dict: + raise RuntimeError( + f"Migrate module with these parameters is already running. Please wait for it to finish before starting a new one." + ) + + memgraph_db = Memgraph(**config) cursor = memgraph_db.execute_and_fetch(query, params) - memgraph_dict[thread_id][Constants.CONNECTION] = memgraph_db - memgraph_dict[thread_id][Constants.CURSOR] = cursor + memgraph_dict[query_hash] = {} + memgraph_dict[query_hash][Constants.CONNECTION] = memgraph_db + memgraph_dict[query_hash][Constants.CURSOR] = cursor def memgraph( @@ -814,23 +1008,39 @@ def memgraph( """ global memgraph_dict - thread_id = threading.get_native_id() - cursor = memgraph_dict[thread_id][Constants.CURSOR] + if len(config_path) > 0: + config = _combine_config(config=config, config_path=config_path) + + query = _formulate_cypher_query(label_or_rel_or_query) + query_hash = _get_query_hash(query, config, params) + cursor = memgraph_dict[query_hash][Constants.CURSOR] - return [ + result = [ mgp.Record(row=row) for row in (next(cursor, None) for _ in range(Constants.BATCH_SIZE)) if row is not None ] + # if results are empty, cleanup the query since cleanup doesn't accept any parameters + if not result: + _cleanup_memgraph_by_hash(query_hash) -def cleanup_migrate_memgraph(): + return result + + +def _cleanup_memgraph_by_hash(query_hash: str): + """Internal cleanup function that takes a query hash.""" global memgraph_dict - thread_id = threading.get_native_id() - if Constants.CONNECTION in memgraph_dict[thread_id]: - memgraph_dict[thread_id][Constants.CONNECTION].close() - memgraph_dict.pop(thread_id, None) + if query_hash in memgraph_dict: + if Constants.CONNECTION in memgraph_dict[query_hash]: + memgraph_dict[query_hash][Constants.CONNECTION].close() + memgraph_dict.pop(query_hash, None) + + +def cleanup_migrate_memgraph(): + """Cleanup function called by mgp framework (no parameters).""" + pass mgp.add_batch_read_proc(memgraph, init_migrate_memgraph, cleanup_migrate_memgraph) @@ -858,6 +1068,14 @@ def init_migrate_servicenow( if len(config_path) > 0: config = _combine_config(config=config, config_path=config_path) + query_hash = _get_query_hash(endpoint, config, params) + + # check if query is already running + if query_hash in servicenow_dict: + raise RuntimeError( + f"Migrate module with these parameters is already running. Please wait for it to finish before starting a new one." + ) + auth = (config.get(Constants.USERNAME), config.get(Constants.PASSWORD)) headers = {"Accept": "application/json"} @@ -868,11 +1086,8 @@ def init_migrate_servicenow( if not data: raise ValueError("No data found in ServiceNow response") - thread_id = threading.get_native_id() - if thread_id not in servicenow_dict: - servicenow_dict[thread_id] = {} - - servicenow_dict[thread_id][Constants.CURSOR] = iter(data) + servicenow_dict[query_hash] = {} + servicenow_dict[query_hash][Constants.CURSOR] = iter(data) def servicenow( @@ -892,8 +1107,11 @@ def servicenow( """ global servicenow_dict - thread_id = threading.get_native_id() - data_iter = servicenow_dict[thread_id][Constants.CURSOR] + if len(config_path) > 0: + config = _combine_config(config=config, config_path=config_path) + + query_hash = _get_query_hash(endpoint, config, params) + data_iter = servicenow_dict[query_hash][Constants.CURSOR] batch_rows = [] for _ in range(Constants.BATCH_SIZE): @@ -903,17 +1121,24 @@ def servicenow( except StopIteration: break + # if results are empty, cleanup the query since cleanup doesn't accept any parameters + if not batch_rows: + _cleanup_servicenow_by_hash(query_hash) + return batch_rows -def cleanup_migrate_servicenow(): - """ - Clean up ServiceNow dictionary references per-thread. - """ +def _cleanup_servicenow_by_hash(query_hash: str): + """Internal cleanup function that takes a query hash.""" global servicenow_dict - thread_id = threading.get_native_id() - servicenow_dict.pop(thread_id, None) + if query_hash in servicenow_dict: + servicenow_dict.pop(query_hash, None) + + +def cleanup_migrate_servicenow(): + """Cleanup function called by mgp framework (no parameters).""" + pass mgp.add_batch_read_proc(servicenow, init_migrate_servicenow, cleanup_migrate_servicenow)