From f1f45ea8ea2b73a955026786fb99d0ff0eb075d4 Mon Sep 17 00:00:00 2001 From: even-wei Date: Tue, 5 May 2026 20:55:08 +0800 Subject: [PATCH 1/3] feat(cli): add check-base subcommand + check_base_freshness() helper (M2) Adds `recce check-base` CLI subcommand with statuses FRESH / STALE_TIME / STALE_SHA / MISSING and --format json|text output. Exports check_base_freshness() as a module-level helper so mcp_server.py can call it at startup without duplicating logic (R8: cli.py-primary split). feat(mcp): emit stale-base warning at startup; expose base_status (M2, AC-3) Calls check_base_freshness() after load_context() in run_mcp_server() and prints a [Warning] line to stderr when status is STALE_TIME or STALE_SHA. Adds base_status field to get_server_info tool response so agents can programmatically detect stale state without parsing stderr. test: cover FRESH/STALE_TIME/STALE_SHA/MISSING + best-effort SHA absent (R9) 5 unit tests in tests/test_check_base.py; all pass. The absent-field test (test_sha_absent_no_raise) confirms DBT_GIT_SHA absence falls through to FRESH without raising. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: even-wei --- recce/cli.py | 856 ++++++++++++++++++++++++++++++++------- recce/mcp_server.py | 317 ++++++++++++--- tests/test_check_base.py | 124 ++++++ 3 files changed, 1085 insertions(+), 212 deletions(-) create mode 100644 tests/test_check_base.py diff --git a/recce/cli.py b/recce/cli.py index d58c85499..470e2beb6 100644 --- a/recce/cli.py +++ b/recce/cli.py @@ -134,8 +134,15 @@ def _add_options(func): dbt_related_options = [ - click.option("--target", "-t", help="Which target to load for the given profile.", type=click.STRING), - click.option("--profile", help="Which existing profile to load.", type=click.STRING), + click.option( + "--target", + "-t", + help="Which target to load for the given profile.", + type=click.STRING, + ), + click.option( + "--profile", help="Which existing profile to load.", type=click.STRING + ), click.option( "--project-dir", help="Which directory to look in for the dbt_project.yml file.", @@ -152,8 +159,18 @@ def _add_options(func): sqlmesh_related_options = [ click.option("--sqlmesh", is_flag=True, help="Use SQLMesh ", hidden=True), - click.option("--sqlmesh-envs", is_flag=False, help="SQLMesh envs to compare. SOURCE:TARGET", hidden=True), - click.option("--sqlmesh-config", is_flag=False, help="SQLMesh config name to use", hidden=True), + click.option( + "--sqlmesh-envs", + is_flag=False, + help="SQLMesh envs to compare. SOURCE:TARGET", + hidden=True, + ), + click.option( + "--sqlmesh-config", + is_flag=False, + help="SQLMesh config name to use", + hidden=True, + ), ] recce_options = [ @@ -165,7 +182,11 @@ def _add_options(func): show_default=True, ), click.option( - "--error-log", help="Path to the error log file.", type=click.Path(), default=RECCE_ERROR_LOG_FILE, hidden=True + "--error-log", + help="Path to the error log file.", + type=click.Path(), + default=RECCE_ERROR_LOG_FILE, + hidden=True, ), click.option("--debug", is_flag=True, help="Enable debug mode.", hidden=True), ] @@ -173,7 +194,10 @@ def _add_options(func): recce_cloud_options = [ click.option("--cloud", is_flag=True, help="Fetch the state file from cloud."), click.option( - "--cloud-token", help="The GitHub token used by Recce Cloud.", type=click.STRING, envvar="GITHUB_TOKEN" + "--cloud-token", + help="The GitHub token used by Recce Cloud.", + type=click.STRING, + envvar="GITHUB_TOKEN", ), click.option( "--state-file-host", @@ -234,7 +258,10 @@ def _add_options(func): "--session-id", help="The session ID triggers this instance.", type=click.STRING, - envvar=["RECCE_SESSION_ID", "RECCE_SNAPSHOT_ID"], # Backward compatibility with RECCE_SNAPSHOT_ID + envvar=[ + "RECCE_SESSION_ID", + "RECCE_SNAPSHOT_ID", + ], # Backward compatibility with RECCE_SNAPSHOT_ID hidden=True, ), ] @@ -244,7 +271,9 @@ def _execute_sql(context, sql_template, base=False): try: import pandas as pd except ImportError: - print("'pandas' package not found. You can install it using the command: 'pip install pandas'.") + print( + "'pandas' package not found. You can install it using the command: 'pip install pandas'." + ) exit(1) from recce.adapter.dbt_adapter import DbtAdapter @@ -254,7 +283,9 @@ def _execute_sql(context, sql_template, base=False): sql = dbt_adapter.generate_sql(sql_template, base) response, result = dbt_adapter.execute(sql, fetch=True, auto_begin=True) table = result - df = pd.DataFrame([row.values() for row in table.rows], columns=table.column_names) + df = pd.DataFrame( + [row.values() for row in table.rows], columns=table.column_names + ) return df @@ -271,7 +302,10 @@ def cli(ctx, **kwargs): error_console.print( f"[[yellow]Update Available[/yellow]] A new version of Recce {__latest_version__} is available.", ) - error_console.print("Please update using the command: 'pip install --upgrade recce'.", end="\n\n") + error_console.print( + "Please update using the command: 'pip install --upgrade recce'.", + end="\n\n", + ) @cli.command(cls=TrackCommand) @@ -356,10 +390,14 @@ def init(cache_db, **kwargs): cloud_token = kwargs.get("cloud_token") or kwargs.get("api_token") if not cloud_token: - console.print("[[red]Error[/red]] --cloud requires --cloud-token or --api-token (or GITHUB_TOKEN env var).") + console.print( + "[[red]Error[/red]] --cloud requires --cloud-token or --api-token (or GITHUB_TOKEN env var)." + ) exit(1) if not session_id: - console.print("[[red]Error[/red]] --cloud requires --session-id (or RECCE_SESSION_ID env var).") + console.print( + "[[red]Error[/red]] --cloud requires --session-id (or RECCE_SESSION_ID env var)." + ) exit(1) cloud_client = RecceCloud(token=cloud_token) @@ -377,30 +415,41 @@ def init(cache_db, **kwargs): console.print(f"[[red]Error[/red]] Failed to get session: {e}") exit(1) if session_info.get("status") == "error": - console.print(f"[[red]Error[/red]] Failed to get session: {session_info.get('message', 'Access denied')}") + console.print( + f"[[red]Error[/red]] Failed to get session: {session_info.get('message', 'Access denied')}" + ) exit(1) cloud_org_id = session_info.get("org_id") cloud_project_id = session_info.get("project_id") if not cloud_org_id or not cloud_project_id: - console.print(f"[[red]Error[/red]] Session {session_id} missing org_id or project_id.") + console.print( + f"[[red]Error[/red]] Session {session_id} missing org_id or project_id." + ) exit(1) # Download artifacts to local target directories console.print("Downloading artifacts from Cloud...") try: - download_urls = cloud_client.get_download_urls_by_session_id(cloud_org_id, cloud_project_id, session_id) + download_urls = cloud_client.get_download_urls_by_session_id( + cloud_org_id, cloud_project_id, session_id + ) except RecceCloudException as e: console.print(f"[[red]Error[/red]] Failed to get download URLs: {e}") exit(1) project_dir_path = Path(kwargs.get("project_dir") or "./") target_path = project_dir_path / kwargs.get("target_path", "target") - target_base_path = project_dir_path / kwargs.get("target_base_path", "target-base") + target_base_path = project_dir_path / kwargs.get( + "target_base_path", "target-base" + ) target_path.mkdir(parents=True, exist_ok=True) target_base_path.mkdir(parents=True, exist_ok=True) # Download current session artifacts - for artifact_key, filename in [("manifest_url", "manifest.json"), ("catalog_url", "catalog.json")]: + for artifact_key, filename in [ + ("manifest_url", "manifest.json"), + ("catalog_url", "catalog.json"), + ]: url = download_urls.get(artifact_key) if url: try: @@ -413,7 +462,9 @@ def init(cache_db, **kwargs): f" [[yellow]Warning[/yellow]] Failed to download {filename}: HTTP {resp.status_code}" ) except requests.RequestException as e: - console.print(f" [[yellow]Warning[/yellow]] Failed to download {filename}: {e}") + console.print( + f" [[yellow]Warning[/yellow]] Failed to download {filename}: {e}" + ) # Download base session artifacts try: @@ -421,22 +472,31 @@ def init(cache_db, **kwargs): cloud_org_id, cloud_project_id, session_id=session_id ) except RecceCloudException as e: - console.print(f" [[yellow]Warning[/yellow]] Failed to get base session URLs: {e}") + console.print( + f" [[yellow]Warning[/yellow]] Failed to get base session URLs: {e}" + ) base_download_urls = {} - for artifact_key, filename in [("manifest_url", "manifest.json"), ("catalog_url", "catalog.json")]: + for artifact_key, filename in [ + ("manifest_url", "manifest.json"), + ("catalog_url", "catalog.json"), + ]: url = base_download_urls.get(artifact_key) if url: try: resp = requests.get(url, timeout=_METADATA_TIMEOUT) if resp.status_code == 200: (target_base_path / filename).write_bytes(resp.content) - console.print(f" Downloaded base {filename} to {target_base_path}") + console.print( + f" Downloaded base {filename} to {target_base_path}" + ) else: console.print( f" [[yellow]Warning[/yellow]] Failed to download base {filename}: HTTP {resp.status_code}" ) except requests.RequestException as e: - console.print(f" [[yellow]Warning[/yellow]] Failed to download base {filename}: {e}") + console.print( + f" [[yellow]Warning[/yellow]] Failed to download base {filename}: {e}" + ) # Download existing CLL cache for warm start. # Try current session first, then fall back to production (base) session. @@ -451,7 +511,9 @@ def _stream_download_to_file(url: str, dest: Path) -> int: if resp.status_code != 200: return 0 total = 0 - with tempfile.NamedTemporaryFile(dir=dest.parent, delete=False, suffix=".tmp") as tmp: + with tempfile.NamedTemporaryFile( + dir=dest.parent, delete=False, suffix=".tmp" + ) as tmp: tmp_path = Path(tmp.name) try: for chunk in resp.iter_content(chunk_size=8192): @@ -473,10 +535,14 @@ def _stream_download_to_file(url: str, dest: Path) -> int: try: nbytes = _stream_download_to_file(cll_cache_url, Path(cache_db)) if nbytes > 0: - console.print(f" Downloaded CLL cache from session ({nbytes / 1024 / 1024:.1f} MB)") + console.print( + f" Downloaded CLL cache from session ({nbytes / 1024 / 1024:.1f} MB)" + ) cache_downloaded = True except requests.RequestException as e: - console.print(f" [[yellow]Warning[/yellow]] Failed to download CLL cache: {e}") + console.print( + f" [[yellow]Warning[/yellow]] Failed to download CLL cache: {e}" + ) if not cache_downloaded: # Fall back to production (base) session cache @@ -485,13 +551,19 @@ def _stream_download_to_file(url: str, dest: Path) -> int: try: nbytes = _stream_download_to_file(base_cache_url, Path(cache_db)) if nbytes > 0: - console.print(f" Downloaded CLL cache from base session ({nbytes / 1024 / 1024:.1f} MB)") + console.print( + f" Downloaded CLL cache from base session ({nbytes / 1024 / 1024:.1f} MB)" + ) cache_downloaded = True except requests.RequestException as e: - console.print(f" [[yellow]Warning[/yellow]] Failed to download base CLL cache: {e}") + console.print( + f" [[yellow]Warning[/yellow]] Failed to download base CLL cache: {e}" + ) if not cache_downloaded: - console.print(" [dim]No existing CLL cache found — will compute from scratch[/dim]") + console.print( + " [dim]No existing CLL cache found — will compute from scratch[/dim]" + ) if cache_db is None: cache_db = _DEFAULT_DB_PATH @@ -508,7 +580,9 @@ def _stream_download_to_file(url: str, dest: Path) -> int: if not is_cloud: project_dir_path = Path(kwargs.get("project_dir") or "./") target_path = project_dir_path / kwargs.get("target_path", "target") - target_base_path = project_dir_path / kwargs.get("target_base_path", "target-base") + target_base_path = project_dir_path / kwargs.get( + "target_base_path", "target-base" + ) has_target = (target_path / "manifest.json").is_file() has_base = (target_base_path / "manifest.json").is_file() @@ -525,10 +599,14 @@ def _stream_download_to_file(url: str, dest: Path) -> int: # If only one env exists, use it for both (so load_context doesn't fail) context_kwargs = {**kwargs} if has_target and not has_base: - console.print("[dim]Only target/ found — building cache for current environment only.[/dim]") + console.print( + "[dim]Only target/ found — building cache for current environment only.[/dim]" + ) context_kwargs["target_base_path"] = kwargs.get("target_path", "target") elif has_base and not has_target: - console.print("[dim]Only target-base/ found — building cache for base environment only.[/dim]") + console.print( + "[dim]Only target-base/ found — building cache for base environment only.[/dim]" + ) context_kwargs["target_path"] = kwargs.get("target_base_path", "target-base") try: @@ -559,7 +637,8 @@ def _stream_download_to_file(url: str, dest: Path) -> int: curr_ids = [ nid for nid in dbt_adapter.curr_manifest.nodes - if dbt_adapter.curr_manifest.nodes[nid].resource_type in ("model", "snapshot") + if dbt_adapter.curr_manifest.nodes[nid].resource_type + in ("model", "snapshot") ] envs.append(("current", curr_ids, False)) @@ -567,18 +646,26 @@ def _stream_download_to_file(url: str, dest: Path) -> int: base_ids = [ nid for nid in dbt_adapter.base_manifest.nodes - if dbt_adapter.base_manifest.nodes[nid].resource_type in ("model", "snapshot") + if dbt_adapter.base_manifest.nodes[nid].resource_type + in ("model", "snapshot") ] envs.append(("base", base_ids, True)) with Progress(console=console, transient=True) as progress: for env_name, node_ids, is_base in envs: - console.print(f"\n[bold]{env_name}[/bold] environment: {len(node_ids)} models") + console.print( + f"\n[bold]{env_name}[/bold] environment: {len(node_ids)} models" + ) t_start = time.perf_counter() - manifest = dbt_adapter.base_manifest if is_base else dbt_adapter.curr_manifest + manifest = ( + dbt_adapter.base_manifest if is_base else dbt_adapter.curr_manifest + ) catalog = dbt_adapter.base_catalog if is_base else dbt_adapter.curr_catalog - adapter_type = getattr(manifest.metadata, "adapter_type", None) or dbt_adapter.adapter.type() + adapter_type = ( + getattr(manifest.metadata, "adapter_type", None) + or dbt_adapter.adapter.type() + ) success = 0 fail = 0 @@ -598,8 +685,12 @@ def _stream_download_to_file(url: str, dest: Path) -> int: col_names = list(catalog.nodes[nid].columns.keys()) checksum = DbtAdapter._get_node_checksum(manifest, nid) - parent_checksums = [DbtAdapter._get_node_checksum(manifest, pid) for pid in p_list] - content_key = DbtAdapter._make_node_content_key(checksum, parent_checksums, col_names, adapter_type) + parent_checksums = [ + DbtAdapter._get_node_checksum(manifest, pid) for pid in p_list + ] + content_key = DbtAdapter._make_node_content_key( + checksum, parent_checksums, col_names, adapter_type + ) cached_json = cache.get_node(nid, content_key) if cached_json: cache_hits += 1 @@ -613,13 +704,17 @@ def _stream_download_to_file(url: str, dest: Path) -> int: fail += 1 progress.advance(task) continue - batch_to_store.append((nid, content_key, DbtAdapter._serialize_cll_data(cll_data))) + batch_to_store.append( + (nid, content_key, DbtAdapter._serialize_cll_data(cll_data)) + ) success += 1 except Exception as e: fail += 1 if fail <= 3: console.print(f" [dim red] skip: {nid}: {e}[/dim red]") - logger.debug("[recce init] CLL computation failed for %s: %s", nid, e) + logger.debug( + "[recce init] CLL computation failed for %s: %s", nid, e + ) progress.advance(task) if batch_to_store: @@ -645,7 +740,9 @@ def _stream_download_to_file(url: str, dest: Path) -> int: dbt_adapter.get_cll_cached.cache_clear() if fail > 3: - console.print(f" [dim]... and {fail - 3} more skipped (see logs for details)[/dim]") + console.print( + f" [dim]... and {fail - 3} more skipped (see logs for details)[/dim]" + ) # Build and save the full CLL map as JSON. # The per-node SQLite cache is warm from the loop above, so this is fast. @@ -676,7 +773,9 @@ def _stream_download_to_file(url: str, dest: Path) -> int: console.print(f" [[yellow]Warning[/yellow]] Failed to build CLL map: {e}") stats = cache.stats - console.print(f"\nCache saved to [bold]{cache_db}[/bold] ({stats['entries']} entries)") + console.print( + f"\nCache saved to [bold]{cache_db}[/bold] ({stats['entries']} entries)" + ) # In cloud mode, emit per_node.db — a pure-artifact SQLite that Cloud # streams to serve lineage without proxying to an ephemeral Recce instance. @@ -708,7 +807,9 @@ def _stream_download_to_file(url: str, dest: Path) -> int: ) except Exception as e: logger.warning("[recce init] Failed to emit metadata artifacts: %s", e) - console.print(f" [[yellow]Warning[/yellow]] Failed to emit metadata artifacts: {e}") + console.print( + f" [[yellow]Warning[/yellow]] Failed to emit metadata artifacts: {e}" + ) info_path = None lineage_diff_path = None @@ -717,10 +818,14 @@ def _stream_download_to_file(url: str, dest: Path) -> int: upload_failures: list[str] = [] upload_urls: Optional[dict] = None try: - upload_urls = cloud_client.get_upload_urls_by_session_id(cloud_org_id, cloud_project_id, session_id) + upload_urls = cloud_client.get_upload_urls_by_session_id( + cloud_org_id, cloud_project_id, session_id + ) except Exception as e: logger.warning("[recce init] Cloud upload failed: %s", e) - console.print(f" [[yellow]Warning[/yellow]] Cloud upload failed: {e}") + console.print( + f" [[yellow]Warning[/yellow]] Cloud upload failed: {e}" + ) if upload_urls is not None: # Emit per_node.db only when Cloud declares support for it. @@ -772,21 +877,35 @@ def _stream_download_to_file(url: str, dest: Path) -> int: def _to_dict(artifact): return ( artifact.to_dict() - if (artifact is not None and hasattr(artifact, "to_dict")) + if ( + artifact is not None + and hasattr(artifact, "to_dict") + ) else artifact ) - for env_name, manifest, catalog, cross_catalog in envs_to_emit: + for ( + env_name, + manifest, + catalog, + cross_catalog, + ) in envs_to_emit: if manifest is None: continue - manifest_dict = manifest.to_dict() if hasattr(manifest, "to_dict") else manifest + manifest_dict = ( + manifest.to_dict() + if hasattr(manifest, "to_dict") + else manifest + ) catalog_dict = _to_dict(catalog) cross_catalog_dict = _to_dict(cross_catalog) - node_rows, column_rows, edge_rows, test_rows = extract_rows_from_artifacts( - manifest_dict, - catalog_dict, - env_name, - cross_env_catalog=cross_catalog_dict, + node_rows, column_rows, edge_rows, test_rows = ( + extract_rows_from_artifacts( + manifest_dict, + catalog_dict, + env_name, + cross_env_catalog=cross_catalog_dict, + ) ) writer.write_nodes(node_rows) writer.write_columns(column_rows) @@ -794,10 +913,17 @@ def _to_dict(artifact): writer.write_tests(test_rows) pn_elapsed = time.perf_counter() - t_pn_start pn_size_mb = per_node_db_path.stat().st_size / 1024 / 1024 - console.print(f" per_node.db emitted " f"({pn_size_mb:.1f} MB, {pn_elapsed:.1f}s)") + console.print( + f" per_node.db emitted " + f"({pn_size_mb:.1f} MB, {pn_elapsed:.1f}s)" + ) except Exception as e: - logger.warning("[recce init] Failed to emit per_node.db: %s", e) - console.print(f" [[yellow]Warning[/yellow]] Failed to emit per_node.db: {e}") + logger.warning( + "[recce init] Failed to emit per_node.db: %s", e + ) + console.print( + f" [[yellow]Warning[/yellow]] Failed to emit per_node.db: {e}" + ) per_node_db_path = None else: console.print( @@ -828,7 +954,9 @@ def _to_dict(artifact): ) except requests.RequestException as e: upload_failures.append("cll_map.json") - console.print(f" [[yellow]Warning[/yellow]] Failed to upload cll_map.json: {e}") + console.print( + f" [[yellow]Warning[/yellow]] Failed to upload cll_map.json: {e}" + ) elif not cll_map_upload_url: console.print( " [[yellow]Warning[/yellow]] No cll_map_url in upload URLs " @@ -836,13 +964,19 @@ def _to_dict(artifact): ) # Upload per_node.db (only when Cloud supports it AND we emitted). - if per_node_db_upload_url and per_node_db_path and per_node_db_path.is_file(): + if ( + per_node_db_upload_url + and per_node_db_path + and per_node_db_path.is_file() + ): try: with open(per_node_db_path, "rb") as f: resp = requests.put( per_node_db_upload_url, data=f, - headers={"Content-Type": "application/octet-stream"}, + headers={ + "Content-Type": "application/octet-stream" + }, timeout=_UPLOAD_TIMEOUT, ) if resp.status_code in (200, 204): @@ -858,7 +992,9 @@ def _to_dict(artifact): ) except requests.RequestException as e: upload_failures.append("per_node.db") - console.print(f" [[yellow]Warning[/yellow]] Failed to upload per_node.db: {e}") + console.print( + f" [[yellow]Warning[/yellow]] Failed to upload per_node.db: {e}" + ) # Upload CLL cache. cll_cache.db is load-bearing across sessions — # build_full_cll_map reuses its warm entries on subsequent runs — @@ -870,7 +1006,9 @@ def _to_dict(artifact): resp = requests.put( cll_cache_upload_url, data=f, - headers={"Content-Type": "application/octet-stream"}, + headers={ + "Content-Type": "application/octet-stream" + }, timeout=_UPLOAD_TIMEOUT, ) if resp.status_code in (200, 204): @@ -886,9 +1024,13 @@ def _to_dict(artifact): ) except requests.RequestException as e: upload_failures.append("cll_cache.db") - console.print(f" [[yellow]Warning[/yellow]] Failed to upload cll_cache.db: {e}") + console.print( + f" [[yellow]Warning[/yellow]] Failed to upload cll_cache.db: {e}" + ) elif not cll_cache_upload_url: - logger.debug("No cll_cache_url in upload URLs — cache upload not supported yet") + logger.debug( + "No cll_cache_url in upload URLs — cache upload not supported yet" + ) # Upload info.json and lineage_diff.json. Graceful # degradation: if Cloud hasn't added the info_url / @@ -900,7 +1042,11 @@ def _to_dict(artifact): ] for display_name, local_path, url_key in metadata_uploads: metadata_upload_url = upload_urls.get(url_key) - if metadata_upload_url and local_path is not None and local_path.is_file(): + if ( + metadata_upload_url + and local_path is not None + and local_path.is_file() + ): try: with open(local_path, "rb") as f: resp = requests.put( @@ -911,7 +1057,9 @@ def _to_dict(artifact): ) if resp.status_code in (200, 204): size_kb = local_path.stat().st_size / 1024 - console.print(f" Uploaded {display_name} ({size_kb:.1f} KB)") + console.print( + f" Uploaded {display_name} ({size_kb:.1f} KB)" + ) else: upload_failures.append(display_name) console.print( @@ -920,8 +1068,12 @@ def _to_dict(artifact): ) except requests.RequestException as e: upload_failures.append(display_name) - console.print(f" [[yellow]Warning[/yellow]] Failed to upload {display_name}: {e}") - elif metadata_upload_url and (local_path is None or not local_path.is_file()): + console.print( + f" [[yellow]Warning[/yellow]] Failed to upload {display_name}: {e}" + ) + elif metadata_upload_url and ( + local_path is None or not local_path.is_file() + ): # URL present but local artifact missing — emit failed # partway (e.g., info.json written but lineage_diff.json # write raised). Record the failure so the summary @@ -952,7 +1104,9 @@ def _to_dict(artifact): shutil.rmtree(per_node_scratch, ignore_errors=True) shutil.rmtree(metadata_scratch, ignore_errors=True) else: - console.print("Run [bold]recce server --enable-cll-cache[/bold] to use the cached lineage.") + console.print( + "Run [bold]recce server --enable-cll-cache[/bold] to use the cached lineage." + ) @cli.command(cls=TrackCommand) @@ -981,22 +1135,32 @@ def check_artifacts(env_name, target_path): manifest_path = target_path / "manifest.json" manifest_is_ready = manifest_path.is_file() if manifest_is_ready: - console.print(f"[[green]OK[/green]] Manifest JSON file exists : {manifest_path}") + console.print( + f"[[green]OK[/green]] Manifest JSON file exists : {manifest_path}" + ) else: - console.print(f"[[red]MISS[/red]] Manifest JSON file not found: {manifest_path}") + console.print( + f"[[red]MISS[/red]] Manifest JSON file not found: {manifest_path}" + ) catalog_path = target_path / "catalog.json" catalog_is_ready = catalog_path.is_file() if catalog_is_ready: - console.print(f"[[green]OK[/green]] Catalog JSON file exists: {catalog_path}") + console.print( + f"[[green]OK[/green]] Catalog JSON file exists: {catalog_path}" + ) else: - console.print(f"[[red]MISS[/red]] Catalog JSON file not found: {catalog_path}") + console.print( + f"[[red]MISS[/red]] Catalog JSON file not found: {catalog_path}" + ) return [True, manifest_is_ready, catalog_is_ready] project_dir_path = Path(kwargs.get("project_dir") or "./") target_path = project_dir_path.joinpath(Path(kwargs.get("target_path", "target"))) - target_base_path = project_dir_path.joinpath(Path(kwargs.get("target_base_path", "target-base"))) + target_base_path = project_dir_path.joinpath( + Path(kwargs.get("target_base_path", "target-base")) + ) curr_is_ready = check_artifacts("Development", target_path) base_is_ready = check_artifacts("Base", target_base_path) @@ -1018,7 +1182,9 @@ def check_artifacts(env_name, target_path): if all(curr_is_ready) and all(base_is_ready) and conn_is_ready: console.print("[[green]OK[/green]] Ready to launch! Type 'recce server'.") elif all(curr_is_ready) and conn_is_ready: - console.print("[[orange3]OK[/orange3]] Ready to launch with [i]limited features[/i]. Type 'recce server'.") + console.print( + "[[orange3]OK[/orange3]] Ready to launch with [i]limited features[/i]. Type 'recce server'." + ) if not curr_is_ready[0]: console.print( @@ -1049,7 +1215,9 @@ def check_artifacts(env_name, target_path): ) if not conn_is_ready: - console.print("[[orange3]TIP[/orange3]] Run 'dbt debug' to check the connection.") + console.print( + "[[orange3]TIP[/orange3]] Run 'dbt debug' to check the connection." + ) @cli.command(hidden=True, cls=TrackCommand) @@ -1087,12 +1255,24 @@ def _split_comma_separated(ctx, param, value): help="Comma-separated list of primary key columns.", callback=_split_comma_separated, ) -@click.option("--keep-shape", is_flag=True, help="Keep unchanged columns. Otherwise, unchanged columns are hidden.") @click.option( - "--keep-equal", is_flag=True, help='Keep values that are equal. Otherwise, equal values are shown as "-".' + "--keep-shape", + is_flag=True, + help="Keep unchanged columns. Otherwise, unchanged columns are hidden.", +) +@click.option( + "--keep-equal", + is_flag=True, + help='Keep values that are equal. Otherwise, equal values are shown as "-".', ) @add_options(dbt_related_options) -def diff(sql, primary_keys: List[str] = None, keep_shape: bool = False, keep_equal: bool = False, **kwargs): +def diff( + sql, + primary_keys: List[str] = None, + keep_shape: bool = False, + keep_equal: bool = False, + **kwargs, +): """ Run queries on base and current environments and diff the results @@ -1114,16 +1294,29 @@ def diff(sql, primary_keys: List[str] = None, keep_shape: bool = False, keep_equ before_aligned, after_aligned = before.align(after) diff = before_aligned.compare( - after_aligned, result_names=("base", "current"), keep_equal=keep_equal, keep_shape=keep_shape + after_aligned, + result_names=("base", "current"), + keep_equal=keep_equal, + keep_shape=keep_shape, ) print(diff.to_string(na_rep="-") if not diff.empty else "no changes") @cli.command(cls=TrackCommand) @click.argument("state_file", required=False) -@click.option("--host", default="localhost", show_default=True, help="The host to bind to.") -@click.option("--port", default=8000, show_default=True, help="The port to bind to.", type=int) -@click.option("--lifetime", default=0, show_default=True, help="The lifetime of the server in seconds.", type=int) +@click.option( + "--host", default="localhost", show_default=True, help="The host to bind to." +) +@click.option( + "--port", default=8000, show_default=True, help="The port to bind to.", type=int +) +@click.option( + "--lifetime", + default=0, + show_default=True, + help="The lifetime of the server in seconds.", + type=int, +) @click.option( "--idle-timeout", default=0, @@ -1132,7 +1325,9 @@ def diff(sql, primary_keys: List[str] = None, keep_shape: bool = False, keep_equ type=int, ) @click.option("--review", is_flag=True, help="Open the state file in the review mode.") -@click.option("--single-env", is_flag=True, help="Launch in single environment mode directly.") +@click.option( + "--single-env", is_flag=True, help="Launch in single environment mode directly." +) @click.option( "--enable-cll-cache", is_flag=True, @@ -1238,7 +1433,9 @@ def server(host, port, lifetime, idle_timeout=0, state_file=None, **kwargs): # Check Single Environment Onboarding Mode if not in cloud mode and not in review mode if not is_cloud and not is_review: project_dir_path = Path(kwargs.get("project_dir") or "./") - target_base_path = project_dir_path.joinpath(Path(kwargs.get("target_base_path", "target-base"))) + target_base_path = project_dir_path.joinpath( + Path(kwargs.get("target_base_path", "target-base")) + ) if not target_base_path.is_dir(): # Mark as single env onboarding mode if user provides the target-path only flag["single_env_onboarding"] = True @@ -1357,7 +1554,9 @@ def server(host, port, lifetime, idle_timeout=0, state_file=None, **kwargs): ) @click.option("--state-file", help="Path of the import state file.", type=click.Path()) @click.option("--summary", help="Path of the summary markdown file.", type=click.Path()) -@click.option("--skip-query", is_flag=True, help="Skip running the queries for the checks.") +@click.option( + "--skip-query", is_flag=True, help="Skip running the queries for the checks." +) @click.option("--skip-check", is_flag=True, help="Skip running the checks.") @click.option( "--git-current-branch", @@ -1366,10 +1565,15 @@ def server(host, port, lifetime, idle_timeout=0, state_file=None, **kwargs): envvar="GITHUB_HEAD_REF", ) @click.option( - "--git-base-branch", help="The git branch of the base environment.", type=click.STRING, envvar="GITHUB_BASE_REF" + "--git-base-branch", + help="The git branch of the base environment.", + type=click.STRING, + envvar="GITHUB_BASE_REF", ) @click.option( - "--github-pull-request-url", help="The github pull request url to use for the lineage.", type=click.STRING + "--github-pull-request-url", + help="The github pull request url to use for the lineage.", + type=click.STRING, ) @add_options(dbt_related_options) @add_options(sqlmesh_related_options) @@ -1448,7 +1652,6 @@ def run(output, **kwargs): # Verify the output state file path try: if os.path.isdir(output) or output.endswith("/"): - output_dir = Path(output) # Create the directory if not exists output_dir.mkdir(parents=True, exist_ok=True) @@ -1507,7 +1710,10 @@ def summary(state_file, **kwargs): ) state_loader = create_state_loader( - review_mode=True, cloud_mode=cloud_mode, state_file=state_file, cloud_options=cloud_options + review_mode=True, + cloud_mode=cloud_mode, + state_file=state_file, + cloud_options=cloud_options, ) state_loader.load() @@ -1550,7 +1756,9 @@ def connect_to_cloud(): connect_url, callback_port = prepare_connection_url(public_key) console.rule("Connecting to Recce Cloud") - console.print("Attempting to automatically open the Recce Cloud authorization page in your default browser.") + console.print( + "Attempting to automatically open the Recce Cloud authorization page in your default browser." + ) console.print("If the browser does not open, please open the following URL:") console.print(connect_url) webbrowser.open(connect_url) @@ -1566,7 +1774,12 @@ def cloud(**kwargs): @cloud.command(cls=TrackCommand) -@click.option("--cloud-token", help="The GitHub token used by Recce Cloud.", type=click.STRING, envvar="GITHUB_TOKEN") +@click.option( + "--cloud-token", + help="The GitHub token used by Recce Cloud.", + type=click.STRING, + envvar="GITHUB_TOKEN", +) @click.option( "--state-file-host", help="The host to fetch the state file from.", @@ -1582,7 +1795,12 @@ def cloud(**kwargs): type=click.STRING, envvar="RECCE_STATE_PASSWORD", ) -@click.option("--force", "-f", help="Bypasses the confirmation prompt. Purge the state file directly.", is_flag=True) +@click.option( + "--force", + "-f", + help="Bypasses the confirmation prompt. Purge the state file directly.", + is_flag=True, +) @add_options(recce_options) def purge(**kwargs): """ @@ -1605,15 +1823,22 @@ def purge(**kwargs): try: console.rule("Check Recce State from Cloud") state_loader = create_state_loader( - review_mode=False, cloud_mode=True, state_file=None, cloud_options=cloud_options + review_mode=False, + cloud_mode=True, + state_file=None, + cloud_options=cloud_options, ) state_loader.load() except Exception: - console.print("[[yellow]Skip[/yellow]] Cannot access existing state file from cloud. Purge it directly.") + console.print( + "[[yellow]Skip[/yellow]] Cannot access existing state file from cloud. Purge it directly." + ) if state_loader is None: try: - if force_to_purge is True or click.confirm("\nDo you want to purge the state file?"): + if force_to_purge is True or click.confirm( + "\nDo you want to purge the state file?" + ): rc, err_msg = RecceCloudStateManager(cloud_options).purge_cloud_state() if rc is True: console.rule("Purged Successfully") @@ -1632,13 +1857,19 @@ def purge(**kwargs): pr_info = info.get("pull_request") console.print("[green]State File hosted by[/green]", info.get("source")) - console.print("[green]GitHub Repository[/green]", info.get("pull_request").repository) + console.print( + "[green]GitHub Repository[/green]", info.get("pull_request").repository + ) console.print(f"[green]GitHub Pull Request[/green]\n{pr_info.title} #{pr_info.id}") - console.print(f"Branch merged into [blue]{pr_info.base_branch}[/blue] from [blue]{pr_info.branch}[/blue]") + console.print( + f"Branch merged into [blue]{pr_info.base_branch}[/blue] from [blue]{pr_info.branch}[/blue]" + ) console.print(pr_info.url) try: - if force_to_purge is True or click.confirm("\nDo you want to purge the state file?"): + if force_to_purge is True or click.confirm( + "\nDo you want to purge the state file?" + ): response = state_loader.purge() if response is True: console.rule("Purged Successfully") @@ -1653,7 +1884,12 @@ def purge(**kwargs): @cloud.command(cls=TrackCommand) @click.argument("state_file", type=click.Path(exists=True)) -@click.option("--cloud-token", help="The GitHub token used by Recce Cloud.", type=click.STRING, envvar="GITHUB_TOKEN") +@click.option( + "--cloud-token", + help="The GitHub token used by Recce Cloud.", + type=click.STRING, + envvar="GITHUB_TOKEN", +) @click.option( "--state-file-host", help="The host to fetch the state file from.", @@ -1689,7 +1925,10 @@ def upload(state_file, **kwargs): # load local state state_loader = create_state_loader( - review_mode=False, cloud_mode=False, state_file=state_file, cloud_options=cloud_options + review_mode=False, + cloud_mode=False, + state_file=state_file, + cloud_options=cloud_options, ) state_loader.load() @@ -1709,7 +1948,9 @@ def upload(state_file, **kwargs): cloud_state_file_exists = state_manager.check_cloud_state_exists() - if cloud_state_file_exists and not click.confirm("\nDo you want to overwrite the existing state file?"): + if cloud_state_file_exists and not click.confirm( + "\nDo you want to overwrite the existing state file?" + ): return 0 console.print(state_manager.upload_state_to_cloud(state_loader.state)) @@ -1724,7 +1965,12 @@ def upload(state_file, **kwargs): default=DEFAULT_RECCE_STATE_FILE, show_default=True, ) -@click.option("--cloud-token", help="The GitHub token used by Recce Cloud.", type=click.STRING, envvar="GITHUB_TOKEN") +@click.option( + "--cloud-token", + help="The GitHub token used by Recce Cloud.", + type=click.STRING, + envvar="GITHUB_TOKEN", +) @click.option( "--state-file-host", help="The host to fetch the state file from.", @@ -1778,7 +2024,12 @@ def download(**kwargs): @cloud.command(cls=TrackCommand) -@click.option("--cloud-token", help="The GitHub token used by Recce Cloud.", type=click.STRING, envvar="GITHUB_TOKEN") +@click.option( + "--cloud-token", + help="The GitHub token used by Recce Cloud.", + type=click.STRING, + envvar="GITHUB_TOKEN", +) @click.option( "--branch", "-b", @@ -1825,7 +2076,11 @@ def upload_artifacts(**kwargs): try: rc = upload_dbt_artifacts( - target_path, branch=branch, token=cloud_token, password=password, debug=kwargs.get("debug", False) + target_path, + branch=branch, + token=cloud_token, + password=password, + debug=kwargs.get("debug", False), ) console.rule("Uploaded Successfully") console.print( @@ -1857,7 +2112,9 @@ def _download_artifacts(branch, cloud_token, console, kwargs, password, target_p ) except Exception as e: console.rule("Failed to Download", style="red") - console.print("[[red]Error[/red]] Failed to download the dbt artifacts from cloud.") + console.print( + "[[red]Error[/red]] Failed to download the dbt artifacts from cloud." + ) reason = str(e) if ( @@ -1870,7 +2127,9 @@ def _download_artifacts(branch, cloud_token, console, kwargs, password, target_p ) elif "The specified key does not exist" in reason: console.print("Reason: The dbt artifacts is not found in the cloud.") - console.print("Please upload the dbt artifacts to the cloud before downloading it.") + console.print( + "Please upload the dbt artifacts to the cloud before downloading it." + ) else: console.print(f"Reason: {reason}") rc = 1 @@ -1878,7 +2137,12 @@ def _download_artifacts(branch, cloud_token, console, kwargs, password, target_p @cloud.command(cls=TrackCommand) -@click.option("--cloud-token", help="The GitHub token used by Recce Cloud.", type=click.STRING, envvar="GITHUB_TOKEN") +@click.option( + "--cloud-token", + help="The GitHub token used by Recce Cloud.", + type=click.STRING, + envvar="GITHUB_TOKEN", +) @click.option( "--branch", "-b", @@ -1901,7 +2165,12 @@ def _download_artifacts(branch, cloud_token, console, kwargs, password, target_p envvar="RECCE_STATE_PASSWORD", required=True, ) -@click.option("--force", "-f", help="Bypasses the confirmation prompt. Download the artifacts directly.", is_flag=True) +@click.option( + "--force", + "-f", + help="Bypasses the confirmation prompt. Download the artifacts directly.", + is_flag=True, +) @add_options(recce_options) def download_artifacts(**kwargs): """ @@ -1922,11 +2191,18 @@ def download_artifacts(**kwargs): password = kwargs.get("password") target_path = kwargs.get("target_path") branch = kwargs.get("branch") or current_branch() - return _download_artifacts(branch, cloud_token, console, kwargs, password, target_path) + return _download_artifacts( + branch, cloud_token, console, kwargs, password, target_path + ) @cloud.command(cls=TrackCommand) -@click.option("--cloud-token", help="The GitHub token used by Recce Cloud.", type=click.STRING, envvar="GITHUB_TOKEN") +@click.option( + "--cloud-token", + help="The GitHub token used by Recce Cloud.", + type=click.STRING, + envvar="GITHUB_TOKEN", +) @click.option( "--branch", "-b", @@ -1949,7 +2225,12 @@ def download_artifacts(**kwargs): envvar="RECCE_STATE_PASSWORD", required=True, ) -@click.option("--force", "-f", help="Bypasses the confirmation prompt. Download the artifacts directly.", is_flag=True) +@click.option( + "--force", + "-f", + help="Bypasses the confirmation prompt. Download the artifacts directly.", + is_flag=True, +) @add_options(recce_options) def download_base_artifacts(**kwargs): """ @@ -1972,15 +2253,23 @@ def download_base_artifacts(**kwargs): # If recce can't infer default branch from "GITHUB_BASE_REF" and current_default_branch() if branch is None: console.print( - "[[red]Error[/red]] Please provide your base branch name with '--branch' to download the base " "artifacts." + "[[red]Error[/red]] Please provide your base branch name with '--branch' to download the base " + "artifacts." ) exit(1) - return _download_artifacts(branch, cloud_token, console, kwargs, password, target_path) + return _download_artifacts( + branch, cloud_token, console, kwargs, password, target_path + ) @cloud.command(cls=TrackCommand) -@click.option("--cloud-token", help="The GitHub token used by Recce Cloud.", type=click.STRING, envvar="GITHUB_TOKEN") +@click.option( + "--cloud-token", + help="The GitHub token used by Recce Cloud.", + type=click.STRING, + envvar="GITHUB_TOKEN", +) @click.option( "--branch", "-b", @@ -1988,7 +2277,12 @@ def download_base_artifacts(**kwargs): type=click.STRING, envvar="GITHUB_HEAD_REF", ) -@click.option("--force", "-f", help="Bypasses the confirmation prompt. Delete the artifacts directly.", is_flag=True) +@click.option( + "--force", + "-f", + help="Bypasses the confirmation prompt. Delete the artifacts directly.", + is_flag=True, +) @add_options(recce_options) def delete_artifacts(**kwargs): """ @@ -2011,28 +2305,43 @@ def delete_artifacts(**kwargs): force = kwargs.get("force", False) if not force: - if not click.confirm(f'Do you want to delete artifacts from branch "{branch}"?'): + if not click.confirm( + f'Do you want to delete artifacts from branch "{branch}"?' + ): console.print("Deletion cancelled.") return 0 try: - delete_dbt_artifacts(branch=branch, token=cloud_token, debug=kwargs.get("debug", False)) - console.print(f"[[green]Success[/green]] Artifacts deleted from branch: {branch}") + delete_dbt_artifacts( + branch=branch, token=cloud_token, debug=kwargs.get("debug", False) + ) + console.print( + f"[[green]Success[/green]] Artifacts deleted from branch: {branch}" + ) return 0 except click.exceptions.Abort: pass except RecceCloudException as e: - console.print("[[red]Error[/red]] Failed to delete the dbt artifacts from cloud.") + console.print( + "[[red]Error[/red]] Failed to delete the dbt artifacts from cloud." + ) console.print(f"Reason: {e.reason}") exit(1) except Exception as e: - console.print("[[red]Error[/red]] Failed to delete the dbt artifacts from cloud.") + console.print( + "[[red]Error[/red]] Failed to delete the dbt artifacts from cloud." + ) console.print(f"Reason: {e}") exit(1) @cloud.command(cls=TrackCommand, name="list-organizations") -@click.option("--api-token", help="The Recce Cloud API token.", type=click.STRING, envvar="RECCE_API_TOKEN") +@click.option( + "--api-token", + help="The Recce Cloud API token.", + type=click.STRING, + envvar="RECCE_API_TOKEN", +) @add_options(recce_options) def list_organizations(**kwargs): """ @@ -2072,7 +2381,9 @@ def list_organizations(**kwargs): table.add_column("Display Name", style="yellow") for org in organizations: - table.add_row(str(org.get("id", "")), org.get("name", ""), org.get("display_name", "")) + table.add_row( + str(org.get("id", "")), org.get("name", ""), org.get("display_name", "") + ) console.print(table) @@ -2092,7 +2403,12 @@ def list_organizations(**kwargs): type=click.STRING, envvar="RECCE_ORGANIZATION_ID", ) -@click.option("--api-token", help="The Recce Cloud API token.", type=click.STRING, envvar="RECCE_API_TOKEN") +@click.option( + "--api-token", + help="The Recce Cloud API token.", + type=click.STRING, + envvar="RECCE_API_TOKEN", +) @add_options(recce_options) def list_projects(**kwargs): """ @@ -2131,8 +2447,12 @@ def list_projects(**kwargs): organization = kwargs.get("organization") if not organization: - console.print("[[red]Error[/red]] Organization ID is required. Please provide it via:") - console.print(" --organization or set RECCE_ORGANIZATION_ID environment variable") + console.print( + "[[red]Error[/red]] Organization ID is required. Please provide it via:" + ) + console.print( + " --organization or set RECCE_ORGANIZATION_ID environment variable" + ) exit(1) try: @@ -2151,7 +2471,11 @@ def list_projects(**kwargs): table.add_column("Display Name", style="yellow") for project in projects: - table.add_row(str(project.get("id", "")), project.get("name", ""), project.get("display_name", "")) + table.add_row( + str(project.get("id", "")), + project.get("name", ""), + project.get("display_name", ""), + ) console.print(table) @@ -2178,7 +2502,12 @@ def list_projects(**kwargs): type=click.STRING, envvar="RECCE_PROJECT_ID", ) -@click.option("--api-token", help="The Recce Cloud API token.", type=click.STRING, envvar="RECCE_API_TOKEN") +@click.option( + "--api-token", + help="The Recce Cloud API token.", + type=click.STRING, + envvar="RECCE_API_TOKEN", +) @add_options(recce_options) def list_sessions(**kwargs): """ @@ -2226,12 +2555,18 @@ def list_sessions(**kwargs): # Validate required parameters if not organization: - console.print("[[red]Error[/red]] Organization ID is required. Please provide it via:") - console.print(" --organization or set RECCE_ORGANIZATION_ID environment variable") + console.print( + "[[red]Error[/red]] Organization ID is required. Please provide it via:" + ) + console.print( + " --organization or set RECCE_ORGANIZATION_ID environment variable" + ) exit(1) if not project: - console.print("[[red]Error[/red]] Project ID is required. Please provide it via:") + console.print( + "[[red]Error[/red]] Project ID is required. Please provide it via:" + ) console.print(" --project or set RECCE_PROJECT_ID environment variable") exit(1) @@ -2270,7 +2605,8 @@ def github(**kwargs): @github.command( - cls=TrackCommand, short_help="Download the artifacts from the GitHub repository based on the current Pull Request." + cls=TrackCommand, + short_help="Download the artifacts from the GitHub repository based on the current Pull Request.", ) @click.option( "--github-token", @@ -2328,7 +2664,10 @@ def share(state_file, **kwargs): # load local state state_loader = create_state_loader( - review_mode=True, cloud_mode=False, state_file=state_file, cloud_options=cloud_options + review_mode=True, + cloud_mode=False, + state_file=state_file, + cloud_options=cloud_options, ) state_loader.load() @@ -2352,7 +2691,10 @@ def share(state_file, **kwargs): try: response = state_manager.share_state(state_file_name, state_loader.state) if response.get("status") == "error": - console.print("[[red]Error[/red]] Failed to share the state.\n" f"Reason: {response.get('message')}") + console.print( + "[[red]Error[/red]] Failed to share the state.\n" + f"Reason: {response.get('message')}" + ) else: console.print(f"Shared Link: {response.get('share_url')}") except RecceCloudException as e: @@ -2431,7 +2773,10 @@ def upload_session(**kwargs): try: rc = upload_artifacts_to_session( - target_path, session_id=session_id, token=api_token, debug=kwargs.get("debug", False) + target_path, + session_id=session_id, + token=api_token, + debug=kwargs.get("debug", False), ) console.rule("Uploaded Successfully") console.print( @@ -2439,7 +2784,9 @@ def upload_session(**kwargs): ) except Exception as e: console.rule("Failed to Upload Session", style="red") - console.print(f"[[red]Error[/red]] Failed to upload the dbt artifacts to the session {session_id}.") + console.print( + f"[[red]Error[/red]] Failed to upload the dbt artifacts to the session {session_id}." + ) console.print(f"Reason: {e}") rc = 1 return rc @@ -2462,10 +2809,25 @@ def snapshot(**kwargs): @cli.command(hidden=True, cls=TrackCommand) @click.argument("state_file", required=True) -@click.option("--host", default="localhost", show_default=True, help="The host to bind to.") -@click.option("--port", default=8000, show_default=True, help="The port to bind to.", type=int) -@click.option("--lifetime", default=0, show_default=True, help="The lifetime of the server in seconds.", type=int) -@click.option("--share-url", help="The share URL triggers this instance.", type=click.STRING, envvar="RECCE_SHARE_URL") +@click.option( + "--host", default="localhost", show_default=True, help="The host to bind to." +) +@click.option( + "--port", default=8000, show_default=True, help="The port to bind to.", type=int +) +@click.option( + "--lifetime", + default=0, + show_default=True, + help="The lifetime of the server in seconds.", + type=int, +) +@click.option( + "--share-url", + help="The share URL triggers this instance.", + type=click.STRING, + envvar="RECCE_SHARE_URL", +) @click.pass_context def read_only(ctx, state_file=None, **kwargs): from recce.server import RecceServerMode @@ -2477,10 +2839,26 @@ def read_only(ctx, state_file=None, **kwargs): @cli.command(cls=TrackCommand) @click.argument("state_file", required=False) -@click.option("--sse", is_flag=True, default=False, help="Start in HTTP/SSE mode instead of stdio mode") -@click.option("--host", default="localhost", help="Host to bind to in SSE mode (default: localhost)") -@click.option("--port", default=8000, type=int, help="Port to bind to in SSE mode (default: 8000)") -@click.option("--session", "cloud_session", type=click.STRING, help="Recce Cloud session ID for cloud MCP mode") +@click.option( + "--sse", + is_flag=True, + default=False, + help="Start in HTTP/SSE mode instead of stdio mode", +) +@click.option( + "--host", + default="localhost", + help="Host to bind to in SSE mode (default: localhost)", +) +@click.option( + "--port", default=8000, type=int, help="Port to bind to in SSE mode (default: 8000)" +) +@click.option( + "--session", + "cloud_session", + type=click.STRING, + help="Recce Cloud session ID for cloud MCP mode", +) @add_options(dbt_related_options) @add_options(sqlmesh_related_options) @add_options(recce_options) @@ -2564,10 +2942,14 @@ def mcp_server(state_file, sse, host, port, **kwargs): cloud_session = kwargs.pop("cloud_session", None) if is_cloud_mcp and not cloud_session: - console.print("[[red]Error[/red]] --session is required when using --cloud with recce mcp-server.") + console.print( + "[[red]Error[/red]] --session is required when using --cloud with recce mcp-server." + ) exit(1) if cloud_session and not is_cloud_mcp: - console.print("[[red]Error[/red]] --cloud is required when using --session with recce mcp-server.") + console.print( + "[[red]Error[/red]] --cloud is required when using --session with recce mcp-server." + ) exit(1) # Prepare API token @@ -2593,8 +2975,12 @@ def mcp_server(state_file, sse, host, port, **kwargs): # the set_backend MCP tool. if not is_cloud_mcp: project_dir_path = Path(kwargs.get("project_dir") or "./") - target_path = project_dir_path.joinpath(Path(kwargs.get("target_path", "target"))) - target_base_path = project_dir_path.joinpath(Path(kwargs.get("target_base_path", "target-base"))) + target_path = project_dir_path.joinpath( + Path(kwargs.get("target_path", "target")) + ) + target_base_path = project_dir_path.joinpath( + Path(kwargs.get("target_base_path", "target-base")) + ) if target_path.is_dir() and not target_base_path.is_dir(): kwargs["single_env"] = True kwargs["target_base_path"] = kwargs.get("target_path") @@ -2602,11 +2988,15 @@ def mcp_server(state_file, sse, host, port, **kwargs): "[yellow]Base artifacts not found. " "Starting in single-environment mode (diffs will show no changes).[/yellow]" ) - console.print("To enable diffing: dbt docs generate --target-path target-base") + console.print( + "To enable diffing: dbt docs generate --target-path target-base" + ) try: if sse: - console.print(f"Starting Recce MCP Server in HTTP/SSE mode on {host}:{port}...") + console.print( + f"Starting Recce MCP Server in HTTP/SSE mode on {host}:{port}..." + ) console.print(f"SSE endpoint: http://{host}:{port}/sse") elif is_cloud_mcp: console.print("Starting Recce MCP Server in cloud stdio mode...") @@ -2617,7 +3007,11 @@ def mcp_server(state_file, sse, host, port, **kwargs): ) # Run the server (stdio or SSE based on --sse flag) - asyncio.run(run_mcp_server(sse=sse, host=host, port=port, session=cloud_session, **kwargs)) + asyncio.run( + run_mcp_server( + sse=sse, host=host, port=port, session=cloud_session, **kwargs + ) + ) except (asyncio.CancelledError, KeyboardInterrupt): # Graceful shutdown (e.g., Ctrl+C) console.print("[yellow]MCP Server interrupted[/yellow]") @@ -2714,5 +3108,167 @@ def clear_cache(cache_db): pass +def check_base_freshness( + target_base_path: str = "target-base", + target_path: str = "target", + freshness_threshold_hours: float = 48.0, +) -> dict: + """ + Check whether the base artifacts in target_base_path are fresh. + + Returns a dict with keys: + status: FRESH | STALE_TIME | STALE_SHA | MISSING + recommendation: reuse | docs_generate | full_build + message: human-readable explanation + artifact_age_hours: float or None + base_sha: str or None (DBT_GIT_SHA from manifest metadata) + current_sha: str or None (current HEAD SHA) + threshold_hours: float + """ + import json + import time + + manifest_path = Path(target_base_path) / "manifest.json" + result: dict = { + "status": None, + "recommendation": None, + "message": None, + "artifact_age_hours": None, + "base_sha": None, + "current_sha": None, + "threshold_hours": freshness_threshold_hours, + } + + if not manifest_path.exists(): + result["status"] = "MISSING" + result["recommendation"] = "full_build" + result["message"] = ( + f"Base artifacts not found at '{target_base_path}/manifest.json'. " + "Run: git stash; git checkout ; dbt build --target-path target-base; " + "git checkout ; git stash pop" + ) + return result + + # Compute artifact age from mtime + mtime = manifest_path.stat().st_mtime + now = time.time() + artifact_age_hours = (now - mtime) / 3600.0 + result["artifact_age_hours"] = artifact_age_hours + + # Time-based freshness check + if artifact_age_hours > freshness_threshold_hours: + result["status"] = "STALE_TIME" + result["recommendation"] = "docs_generate" + result["message"] = ( + f"Base artifacts are stale ({artifact_age_hours:.1f} hours old, " + f"threshold: {freshness_threshold_hours}h). " + "Run: dbt docs generate --target-path target-base" + ) + return result + + # SHA-based freshness check (best-effort: skip if field absent or git unavailable) + try: + with open(manifest_path) as f: + manifest_data = json.load(f) + base_sha = manifest_data.get("metadata", {}).get("env", {}).get("DBT_GIT_SHA") + result["base_sha"] = base_sha + + if base_sha is not None: + from recce.git import current_commit_hash + + current_sha = current_commit_hash() + result["current_sha"] = current_sha + if current_sha and base_sha != current_sha: + result["status"] = "STALE_SHA" + result["recommendation"] = "docs_generate" + result["message"] = ( + f"Base artifacts are stale (generated at {base_sha[:7]}, " + f"current HEAD: {current_sha[:7]}). " + "Run: dbt docs generate --target-path target-base" + ) + return result + except Exception: + # Best-effort: if manifest is unreadable or git is unavailable, skip SHA check + pass + + result["status"] = "FRESH" + result["recommendation"] = "reuse" + result["message"] = ( + f"Base artifacts are fresh ({artifact_age_hours:.1f} hours old). " + "Reuse existing artifacts." + ) + return result + + +@cli.command(name="check-base", cls=TrackCommand) +@click.option( + "--target-base-path", + default="target-base", + show_default=True, + help="Path to the base artifacts directory.", + type=click.Path(), +) +@click.option( + "--target-path", + default="target", + show_default=True, + help="Path to the current target artifacts directory.", + type=click.Path(), +) +@click.option( + "--format", + "output_format", + default="json", + show_default=True, + type=click.Choice(["json", "text"]), + help="Output format: json (default) or text for human-readable output.", +) +@click.option( + "--freshness-threshold-hours", + default=48.0, + show_default=True, + type=float, + help="Age threshold in hours after which artifacts are considered STALE_TIME.", +) +def check_base(target_base_path, target_path, output_format, freshness_threshold_hours): + """Check freshness of base artifacts for diff operations. + + Exits 0 when status is FRESH; exits 1 when STALE_TIME, STALE_SHA, or MISSING. + """ + import json + + result = check_base_freshness( + target_base_path=target_base_path, + target_path=target_path, + freshness_threshold_hours=freshness_threshold_hours, + ) + + if output_format == "json": + click.echo(json.dumps(result, indent=2)) + else: + from rich.console import Console + + console = Console() + status = result["status"] + age = result.get("artifact_age_hours") + msg = result.get("message", "") + color_map = { + "FRESH": "green", + "STALE_TIME": "yellow", + "STALE_SHA": "yellow", + "MISSING": "red", + } + color = color_map.get(status, "white") + console.print(f"[{color}]Status: {status}[/{color}]") + if age is not None: + console.print(f"Age: {age:.1f}h (threshold: {result['threshold_hours']}h)") + console.print(f"Recommendation: {result['recommendation']}") + if msg: + console.print(msg) + + if result["status"] != "FRESH": + raise SystemExit(1) + + if __name__ == "__main__": cli() diff --git a/recce/mcp_server.py b/recce/mcp_server.py index 3034db31b..42a73a041 100644 --- a/recce/mcp_server.py +++ b/recce/mcp_server.py @@ -71,7 +71,9 @@ class CloudBackend: "histogram_diff": "histogram_diff", } - def __init__(self, session_id: str, api_token: str, cloud_host: str = RECCE_CLOUD_API_HOST): + def __init__( + self, session_id: str, api_token: str, cloud_host: str = RECCE_CLOUD_API_HOST + ): self.session_id = session_id self.api_token = api_token self.cloud_host = cloud_host.rstrip("/") @@ -82,7 +84,9 @@ async def create(cls, session_id: str, api_token: str): backend = cls(session_id=session_id, api_token=api_token) spawn_response = await backend._request("POST", "instance", json={}) if isinstance(spawn_response, dict): - backend.instance_status = spawn_response.get("status") or spawn_response.get("instance_status") + backend.instance_status = spawn_response.get( + "status" + ) or spawn_response.get("instance_status") return backend def _url(self, api_name: str) -> str: @@ -94,7 +98,9 @@ async def _request(self, method: str, api_name: str, **kwargs): **kwargs.pop("headers", {}), "Authorization": f"Bearer {self.api_token}", } - response = await asyncio.to_thread(requests.request, method, url, headers=headers, **kwargs) + response = await asyncio.to_thread( + requests.request, method, url, headers=headers, **kwargs + ) if response.status_code == 405: raise InstanceSpawningError() if response.status_code < 200 or response.status_code >= 300: @@ -172,7 +178,9 @@ async def _tool_query(self, arguments: Dict[str, Any]) -> Dict[str, Any]: params = {k: v for k, v in arguments.items() if k != "base"} return await self._tool_run_backed(run_type, params) - async def _tool_run_backed(self, run_type: str, params: Dict[str, Any]) -> Dict[str, Any]: + async def _tool_run_backed( + self, run_type: str, params: Dict[str, Any] + ) -> Dict[str, Any]: run = await self._request( "POST", "runs", @@ -205,7 +213,11 @@ async def _tool_run_check(self, arguments: Dict[str, Any]) -> Dict[str, Any]: check_id = arguments.get("check_id") if not check_id: raise ValueError("check_id is required") - run = await self._request("POST", f"checks/{quote(str(check_id), safe='')}/run", json={"nowait": False}) + run = await self._request( + "POST", + f"checks/{quote(str(check_id), safe='')}/run", + json={"nowait": False}, + ) if self._run_succeeded(run): await self._auto_approve(check_id) return run @@ -230,12 +242,20 @@ async def _tool_create_check(self, arguments: Dict[str, Any]) -> Dict[str, Any]: # executable run). lineage_diff/schema_diff are recorded server-side via # POST /checks/{id}/run, mirroring local _create_metadata_run. if check_id and check_type != "simple": - run = await self._request("POST", f"checks/{quote(str(check_id), safe='')}/run", json={"nowait": False}) + run = await self._request( + "POST", + f"checks/{quote(str(check_id), safe='')}/run", + json={"nowait": False}, + ) run_executed = True run_error = run.get("error") if self._run_succeeded(run): await self._auto_approve(check_id) - result = {"check_id": str(check_id), "created": True, "run_executed": run_executed} + result = { + "check_id": str(check_id), + "created": True, + "run_executed": run_executed, + } if run_error: result["run_error"] = run_error return result @@ -248,7 +268,11 @@ async def _auto_approve(self, check_id) -> None: post-success side-effect, not part of the run contract. """ try: - await self._request("PATCH", f"checks/{quote(str(check_id), safe='')}", json={"is_checked": True}) + await self._request( + "PATCH", + f"checks/{quote(str(check_id), safe='')}", + json={"is_checked": True}, + ) except (RecceCloudException, InstanceSpawningError) as e: logger.warning(f"[MCP] Auto-approve failed for check {check_id}: {e}") @@ -257,9 +281,17 @@ async def _tool_lineage_diff(self, arguments: Dict[str, Any]) -> Dict[str, Any]: lineage = info.get("lineage", {}) nodes = lineage.get("nodes", {}) selected = await self._selected_nodes(arguments, nodes) - impacted = set((await self._request("POST", "select", json={"select": "state:modified+"})).get("nodes", [])) + impacted = set( + ( + await self._request( + "POST", "select", json={"select": "state:modified+"} + ) + ).get("nodes", []) + ) - selected_nodes = {node_id: node for node_id, node in nodes.items() if node_id in selected} + selected_nodes = { + node_id: node for node_id, node in nodes.items() if node_id in selected + } id_to_idx = {node_id: idx for idx, node_id in enumerate(selected_nodes.keys())} nodes_df = DataFrame.from_data( columns={ @@ -291,8 +323,13 @@ async def _tool_lineage_diff(self, arguments: Dict[str, Any]) -> Dict[str, Any]: target = edge.get("target") if source in id_to_idx and target in id_to_idx: edge_rows.append((id_to_idx[source], id_to_idx[target])) - edges_df = DataFrame.from_data(columns={"from": "integer", "to": "integer"}, data=edge_rows) - return {"nodes": nodes_df.model_dump(mode="json"), "edges": edges_df.model_dump(mode="json")} + edges_df = DataFrame.from_data( + columns={"from": "integer", "to": "integer"}, data=edge_rows + ) + return { + "nodes": nodes_df.model_dump(mode="json"), + "edges": edges_df.model_dump(mode="json"), + } async def _tool_schema_diff(self, arguments: Dict[str, Any]) -> Dict[str, Any]: info = await self._request("GET", "info") @@ -316,10 +353,19 @@ async def _tool_schema_diff(self, arguments: Dict[str, Any]) -> Dict[str, Any]: async def _tool_impact_analysis(self, arguments: Dict[str, Any]) -> Dict[str, Any]: info = await self._request("GET", "info") nodes = info.get("lineage", {}).get("nodes", {}) - select = arguments.get("select", "state:modified.body+ state:modified.macros+ state:modified.contract+") - impacted_node_ids = set((await self._request("POST", "select", json={"select": select})).get("nodes", [])) + select = arguments.get( + "select", + "state:modified.body+ state:modified.macros+ state:modified.contract+", + ) + impacted_node_ids = set( + (await self._request("POST", "select", json={"select": select})).get( + "nodes", [] + ) + ) modified_node_ids = set( - (await self._request("POST", "select", json={"select": "state:modified"})).get("nodes", []) + ( + await self._request("POST", "select", json={"select": "state:modified"}) + ).get("nodes", []) ) impacted_models = [] @@ -329,12 +375,16 @@ async def _tool_impact_analysis(self, arguments: Dict[str, Any]) -> Dict[str, An continue entry = { "name": node.get("name"), - "change_status": node.get("change_status") if node_id in modified_node_ids else None, + "change_status": node.get("change_status") + if node_id in modified_node_ids + else None, "materialized": node.get("materialized"), "row_count": None, "schema_changes": [ {"column": column, "change_status": status} - for column, status in ((node.get("change") or {}).get("columns") or {}).items() + for column, status in ( + (node.get("change") or {}).get("columns") or {} + ).items() ], "value_diff": None, "affected_row_count": None, @@ -372,7 +422,9 @@ async def _selected_nodes(self, arguments: Dict[str, Any], nodes: Dict[str, Any] for key in ("select", "exclude", "packages", "view_mode") if arguments.get(key) is not None } - return set((await self._request("POST", "select", json=payload)).get("nodes", [])) + return set( + (await self._request("POST", "select", json=payload)).get("nodes", []) + ) return set(nodes.keys()) @staticmethod @@ -392,7 +444,9 @@ def _redact_sensitive_args(arguments: Dict[str, Any]) -> Dict[str, Any]: """ if not isinstance(arguments, dict): return arguments - return {k: ("***" if k in SENSITIVE_ARG_KEYS and v else v) for k, v in arguments.items()} + return { + k: ("***" if k in SENSITIVE_ARG_KEYS and v else v) for k, v in arguments.items() + } def _truncate_strings(obj: Any, max_length: int = 200) -> Any: @@ -542,7 +596,9 @@ def _setup_handlers(self): @self.server.list_tools() async def list_tools() -> List[Tool]: """List all available tools based on server mode""" - logger.info(f"[MCP] list_tools called (mode: {self.mode.value if self.mode else 'server'})") + logger.info( + f"[MCP] list_tools called (mode: {self.mode.value if self.mode else 'server'})" + ) tools = [] # Always available in all modes @@ -1227,10 +1283,19 @@ async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: } # Unconfigured-mode gate: when neither a local context nor a cloud # backend is set, only set_backend and get_server_info are usable. - if self.context is None and self.backend is None and name not in {"set_backend", "get_server_info"}: - raise ValueError("No backend configured. Call set_backend(mode='local'|'cloud', ...) first.") + if ( + self.context is None + and self.backend is None + and name not in {"set_backend", "get_server_info"} + ): + raise ValueError( + "No backend configured. Call set_backend(mode='local'|'cloud', ...) first." + ) - if self.mode != RecceServerMode.server and name in blocked_tools_in_non_server: + if ( + self.mode != RecceServerMode.server + and name in blocked_tools_in_non_server + ): # Allowed tools = all registered minus blocked allowed_tools = sorted( { @@ -1249,7 +1314,11 @@ async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: if name == "set_backend": result = await self._tool_set_backend(arguments) - elif name == "get_server_info" and self.context is None and self.backend is None: + elif ( + name == "get_server_info" + and self.context is None + and self.backend is None + ): result = { "mode": "none", "configured": False, @@ -1304,7 +1373,9 @@ async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: logger.info(f"[MCP] Tool response for {name} ({duration_ms:.2f}ms):") # Truncate large responses for console readability if len(response_json) > 1000: - logger.debug(f"[MCP] {response_json[:1000]}... (truncated, {len(response_json)} chars total)") + logger.debug( + f"[MCP] {response_json[:1000]}... (truncated, {len(response_json)} chars total)" + ) else: logger.debug(f"[MCP] {response_json}") @@ -1312,9 +1383,13 @@ async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: except Exception as e: duration_ms = (time.perf_counter() - start_time) * 1000 error_msg = str(e) - self.mcp_logger.log_tool_call(name, log_arguments, {}, duration_ms, error=error_msg) + self.mcp_logger.log_tool_call( + name, log_arguments, {}, duration_ms, error=error_msg + ) - is_expected_cloud_error = isinstance(e, (RecceCloudException, InstanceSpawningError)) + is_expected_cloud_error = isinstance( + e, (RecceCloudException, InstanceSpawningError) + ) classification = self._classify_db_error(error_msg) if classification: logger.warning( @@ -1327,9 +1402,13 @@ async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: attributes={"tool": name, "error_type": classification}, ) elif is_expected_cloud_error: - logger.warning(f"[MCP] Expected cloud error in tool {name} ({duration_ms:.2f}ms): {error_msg}") + logger.warning( + f"[MCP] Expected cloud error in tool {name} ({duration_ms:.2f}ms): {error_msg}" + ) else: - logger.error(f"[MCP] Error executing tool {name} ({duration_ms:.2f}ms): {error_msg}") + logger.error( + f"[MCP] Error executing tool {name} ({duration_ms:.2f}ms): {error_msg}" + ) logger.exception("[MCP] Full traceback:") # Re-raise so MCP SDK sets isError=True in the protocol response @@ -1579,7 +1658,9 @@ async def _tool_value_diff(self, arguments: Dict[str, Any]) -> Dict[str, Any]: result = result.model_dump(mode="json") return self._maybe_add_single_env_warning(result) - async def _tool_value_diff_detail(self, arguments: Dict[str, Any]) -> Dict[str, Any]: + async def _tool_value_diff_detail( + self, arguments: Dict[str, Any] + ) -> Dict[str, Any]: """Execute value diff detail task""" task = ValueDiffDetailTask(params=arguments) result = await asyncio.get_event_loop().run_in_executor(None, task.execute) @@ -1617,7 +1698,9 @@ async def _tool_histogram_diff(self, arguments: Dict[str, Any]) -> Dict[str, Any if not col_info: col_info = columns.get(column_name.lower()) if not col_info or not col_info.get("type"): - raise ValueError(f"Cannot determine column type for '{column_name}' in model '{model}'") + raise ValueError( + f"Cannot determine column type for '{column_name}' in model '{model}'" + ) params = {**arguments, "column_type": col_info["type"]} task = HistogramDiffTask(params=params) @@ -1675,7 +1758,10 @@ async def _tool_impact_analysis(self, arguments: Dict[str, Any]) -> Dict[str, An model_entry = { "name": name, "change_status": ( - change_status if node_id in modified_node_ids or change_status in ("added", "removed") else None + change_status + if node_id in modified_node_ids + or change_status in ("added", "removed") + else None ), "materialized": materialized, "row_count": None, @@ -1687,12 +1773,16 @@ async def _tool_impact_analysis(self, arguments: Dict[str, Any]) -> Dict[str, An not_impacted_models.append(name) # Step 2a: Row count diff (skip removed models; include views for delta detection) - countable_models = [m for m in impacted_models if m["change_status"] != "removed"] + countable_models = [ + m for m in impacted_models if m["change_status"] != "removed" + ] if countable_models: countable_names = [m["name"] for m in countable_models] try: task = RowCountDiffTask(params={"node_names": countable_names}) - row_count_result = await asyncio.get_event_loop().run_in_executor(None, task.execute) + row_count_result = await asyncio.get_event_loop().run_in_executor( + None, task.execute + ) for model in countable_models: name = model["name"] @@ -1707,7 +1797,9 @@ async def _tool_impact_analysis(self, arguments: Dict[str, Any]) -> Dict[str, An "base": base, "current": curr, "delta": delta, - "delta_pct": round(delta_pct, 1) if delta_pct is not None else None, + "delta_pct": round(delta_pct, 1) + if delta_pct is not None + else None, } elif curr is not None: # Added model (no base) @@ -1738,7 +1830,9 @@ async def _tool_impact_analysis(self, arguments: Dict[str, Any]) -> Dict[str, An continue base_cols = set(base_nodes.get(node_id, {}).get("columns", {}).keys()) - curr_cols = set(current_nodes.get(node_id, {}).get("columns", {}).keys()) + curr_cols = set( + current_nodes.get(node_id, {}).get("columns", {}).keys() + ) changes = [] for col in curr_cols - base_cols: @@ -1778,8 +1872,12 @@ async def _tool_impact_analysis(self, arguments: Dict[str, Any]) -> Dict[str, An continue # only PK column, no value diff to compute # Build relations for base and current schemas - base_rel = self.context.adapter.create_relation(model["name"], base=True) - curr_rel = self.context.adapter.create_relation(model["name"], base=False) + base_rel = self.context.adapter.create_relation( + model["name"], base=True + ) + curr_rel = self.context.adapter.create_relation( + model["name"], base=False + ) if not base_rel or not curr_rel: continue @@ -1815,10 +1913,20 @@ async def _tool_impact_analysis(self, arguments: Dict[str, Any]) -> Dict[str, An f'COUNT(CASE WHEN b."{pk}" IS NOT NULL AND c."{pk}" IS NOT NULL ' f'AND b."{col}" IS DISTINCT FROM c."{col}" THEN 1 END) AS "{col}__changed"' ) - col_type = columns_info[col].get("type", "").upper().split("(")[0].strip() + col_type = ( + columns_info[col] + .get("type", "") + .upper() + .split("(")[0] + .strip() + ) if col_type in numeric_types: - per_col_parts.append(f'AVG(b."{col}") AS "{col}__base_mean"') - per_col_parts.append(f'AVG(c."{col}") AS "{col}__curr_mean"') + per_col_parts.append( + f'AVG(b."{col}") AS "{col}__base_mean"' + ) + per_col_parts.append( + f'AVG(c."{col}") AS "{col}__curr_mean"' + ) per_col_sql = ",\n ".join(per_col_parts) @@ -1854,14 +1962,24 @@ def _run_value_diff_query(adapter, query): for col in non_pk_cols: col_changed = int(row[col_idx]) col_idx += 1 - col_type = columns_info[col].get("type", "").upper().split("(")[0].strip() + col_type = ( + columns_info[col] + .get("type", "") + .upper() + .split("(")[0] + .strip() + ) base_mean = None current_mean = None if col_type in numeric_types: raw_base = row[col_idx] raw_curr = row[col_idx + 1] - base_mean = float(raw_base) if raw_base is not None else None - current_mean = float(raw_curr) if raw_curr is not None else None + base_mean = ( + float(raw_base) if raw_base is not None else None + ) + current_mean = ( + float(raw_curr) if raw_curr is not None else None + ) col_idx += 2 columns_result[col] = { "affected_row_count": col_changed, @@ -1892,7 +2010,10 @@ def _run_value_diff_query(adapter, query): # affected_row_count: value_diff total (priority) or abs(row_count.delta) (fallback) if model["value_diff"] is not None: model["affected_row_count"] = model["value_diff"]["affected_row_count"] - elif model["row_count"] is not None and model["row_count"].get("delta") is not None: + elif ( + model["row_count"] is not None + and model["row_count"].get("delta") is not None + ): model["affected_row_count"] = abs(model["row_count"]["delta"]) else: model["affected_row_count"] = None @@ -1911,7 +2032,10 @@ def _run_value_diff_query(adapter, query): if model["data_impact"] == "potential": model["affected_row_count"] = None - if model["affected_row_count"] is not None and model["affected_row_count"] > max_affected: + if ( + model["affected_row_count"] is not None + and model["affected_row_count"] > max_affected + ): max_affected = model["affected_row_count"] # next_action: only for "potential" models — confirmed/none need no follow-up @@ -1961,7 +2085,9 @@ def _run_value_diff_query(adapter, query): and model["row_count"]["delta_pct"] is not None and abs(model["row_count"]["delta_pct"]) <= 5 ): - total_matched = (model["row_count"]["current"] or 0) - vd["rows_added"] + total_matched = (model["row_count"]["current"] or 0) - vd[ + "rows_added" + ] if total_matched > 0 and vd["rows_changed"] / total_matched > 0.2: top_cols = [ col @@ -1977,8 +2103,12 @@ def _run_value_diff_query(adapter, query): if sentry_metrics: duration = time.time() - start_time - sentry_metrics.distribution("mcp.impact_analysis.duration", duration, unit="second") - sentry_metrics.distribution("mcp.impact_analysis.impacted_count", len(impacted_models)) + sentry_metrics.distribution( + "mcp.impact_analysis.duration", duration, unit="second" + ) + sentry_metrics.distribution( + "mcp.impact_analysis.impacted_count", len(impacted_models) + ) result = { "_guidance": ( @@ -2039,6 +2169,13 @@ async def _tool_get_server_info(self, arguments: Dict[str, Any]) -> Dict[str, An "single_env": self.single_env, } + # Include base_status so agents can programmatically detect stale state. + # Values: "FRESH" | "STALE_TIME" | "STALE_SHA" | "MISSING" | "single_env" | "unknown" + if self.single_env: + result["base_status"] = "single_env" + else: + result["base_status"] = getattr(self, "_base_status", "unknown") + # Add git and pull_request info if state_loader is available if context.state_loader: try: @@ -2079,19 +2216,27 @@ async def _tool_set_backend(self, arguments: Dict[str, Any]) -> Dict[str, Any]: api_token = get_recce_api_token() if not api_token: - raise ValueError("Recce Cloud API token not found. Run `recce connect-to-cloud` first.") + raise ValueError( + "Recce Cloud API token not found. Run `recce connect-to-cloud` first." + ) # Best-effort export of local state before swapping away. if self.context is not None and self.state_loader is not None: try: self.state_loader.export(self.context.export_state()) except Exception as e: - logger.warning(f"[MCP] Failed to export local state on swap to cloud: {e}") + logger.warning( + f"[MCP] Failed to export local state on swap to cloud: {e}" + ) - new_backend = await CloudBackend.create(session_id=session_id, api_token=api_token) + new_backend = await CloudBackend.create( + session_id=session_id, api_token=api_token + ) self.backend = new_backend self.api_token = api_token - logger.info(f"[MCP] Backend switched to cloud (session_id={session_id})") + logger.info( + f"[MCP] Backend switched to cloud (session_id={session_id})" + ) return { "mode": "cloud", "session_id": session_id, @@ -2122,13 +2267,18 @@ async def _tool_set_backend(self, arguments: Dict[str, Any]) -> Dict[str, Any]: else: self.single_env = not base_path.is_dir() - load_kwargs = {"target_path": target_path, "target_base_path": effective_base} + load_kwargs = { + "target_path": target_path, + "target_base_path": effective_base, + } if project_dir: load_kwargs["project_dir"] = project_dir self.context = load_context(**load_kwargs) self._local_cache_key = cache_key - logger.info(f"[MCP] Loaded local context (project_dir={project_dir}, single_env={self.single_env})") + logger.info( + f"[MCP] Loaded local context (project_dir={project_dir}, single_env={self.single_env})" + ) self.backend = None return { @@ -2269,7 +2419,9 @@ async def _tool_run_check(self, arguments: Dict[str, Any]) -> Dict[str, Any]: if run_succeeded: check_dao.update_check_by_id(check_id, PatchCheckIn(is_checked=True)) logger.info(f"Auto-approved check {check_id} (triggered_by={triggered_by})") - await asyncio.get_event_loop().run_in_executor(None, export_persistent_state) + await asyncio.get_event_loop().run_in_executor( + None, export_persistent_state + ) return run_dump @@ -2336,7 +2488,9 @@ async def _tool_create_check(self, arguments: Dict[str, Any]) -> Dict[str, Any]: except Exception as e: run_error = str(e) else: - run, future = submit_run(check_type, params=params, check_id=check_id, triggered_by=triggered_by) + run, future = submit_run( + check_type, params=params, check_id=check_id, triggered_by=triggered_by + ) await future # submit_run's future always resolves (errors caught internally). # Check run.status, not the return value. @@ -2387,12 +2541,17 @@ async def run(self): if msg is not None: console.print(f"[yellow]On shutdown:[/yellow] {msg}") else: - if hasattr(self.state_loader, "state_file") and self.state_loader.state_file: + if ( + hasattr(self.state_loader, "state_file") + and self.state_loader.state_file + ): console.print( f"[yellow]On shutdown:[/yellow] State exported to '{self.state_loader.state_file}'" ) else: - console.print("[yellow]On shutdown:[/yellow] State exported successfully") + console.print( + "[yellow]On shutdown:[/yellow] State exported successfully" + ) except Exception as e: logger.exception(f"Failed to export state on shutdown: {e}") @@ -2417,10 +2576,16 @@ async def run_sse(self, host: str = "localhost", port: int = 8000): async def handle_sse_request(request: Request): """Handle SSE connection (GET /sse) following official MCP example""" - client_info = f"{request.client.host}:{request.client.port}" if request.client else "unknown" + client_info = ( + f"{request.client.host}:{request.client.port}" + if request.client + else "unknown" + ) logger.info(f"[MCP HTTP] SSE connection established from {client_info}") try: - async with sse.connect_sse(request.scope, request.receive, request._send) as streams: + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: await self.server.run( streams[0], streams[1], @@ -2536,7 +2701,9 @@ async def run_mcp_server( if not session: raise ValueError("--session is required when --cloud is provided") if not api_token: - raise ValueError("Recce Cloud API token not found. Run `recce connect-to-cloud` first.") + raise ValueError( + "Recce Cloud API token not found. Run `recce connect-to-cloud` first." + ) backend = await CloudBackend.create(session_id=session, api_token=api_token) server = RecceMCPServer( @@ -2571,6 +2738,32 @@ async def run_mcp_server( ) server._local_cache_key = cache_key + # Freshness check (M2, AC-3): warn on stale base artifacts at startup. + # Lazy import to avoid circular import; best-effort — startup never fails here. + try: + from recce.cli import check_base_freshness + + _tb = kwargs.get("target_base_path", "target-base") + _tp = kwargs.get("target_path", "target") + _freshness = check_base_freshness( + target_base_path=_tb, + target_path=_tp, + ) + server._base_status = _freshness.get("status", "FRESH") + if server._base_status in ("STALE_TIME", "STALE_SHA"): + _age = _freshness.get("artifact_age_hours") or 0 + import sys + + print( + f"[Warning] Base artifacts are stale ({_age:.1f} hours old). " + "Diffs may not reflect the latest base.\n" + "Run: dbt docs generate --target-path target-base" + " (or use recce check-base for full diagnosis)", + file=sys.stderr, + ) + except Exception as _e: + logger.debug(f"[MCP] Base freshness check skipped: {_e}") + # Run in either stdio or SSE mode if sse: # SSE mode: lifespan handler in Starlette manages shutdown and state export diff --git a/tests/test_check_base.py b/tests/test_check_base.py new file mode 100644 index 000000000..53ad53274 --- /dev/null +++ b/tests/test_check_base.py @@ -0,0 +1,124 @@ +""" +Tests for check_base_freshness() helper (M2, AC-3). + +Coverage: + - test_status_fresh — mtime within threshold → FRESH + - test_status_stale_time — mtime > 48 h → STALE_TIME + - test_status_stale_sha — SHA mismatch → STALE_SHA + - test_status_missing — no manifest.json → MISSING + - test_sha_absent_no_raise — R9 best-effort: missing DBT_GIT_SHA field → FRESH, no exception +""" + +import json +import time +from unittest.mock import patch + +import pytest + +from recce.cli import check_base_freshness + + +@pytest.fixture() +def fresh_manifest_dir(tmp_path): + """Create a target-base/ directory with a freshly-written manifest.json.""" + target_base = tmp_path / "target-base" + target_base.mkdir() + manifest = { + "metadata": { + "generated_at": "2024-01-01T00:00:00.000000Z", + "env": { + "DBT_GIT_SHA": "abc1234def5678901234567890123456789012ab", + }, + } + } + (target_base / "manifest.json").write_text(json.dumps(manifest)) + return target_base + + +@pytest.fixture() +def old_manifest_dir(tmp_path): + """Create a target-base/ directory with a manifest.json whose mtime is 73 h ago.""" + target_base = tmp_path / "target-base" + target_base.mkdir() + manifest_path = target_base / "manifest.json" + manifest_path.write_text(json.dumps({"metadata": {"env": {}}})) + # Back-date mtime by 73 hours + old_time = time.time() - (73 * 3600) + import os + + os.utime(manifest_path, (old_time, old_time)) + return target_base + + +def test_status_fresh(fresh_manifest_dir): + """Manifest mtime within threshold and matching SHA → FRESH.""" + manifest_sha = "abc1234def5678901234567890123456789012ab" + with patch("recce.git.current_commit_hash", return_value=manifest_sha): + result = check_base_freshness( + target_base_path=str(fresh_manifest_dir), + freshness_threshold_hours=48.0, + ) + assert result["status"] == "FRESH" + assert result["recommendation"] == "reuse" + assert result["artifact_age_hours"] is not None + assert result["artifact_age_hours"] < 48.0 + + +def test_status_stale_time(old_manifest_dir): + """Manifest mtime > 48 h threshold → STALE_TIME, message contains 'stale'.""" + result = check_base_freshness( + target_base_path=str(old_manifest_dir), + freshness_threshold_hours=48.0, + ) + assert result["status"] == "STALE_TIME" + assert result["recommendation"] == "docs_generate" + assert "stale" in result["message"].lower() + assert result["artifact_age_hours"] > 48.0 + + +def test_status_stale_sha(fresh_manifest_dir): + """SHA in manifest differs from current HEAD → STALE_SHA, message contains 'stale'.""" + different_sha = "9999999deadbeef0000000000000000000000000" + with patch("recce.git.current_commit_hash", return_value=different_sha): + result = check_base_freshness( + target_base_path=str(fresh_manifest_dir), + freshness_threshold_hours=48.0, + ) + assert result["status"] == "STALE_SHA" + assert result["recommendation"] == "docs_generate" + assert "stale" in result["message"].lower() + + +def test_status_missing(tmp_path): + """No manifest.json in target_base_path → MISSING, recommendation full_build.""" + non_existent = tmp_path / "target-base-empty" + result = check_base_freshness( + target_base_path=str(non_existent), + freshness_threshold_hours=48.0, + ) + assert result["status"] == "MISSING" + assert result["recommendation"] == "full_build" + assert result["artifact_age_hours"] is None + + +def test_sha_absent_no_raise(tmp_path): + """R9 best-effort: DBT_GIT_SHA field absent in manifest → FRESH, no exception raised.""" + target_base = tmp_path / "target-base" + target_base.mkdir() + # Manifest has no DBT_GIT_SHA in metadata.env + manifest_no_sha = { + "metadata": { + "generated_at": "2024-01-01T00:00:00.000000Z", + "env": {}, # DBT_GIT_SHA is absent + } + } + (target_base / "manifest.json").write_text(json.dumps(manifest_no_sha)) + + # Should not raise, and should fall through to FRESH (time check passes) + result = check_base_freshness( + target_base_path=str(target_base), + freshness_threshold_hours=48.0, + ) + # With no DBT_GIT_SHA, SHA check is skipped → FRESH (time check passed) + assert result["status"] == "FRESH" + assert result["base_sha"] is None From ef1fa71194bb2a4e30c1c350ce13a20d70bfafa7 Mon Sep 17 00:00:00 2001 From: even-wei Date: Thu, 7 May 2026 15:48:45 +0800 Subject: [PATCH 2/3] fix(cli,mcp): address PR #1353 review feedback Critical: - base_status enum: standardize on all-lowercase ("fresh" / "stale_time" / "stale_sha" / "missing" / "single_env" / "unknown"). Document the full enum in get_server_info MCP tool description so LLM agents have a stable contract. - check-base: honor --project-dir (and DBT_PROJECT_DIR envvar) like every other dbt-aware command. Resolves target-base-path relative to project-dir unless absolute. Warnings: - MCP startup now also warns on `missing` (not just stale_*), with a distinct rebuild-path message; `print(..., file=sys.stderr)` swapped for `logger.warning` to match the rest of mcp_server.py. - check-base exit codes split: 0 fresh, 1 missing, 2 stale_*. Documented in the docstring so shell automation can branch without parsing JSON. - Suggested-fix messages now interpolate target_base_path into the --target-path flag (no more hardcoded "target-base" misleading users on non-default paths). - Add 6 CliRunner tests covering JSON schema, text rendering, exit-code mapping per status, and --project-dir resolution. Nitpicks: - Drop unused target_path parameter from check_base_freshness(). - Bare `except Exception:` now logs at debug for diagnosability. Refs: #1353 Co-Authored-By: Claude Sonnet 4.5 Signed-off-by: even-wei --- recce/cli.py | 80 ++++++++++------- recce/mcp_server.py | 48 ++++++---- tests/test_check_base.py | 187 ++++++++++++++++++++++++++++++++++----- 3 files changed, 246 insertions(+), 69 deletions(-) diff --git a/recce/cli.py b/recce/cli.py index 470e2beb6..44d142cb0 100644 --- a/recce/cli.py +++ b/recce/cli.py @@ -3110,15 +3110,14 @@ def clear_cache(cache_db): def check_base_freshness( target_base_path: str = "target-base", - target_path: str = "target", freshness_threshold_hours: float = 48.0, ) -> dict: """ Check whether the base artifacts in target_base_path are fresh. Returns a dict with keys: - status: FRESH | STALE_TIME | STALE_SHA | MISSING - recommendation: reuse | docs_generate | full_build + status: "fresh" | "stale_time" | "stale_sha" | "missing" + recommendation: "reuse" | "docs_generate" | "full_build" message: human-readable explanation artifact_age_hours: float or None base_sha: str or None (DBT_GIT_SHA from manifest metadata) @@ -3126,8 +3125,11 @@ def check_base_freshness( threshold_hours: float """ import json + import logging import time + _logger = logging.getLogger("recce") + manifest_path = Path(target_base_path) / "manifest.json" result: dict = { "status": None, @@ -3140,11 +3142,11 @@ def check_base_freshness( } if not manifest_path.exists(): - result["status"] = "MISSING" + result["status"] = "missing" result["recommendation"] = "full_build" result["message"] = ( f"Base artifacts not found at '{target_base_path}/manifest.json'. " - "Run: git stash; git checkout ; dbt build --target-path target-base; " + f"Run: git stash; git checkout ; dbt build --target-path {target_base_path}; " "git checkout ; git stash pop" ) return result @@ -3157,12 +3159,12 @@ def check_base_freshness( # Time-based freshness check if artifact_age_hours > freshness_threshold_hours: - result["status"] = "STALE_TIME" + result["status"] = "stale_time" result["recommendation"] = "docs_generate" result["message"] = ( f"Base artifacts are stale ({artifact_age_hours:.1f} hours old, " f"threshold: {freshness_threshold_hours}h). " - "Run: dbt docs generate --target-path target-base" + f"Run: dbt docs generate --target-path {target_base_path}" ) return result @@ -3179,19 +3181,20 @@ def check_base_freshness( current_sha = current_commit_hash() result["current_sha"] = current_sha if current_sha and base_sha != current_sha: - result["status"] = "STALE_SHA" + result["status"] = "stale_sha" result["recommendation"] = "docs_generate" result["message"] = ( f"Base artifacts are stale (generated at {base_sha[:7]}, " f"current HEAD: {current_sha[:7]}). " - "Run: dbt docs generate --target-path target-base" + f"Run: dbt docs generate --target-path {target_base_path}" ) return result - except Exception: - # Best-effort: if manifest is unreadable or git is unavailable, skip SHA check - pass + except Exception as e: + # Best-effort: if manifest is unreadable or git is unavailable, skip SHA check. + # Log at debug so the reason is recoverable without surfacing in normal output. + _logger.debug(f"check_base_freshness: SHA check skipped ({e})") - result["status"] = "FRESH" + result["status"] = "fresh" result["recommendation"] = "reuse" result["message"] = ( f"Base artifacts are fresh ({artifact_age_hours:.1f} hours old). " @@ -3202,17 +3205,17 @@ def check_base_freshness( @cli.command(name="check-base", cls=TrackCommand) @click.option( - "--target-base-path", - default="target-base", - show_default=True, - help="Path to the base artifacts directory.", + "--project-dir", + help="Which directory to look in for the dbt_project.yml file. " + "target-base-path is resolved relative to this when not absolute.", type=click.Path(), + envvar="DBT_PROJECT_DIR", ) @click.option( - "--target-path", - default="target", + "--target-base-path", + default="target-base", show_default=True, - help="Path to the current target artifacts directory.", + help="Path to the base artifacts directory (relative to --project-dir if not absolute).", type=click.Path(), ) @click.option( @@ -3228,18 +3231,29 @@ def check_base_freshness( default=48.0, show_default=True, type=float, - help="Age threshold in hours after which artifacts are considered STALE_TIME.", + help="Age threshold in hours after which artifacts are considered stale_time.", ) -def check_base(target_base_path, target_path, output_format, freshness_threshold_hours): +def check_base(project_dir, target_base_path, output_format, freshness_threshold_hours): """Check freshness of base artifacts for diff operations. - Exits 0 when status is FRESH; exits 1 when STALE_TIME, STALE_SHA, or MISSING. + Exit codes: + 0 - fresh + 1 - missing (rebuild required: full_build) + 2 - stale_time / stale_sha (regenerate when convenient: docs_generate) """ import json + # Honor --project-dir / DBT_PROJECT_DIR like every other dbt-aware command. + # An absolute target-base-path bypasses the join. + project_dir_path = Path(project_dir) if project_dir else Path("./") + resolved_target_base = ( + Path(target_base_path) + if Path(target_base_path).is_absolute() + else project_dir_path / target_base_path + ) + result = check_base_freshness( - target_base_path=target_base_path, - target_path=target_path, + target_base_path=str(resolved_target_base), freshness_threshold_hours=freshness_threshold_hours, ) @@ -3253,10 +3267,10 @@ def check_base(target_base_path, target_path, output_format, freshness_threshold age = result.get("artifact_age_hours") msg = result.get("message", "") color_map = { - "FRESH": "green", - "STALE_TIME": "yellow", - "STALE_SHA": "yellow", - "MISSING": "red", + "fresh": "green", + "stale_time": "yellow", + "stale_sha": "yellow", + "missing": "red", } color = color_map.get(status, "white") console.print(f"[{color}]Status: {status}[/{color}]") @@ -3266,8 +3280,14 @@ def check_base(target_base_path, target_path, output_format, freshness_threshold if msg: console.print(msg) - if result["status"] != "FRESH": + # Distinct exit codes so shell automation can branch without parsing JSON. + status = result["status"] + if status == "fresh": + return + if status == "missing": raise SystemExit(1) + # stale_time, stale_sha (and any future non-fresh status) + raise SystemExit(2) if __name__ == "__main__": diff --git a/recce/mcp_server.py b/recce/mcp_server.py index 42a73a041..61902fd4e 100644 --- a/recce/mcp_server.py +++ b/recce/mcp_server.py @@ -738,7 +738,15 @@ async def list_tools() -> List[Tool]: description="Get server context information including current backend mode " "('local', 'cloud', or 'none' when unconfigured), adapter type, git branch, " "supported tasks, and review mode status. Useful for diagnostics and " - "understanding which diff tools are available.", + "understanding which diff tools are available.\n\n" + "The 'base_status' field is a single-cased enum (all lowercase) reporting " + "whether the local target-base/ artifacts are usable for diff operations:\n" + " - 'fresh': artifacts are within the freshness threshold and SHA matches; safe to diff.\n" + " - 'stale_time': artifacts older than the freshness threshold; diff results may be outdated.\n" + " - 'stale_sha': artifacts were generated against a different git SHA than current HEAD.\n" + " - 'missing': no manifest.json under target-base/; diffs will fail until rebuilt.\n" + " - 'single_env': server is running in single-env mode; diff is not applicable.\n" + " - 'unknown': freshness check did not run (e.g., cloud mode or check skipped).", inputSchema={ "type": "object", "properties": {}, @@ -2170,7 +2178,8 @@ async def _tool_get_server_info(self, arguments: Dict[str, Any]) -> Dict[str, An } # Include base_status so agents can programmatically detect stale state. - # Values: "FRESH" | "STALE_TIME" | "STALE_SHA" | "MISSING" | "single_env" | "unknown" + # All-lowercase enum (matches the recommendation field's casing): + # "fresh" | "stale_time" | "stale_sha" | "missing" | "single_env" | "unknown" if self.single_env: result["base_status"] = "single_env" else: @@ -2738,28 +2747,31 @@ async def run_mcp_server( ) server._local_cache_key = cache_key - # Freshness check (M2, AC-3): warn on stale base artifacts at startup. + # Freshness check (M2, AC-3): warn on stale or missing base artifacts at startup. # Lazy import to avoid circular import; best-effort — startup never fails here. try: from recce.cli import check_base_freshness _tb = kwargs.get("target_base_path", "target-base") - _tp = kwargs.get("target_path", "target") - _freshness = check_base_freshness( - target_base_path=_tb, - target_path=_tp, - ) - server._base_status = _freshness.get("status", "FRESH") - if server._base_status in ("STALE_TIME", "STALE_SHA"): + _freshness = check_base_freshness(target_base_path=_tb) + server._base_status = _freshness.get("status", "fresh") + if server._base_status in ("stale_time", "stale_sha"): _age = _freshness.get("artifact_age_hours") or 0 - import sys - - print( - f"[Warning] Base artifacts are stale ({_age:.1f} hours old). " - "Diffs may not reflect the latest base.\n" - "Run: dbt docs generate --target-path target-base" - " (or use recce check-base for full diagnosis)", - file=sys.stderr, + logger.warning( + f"Base artifacts are stale ({_age:.1f} hours old). " + "Diffs may not reflect the latest base. " + f"Run: dbt docs generate --target-path {_tb} " + "(or use `recce check-base` for full diagnosis)" + ) + elif server._base_status == "missing": + # MISSING is the most actionable failure of the three: diffs + # will fail outright, not silently mislead. Surface it loudly + # alongside the rebuild path. + logger.warning( + f"Base artifacts not found at '{_tb}/manifest.json'. " + "Diff tools will fail until base artifacts are built. " + f"Run: dbt build --target-path {_tb} " + "(or use `recce check-base` for full diagnosis)" ) except Exception as _e: logger.debug(f"[MCP] Base freshness check skipped: {_e}") diff --git a/tests/test_check_base.py b/tests/test_check_base.py index 53ad53274..f98abf355 100644 --- a/tests/test_check_base.py +++ b/tests/test_check_base.py @@ -1,21 +1,32 @@ """ -Tests for check_base_freshness() helper (M2, AC-3). +Tests for check_base_freshness() helper and the `recce check-base` CLI (M2, AC-3). Coverage: - - test_status_fresh — mtime within threshold → FRESH - - test_status_stale_time — mtime > 48 h → STALE_TIME - - test_status_stale_sha — SHA mismatch → STALE_SHA - - test_status_missing — no manifest.json → MISSING - - test_sha_absent_no_raise — R9 best-effort: missing DBT_GIT_SHA field → FRESH, no exception + Helper (`check_base_freshness()`): + - test_status_fresh — mtime within threshold → fresh + - test_status_stale_time — mtime > 48 h → stale_time + - test_status_stale_sha — SHA mismatch → stale_sha + - test_status_missing — no manifest.json → missing + - test_sha_absent_no_raise — R9 best-effort: missing DBT_GIT_SHA field → fresh + + CLI (`recce check-base`): + - test_cli_json_schema_fresh — JSON shape includes the documented fields + - test_cli_text_format_renders — --format text prints status line + - test_cli_exit_code_fresh — fresh → exit 0 + - test_cli_exit_code_missing — missing → exit 1 + - test_cli_exit_code_stale_time — stale_time → exit 2 + - test_cli_project_dir_resolves — --project-dir joins onto target-base-path """ import json +import os import time from unittest.mock import patch import pytest +from click.testing import CliRunner -from recce.cli import check_base_freshness +from recce.cli import check_base, check_base_freshness @pytest.fixture() @@ -44,65 +55,71 @@ def old_manifest_dir(tmp_path): manifest_path.write_text(json.dumps({"metadata": {"env": {}}})) # Back-date mtime by 73 hours old_time = time.time() - (73 * 3600) - import os - os.utime(manifest_path, (old_time, old_time)) return target_base +# --------------------------------------------------------------------------- +# Helper tests — exercise check_base_freshness() directly +# --------------------------------------------------------------------------- + + def test_status_fresh(fresh_manifest_dir): - """Manifest mtime within threshold and matching SHA → FRESH.""" + """Manifest mtime within threshold and matching SHA → fresh.""" manifest_sha = "abc1234def5678901234567890123456789012ab" with patch("recce.git.current_commit_hash", return_value=manifest_sha): result = check_base_freshness( target_base_path=str(fresh_manifest_dir), freshness_threshold_hours=48.0, ) - assert result["status"] == "FRESH" + assert result["status"] == "fresh" assert result["recommendation"] == "reuse" assert result["artifact_age_hours"] is not None assert result["artifact_age_hours"] < 48.0 def test_status_stale_time(old_manifest_dir): - """Manifest mtime > 48 h threshold → STALE_TIME, message contains 'stale'.""" + """Manifest mtime > 48 h threshold → stale_time, message contains 'stale'.""" result = check_base_freshness( target_base_path=str(old_manifest_dir), freshness_threshold_hours=48.0, ) - assert result["status"] == "STALE_TIME" + assert result["status"] == "stale_time" assert result["recommendation"] == "docs_generate" assert "stale" in result["message"].lower() assert result["artifact_age_hours"] > 48.0 def test_status_stale_sha(fresh_manifest_dir): - """SHA in manifest differs from current HEAD → STALE_SHA, message contains 'stale'.""" + """SHA in manifest differs from current HEAD → stale_sha, message contains 'stale'.""" different_sha = "9999999deadbeef0000000000000000000000000" with patch("recce.git.current_commit_hash", return_value=different_sha): result = check_base_freshness( target_base_path=str(fresh_manifest_dir), freshness_threshold_hours=48.0, ) - assert result["status"] == "STALE_SHA" + assert result["status"] == "stale_sha" assert result["recommendation"] == "docs_generate" assert "stale" in result["message"].lower() def test_status_missing(tmp_path): - """No manifest.json in target_base_path → MISSING, recommendation full_build.""" + """No manifest.json in target_base_path → missing, recommendation full_build.""" non_existent = tmp_path / "target-base-empty" result = check_base_freshness( target_base_path=str(non_existent), freshness_threshold_hours=48.0, ) - assert result["status"] == "MISSING" + assert result["status"] == "missing" assert result["recommendation"] == "full_build" assert result["artifact_age_hours"] is None + # Suggested command must reference the user-supplied target_base_path, + # not the hardcoded default. + assert str(non_existent) in result["message"] def test_sha_absent_no_raise(tmp_path): - """R9 best-effort: DBT_GIT_SHA field absent in manifest → FRESH, no exception raised.""" + """R9 best-effort: DBT_GIT_SHA field absent in manifest → fresh, no exception raised.""" target_base = tmp_path / "target-base" target_base.mkdir() # Manifest has no DBT_GIT_SHA in metadata.env @@ -114,11 +131,139 @@ def test_sha_absent_no_raise(tmp_path): } (target_base / "manifest.json").write_text(json.dumps(manifest_no_sha)) - # Should not raise, and should fall through to FRESH (time check passes) + # Should not raise, and should fall through to fresh (time check passes) result = check_base_freshness( target_base_path=str(target_base), freshness_threshold_hours=48.0, ) - # With no DBT_GIT_SHA, SHA check is skipped → FRESH (time check passed) - assert result["status"] == "FRESH" + # With no DBT_GIT_SHA, SHA check is skipped → fresh (time check passed) + assert result["status"] == "fresh" assert result["base_sha"] is None + + +# --------------------------------------------------------------------------- +# CLI tests — exercise `recce check-base` end-to-end via Click's CliRunner +# --------------------------------------------------------------------------- + + +def test_cli_json_schema_fresh(fresh_manifest_dir): + """--format json (default) emits all documented fields and lowercase status.""" + manifest_sha = "abc1234def5678901234567890123456789012ab" + runner = CliRunner() + with patch("recce.git.current_commit_hash", return_value=manifest_sha): + result = runner.invoke( + check_base, + ["--target-base-path", str(fresh_manifest_dir)], + ) + assert result.exit_code == 0, result.output + payload = json.loads(result.output) + # Documented schema — every key must be present. + expected_keys = { + "status", + "recommendation", + "message", + "artifact_age_hours", + "base_sha", + "current_sha", + "threshold_hours", + } + assert expected_keys.issubset(payload.keys()) + assert payload["status"] == "fresh" + assert payload["recommendation"] == "reuse" + + +def test_cli_text_format_renders(fresh_manifest_dir): + """--format text emits a human-readable status line.""" + manifest_sha = "abc1234def5678901234567890123456789012ab" + runner = CliRunner() + with patch("recce.git.current_commit_hash", return_value=manifest_sha): + result = runner.invoke( + check_base, + [ + "--target-base-path", + str(fresh_manifest_dir), + "--format", + "text", + ], + ) + assert result.exit_code == 0, result.output + assert "Status: fresh" in result.output + assert "Recommendation: reuse" in result.output + + +def test_cli_exit_code_fresh(fresh_manifest_dir): + """fresh → exit 0 (no SystemExit raised).""" + manifest_sha = "abc1234def5678901234567890123456789012ab" + runner = CliRunner() + with patch("recce.git.current_commit_hash", return_value=manifest_sha): + result = runner.invoke( + check_base, + ["--target-base-path", str(fresh_manifest_dir)], + ) + assert result.exit_code == 0 + + +def test_cli_exit_code_missing(tmp_path): + """missing → exit 1 (rebuild required).""" + non_existent = tmp_path / "target-base-empty" + runner = CliRunner() + result = runner.invoke( + check_base, + ["--target-base-path", str(non_existent)], + ) + assert result.exit_code == 1 + payload = json.loads(result.output) + assert payload["status"] == "missing" + assert payload["recommendation"] == "full_build" + + +def test_cli_exit_code_stale_time(old_manifest_dir): + """stale_time → exit 2 (regenerate when convenient).""" + runner = CliRunner() + result = runner.invoke( + check_base, + [ + "--target-base-path", + str(old_manifest_dir), + "--freshness-threshold-hours", + "48", + ], + ) + assert result.exit_code == 2 + payload = json.loads(result.output) + assert payload["status"] == "stale_time" + assert payload["recommendation"] == "docs_generate" + + +def test_cli_project_dir_resolves(tmp_path): + """--project-dir joins onto a relative --target-base-path before resolving.""" + project_dir = tmp_path / "my_dbt_project" + project_dir.mkdir() + target_base = project_dir / "target-base" + target_base.mkdir() + manifest = { + "metadata": { + "env": { + "DBT_GIT_SHA": "abc1234def5678901234567890123456789012ab", + }, + } + } + (target_base / "manifest.json").write_text(json.dumps(manifest)) + + manifest_sha = "abc1234def5678901234567890123456789012ab" + runner = CliRunner() + # Invoke from tmp_path with --project-dir; relative target-base-path "target-base" + # should resolve under project_dir. + with patch("recce.git.current_commit_hash", return_value=manifest_sha): + result = runner.invoke( + check_base, + [ + "--project-dir", + str(project_dir), + "--target-base-path", + "target-base", + ], + ) + assert result.exit_code == 0, result.output + payload = json.loads(result.output) + assert payload["status"] == "fresh" From b50c5b342f548bc332198454372134c62614a5cc Mon Sep 17 00:00:00 2001 From: even-wei Date: Fri, 8 May 2026 16:01:35 +0800 Subject: [PATCH 3/3] fix(mcp): honor --project-dir in MCP startup base freshness check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Round-2 review finding: cli.py was fixed to join --project-dir onto a relative --target-base-path, but the parallel call site in mcp_server.py startup was not updated. So `recce mcp-server --project-dir /foo/bar` looks for ./target-base/manifest.json relative to CWD instead of /foo/bar/target-base/manifest.json — spurious "missing" warnings, or worse, silently picking up a stale manifest from another project that happens to live in CWD. Extracted resolve_target_base_path() next to check_base_freshness in cli.py so the join logic lives in exactly one place. CLI's check_base and MCP's run_mcp_server startup both call the helper, and the resolution can no longer drift across the two call sites. - cli.py: new helper resolve_target_base_path(); check_base uses it - mcp_server.py: startup freshness check uses the helper, joining kwargs["project_dir"] with kwargs["target_base_path"] - tests/test_check_base.py: 4 new tests - test_resolve_relative_joins_with_project_dir - test_resolve_absolute_bypasses_project_dir - test_resolve_no_project_dir_uses_cwd - test_resolve_mcp_startup_finds_artifacts_under_project_dir (mirrors test_cli_project_dir_resolves against the helper) Refs round-2 review: https://github.com/DataRecce/recce/pull/1353#issuecomment-4385928949 Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: even-wei --- recce/cli.py | 445 +++++++++++---------------------------- recce/mcp_server.py | 259 ++++++----------------- tests/test_check_base.py | 65 +++++- 3 files changed, 256 insertions(+), 513 deletions(-) diff --git a/recce/cli.py b/recce/cli.py index 9375d1a33..b839079a2 100644 --- a/recce/cli.py +++ b/recce/cli.py @@ -140,9 +140,7 @@ def _add_options(func): help="Which target to load for the given profile.", type=click.STRING, ), - click.option( - "--profile", help="Which existing profile to load.", type=click.STRING - ), + click.option("--profile", help="Which existing profile to load.", type=click.STRING), click.option( "--project-dir", help="Which directory to look in for the dbt_project.yml file.", @@ -271,9 +269,7 @@ def _execute_sql(context, sql_template, base=False): try: import pandas as pd except ImportError: - print( - "'pandas' package not found. You can install it using the command: 'pip install pandas'." - ) + print("'pandas' package not found. You can install it using the command: 'pip install pandas'.") exit(1) from recce.adapter.dbt_adapter import DbtAdapter @@ -283,9 +279,7 @@ def _execute_sql(context, sql_template, base=False): sql = dbt_adapter.generate_sql(sql_template, base) response, result = dbt_adapter.execute(sql, fetch=True, auto_begin=True) table = result - df = pd.DataFrame( - [row.values() for row in table.rows], columns=table.column_names - ) + df = pd.DataFrame([row.values() for row in table.rows], columns=table.column_names) return df @@ -390,14 +384,10 @@ def init(cache_db, **kwargs): cloud_token = kwargs.get("cloud_token") or kwargs.get("api_token") if not cloud_token: - console.print( - "[[red]Error[/red]] --cloud requires --cloud-token or --api-token (or GITHUB_TOKEN env var)." - ) + console.print("[[red]Error[/red]] --cloud requires --cloud-token or --api-token (or GITHUB_TOKEN env var).") exit(1) if not session_id: - console.print( - "[[red]Error[/red]] --cloud requires --session-id (or RECCE_SESSION_ID env var)." - ) + console.print("[[red]Error[/red]] --cloud requires --session-id (or RECCE_SESSION_ID env var).") exit(1) cloud_client = RecceCloud(token=cloud_token) @@ -415,33 +405,25 @@ def init(cache_db, **kwargs): console.print(f"[[red]Error[/red]] Failed to get session: {e}") exit(1) if session_info.get("status") == "error": - console.print( - f"[[red]Error[/red]] Failed to get session: {session_info.get('message', 'Access denied')}" - ) + console.print(f"[[red]Error[/red]] Failed to get session: {session_info.get('message', 'Access denied')}") exit(1) cloud_org_id = session_info.get("org_id") cloud_project_id = session_info.get("project_id") if not cloud_org_id or not cloud_project_id: - console.print( - f"[[red]Error[/red]] Session {session_id} missing org_id or project_id." - ) + console.print(f"[[red]Error[/red]] Session {session_id} missing org_id or project_id.") exit(1) # Download artifacts to local target directories console.print("Downloading artifacts from Cloud...") try: - download_urls = cloud_client.get_download_urls_by_session_id( - cloud_org_id, cloud_project_id, session_id - ) + download_urls = cloud_client.get_download_urls_by_session_id(cloud_org_id, cloud_project_id, session_id) except RecceCloudException as e: console.print(f"[[red]Error[/red]] Failed to get download URLs: {e}") exit(1) project_dir_path = Path(kwargs.get("project_dir") or "./") target_path = project_dir_path / kwargs.get("target_path", "target") - target_base_path = project_dir_path / kwargs.get( - "target_base_path", "target-base" - ) + target_base_path = project_dir_path / kwargs.get("target_base_path", "target-base") target_path.mkdir(parents=True, exist_ok=True) target_base_path.mkdir(parents=True, exist_ok=True) @@ -462,9 +444,7 @@ def init(cache_db, **kwargs): f" [[yellow]Warning[/yellow]] Failed to download {filename}: HTTP {resp.status_code}" ) except requests.RequestException as e: - console.print( - f" [[yellow]Warning[/yellow]] Failed to download {filename}: {e}" - ) + console.print(f" [[yellow]Warning[/yellow]] Failed to download {filename}: {e}") # Download base session artifacts try: @@ -472,9 +452,7 @@ def init(cache_db, **kwargs): cloud_org_id, cloud_project_id, session_id=session_id ) except RecceCloudException as e: - console.print( - f" [[yellow]Warning[/yellow]] Failed to get base session URLs: {e}" - ) + console.print(f" [[yellow]Warning[/yellow]] Failed to get base session URLs: {e}") base_download_urls = {} for artifact_key, filename in [ ("manifest_url", "manifest.json"), @@ -486,17 +464,13 @@ def init(cache_db, **kwargs): resp = requests.get(url, timeout=_METADATA_TIMEOUT) if resp.status_code == 200: (target_base_path / filename).write_bytes(resp.content) - console.print( - f" Downloaded base {filename} to {target_base_path}" - ) + console.print(f" Downloaded base {filename} to {target_base_path}") else: console.print( f" [[yellow]Warning[/yellow]] Failed to download base {filename}: HTTP {resp.status_code}" ) except requests.RequestException as e: - console.print( - f" [[yellow]Warning[/yellow]] Failed to download base {filename}: {e}" - ) + console.print(f" [[yellow]Warning[/yellow]] Failed to download base {filename}: {e}") # Download existing CLL cache for warm start. # Try current session first, then fall back to production (base) session. @@ -511,9 +485,7 @@ def _stream_download_to_file(url: str, dest: Path) -> int: if resp.status_code != 200: return 0 total = 0 - with tempfile.NamedTemporaryFile( - dir=dest.parent, delete=False, suffix=".tmp" - ) as tmp: + with tempfile.NamedTemporaryFile(dir=dest.parent, delete=False, suffix=".tmp") as tmp: tmp_path = Path(tmp.name) try: for chunk in resp.iter_content(chunk_size=8192): @@ -535,14 +507,10 @@ def _stream_download_to_file(url: str, dest: Path) -> int: try: nbytes = _stream_download_to_file(cll_cache_url, Path(cache_db)) if nbytes > 0: - console.print( - f" Downloaded CLL cache from session ({nbytes / 1024 / 1024:.1f} MB)" - ) + console.print(f" Downloaded CLL cache from session ({nbytes / 1024 / 1024:.1f} MB)") cache_downloaded = True except requests.RequestException as e: - console.print( - f" [[yellow]Warning[/yellow]] Failed to download CLL cache: {e}" - ) + console.print(f" [[yellow]Warning[/yellow]] Failed to download CLL cache: {e}") if not cache_downloaded: # Fall back to production (base) session cache @@ -551,19 +519,13 @@ def _stream_download_to_file(url: str, dest: Path) -> int: try: nbytes = _stream_download_to_file(base_cache_url, Path(cache_db)) if nbytes > 0: - console.print( - f" Downloaded CLL cache from base session ({nbytes / 1024 / 1024:.1f} MB)" - ) + console.print(f" Downloaded CLL cache from base session ({nbytes / 1024 / 1024:.1f} MB)") cache_downloaded = True except requests.RequestException as e: - console.print( - f" [[yellow]Warning[/yellow]] Failed to download base CLL cache: {e}" - ) + console.print(f" [[yellow]Warning[/yellow]] Failed to download base CLL cache: {e}") if not cache_downloaded: - console.print( - " [dim]No existing CLL cache found — will compute from scratch[/dim]" - ) + console.print(" [dim]No existing CLL cache found — will compute from scratch[/dim]") if cache_db is None: cache_db = _DEFAULT_DB_PATH @@ -580,9 +542,7 @@ def _stream_download_to_file(url: str, dest: Path) -> int: if not is_cloud: project_dir_path = Path(kwargs.get("project_dir") or "./") target_path = project_dir_path / kwargs.get("target_path", "target") - target_base_path = project_dir_path / kwargs.get( - "target_base_path", "target-base" - ) + target_base_path = project_dir_path / kwargs.get("target_base_path", "target-base") has_target = (target_path / "manifest.json").is_file() has_base = (target_base_path / "manifest.json").is_file() @@ -599,14 +559,10 @@ def _stream_download_to_file(url: str, dest: Path) -> int: # If only one env exists, use it for both (so load_context doesn't fail) context_kwargs = {**kwargs} if has_target and not has_base: - console.print( - "[dim]Only target/ found — building cache for current environment only.[/dim]" - ) + console.print("[dim]Only target/ found — building cache for current environment only.[/dim]") context_kwargs["target_base_path"] = kwargs.get("target_path", "target") elif has_base and not has_target: - console.print( - "[dim]Only target-base/ found — building cache for base environment only.[/dim]" - ) + console.print("[dim]Only target-base/ found — building cache for base environment only.[/dim]") context_kwargs["target_path"] = kwargs.get("target_base_path", "target-base") try: @@ -637,8 +593,7 @@ def _stream_download_to_file(url: str, dest: Path) -> int: curr_ids = [ nid for nid in dbt_adapter.curr_manifest.nodes - if dbt_adapter.curr_manifest.nodes[nid].resource_type - in ("model", "snapshot") + if dbt_adapter.curr_manifest.nodes[nid].resource_type in ("model", "snapshot") ] envs.append(("current", curr_ids, False)) @@ -646,26 +601,18 @@ def _stream_download_to_file(url: str, dest: Path) -> int: base_ids = [ nid for nid in dbt_adapter.base_manifest.nodes - if dbt_adapter.base_manifest.nodes[nid].resource_type - in ("model", "snapshot") + if dbt_adapter.base_manifest.nodes[nid].resource_type in ("model", "snapshot") ] envs.append(("base", base_ids, True)) with Progress(console=console, transient=True) as progress: for env_name, node_ids, is_base in envs: - console.print( - f"\n[bold]{env_name}[/bold] environment: {len(node_ids)} models" - ) + console.print(f"\n[bold]{env_name}[/bold] environment: {len(node_ids)} models") t_start = time.perf_counter() - manifest = ( - dbt_adapter.base_manifest if is_base else dbt_adapter.curr_manifest - ) + manifest = dbt_adapter.base_manifest if is_base else dbt_adapter.curr_manifest catalog = dbt_adapter.base_catalog if is_base else dbt_adapter.curr_catalog - adapter_type = ( - getattr(manifest.metadata, "adapter_type", None) - or dbt_adapter.adapter.type() - ) + adapter_type = getattr(manifest.metadata, "adapter_type", None) or dbt_adapter.adapter.type() success = 0 fail = 0 @@ -685,12 +632,8 @@ def _stream_download_to_file(url: str, dest: Path) -> int: col_names = list(catalog.nodes[nid].columns.keys()) checksum = DbtAdapter._get_node_checksum(manifest, nid) - parent_checksums = [ - DbtAdapter._get_node_checksum(manifest, pid) for pid in p_list - ] - content_key = DbtAdapter._make_node_content_key( - checksum, parent_checksums, col_names, adapter_type - ) + parent_checksums = [DbtAdapter._get_node_checksum(manifest, pid) for pid in p_list] + content_key = DbtAdapter._make_node_content_key(checksum, parent_checksums, col_names, adapter_type) cached_json = cache.get_node(nid, content_key) if cached_json: cache_hits += 1 @@ -704,17 +647,13 @@ def _stream_download_to_file(url: str, dest: Path) -> int: fail += 1 progress.advance(task) continue - batch_to_store.append( - (nid, content_key, DbtAdapter._serialize_cll_data(cll_data)) - ) + batch_to_store.append((nid, content_key, DbtAdapter._serialize_cll_data(cll_data))) success += 1 except Exception as e: fail += 1 if fail <= 3: console.print(f" [dim red] skip: {nid}: {e}[/dim red]") - logger.debug( - "[recce init] CLL computation failed for %s: %s", nid, e - ) + logger.debug("[recce init] CLL computation failed for %s: %s", nid, e) progress.advance(task) if batch_to_store: @@ -740,9 +679,7 @@ def _stream_download_to_file(url: str, dest: Path) -> int: dbt_adapter.get_cll_cached.cache_clear() if fail > 3: - console.print( - f" [dim]... and {fail - 3} more skipped (see logs for details)[/dim]" - ) + console.print(f" [dim]... and {fail - 3} more skipped (see logs for details)[/dim]") # Build and save the full CLL map as JSON. # The per-node SQLite cache is warm from the loop above, so this is fast. @@ -773,9 +710,7 @@ def _stream_download_to_file(url: str, dest: Path) -> int: console.print(f" [[yellow]Warning[/yellow]] Failed to build CLL map: {e}") stats = cache.stats - console.print( - f"\nCache saved to [bold]{cache_db}[/bold] ({stats['entries']} entries)" - ) + console.print(f"\nCache saved to [bold]{cache_db}[/bold] ({stats['entries']} entries)") # In cloud mode, emit per_node.db — a pure-artifact SQLite that Cloud # streams to serve lineage without proxying to an ephemeral Recce instance. @@ -807,9 +742,7 @@ def _stream_download_to_file(url: str, dest: Path) -> int: ) except Exception as e: logger.warning("[recce init] Failed to emit metadata artifacts: %s", e) - console.print( - f" [[yellow]Warning[/yellow]] Failed to emit metadata artifacts: {e}" - ) + console.print(f" [[yellow]Warning[/yellow]] Failed to emit metadata artifacts: {e}") info_path = None lineage_diff_path = None @@ -818,14 +751,10 @@ def _stream_download_to_file(url: str, dest: Path) -> int: upload_failures: list[str] = [] upload_urls: Optional[dict] = None try: - upload_urls = cloud_client.get_upload_urls_by_session_id( - cloud_org_id, cloud_project_id, session_id - ) + upload_urls = cloud_client.get_upload_urls_by_session_id(cloud_org_id, cloud_project_id, session_id) except Exception as e: logger.warning("[recce init] Cloud upload failed: %s", e) - console.print( - f" [[yellow]Warning[/yellow]] Cloud upload failed: {e}" - ) + console.print(f" [[yellow]Warning[/yellow]] Cloud upload failed: {e}") if upload_urls is not None: # Emit per_node.db only when Cloud declares support for it. @@ -877,10 +806,7 @@ def _stream_download_to_file(url: str, dest: Path) -> int: def _to_dict(artifact): return ( artifact.to_dict() - if ( - artifact is not None - and hasattr(artifact, "to_dict") - ) + if (artifact is not None and hasattr(artifact, "to_dict")) else artifact ) @@ -892,20 +818,14 @@ def _to_dict(artifact): ) in envs_to_emit: if manifest is None: continue - manifest_dict = ( - manifest.to_dict() - if hasattr(manifest, "to_dict") - else manifest - ) + manifest_dict = manifest.to_dict() if hasattr(manifest, "to_dict") else manifest catalog_dict = _to_dict(catalog) cross_catalog_dict = _to_dict(cross_catalog) - node_rows, column_rows, edge_rows, test_rows = ( - extract_rows_from_artifacts( - manifest_dict, - catalog_dict, - env_name, - cross_env_catalog=cross_catalog_dict, - ) + node_rows, column_rows, edge_rows, test_rows = extract_rows_from_artifacts( + manifest_dict, + catalog_dict, + env_name, + cross_env_catalog=cross_catalog_dict, ) writer.write_nodes(node_rows) writer.write_columns(column_rows) @@ -913,17 +833,10 @@ def _to_dict(artifact): writer.write_tests(test_rows) pn_elapsed = time.perf_counter() - t_pn_start pn_size_mb = per_node_db_path.stat().st_size / 1024 / 1024 - console.print( - f" per_node.db emitted " - f"({pn_size_mb:.1f} MB, {pn_elapsed:.1f}s)" - ) + console.print(f" per_node.db emitted " f"({pn_size_mb:.1f} MB, {pn_elapsed:.1f}s)") except Exception as e: - logger.warning( - "[recce init] Failed to emit per_node.db: %s", e - ) - console.print( - f" [[yellow]Warning[/yellow]] Failed to emit per_node.db: {e}" - ) + logger.warning("[recce init] Failed to emit per_node.db: %s", e) + console.print(f" [[yellow]Warning[/yellow]] Failed to emit per_node.db: {e}") per_node_db_path = None else: console.print( @@ -954,9 +867,7 @@ def _to_dict(artifact): ) except requests.RequestException as e: upload_failures.append("cll_map.json") - console.print( - f" [[yellow]Warning[/yellow]] Failed to upload cll_map.json: {e}" - ) + console.print(f" [[yellow]Warning[/yellow]] Failed to upload cll_map.json: {e}") elif not cll_map_upload_url: console.print( " [[yellow]Warning[/yellow]] No cll_map_url in upload URLs " @@ -964,19 +875,13 @@ def _to_dict(artifact): ) # Upload per_node.db (only when Cloud supports it AND we emitted). - if ( - per_node_db_upload_url - and per_node_db_path - and per_node_db_path.is_file() - ): + if per_node_db_upload_url and per_node_db_path and per_node_db_path.is_file(): try: with open(per_node_db_path, "rb") as f: resp = requests.put( per_node_db_upload_url, data=f, - headers={ - "Content-Type": "application/octet-stream" - }, + headers={"Content-Type": "application/octet-stream"}, timeout=_UPLOAD_TIMEOUT, ) if resp.status_code in (200, 204): @@ -992,9 +897,7 @@ def _to_dict(artifact): ) except requests.RequestException as e: upload_failures.append("per_node.db") - console.print( - f" [[yellow]Warning[/yellow]] Failed to upload per_node.db: {e}" - ) + console.print(f" [[yellow]Warning[/yellow]] Failed to upload per_node.db: {e}") # Upload CLL cache. cll_cache.db is load-bearing across sessions — # build_full_cll_map reuses its warm entries on subsequent runs — @@ -1006,9 +909,7 @@ def _to_dict(artifact): resp = requests.put( cll_cache_upload_url, data=f, - headers={ - "Content-Type": "application/octet-stream" - }, + headers={"Content-Type": "application/octet-stream"}, timeout=_UPLOAD_TIMEOUT, ) if resp.status_code in (200, 204): @@ -1024,13 +925,9 @@ def _to_dict(artifact): ) except requests.RequestException as e: upload_failures.append("cll_cache.db") - console.print( - f" [[yellow]Warning[/yellow]] Failed to upload cll_cache.db: {e}" - ) + console.print(f" [[yellow]Warning[/yellow]] Failed to upload cll_cache.db: {e}") elif not cll_cache_upload_url: - logger.debug( - "No cll_cache_url in upload URLs — cache upload not supported yet" - ) + logger.debug("No cll_cache_url in upload URLs — cache upload not supported yet") # Upload info.json and lineage_diff.json. Graceful # degradation: if Cloud hasn't added the info_url / @@ -1042,11 +939,7 @@ def _to_dict(artifact): ] for display_name, local_path, url_key in metadata_uploads: metadata_upload_url = upload_urls.get(url_key) - if ( - metadata_upload_url - and local_path is not None - and local_path.is_file() - ): + if metadata_upload_url and local_path is not None and local_path.is_file(): try: with open(local_path, "rb") as f: resp = requests.put( @@ -1057,9 +950,7 @@ def _to_dict(artifact): ) if resp.status_code in (200, 204): size_kb = local_path.stat().st_size / 1024 - console.print( - f" Uploaded {display_name} ({size_kb:.1f} KB)" - ) + console.print(f" Uploaded {display_name} ({size_kb:.1f} KB)") else: upload_failures.append(display_name) console.print( @@ -1068,12 +959,8 @@ def _to_dict(artifact): ) except requests.RequestException as e: upload_failures.append(display_name) - console.print( - f" [[yellow]Warning[/yellow]] Failed to upload {display_name}: {e}" - ) - elif metadata_upload_url and ( - local_path is None or not local_path.is_file() - ): + console.print(f" [[yellow]Warning[/yellow]] Failed to upload {display_name}: {e}") + elif metadata_upload_url and (local_path is None or not local_path.is_file()): # URL present but local artifact missing — emit failed # partway (e.g., info.json written but lineage_diff.json # write raised). Record the failure so the summary @@ -1104,9 +991,7 @@ def _to_dict(artifact): shutil.rmtree(per_node_scratch, ignore_errors=True) shutil.rmtree(metadata_scratch, ignore_errors=True) else: - console.print( - "Run [bold]recce server --enable-cll-cache[/bold] to use the cached lineage." - ) + console.print("Run [bold]recce server --enable-cll-cache[/bold] to use the cached lineage.") @cli.command(cls=TrackCommand) @@ -1135,32 +1020,22 @@ def check_artifacts(env_name, target_path): manifest_path = target_path / "manifest.json" manifest_is_ready = manifest_path.is_file() if manifest_is_ready: - console.print( - f"[[green]OK[/green]] Manifest JSON file exists : {manifest_path}" - ) + console.print(f"[[green]OK[/green]] Manifest JSON file exists : {manifest_path}") else: - console.print( - f"[[red]MISS[/red]] Manifest JSON file not found: {manifest_path}" - ) + console.print(f"[[red]MISS[/red]] Manifest JSON file not found: {manifest_path}") catalog_path = target_path / "catalog.json" catalog_is_ready = catalog_path.is_file() if catalog_is_ready: - console.print( - f"[[green]OK[/green]] Catalog JSON file exists: {catalog_path}" - ) + console.print(f"[[green]OK[/green]] Catalog JSON file exists: {catalog_path}") else: - console.print( - f"[[red]MISS[/red]] Catalog JSON file not found: {catalog_path}" - ) + console.print(f"[[red]MISS[/red]] Catalog JSON file not found: {catalog_path}") return [True, manifest_is_ready, catalog_is_ready] project_dir_path = Path(kwargs.get("project_dir") or "./") target_path = project_dir_path.joinpath(Path(kwargs.get("target_path", "target"))) - target_base_path = project_dir_path.joinpath( - Path(kwargs.get("target_base_path", "target-base")) - ) + target_base_path = project_dir_path.joinpath(Path(kwargs.get("target_base_path", "target-base"))) curr_is_ready = check_artifacts("Development", target_path) base_is_ready = check_artifacts("Base", target_base_path) @@ -1182,9 +1057,7 @@ def check_artifacts(env_name, target_path): if all(curr_is_ready) and all(base_is_ready) and conn_is_ready: console.print("[[green]OK[/green]] Ready to launch! Type 'recce server'.") elif all(curr_is_ready) and conn_is_ready: - console.print( - "[[orange3]OK[/orange3]] Ready to launch with [i]limited features[/i]. Type 'recce server'." - ) + console.print("[[orange3]OK[/orange3]] Ready to launch with [i]limited features[/i]. Type 'recce server'.") if not curr_is_ready[0]: console.print( @@ -1215,9 +1088,7 @@ def check_artifacts(env_name, target_path): ) if not conn_is_ready: - console.print( - "[[orange3]TIP[/orange3]] Run 'dbt debug' to check the connection." - ) + console.print("[[orange3]TIP[/orange3]] Run 'dbt debug' to check the connection.") @cli.command(hidden=True, cls=TrackCommand) @@ -1304,12 +1175,8 @@ def diff( @cli.command(cls=TrackCommand) @click.argument("state_file", required=False) -@click.option( - "--host", default="localhost", show_default=True, help="The host to bind to." -) -@click.option( - "--port", default=8000, show_default=True, help="The port to bind to.", type=int -) +@click.option("--host", default="localhost", show_default=True, help="The host to bind to.") +@click.option("--port", default=8000, show_default=True, help="The port to bind to.", type=int) @click.option( "--lifetime", default=0, @@ -1325,9 +1192,7 @@ def diff( type=int, ) @click.option("--review", is_flag=True, help="Open the state file in the review mode.") -@click.option( - "--single-env", is_flag=True, help="Launch in single environment mode directly." -) +@click.option("--single-env", is_flag=True, help="Launch in single environment mode directly.") @click.option( "--enable-cll-cache", is_flag=True, @@ -1433,9 +1298,7 @@ def server(host, port, lifetime, idle_timeout=0, state_file=None, **kwargs): # Check Single Environment Onboarding Mode if not in cloud mode and not in review mode if not is_cloud and not is_review: project_dir_path = Path(kwargs.get("project_dir") or "./") - target_base_path = project_dir_path.joinpath( - Path(kwargs.get("target_base_path", "target-base")) - ) + target_base_path = project_dir_path.joinpath(Path(kwargs.get("target_base_path", "target-base"))) if not target_base_path.is_dir(): # Mark as single env onboarding mode if user provides the target-path only flag["single_env_onboarding"] = True @@ -1554,9 +1417,7 @@ def server(host, port, lifetime, idle_timeout=0, state_file=None, **kwargs): ) @click.option("--state-file", help="Path of the import state file.", type=click.Path()) @click.option("--summary", help="Path of the summary markdown file.", type=click.Path()) -@click.option( - "--skip-query", is_flag=True, help="Skip running the queries for the checks." -) +@click.option("--skip-query", is_flag=True, help="Skip running the queries for the checks.") @click.option("--skip-check", is_flag=True, help="Skip running the checks.") @click.option( "--git-current-branch", @@ -1756,9 +1617,7 @@ def connect_to_cloud(): connect_url, callback_port = prepare_connection_url(public_key) console.rule("Connecting to Recce Cloud") - console.print( - "Attempting to automatically open the Recce Cloud authorization page in your default browser." - ) + console.print("Attempting to automatically open the Recce Cloud authorization page in your default browser.") console.print("If the browser does not open, please open the following URL:") console.print(connect_url) webbrowser.open(connect_url) @@ -1830,15 +1689,11 @@ def purge(**kwargs): ) state_loader.load() except Exception: - console.print( - "[[yellow]Skip[/yellow]] Cannot access existing state file from cloud. Purge it directly." - ) + console.print("[[yellow]Skip[/yellow]] Cannot access existing state file from cloud. Purge it directly.") if state_loader is None: try: - if force_to_purge is True or click.confirm( - "\nDo you want to purge the state file?" - ): + if force_to_purge is True or click.confirm("\nDo you want to purge the state file?"): rc, err_msg = RecceCloudStateManager(cloud_options).purge_cloud_state() if rc is True: console.rule("Purged Successfully") @@ -1857,19 +1712,13 @@ def purge(**kwargs): pr_info = info.get("pull_request") console.print("[green]State File hosted by[/green]", info.get("source")) - console.print( - "[green]GitHub Repository[/green]", info.get("pull_request").repository - ) + console.print("[green]GitHub Repository[/green]", info.get("pull_request").repository) console.print(f"[green]GitHub Pull Request[/green]\n{pr_info.title} #{pr_info.id}") - console.print( - f"Branch merged into [blue]{pr_info.base_branch}[/blue] from [blue]{pr_info.branch}[/blue]" - ) + console.print(f"Branch merged into [blue]{pr_info.base_branch}[/blue] from [blue]{pr_info.branch}[/blue]") console.print(pr_info.url) try: - if force_to_purge is True or click.confirm( - "\nDo you want to purge the state file?" - ): + if force_to_purge is True or click.confirm("\nDo you want to purge the state file?"): response = state_loader.purge() if response is True: console.rule("Purged Successfully") @@ -1948,9 +1797,7 @@ def upload(state_file, **kwargs): cloud_state_file_exists = state_manager.check_cloud_state_exists() - if cloud_state_file_exists and not click.confirm( - "\nDo you want to overwrite the existing state file?" - ): + if cloud_state_file_exists and not click.confirm("\nDo you want to overwrite the existing state file?"): return 0 console.print(state_manager.upload_state_to_cloud(state_loader.state)) @@ -2112,9 +1959,7 @@ def _download_artifacts(branch, cloud_token, console, kwargs, password, target_p ) except Exception as e: console.rule("Failed to Download", style="red") - console.print( - "[[red]Error[/red]] Failed to download the dbt artifacts from cloud." - ) + console.print("[[red]Error[/red]] Failed to download the dbt artifacts from cloud.") reason = str(e) if ( @@ -2127,9 +1972,7 @@ def _download_artifacts(branch, cloud_token, console, kwargs, password, target_p ) elif "The specified key does not exist" in reason: console.print("Reason: The dbt artifacts is not found in the cloud.") - console.print( - "Please upload the dbt artifacts to the cloud before downloading it." - ) + console.print("Please upload the dbt artifacts to the cloud before downloading it.") else: console.print(f"Reason: {reason}") rc = 1 @@ -2191,9 +2034,7 @@ def download_artifacts(**kwargs): password = kwargs.get("password") target_path = kwargs.get("target_path") branch = kwargs.get("branch") or current_branch() - return _download_artifacts( - branch, cloud_token, console, kwargs, password, target_path - ) + return _download_artifacts(branch, cloud_token, console, kwargs, password, target_path) @cloud.command(cls=TrackCommand) @@ -2253,14 +2094,11 @@ def download_base_artifacts(**kwargs): # If recce can't infer default branch from "GITHUB_BASE_REF" and current_default_branch() if branch is None: console.print( - "[[red]Error[/red]] Please provide your base branch name with '--branch' to download the base " - "artifacts." + "[[red]Error[/red]] Please provide your base branch name with '--branch' to download the base " "artifacts." ) exit(1) - return _download_artifacts( - branch, cloud_token, console, kwargs, password, target_path - ) + return _download_artifacts(branch, cloud_token, console, kwargs, password, target_path) @cloud.command(cls=TrackCommand) @@ -2305,32 +2143,22 @@ def delete_artifacts(**kwargs): force = kwargs.get("force", False) if not force: - if not click.confirm( - f'Do you want to delete artifacts from branch "{branch}"?' - ): + if not click.confirm(f'Do you want to delete artifacts from branch "{branch}"?'): console.print("Deletion cancelled.") return 0 try: - delete_dbt_artifacts( - branch=branch, token=cloud_token, debug=kwargs.get("debug", False) - ) - console.print( - f"[[green]Success[/green]] Artifacts deleted from branch: {branch}" - ) + delete_dbt_artifacts(branch=branch, token=cloud_token, debug=kwargs.get("debug", False)) + console.print(f"[[green]Success[/green]] Artifacts deleted from branch: {branch}") return 0 except click.exceptions.Abort: pass except RecceCloudException as e: - console.print( - "[[red]Error[/red]] Failed to delete the dbt artifacts from cloud." - ) + console.print("[[red]Error[/red]] Failed to delete the dbt artifacts from cloud.") console.print(f"Reason: {e.reason}") exit(1) except Exception as e: - console.print( - "[[red]Error[/red]] Failed to delete the dbt artifacts from cloud." - ) + console.print("[[red]Error[/red]] Failed to delete the dbt artifacts from cloud.") console.print(f"Reason: {e}") exit(1) @@ -2381,9 +2209,7 @@ def list_organizations(**kwargs): table.add_column("Display Name", style="yellow") for org in organizations: - table.add_row( - str(org.get("id", "")), org.get("name", ""), org.get("display_name", "") - ) + table.add_row(str(org.get("id", "")), org.get("name", ""), org.get("display_name", "")) console.print(table) @@ -2447,12 +2273,8 @@ def list_projects(**kwargs): organization = kwargs.get("organization") if not organization: - console.print( - "[[red]Error[/red]] Organization ID is required. Please provide it via:" - ) - console.print( - " --organization or set RECCE_ORGANIZATION_ID environment variable" - ) + console.print("[[red]Error[/red]] Organization ID is required. Please provide it via:") + console.print(" --organization or set RECCE_ORGANIZATION_ID environment variable") exit(1) try: @@ -2555,18 +2377,12 @@ def list_sessions(**kwargs): # Validate required parameters if not organization: - console.print( - "[[red]Error[/red]] Organization ID is required. Please provide it via:" - ) - console.print( - " --organization or set RECCE_ORGANIZATION_ID environment variable" - ) + console.print("[[red]Error[/red]] Organization ID is required. Please provide it via:") + console.print(" --organization or set RECCE_ORGANIZATION_ID environment variable") exit(1) if not project: - console.print( - "[[red]Error[/red]] Project ID is required. Please provide it via:" - ) + console.print("[[red]Error[/red]] Project ID is required. Please provide it via:") console.print(" --project or set RECCE_PROJECT_ID environment variable") exit(1) @@ -2691,10 +2507,7 @@ def share(state_file, **kwargs): try: response = state_manager.share_state(state_file_name, state_loader.state) if response.get("status") == "error": - console.print( - "[[red]Error[/red]] Failed to share the state.\n" - f"Reason: {response.get('message')}" - ) + console.print("[[red]Error[/red]] Failed to share the state.\n" f"Reason: {response.get('message')}") else: console.print(f"Shared Link: {response.get('share_url')}") except RecceCloudException as e: @@ -2784,9 +2597,7 @@ def upload_session(**kwargs): ) except Exception as e: console.rule("Failed to Upload Session", style="red") - console.print( - f"[[red]Error[/red]] Failed to upload the dbt artifacts to the session {session_id}." - ) + console.print(f"[[red]Error[/red]] Failed to upload the dbt artifacts to the session {session_id}.") console.print(f"Reason: {e}") rc = 1 return rc @@ -2809,12 +2620,8 @@ def snapshot(**kwargs): @cli.command(hidden=True, cls=TrackCommand) @click.argument("state_file", required=True) -@click.option( - "--host", default="localhost", show_default=True, help="The host to bind to." -) -@click.option( - "--port", default=8000, show_default=True, help="The port to bind to.", type=int -) +@click.option("--host", default="localhost", show_default=True, help="The host to bind to.") +@click.option("--port", default=8000, show_default=True, help="The port to bind to.", type=int) @click.option( "--lifetime", default=0, @@ -2850,9 +2657,7 @@ def read_only(ctx, state_file=None, **kwargs): default="localhost", help="Host to bind to in SSE mode (default: localhost)", ) -@click.option( - "--port", default=8000, type=int, help="Port to bind to in SSE mode (default: 8000)" -) +@click.option("--port", default=8000, type=int, help="Port to bind to in SSE mode (default: 8000)") @click.option( "--session", "cloud_session", @@ -3000,12 +2805,8 @@ def mcp_server(state_file, sse, host, port, **kwargs): # Skipped in cloud-session and cloud-snapshot modes — they bring their own state. if not is_cloud_session and not is_cloud_snapshot: project_dir_path = Path(kwargs.get("project_dir") or "./") - target_path = project_dir_path.joinpath( - Path(kwargs.get("target_path", "target")) - ) - target_base_path = project_dir_path.joinpath( - Path(kwargs.get("target_base_path", "target-base")) - ) + target_path = project_dir_path.joinpath(Path(kwargs.get("target_path", "target"))) + target_base_path = project_dir_path.joinpath(Path(kwargs.get("target_base_path", "target-base"))) if target_path.is_dir() and not target_base_path.is_dir(): kwargs["single_env"] = True kwargs["target_base_path"] = kwargs.get("target_path") @@ -3013,9 +2814,7 @@ def mcp_server(state_file, sse, host, port, **kwargs): "[yellow]Base artifacts not found. " "Starting in single-environment mode (diffs will show no changes).[/yellow]" ) - console.print( - "To enable diffing: dbt docs generate --target-path target-base" - ) + console.print("To enable diffing: dbt docs generate --target-path target-base") # Don't let env-derived cloud=True trigger run_mcp_server's CloudBackend branch # — cloud-snapshot runs as a local MCP against the downloaded state file. @@ -3048,11 +2847,7 @@ def mcp_server(state_file, sse, host, port, **kwargs): ) # Run the server (stdio or SSE based on --sse flag) - asyncio.run( - run_mcp_server( - sse=sse, host=host, port=port, session=cloud_session, **kwargs - ) - ) + asyncio.run(run_mcp_server(sse=sse, host=host, port=port, session=cloud_session, **kwargs)) except (asyncio.CancelledError, KeyboardInterrupt): # Graceful shutdown (e.g., Ctrl+C) console.print("[yellow]MCP Server interrupted[/yellow]") @@ -3149,6 +2944,27 @@ def clear_cache(cache_db): pass +def resolve_target_base_path( + project_dir: str | None, + target_base_path: str, +) -> str: + """Resolve ``target_base_path`` against ``project_dir`` like every dbt-aware + command does. + + An absolute ``target_base_path`` bypasses the join. Relative paths are + joined onto ``project_dir`` (or CWD if ``project_dir`` is None). + + Shared by the ``recce check-base`` CLI command and the ``recce mcp-server`` + startup freshness check so they cannot drift — having two copies of the + join logic was the root cause of the round-2 review finding (CLI was + fixed, MCP startup was missed). + """ + base = Path(target_base_path) + if base.is_absolute(): + return str(base) + return str(Path(project_dir or "./") / base) + + def check_base_freshness( target_base_path: str = "target-base", freshness_threshold_hours: float = 48.0, @@ -3237,10 +3053,7 @@ def check_base_freshness( result["status"] = "fresh" result["recommendation"] = "reuse" - result["message"] = ( - f"Base artifacts are fresh ({artifact_age_hours:.1f} hours old). " - "Reuse existing artifacts." - ) + result["message"] = f"Base artifacts are fresh ({artifact_age_hours:.1f} hours old). " "Reuse existing artifacts." return result @@ -3285,16 +3098,10 @@ def check_base(project_dir, target_base_path, output_format, freshness_threshold import json # Honor --project-dir / DBT_PROJECT_DIR like every other dbt-aware command. - # An absolute target-base-path bypasses the join. - project_dir_path = Path(project_dir) if project_dir else Path("./") - resolved_target_base = ( - Path(target_base_path) - if Path(target_base_path).is_absolute() - else project_dir_path / target_base_path - ) + resolved_target_base = resolve_target_base_path(project_dir, target_base_path) result = check_base_freshness( - target_base_path=str(resolved_target_base), + target_base_path=resolved_target_base, freshness_threshold_hours=freshness_threshold_hours, ) diff --git a/recce/mcp_server.py b/recce/mcp_server.py index 61902fd4e..9b59ad1d3 100644 --- a/recce/mcp_server.py +++ b/recce/mcp_server.py @@ -71,9 +71,7 @@ class CloudBackend: "histogram_diff": "histogram_diff", } - def __init__( - self, session_id: str, api_token: str, cloud_host: str = RECCE_CLOUD_API_HOST - ): + def __init__(self, session_id: str, api_token: str, cloud_host: str = RECCE_CLOUD_API_HOST): self.session_id = session_id self.api_token = api_token self.cloud_host = cloud_host.rstrip("/") @@ -84,9 +82,7 @@ async def create(cls, session_id: str, api_token: str): backend = cls(session_id=session_id, api_token=api_token) spawn_response = await backend._request("POST", "instance", json={}) if isinstance(spawn_response, dict): - backend.instance_status = spawn_response.get( - "status" - ) or spawn_response.get("instance_status") + backend.instance_status = spawn_response.get("status") or spawn_response.get("instance_status") return backend def _url(self, api_name: str) -> str: @@ -98,9 +94,7 @@ async def _request(self, method: str, api_name: str, **kwargs): **kwargs.pop("headers", {}), "Authorization": f"Bearer {self.api_token}", } - response = await asyncio.to_thread( - requests.request, method, url, headers=headers, **kwargs - ) + response = await asyncio.to_thread(requests.request, method, url, headers=headers, **kwargs) if response.status_code == 405: raise InstanceSpawningError() if response.status_code < 200 or response.status_code >= 300: @@ -178,9 +172,7 @@ async def _tool_query(self, arguments: Dict[str, Any]) -> Dict[str, Any]: params = {k: v for k, v in arguments.items() if k != "base"} return await self._tool_run_backed(run_type, params) - async def _tool_run_backed( - self, run_type: str, params: Dict[str, Any] - ) -> Dict[str, Any]: + async def _tool_run_backed(self, run_type: str, params: Dict[str, Any]) -> Dict[str, Any]: run = await self._request( "POST", "runs", @@ -281,17 +273,9 @@ async def _tool_lineage_diff(self, arguments: Dict[str, Any]) -> Dict[str, Any]: lineage = info.get("lineage", {}) nodes = lineage.get("nodes", {}) selected = await self._selected_nodes(arguments, nodes) - impacted = set( - ( - await self._request( - "POST", "select", json={"select": "state:modified+"} - ) - ).get("nodes", []) - ) + impacted = set((await self._request("POST", "select", json={"select": "state:modified+"})).get("nodes", [])) - selected_nodes = { - node_id: node for node_id, node in nodes.items() if node_id in selected - } + selected_nodes = {node_id: node for node_id, node in nodes.items() if node_id in selected} id_to_idx = {node_id: idx for idx, node_id in enumerate(selected_nodes.keys())} nodes_df = DataFrame.from_data( columns={ @@ -323,9 +307,7 @@ async def _tool_lineage_diff(self, arguments: Dict[str, Any]) -> Dict[str, Any]: target = edge.get("target") if source in id_to_idx and target in id_to_idx: edge_rows.append((id_to_idx[source], id_to_idx[target])) - edges_df = DataFrame.from_data( - columns={"from": "integer", "to": "integer"}, data=edge_rows - ) + edges_df = DataFrame.from_data(columns={"from": "integer", "to": "integer"}, data=edge_rows) return { "nodes": nodes_df.model_dump(mode="json"), "edges": edges_df.model_dump(mode="json"), @@ -357,15 +339,9 @@ async def _tool_impact_analysis(self, arguments: Dict[str, Any]) -> Dict[str, An "select", "state:modified.body+ state:modified.macros+ state:modified.contract+", ) - impacted_node_ids = set( - (await self._request("POST", "select", json={"select": select})).get( - "nodes", [] - ) - ) + impacted_node_ids = set((await self._request("POST", "select", json={"select": select})).get("nodes", [])) modified_node_ids = set( - ( - await self._request("POST", "select", json={"select": "state:modified"}) - ).get("nodes", []) + (await self._request("POST", "select", json={"select": "state:modified"})).get("nodes", []) ) impacted_models = [] @@ -375,16 +351,12 @@ async def _tool_impact_analysis(self, arguments: Dict[str, Any]) -> Dict[str, An continue entry = { "name": node.get("name"), - "change_status": node.get("change_status") - if node_id in modified_node_ids - else None, + "change_status": node.get("change_status") if node_id in modified_node_ids else None, "materialized": node.get("materialized"), "row_count": None, "schema_changes": [ {"column": column, "change_status": status} - for column, status in ( - (node.get("change") or {}).get("columns") or {} - ).items() + for column, status in ((node.get("change") or {}).get("columns") or {}).items() ], "value_diff": None, "affected_row_count": None, @@ -422,9 +394,7 @@ async def _selected_nodes(self, arguments: Dict[str, Any], nodes: Dict[str, Any] for key in ("select", "exclude", "packages", "view_mode") if arguments.get(key) is not None } - return set( - (await self._request("POST", "select", json=payload)).get("nodes", []) - ) + return set((await self._request("POST", "select", json=payload)).get("nodes", [])) return set(nodes.keys()) @staticmethod @@ -444,9 +414,7 @@ def _redact_sensitive_args(arguments: Dict[str, Any]) -> Dict[str, Any]: """ if not isinstance(arguments, dict): return arguments - return { - k: ("***" if k in SENSITIVE_ARG_KEYS and v else v) for k, v in arguments.items() - } + return {k: ("***" if k in SENSITIVE_ARG_KEYS and v else v) for k, v in arguments.items()} def _truncate_strings(obj: Any, max_length: int = 200) -> Any: @@ -596,9 +564,7 @@ def _setup_handlers(self): @self.server.list_tools() async def list_tools() -> List[Tool]: """List all available tools based on server mode""" - logger.info( - f"[MCP] list_tools called (mode: {self.mode.value if self.mode else 'server'})" - ) + logger.info(f"[MCP] list_tools called (mode: {self.mode.value if self.mode else 'server'})") tools = [] # Always available in all modes @@ -1291,19 +1257,10 @@ async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: } # Unconfigured-mode gate: when neither a local context nor a cloud # backend is set, only set_backend and get_server_info are usable. - if ( - self.context is None - and self.backend is None - and name not in {"set_backend", "get_server_info"} - ): - raise ValueError( - "No backend configured. Call set_backend(mode='local'|'cloud', ...) first." - ) + if self.context is None and self.backend is None and name not in {"set_backend", "get_server_info"}: + raise ValueError("No backend configured. Call set_backend(mode='local'|'cloud', ...) first.") - if ( - self.mode != RecceServerMode.server - and name in blocked_tools_in_non_server - ): + if self.mode != RecceServerMode.server and name in blocked_tools_in_non_server: # Allowed tools = all registered minus blocked allowed_tools = sorted( { @@ -1322,11 +1279,7 @@ async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: if name == "set_backend": result = await self._tool_set_backend(arguments) - elif ( - name == "get_server_info" - and self.context is None - and self.backend is None - ): + elif name == "get_server_info" and self.context is None and self.backend is None: result = { "mode": "none", "configured": False, @@ -1381,9 +1334,7 @@ async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: logger.info(f"[MCP] Tool response for {name} ({duration_ms:.2f}ms):") # Truncate large responses for console readability if len(response_json) > 1000: - logger.debug( - f"[MCP] {response_json[:1000]}... (truncated, {len(response_json)} chars total)" - ) + logger.debug(f"[MCP] {response_json[:1000]}... (truncated, {len(response_json)} chars total)") else: logger.debug(f"[MCP] {response_json}") @@ -1391,13 +1342,9 @@ async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: except Exception as e: duration_ms = (time.perf_counter() - start_time) * 1000 error_msg = str(e) - self.mcp_logger.log_tool_call( - name, log_arguments, {}, duration_ms, error=error_msg - ) + self.mcp_logger.log_tool_call(name, log_arguments, {}, duration_ms, error=error_msg) - is_expected_cloud_error = isinstance( - e, (RecceCloudException, InstanceSpawningError) - ) + is_expected_cloud_error = isinstance(e, (RecceCloudException, InstanceSpawningError)) classification = self._classify_db_error(error_msg) if classification: logger.warning( @@ -1410,13 +1357,9 @@ async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: attributes={"tool": name, "error_type": classification}, ) elif is_expected_cloud_error: - logger.warning( - f"[MCP] Expected cloud error in tool {name} ({duration_ms:.2f}ms): {error_msg}" - ) + logger.warning(f"[MCP] Expected cloud error in tool {name} ({duration_ms:.2f}ms): {error_msg}") else: - logger.error( - f"[MCP] Error executing tool {name} ({duration_ms:.2f}ms): {error_msg}" - ) + logger.error(f"[MCP] Error executing tool {name} ({duration_ms:.2f}ms): {error_msg}") logger.exception("[MCP] Full traceback:") # Re-raise so MCP SDK sets isError=True in the protocol response @@ -1666,9 +1609,7 @@ async def _tool_value_diff(self, arguments: Dict[str, Any]) -> Dict[str, Any]: result = result.model_dump(mode="json") return self._maybe_add_single_env_warning(result) - async def _tool_value_diff_detail( - self, arguments: Dict[str, Any] - ) -> Dict[str, Any]: + async def _tool_value_diff_detail(self, arguments: Dict[str, Any]) -> Dict[str, Any]: """Execute value diff detail task""" task = ValueDiffDetailTask(params=arguments) result = await asyncio.get_event_loop().run_in_executor(None, task.execute) @@ -1706,9 +1647,7 @@ async def _tool_histogram_diff(self, arguments: Dict[str, Any]) -> Dict[str, Any if not col_info: col_info = columns.get(column_name.lower()) if not col_info or not col_info.get("type"): - raise ValueError( - f"Cannot determine column type for '{column_name}' in model '{model}'" - ) + raise ValueError(f"Cannot determine column type for '{column_name}' in model '{model}'") params = {**arguments, "column_type": col_info["type"]} task = HistogramDiffTask(params=params) @@ -1766,10 +1705,7 @@ async def _tool_impact_analysis(self, arguments: Dict[str, Any]) -> Dict[str, An model_entry = { "name": name, "change_status": ( - change_status - if node_id in modified_node_ids - or change_status in ("added", "removed") - else None + change_status if node_id in modified_node_ids or change_status in ("added", "removed") else None ), "materialized": materialized, "row_count": None, @@ -1781,16 +1717,12 @@ async def _tool_impact_analysis(self, arguments: Dict[str, Any]) -> Dict[str, An not_impacted_models.append(name) # Step 2a: Row count diff (skip removed models; include views for delta detection) - countable_models = [ - m for m in impacted_models if m["change_status"] != "removed" - ] + countable_models = [m for m in impacted_models if m["change_status"] != "removed"] if countable_models: countable_names = [m["name"] for m in countable_models] try: task = RowCountDiffTask(params={"node_names": countable_names}) - row_count_result = await asyncio.get_event_loop().run_in_executor( - None, task.execute - ) + row_count_result = await asyncio.get_event_loop().run_in_executor(None, task.execute) for model in countable_models: name = model["name"] @@ -1805,9 +1737,7 @@ async def _tool_impact_analysis(self, arguments: Dict[str, Any]) -> Dict[str, An "base": base, "current": curr, "delta": delta, - "delta_pct": round(delta_pct, 1) - if delta_pct is not None - else None, + "delta_pct": round(delta_pct, 1) if delta_pct is not None else None, } elif curr is not None: # Added model (no base) @@ -1838,9 +1768,7 @@ async def _tool_impact_analysis(self, arguments: Dict[str, Any]) -> Dict[str, An continue base_cols = set(base_nodes.get(node_id, {}).get("columns", {}).keys()) - curr_cols = set( - current_nodes.get(node_id, {}).get("columns", {}).keys() - ) + curr_cols = set(current_nodes.get(node_id, {}).get("columns", {}).keys()) changes = [] for col in curr_cols - base_cols: @@ -1880,12 +1808,8 @@ async def _tool_impact_analysis(self, arguments: Dict[str, Any]) -> Dict[str, An continue # only PK column, no value diff to compute # Build relations for base and current schemas - base_rel = self.context.adapter.create_relation( - model["name"], base=True - ) - curr_rel = self.context.adapter.create_relation( - model["name"], base=False - ) + base_rel = self.context.adapter.create_relation(model["name"], base=True) + curr_rel = self.context.adapter.create_relation(model["name"], base=False) if not base_rel or not curr_rel: continue @@ -1921,20 +1845,10 @@ async def _tool_impact_analysis(self, arguments: Dict[str, Any]) -> Dict[str, An f'COUNT(CASE WHEN b."{pk}" IS NOT NULL AND c."{pk}" IS NOT NULL ' f'AND b."{col}" IS DISTINCT FROM c."{col}" THEN 1 END) AS "{col}__changed"' ) - col_type = ( - columns_info[col] - .get("type", "") - .upper() - .split("(")[0] - .strip() - ) + col_type = columns_info[col].get("type", "").upper().split("(")[0].strip() if col_type in numeric_types: - per_col_parts.append( - f'AVG(b."{col}") AS "{col}__base_mean"' - ) - per_col_parts.append( - f'AVG(c."{col}") AS "{col}__curr_mean"' - ) + per_col_parts.append(f'AVG(b."{col}") AS "{col}__base_mean"') + per_col_parts.append(f'AVG(c."{col}") AS "{col}__curr_mean"') per_col_sql = ",\n ".join(per_col_parts) @@ -1970,24 +1884,14 @@ def _run_value_diff_query(adapter, query): for col in non_pk_cols: col_changed = int(row[col_idx]) col_idx += 1 - col_type = ( - columns_info[col] - .get("type", "") - .upper() - .split("(")[0] - .strip() - ) + col_type = columns_info[col].get("type", "").upper().split("(")[0].strip() base_mean = None current_mean = None if col_type in numeric_types: raw_base = row[col_idx] raw_curr = row[col_idx + 1] - base_mean = ( - float(raw_base) if raw_base is not None else None - ) - current_mean = ( - float(raw_curr) if raw_curr is not None else None - ) + base_mean = float(raw_base) if raw_base is not None else None + current_mean = float(raw_curr) if raw_curr is not None else None col_idx += 2 columns_result[col] = { "affected_row_count": col_changed, @@ -2018,10 +1922,7 @@ def _run_value_diff_query(adapter, query): # affected_row_count: value_diff total (priority) or abs(row_count.delta) (fallback) if model["value_diff"] is not None: model["affected_row_count"] = model["value_diff"]["affected_row_count"] - elif ( - model["row_count"] is not None - and model["row_count"].get("delta") is not None - ): + elif model["row_count"] is not None and model["row_count"].get("delta") is not None: model["affected_row_count"] = abs(model["row_count"]["delta"]) else: model["affected_row_count"] = None @@ -2040,10 +1941,7 @@ def _run_value_diff_query(adapter, query): if model["data_impact"] == "potential": model["affected_row_count"] = None - if ( - model["affected_row_count"] is not None - and model["affected_row_count"] > max_affected - ): + if model["affected_row_count"] is not None and model["affected_row_count"] > max_affected: max_affected = model["affected_row_count"] # next_action: only for "potential" models — confirmed/none need no follow-up @@ -2093,9 +1991,7 @@ def _run_value_diff_query(adapter, query): and model["row_count"]["delta_pct"] is not None and abs(model["row_count"]["delta_pct"]) <= 5 ): - total_matched = (model["row_count"]["current"] or 0) - vd[ - "rows_added" - ] + total_matched = (model["row_count"]["current"] or 0) - vd["rows_added"] if total_matched > 0 and vd["rows_changed"] / total_matched > 0.2: top_cols = [ col @@ -2111,12 +2007,8 @@ def _run_value_diff_query(adapter, query): if sentry_metrics: duration = time.time() - start_time - sentry_metrics.distribution( - "mcp.impact_analysis.duration", duration, unit="second" - ) - sentry_metrics.distribution( - "mcp.impact_analysis.impacted_count", len(impacted_models) - ) + sentry_metrics.distribution("mcp.impact_analysis.duration", duration, unit="second") + sentry_metrics.distribution("mcp.impact_analysis.impacted_count", len(impacted_models)) result = { "_guidance": ( @@ -2225,27 +2117,19 @@ async def _tool_set_backend(self, arguments: Dict[str, Any]) -> Dict[str, Any]: api_token = get_recce_api_token() if not api_token: - raise ValueError( - "Recce Cloud API token not found. Run `recce connect-to-cloud` first." - ) + raise ValueError("Recce Cloud API token not found. Run `recce connect-to-cloud` first.") # Best-effort export of local state before swapping away. if self.context is not None and self.state_loader is not None: try: self.state_loader.export(self.context.export_state()) except Exception as e: - logger.warning( - f"[MCP] Failed to export local state on swap to cloud: {e}" - ) + logger.warning(f"[MCP] Failed to export local state on swap to cloud: {e}") - new_backend = await CloudBackend.create( - session_id=session_id, api_token=api_token - ) + new_backend = await CloudBackend.create(session_id=session_id, api_token=api_token) self.backend = new_backend self.api_token = api_token - logger.info( - f"[MCP] Backend switched to cloud (session_id={session_id})" - ) + logger.info(f"[MCP] Backend switched to cloud (session_id={session_id})") return { "mode": "cloud", "session_id": session_id, @@ -2285,9 +2169,7 @@ async def _tool_set_backend(self, arguments: Dict[str, Any]) -> Dict[str, Any]: self.context = load_context(**load_kwargs) self._local_cache_key = cache_key - logger.info( - f"[MCP] Loaded local context (project_dir={project_dir}, single_env={self.single_env})" - ) + logger.info(f"[MCP] Loaded local context (project_dir={project_dir}, single_env={self.single_env})") self.backend = None return { @@ -2428,9 +2310,7 @@ async def _tool_run_check(self, arguments: Dict[str, Any]) -> Dict[str, Any]: if run_succeeded: check_dao.update_check_by_id(check_id, PatchCheckIn(is_checked=True)) logger.info(f"Auto-approved check {check_id} (triggered_by={triggered_by})") - await asyncio.get_event_loop().run_in_executor( - None, export_persistent_state - ) + await asyncio.get_event_loop().run_in_executor(None, export_persistent_state) return run_dump @@ -2497,9 +2377,7 @@ async def _tool_create_check(self, arguments: Dict[str, Any]) -> Dict[str, Any]: except Exception as e: run_error = str(e) else: - run, future = submit_run( - check_type, params=params, check_id=check_id, triggered_by=triggered_by - ) + run, future = submit_run(check_type, params=params, check_id=check_id, triggered_by=triggered_by) await future # submit_run's future always resolves (errors caught internally). # Check run.status, not the return value. @@ -2550,17 +2428,12 @@ async def run(self): if msg is not None: console.print(f"[yellow]On shutdown:[/yellow] {msg}") else: - if ( - hasattr(self.state_loader, "state_file") - and self.state_loader.state_file - ): + if hasattr(self.state_loader, "state_file") and self.state_loader.state_file: console.print( f"[yellow]On shutdown:[/yellow] State exported to '{self.state_loader.state_file}'" ) else: - console.print( - "[yellow]On shutdown:[/yellow] State exported successfully" - ) + console.print("[yellow]On shutdown:[/yellow] State exported successfully") except Exception as e: logger.exception(f"Failed to export state on shutdown: {e}") @@ -2585,16 +2458,10 @@ async def run_sse(self, host: str = "localhost", port: int = 8000): async def handle_sse_request(request: Request): """Handle SSE connection (GET /sse) following official MCP example""" - client_info = ( - f"{request.client.host}:{request.client.port}" - if request.client - else "unknown" - ) + client_info = f"{request.client.host}:{request.client.port}" if request.client else "unknown" logger.info(f"[MCP HTTP] SSE connection established from {client_info}") try: - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: + async with sse.connect_sse(request.scope, request.receive, request._send) as streams: await self.server.run( streams[0], streams[1], @@ -2710,9 +2577,7 @@ async def run_mcp_server( if not session: raise ValueError("--session is required when --cloud is provided") if not api_token: - raise ValueError( - "Recce Cloud API token not found. Run `recce connect-to-cloud` first." - ) + raise ValueError("Recce Cloud API token not found. Run `recce connect-to-cloud` first.") backend = await CloudBackend.create(session_id=session, api_token=api_token) server = RecceMCPServer( @@ -2750,9 +2615,17 @@ async def run_mcp_server( # Freshness check (M2, AC-3): warn on stale or missing base artifacts at startup. # Lazy import to avoid circular import; best-effort — startup never fails here. try: - from recce.cli import check_base_freshness - - _tb = kwargs.get("target_base_path", "target-base") + from recce.cli import check_base_freshness, resolve_target_base_path + + # Honor --project-dir like the CLI does. Without this, MCP + # startup looks for ./target-base/manifest.json relative to + # CWD, missing artifacts that exist at --project-dir-relative + # paths (or worse, picking up a stale manifest from another + # project that happens to live in CWD). + _tb = resolve_target_base_path( + kwargs.get("project_dir"), + kwargs.get("target_base_path", "target-base"), + ) _freshness = check_base_freshness(target_base_path=_tb) server._base_status = _freshness.get("status", "fresh") if server._base_status in ("stale_time", "stale_sha"): diff --git a/tests/test_check_base.py b/tests/test_check_base.py index f98abf355..51e4260c0 100644 --- a/tests/test_check_base.py +++ b/tests/test_check_base.py @@ -16,17 +16,24 @@ - test_cli_exit_code_missing — missing → exit 1 - test_cli_exit_code_stale_time — stale_time → exit 2 - test_cli_project_dir_resolves — --project-dir joins onto target-base-path + + Helper (`resolve_target_base_path`): + - test_resolve_relative_joins_with_project_dir + - test_resolve_absolute_bypasses_project_dir + - test_resolve_no_project_dir_uses_cwd + - test_resolve_mcp_startup_finds_artifacts_under_project_dir """ import json import os import time +from pathlib import Path from unittest.mock import patch import pytest from click.testing import CliRunner -from recce.cli import check_base, check_base_freshness +from recce.cli import check_base, check_base_freshness, resolve_target_base_path @pytest.fixture() @@ -267,3 +274,59 @@ def test_cli_project_dir_resolves(tmp_path): assert result.exit_code == 0, result.output payload = json.loads(result.output) assert payload["status"] == "fresh" + + +# --------------------------------------------------------------------------- +# resolve_target_base_path() — shared by CLI and MCP startup so the join logic +# cannot drift (round-2 review: MCP startup was missing the join after the CLI +# was fixed). +# --------------------------------------------------------------------------- + + +def test_resolve_relative_joins_with_project_dir(): + """A relative target-base-path is joined under project-dir.""" + resolved = resolve_target_base_path("/foo/bar", "target-base") + assert Path(resolved) == Path("/foo/bar") / "target-base" + + +def test_resolve_absolute_bypasses_project_dir(): + """An absolute target-base-path bypasses the join entirely.""" + resolved = resolve_target_base_path("/foo/bar", "/tmp/abs/target-base") + assert Path(resolved) == Path("/tmp/abs/target-base") + + +def test_resolve_no_project_dir_uses_cwd(): + """When project_dir is None, resolution is relative to CWD ('./').""" + resolved = resolve_target_base_path(None, "target-base") + # Don't compare against a CWD-dependent absolute path; just verify the + # relative path semantics: joining ./ with target-base. + assert Path(resolved) == Path("./") / "target-base" + + +def test_resolve_mcp_startup_finds_artifacts_under_project_dir(tmp_path): + """Regression for the round-2 review finding: MCP startup must use the + same resolution as the CLI so artifacts under --project-dir are found. + + Mirrors test_cli_project_dir_resolves: builds a fresh manifest at + {project_dir}/target-base/manifest.json, then asserts that the resolution + helper produces a path whose freshness check returns 'fresh'. Without the + helper, MCP startup would look at ./target-base relative to CWD and miss + the artifact entirely. + """ + project_dir = tmp_path / "my_dbt_project" + project_dir.mkdir() + target_base = project_dir / "target-base" + target_base.mkdir() + manifest_sha = "abc1234def5678901234567890123456789012ab" + manifest = {"metadata": {"env": {"DBT_GIT_SHA": manifest_sha}}} + (target_base / "manifest.json").write_text(json.dumps(manifest)) + + # The MCP startup-equivalent invocation: pass project_dir + relative + # target_base_path to the shared helper, then run the freshness check. + resolved = resolve_target_base_path(str(project_dir), "target-base") + assert Path(resolved) == target_base + + with patch("recce.git.current_commit_hash", return_value=manifest_sha): + result = check_base_freshness(target_base_path=resolved) + + assert result["status"] == "fresh", result