diff --git a/connector_manager/manager/__main__.py b/connector_manager/manager/__main__.py index 38fe1a5..0ebaa8e 100644 --- a/connector_manager/manager/__main__.py +++ b/connector_manager/manager/__main__.py @@ -7,7 +7,8 @@ from pathlib import Path from loguru import logger import sys -import json +import multiprocessing +from collections import defaultdict # Import the Podigee connector functionality from manager.podigee_connector import handle_podigee_refresh @@ -73,106 +74,65 @@ def ensure_db_connection(): logger.info("Interactive mode enabled") interactiveMode = True -print("Fetching all podcast tasks from database...") -sql = """ - SELECT account_id, source_name, source_podcast_id, source_access_keys_encrypted, pod_name - FROM podcastSources JOIN openpodcast.podcasts USING (account_id) -""" - -successful = 0 -failed = 0 - -with db.cursor() as cursor: - cursor.execute(sql) - results = cursor.fetchall() - -for ( - account_id, - source_name, - source_podcast_id, - source_access_keys_encrypted, - pod_name, -) in results: - if interactiveMode: - print( - f"Fetch podcast {pod_name} {account_id} for {source_name} using podcast_id {source_podcast_id}? [y/n]" +# Import worker functions and types from separate module for multiprocessing +from manager.worker import process_source_jobs, PodcastJob + + +if __name__ == '__main__': + + print("Fetching all podcast tasks from database...") + sql = """ + SELECT account_id, source_name, source_podcast_id, source_access_keys_encrypted, pod_name + FROM podcastSources JOIN openpodcast.podcasts USING (account_id) + """ + + with db.cursor() as cursor: + cursor.execute(sql) + results = cursor.fetchall() + + # Handle interactive mode by filtering jobs upfront + jobs_to_process = [] + for row in results: + job = PodcastJob( + account_id=row[0], + source_name=row[1], + source_podcast_id=row[2], + source_access_keys_encrypted=row[3], + pod_name=row[4] ) - if input() != "y": - continue - - # all keys that are needed to access the source - print(f"Decrypting keys for {pod_name} {account_id} for {source_name}") - source_access_keys = decrypt_json( - source_access_keys_encrypted, OPENPODCAST_ENCRYPTION_KEY - ) - - # Handle Podigee token refresh if this is a Podigee source - if source_name == "podigee": - # check if all relevant variables are set, otherwise skip this source - if not PODIGEE_CLIENT_ID or not PODIGEE_CLIENT_SECRET: - logger.error( - f"Missing Podigee credentials for {pod_name} {account_id}. Skipping this source." + + if interactiveMode: + print( + f"Fetch podcast {job.pod_name} {job.account_id} for {job.source_name} using podcast_id {job.source_podcast_id}? [y/n]" ) - continue - - # Ensure database connection is valid before token refresh - try: - db = ensure_db_connection() - except mysql.connector.Error: - logger.error(f"Cannot establish database connection for Podigee token refresh of {pod_name} {account_id}. Skipping this source.") - continue - - # Handle the token refresh and database update - source_access_keys = handle_podigee_refresh( - db_connection=db, - account_id=account_id, - source_name=source_name, - source_access_keys=source_access_keys, - pod_name=pod_name, - encryption_key=OPENPODCAST_ENCRYPTION_KEY, - client_id=PODIGEE_CLIENT_ID, - client_secret=PODIGEE_CLIENT_SECRET, - redirect_uri=PODIGEE_REDIRECT_URI - ) + if input() != "y": + continue - if (not source_access_keys) or ("PODIGEE_ACCESS_TOKEN" not in source_access_keys): - logger.error(f"Failed to refresh Podigee token for {pod_name} {account_id}. Skipping this source.") - continue + jobs_to_process.append(job) - logger.info( - f"Starting fetcher for {pod_name} {account_id} for {source_name} using podcast_id {source_podcast_id}" - ) + # Group jobs by source to avoid running multiple jobs for the same source in parallel + # This prevents rate limiting and credential issues with Apple, Spotify, etc. + jobs_by_source = defaultdict(list) + for job in jobs_to_process: + jobs_by_source[job.source_name].append(job) - # parent path of fetcher/connector - cwd = Path(CONNECTORS_PATH) / source_name - try: - # Ensure that environment variables are proper strings - source_access_keys = {k: str(v) for k, v in source_access_keys.items()} - - # run an external process, switch to right fetcher depending on - # source_name, and set env variables from source_access_keys - result = subprocess.run( - ["python", "-m", "job"], - cwd=cwd, - env={ - **os.environ, - **source_access_keys, - "PODCAST_ID": source_podcast_id, - "PODCAST_NAME": pod_name, - }, - text=True, - timeout=7200, # 120 minute timeout to prevent hanging of subprocesses - ) - if result.returncode == 0: - successful += 1 - else: - failed += 1 - logger.error(f"Fetching of {pod_name} not successful. Subprocess error output: {result.stderr}") - except subprocess.TimeoutExpired: - failed += 1 - logger.error(f"Error: Timeout while fetching {pod_name} (exceeded 120 minutes)") - except Exception as e: - failed += 1 - logger.error(f"Exception while fetching {pod_name}: {e}") - -logger.info(f"Completed. Successful: {successful}, Failed: {failed}") + # Process jobs: run different sources in parallel, but same-source jobs sequentially + if jobs_to_process: + logger.info(f"Processing {len(jobs_to_process)} jobs across {len(jobs_by_source)} sources...") + + all_results = [] + + # Use multiprocessing to process different sources in parallel + with multiprocessing.Pool() as pool: + results_by_source = pool.map(process_source_jobs, jobs_by_source.values()) + + # Flatten results + for source_results in results_by_source: + all_results.extend(source_results) + + successful = sum(1 for r in all_results if r) + failed = sum(1 for r in all_results if not r) + + logger.info(f"Completed. Successful: {successful}, Failed: {failed}") + else: + logger.info("No jobs to process") diff --git a/connector_manager/manager/worker.py b/connector_manager/manager/worker.py new file mode 100644 index 0000000..7b0dff5 --- /dev/null +++ b/connector_manager/manager/worker.py @@ -0,0 +1,183 @@ +""" +Worker functions for multiprocessing. + +These functions must be in a separate module (not __main__.py) +to be picklable for multiprocessing.Pool. +""" + +from dataclasses import dataclass +from manager.load_env import load_env, load_file_or_env +import mysql.connector +from manager.cryptography import decrypt_json +import os +import subprocess +from pathlib import Path +from loguru import logger + + +@dataclass +class PodcastJob: + account_id: str + source_name: str + source_podcast_id: str + source_access_keys_encrypted: str + pod_name: str + +# Load environment variables +CONNECTORS_PATH = load_env("CONNECTORS_PATH", ".") +MYSQL_HOST = load_env("MYSQL_HOST", "localhost") +MYSQL_PORT = load_env("MYSQL_PORT", 3306) +MYSQL_USER = load_env("MYSQL_USER", "root") +MYSQL_PASSWORD = load_file_or_env("MYSQL_PASSWORD") +MYSQL_DATABASE = load_env("MYSQL_DATABASE", "openpodcast_auth") +OPENPODCAST_ENCRYPTION_KEY = load_file_or_env("OPENPODCAST_ENCRYPTION_KEY") + +PODIGEE_CLIENT_ID = load_env("PODIGEE_CLIENT_ID") +PODIGEE_CLIENT_SECRET = load_file_or_env("PODIGEE_CLIENT_SECRET") +PODIGEE_REDIRECT_URI = load_env("PODIGEE_REDIRECT_URI", + "https://connect.openpodcast.app/auth/v1/podigee/callback") + + +def ensure_db_connection(): + """ + Ensure database connection is valid, reconnect if necessary. + Returns the database connection. + """ + global db + try: + if db is None: + logger.info("Establishing database connection...") + db = mysql.connector.connect( + host=MYSQL_HOST, + port=MYSQL_PORT, + user=MYSQL_USER, + passwd=MYSQL_PASSWORD, + database=MYSQL_DATABASE, + autocommit=True, + ) + logger.info("Database connection established") + elif not db.is_connected(): + logger.info("Database connection lost, reconnecting...") + db.close() + db = mysql.connector.connect( + host=MYSQL_HOST, + port=MYSQL_PORT, + user=MYSQL_USER, + passwd=MYSQL_PASSWORD, + database=MYSQL_DATABASE, + autocommit=True, + ) + logger.info("Database connection re-established") + return db + except mysql.connector.Error as e: + logger.error(f"Error connecting to mysql: {e}") + raise + + +# Initialize as None for worker processes +db = None + + +def process_podcast_job(job): + """ + Worker function to process a single podcast job. + Each worker process will have its own database connection. + """ + from manager.podigee_connector import handle_podigee_refresh + + # Each worker process needs its own database connection + global db + db = None + + try: + # all keys that are needed to access the source + print(f"Decrypting keys for {job.pod_name} {job.account_id} for {job.source_name}") + source_access_keys = decrypt_json( + job.source_access_keys_encrypted, OPENPODCAST_ENCRYPTION_KEY + ) + + # Handle Podigee token refresh if this is a Podigee source + if job.source_name == "podigee": + # check if all relevant variables are set, otherwise skip this source + if not PODIGEE_CLIENT_ID or not PODIGEE_CLIENT_SECRET: + logger.error( + f"Missing Podigee credentials for {job.pod_name} {job.account_id}. Skipping this source." + ) + return False + + # Ensure database connection is valid before token refresh + try: + db = ensure_db_connection() + except mysql.connector.Error: + logger.error(f"Cannot establish database connection for Podigee token refresh of {job.pod_name} {job.account_id}. Skipping this source.") + return False + + # Handle the token refresh and database update + source_access_keys = handle_podigee_refresh( + db_connection=db, + account_id=job.account_id, + source_name=job.source_name, + source_access_keys=source_access_keys, + pod_name=job.pod_name, + encryption_key=OPENPODCAST_ENCRYPTION_KEY, + client_id=PODIGEE_CLIENT_ID, + client_secret=PODIGEE_CLIENT_SECRET, + redirect_uri=PODIGEE_REDIRECT_URI + ) + + if (not source_access_keys) or ("PODIGEE_ACCESS_TOKEN" not in source_access_keys): + logger.error(f"Failed to refresh Podigee token for {job.pod_name} {job.account_id}. Skipping this source.") + return False + + logger.info( + f"Starting fetcher for {job.pod_name} {job.account_id} for {job.source_name} using podcast_id {job.source_podcast_id}" + ) + + # parent path of fetcher/connector + cwd = Path(CONNECTORS_PATH) / job.source_name + + # Ensure that environment variables are proper strings + source_access_keys = {k: str(v) for k, v in source_access_keys.items()} + + # run an external process, switch to right fetcher depending on + # source_name, and set env variables from source_access_keys + result = subprocess.run( + ["python", "-m", "job"], + cwd=cwd, + env={ + **os.environ, + **source_access_keys, + "PODCAST_ID": job.source_podcast_id, + "PODCAST_NAME": job.pod_name, + }, + text=True, + timeout=7200, # 120 minute timeout to prevent hanging of subprocesses + ) + + if result.returncode == 0: + return True + else: + logger.error(f"Fetching of {job.pod_name} not successful. Subprocess error output: {result.stderr}") + return False + + except subprocess.TimeoutExpired: + logger.error(f"Error: Timeout while fetching {job.pod_name} (exceeded 120 minutes)") + return False + except Exception as e: + logger.error(f"Exception while fetching {job.pod_name}: {e}") + return False + finally: + # Clean up database connection for this worker + if db is not None and db.is_connected(): + db.close() + + +def process_source_jobs(source_jobs): + """ + Process all jobs for a single source sequentially. + """ + source_results = [] + for job in source_jobs: + result = process_podcast_job(job) + source_results.append(result) + return source_results