diff --git a/README.md b/README.md index d5005ed6..fb3e55cd 100644 --- a/README.md +++ b/README.md @@ -349,6 +349,7 @@ Postgres MCP Pro Tools: | `analyze_workload_indexes` | Analyzes the database workload to identify resource-intensive queries, then recommends optimal indexes for them. | | `analyze_query_indexes` | Analyzes a list of specific SQL queries (up to 10) and recommends optimal indexes for them. | | `analyze_db_health` | Performs comprehensive health checks including: buffer cache hit rates, connection health, constraint validation, index health (duplicate/unused/invalid), sequence limits, and vacuum health. | +| `execute_sql_xlsx` | Executes a SQL query and exports the results to an Excel (.xlsx) file. Supports a configurable row limit to prevent excessive output. | ## Related Projects diff --git a/pyproject.toml b/pyproject.toml index 3160a03a..6fe8cca7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "attrs>=25.4.0", "psycopg-pool>=3.3.0", "instructor>=1.14.4", + "openpyxl>=3.1.2", ] license = "mit" license-files = ["LICENSE"] diff --git a/src/postgres_mcp/formatter.py b/src/postgres_mcp/formatter.py new file mode 100644 index 00000000..ddd335b9 --- /dev/null +++ b/src/postgres_mcp/formatter.py @@ -0,0 +1,67 @@ +"""Excel formatter for query results.""" + +import json +import os +import tempfile +from datetime import datetime +from typing import Any + + +def _serialize_cell(value: Any) -> Any: + """Serialize non-scalar PostgreSQL types (json/jsonb/array) to JSON strings.""" + if isinstance(value, (dict, list)): + return json.dumps(value, ensure_ascii=False, default=str) + return value + + +def format_to_excel(rows: list[dict], columns: list[str], output_dir: str | None = None) -> str: + """Format query result rows to an Excel file. + + Args: + rows: List of row dictionaries from query results. + columns: List of column names. + output_dir: Output directory (default: system temp / postgres-mcp-results). + + Returns: + Path to the created Excel file. + """ + import uuid + + from openpyxl import Workbook + + if output_dir is None: + output_dir = os.path.join(tempfile.gettempdir(), "postgres-mcp-results") + + os.makedirs(output_dir, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + unique_suffix = uuid.uuid4().hex[:8] + filename = f"query_{timestamp}_{unique_suffix}.xlsx" + filepath = os.path.join(output_dir, filename) + + wb = Workbook() + ws = wb.active + ws.title = "Query Results" + + # Write header row + ws.append(columns) + + # Write data rows, serializing complex types before appending + for row in rows: + ws.append([_serialize_cell(row.get(col)) for col in columns]) + + # Auto-adjust column widths + for column in ws.columns: + max_length = 0 + column_letter = column[0].column_letter + for cell in column: + try: + if cell.value is not None: + max_length = max(max_length, len(str(cell.value))) + except Exception: + pass + adjusted_width = min(max_length + 2, 50) + ws.column_dimensions[column_letter].width = adjusted_width + + wb.save(filepath) + return filepath diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index f3ba8f8b..09e28997 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -24,6 +24,7 @@ from .database_health import DatabaseHealthTool from .database_health import HealthType from .explain import ExplainPlanTool +from .formatter import format_to_excel from .index.index_opt_base import MAX_NUM_INDEX_TUNING_QUERIES from .index.llm_opt import LLMOptimizerTool from .index.presentation import TextPresentation @@ -554,6 +555,49 @@ async def get_top_queries( return format_error_response(str(e)) +# Tool function declaration without decorator - registered dynamically based on access mode (like execute_sql) +@validate_call +async def execute_sql_xlsx( + sql: str = Field(description="SQL query to execute and export to Excel"), + max_rows: int = Field( + description="Maximum number of rows to export. Rows beyond this limit are truncated.", + default=10000, + ge=1, + ), +) -> ResponseType: + """Executes a SQL query and exports results to an Excel file.""" + try: + sql_driver = await get_sql_driver() + + # Inject LIMIT to protect server from large result sets. + # Skip if user already provided a LIMIT clause. + import re + + sql_stripped = sql.strip().rstrip(";") + if not re.search(r"\bLIMIT\b", sql_stripped, re.IGNORECASE): + capped_sql = f"{sql_stripped} LIMIT {max_rows}" + else: + capped_sql = sql + + rows = await sql_driver.execute_query(capped_sql) # type: ignore + if rows is None or len(rows) == 0: + return format_text_response("Query returned no results. No Excel file was created.") + + row_dicts = [r.cells for r in rows] + columns = list(row_dicts[0].keys()) + file_path = format_to_excel(rows=row_dicts, columns=columns) + + result_parts = [ + f"Excel file created: {file_path}", + f"Rows exported: {len(row_dicts)}", + f"Columns: {', '.join(columns)}", + ] + return format_text_response("\n".join(result_parts)) + except Exception as e: + logger.error(f"Error executing query for Excel export: {e}") + return format_error_response(str(e)) + + async def main(): # Parse command line arguments parser = argparse.ArgumentParser(description="PostgreSQL MCP Server") @@ -623,6 +667,28 @@ async def main(): ), ) + # Add the xlsx export tool with a description and annotations appropriate to the access mode + if current_access_mode == AccessMode.UNRESTRICTED: + mcp.add_tool( + execute_sql_xlsx, + description="Executes a SQL query and exports results to an Excel (.xlsx) file. " + "Use this when the user wants to save query results as a spreadsheet.", + annotations=ToolAnnotations( + title="Execute SQL to Excel", + destructiveHint=True, + ), + ) + else: + mcp.add_tool( + execute_sql_xlsx, + description="Executes a read-only SQL query and exports results to an Excel (.xlsx) file. " + "Use this when the user wants to save query results as a spreadsheet.", + annotations=ToolAnnotations( + title="Execute SQL to Excel (Read-Only)", + readOnlyHint=True, + ), + ) + logger.info(f"Starting PostgreSQL MCP Server in {current_access_mode.upper()} mode") # Get database URL from environment variable or command line diff --git a/tests/unit/test_excel_export.py b/tests/unit/test_excel_export.py new file mode 100644 index 00000000..87c539f6 --- /dev/null +++ b/tests/unit/test_excel_export.py @@ -0,0 +1,290 @@ +"""Tests for Excel export functionality (format_to_excel and execute_sql_xlsx tool).""" + +import os +import tempfile +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +import pytest_asyncio + +import postgres_mcp.server as server +from postgres_mcp.formatter import format_to_excel +from postgres_mcp.server import AccessMode + + +class MockRowResult: + """Mock row matching SqlDriver.RowResult interface.""" + + def __init__(self, cells: dict): + self.cells = cells + + +# --------------------------------------------------------------------------- +# format_to_excel tests +# --------------------------------------------------------------------------- + + +def test_format_to_excel_creates_file(): + """format_to_excel creates a valid xlsx file with data.""" + rows = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}] + columns = ["id", "name"] + + with tempfile.TemporaryDirectory() as tmpdir: + path = format_to_excel(rows, columns, output_dir=tmpdir) + assert os.path.exists(path) + assert path.endswith(".xlsx") + assert tmpdir in path + + +def test_format_to_excel_custom_output_dir(): + """format_to_excel uses the provided output directory.""" + rows = [{"x": 10}] + columns = ["x"] + + with tempfile.TemporaryDirectory() as tmpdir: + path = format_to_excel(rows, columns, output_dir=tmpdir) + assert os.path.dirname(path) == tmpdir + + +def test_format_to_excel_default_output_dir(): + """format_to_excel defaults to system temp / postgres-mcp-results.""" + rows = [{"a": 1}] + columns = ["a"] + + path = format_to_excel(rows, columns) + expected_dir = os.path.join(tempfile.gettempdir(), "postgres-mcp-results") + assert path.startswith(expected_dir) + assert os.path.exists(path) + os.unlink(path) + + +def test_format_to_excel_column_width_capped(): + """format_to_excel caps column width at 50.""" + rows = [{"short": "hi", "long": "a" * 100}] + columns = ["short", "long"] + + with tempfile.TemporaryDirectory() as tmpdir: + path = format_to_excel(rows, columns, output_dir=tmpdir) + + from openpyxl import load_workbook + + wb = load_workbook(path) + ws = wb.active + assert ws["A1"].value == "short" + assert ws["B1"].value == "long" + assert ws["A2"].value == "hi" + assert ws.column_dimensions["B"].width <= 50 + + +def test_format_to_excel_handles_none_values(): + """format_to_excel handles None values in row data.""" + rows = [{"id": 1, "name": None}, {"id": None, "name": "test"}] + columns = ["id", "name"] + + with tempfile.TemporaryDirectory() as tmpdir: + path = format_to_excel(rows, columns, output_dir=tmpdir) + + from openpyxl import load_workbook + + wb = load_workbook(path) + ws = wb.active + assert ws["B2"].value is None + assert ws["A3"].value is None + + +def test_format_to_excel_empty_rows(): + """format_to_excel handles empty row list (creates header-only file).""" + rows = [] + columns = ["id", "name"] + + with tempfile.TemporaryDirectory() as tmpdir: + path = format_to_excel(rows, columns, output_dir=tmpdir) + + from openpyxl import load_workbook + + wb = load_workbook(path) + ws = wb.active + assert ws["A1"].value == "id" + assert ws.max_row == 1 # header only + + +def test_format_to_excel_serializes_complex_types(): + """format_to_excel serializes dict and list values to JSON strings.""" + rows = [ + {"id": 1, "json_col": {"key": "value"}, "arr_col": [1, 2, 3]}, + ] + columns = ["id", "json_col", "arr_col"] + + with tempfile.TemporaryDirectory() as tmpdir: + path = format_to_excel(rows, columns, output_dir=tmpdir) + + from openpyxl import load_workbook + + wb = load_workbook(path) + ws = wb.active + # dict -> JSON string + assert ws["B2"].value == '{"key": "value"}' + # list -> JSON string + assert ws["C2"].value == "[1, 2, 3]" + + +def test_format_to_excel_filename_is_unique(): + """format_to_excel generates unique filenames to avoid concurrent overwrites.""" + rows = [{"a": 1}] + columns = ["a"] + + with tempfile.TemporaryDirectory() as tmpdir: + path1 = format_to_excel(rows, columns, output_dir=tmpdir) + path2 = format_to_excel(rows, columns, output_dir=tmpdir) + assert path1 != path2 + os.unlink(path1) + os.unlink(path2) + + +# --------------------------------------------------------------------------- +# execute_sql_xlsx tool tests +# --------------------------------------------------------------------------- + + +@pytest_asyncio.fixture +async def mock_db_connection(): + """Create a mock DB connection pool.""" + conn = MagicMock() + conn.pool_connect = AsyncMock() + conn.close = AsyncMock() + return conn + + +@pytest.mark.asyncio +async def test_execute_sql_xlsx_success(mock_db_connection): + """execute_sql_xlsx returns file path and row count on success.""" + mock_driver = AsyncMock() + mock_driver.execute_query = AsyncMock(return_value=[ + MockRowResult({"id": 1, "name": "Alice"}), + MockRowResult({"id": 2, "name": "Bob"}), + ]) + + with ( + patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), + patch("postgres_mcp.server.db_connection", mock_db_connection), + patch("postgres_mcp.server.get_sql_driver", return_value=mock_driver), + ): + result = await server.execute_sql_xlsx("SELECT * FROM users") + + assert len(result) == 1 + text = result[0].text + assert "Excel file created:" in text + assert "Rows exported: 2" in text + assert "id, name" in text + + +@pytest.mark.asyncio +async def test_execute_sql_xlsx_empty_results(mock_db_connection): + """execute_sql_xlsx returns informational text for empty results.""" + mock_driver = AsyncMock() + mock_driver.execute_query = AsyncMock(return_value=[]) + + with ( + patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), + patch("postgres_mcp.server.db_connection", mock_db_connection), + patch("postgres_mcp.server.get_sql_driver", return_value=mock_driver), + ): + result = await server.execute_sql_xlsx("SELECT * FROM empty_table") + + assert len(result) == 1 + assert "no results" in result[0].text.lower() + assert not result[0].text.startswith("Error:") + + +@pytest.mark.asyncio +async def test_execute_sql_xlsx_none_results(mock_db_connection): + """execute_sql_xlsx returns informational text when query returns None.""" + mock_driver = AsyncMock() + mock_driver.execute_query = AsyncMock(return_value=None) + + with ( + patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), + patch("postgres_mcp.server.db_connection", mock_db_connection), + patch("postgres_mcp.server.get_sql_driver", return_value=mock_driver), + ): + result = await server.execute_sql_xlsx("DELETE FROM users") + + assert len(result) == 1 + assert "no results" in result[0].text.lower() + assert not result[0].text.startswith("Error:") + + +@pytest.mark.asyncio +async def test_execute_sql_xlsx_injects_limit(mock_db_connection): + """execute_sql_xlsx injects LIMIT max_rows when not present in SQL.""" + mock_driver = AsyncMock() + mock_driver.execute_query = AsyncMock(return_value=[ + MockRowResult({"id": 1}), + MockRowResult({"id": 2}), + ]) + + with ( + patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), + patch("postgres_mcp.server.db_connection", mock_db_connection), + patch("postgres_mcp.server.get_sql_driver", return_value=mock_driver), + ): + await server.execute_sql_xlsx("SELECT * FROM users", max_rows=100) + + # Verify LIMIT was injected + called_sql = mock_driver.execute_query.call_args[0][0] + assert "LIMIT 100" in called_sql + assert "SELECT * FROM users LIMIT 100" == called_sql + + +@pytest.mark.asyncio +async def test_execute_sql_xlsx_preserves_existing_limit(mock_db_connection): + """execute_sql_xlsx does not inject LIMIT when user already provided one.""" + mock_driver = AsyncMock() + mock_driver.execute_query = AsyncMock(return_value=[ + MockRowResult({"id": 1}), + ]) + + with ( + patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), + patch("postgres_mcp.server.db_connection", mock_db_connection), + patch("postgres_mcp.server.get_sql_driver", return_value=mock_driver), + ): + await server.execute_sql_xlsx("SELECT * FROM users LIMIT 50", max_rows=100) + + called_sql = mock_driver.execute_query.call_args[0][0] + # Should NOT inject LIMIT when user already has one + assert called_sql == "SELECT * FROM users LIMIT 50" + + +@pytest.mark.asyncio +async def test_execute_sql_xlsx_query_error(mock_db_connection): + """execute_sql_xlsx returns error response on query failure.""" + mock_driver = AsyncMock() + mock_driver.execute_query = AsyncMock(side_effect=Exception("Connection lost")) + + with ( + patch("postgres_mcp.server.current_access_mode", AccessMode.UNRESTRICTED), + patch("postgres_mcp.server.db_connection", mock_db_connection), + patch("postgres_mcp.server.get_sql_driver", return_value=mock_driver), + ): + result = await server.execute_sql_xlsx("SELECT * FROM users") + + assert len(result) == 1 + assert result[0].text.startswith("Error:") + assert "Connection lost" in result[0].text + + +@pytest.mark.asyncio +async def test_execute_sql_xlsx_max_rows_zero_rejected(): + """execute_sql_xlsx rejects max_rows=0 via validation.""" + with pytest.raises(Exception): # noqa: B017 + await server.execute_sql_xlsx("SELECT 1", max_rows=0) + + +@pytest.mark.asyncio +async def test_execute_sql_xlsx_max_rows_negative_rejected(): + """execute_sql_xlsx rejects negative max_rows via validation.""" + with pytest.raises(Exception): # noqa: B017 + await server.execute_sql_xlsx("SELECT 1", max_rows=-5)