Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,7 @@ lint.select = [
"RUF" # ruff-specific rules
]

# TODO: Remove these ignores when fixing #129 (code modernization)
lint.ignore = [
"UP006", # Use `list` instead of `List` for type annotations
"UP035", # Import from `collections.abc` instead of `typing`
"UP045", # Use `X | None` instead of `Optional[X]`
"RUF059", # Unused unpacked variable
"RUF100", # Unused noqa directive
]
Expand Down
3 changes: 1 addition & 2 deletions src/postgres_mcp/database_health/database_health.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
from enum import Enum
from typing import List

import mcp.types as types

Expand All @@ -14,7 +13,7 @@
from .sequence_health_calc import SequenceHealthCalc
from .vacuum_health_calc import VacuumHealthCalc

ResponseType = List[types.TextContent | types.ImageContent | types.EmbeddedResource]
ResponseType = list[types.TextContent | types.ImageContent | types.EmbeddedResource]

logger = logging.getLogger(__name__)

Expand Down
7 changes: 3 additions & 4 deletions src/postgres_mcp/database_health/replication_calc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Optional

from ..sql import SqlDriver

Expand All @@ -14,15 +13,15 @@ class ReplicationSlot:
@dataclass
class ReplicationMetrics:
is_replica: bool
replication_lag_seconds: Optional[float]
replication_lag_seconds: float | None
is_replicating: bool
replication_slots: list[ReplicationSlot]


class ReplicationCalc:
def __init__(self, sql_driver: SqlDriver):
self.sql_driver = sql_driver
self._server_version: Optional[int] = None
self._server_version: int | None = None
self._feature_support: dict[str, bool] = {}

async def replication_health_check(self) -> str:
Expand Down Expand Up @@ -85,7 +84,7 @@ async def _is_replica(self) -> bool:
result_list = [dict(x.cells) for x in result] if result is not None else []
return bool(result_list[0]["pg_is_in_recovery"]) if result_list else False

async def _get_replication_lag(self) -> Optional[float]:
async def _get_replication_lag(self) -> float | None:
"""Get replication lag in seconds."""
if not self._feature_supported("replication_lag"):
return None
Expand Down
6 changes: 2 additions & 4 deletions src/postgres_mcp/index/presentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import logging
import os
from typing import Any
from typing import Dict
from typing import List

import humanize

Expand Down Expand Up @@ -169,7 +167,7 @@ async def _execute_analysis(
logger.error(f"Error analyzing queries: {e}", exc_info=True)
return {"error": f"Error analyzing queries: {e}"}

def _build_recommendations_list(self, session: IndexTuningResult) -> List[Dict[str, Any]]:
def _build_recommendations_list(self, session: IndexTuningResult) -> list[dict[str, Any]]:
recommendations = []
for index_apply_order, rec in enumerate(session.recommendations):
rec_dict = {
Expand Down Expand Up @@ -200,7 +198,7 @@ def _build_recommendations_list(self, session: IndexTuningResult) -> List[Dict[s
recommendations.append(rec_dict)
return recommendations

async def _generate_query_impact(self, session: IndexTuningResult) -> List[Dict[str, Any]]:
async def _generate_query_impact(self, session: IndexTuningResult) -> list[dict[str, Any]]:
"""
Generate the query impact section showing before/after explain plans.

Expand Down
6 changes: 2 additions & 4 deletions src/postgres_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
import sys
from enum import Enum
from typing import Any
from typing import List
from typing import Literal
from typing import Union

import mcp.types as types
from mcp.server.fastmcp import FastMCP
Expand Down Expand Up @@ -41,7 +39,7 @@
PG_STAT_STATEMENTS = "pg_stat_statements"
HYPOPG_EXTENSION = "hypopg"

ResponseType = List[types.TextContent | types.ImageContent | types.EmbeddedResource]
ResponseType = list[types.TextContent | types.ImageContent | types.EmbeddedResource]

logger = logging.getLogger(__name__)

Expand All @@ -59,7 +57,7 @@ class AccessMode(str, Enum):
shutdown_in_progress = False


async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver]:
async def get_sql_driver() -> SqlDriver | SafeSqlDriver:
"""Get the appropriate SQL driver based on the current access mode."""
base_driver = SqlDriver(conn=db_connection)

Expand Down
3 changes: 1 addition & 2 deletions src/postgres_mcp/sql/safe_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import re
from typing import Any
from typing import ClassVar
from typing import Optional

import pglast
from pglast.ast import A_ArrayExpr
Expand Down Expand Up @@ -982,7 +981,7 @@ async def execute_query(
query: LiteralString,
params: list[Any] | None = None,
force_readonly: bool = True, # do not use value passed in
) -> Optional[list[SqlDriver.RowResult]]: # noqa: UP007
) -> list[SqlDriver.RowResult] | None:
"""Execute a query after validating it is safe"""
self._validate(query)

Expand Down
15 changes: 6 additions & 9 deletions src/postgres_mcp/sql/sql_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
import re
from dataclasses import dataclass
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from urllib.parse import urlparse
from urllib.parse import urlunparse

Expand Down Expand Up @@ -62,13 +59,13 @@ def obfuscate_password(text: str | None) -> str | None:
class DbConnPool:
"""Database connection manager using psycopg's connection pool."""

def __init__(self, connection_url: Optional[str] = None):
def __init__(self, connection_url: str | None = None):
self.connection_url = connection_url
self.pool: AsyncConnectionPool | None = None
self._is_valid = False
self._last_error = None

async def pool_connect(self, connection_url: Optional[str] = None) -> AsyncConnectionPool:
async def pool_connect(self, connection_url: str | None = None) -> AsyncConnectionPool:
"""Initialize connection pool with retry logic."""
# If we already have a valid pool, return it
if self.pool and self._is_valid:
Expand Down Expand Up @@ -131,7 +128,7 @@ def is_valid(self) -> bool:
return self._is_valid

@property
def last_error(self) -> Optional[str]:
def last_error(self) -> str | None:
"""Get the last error message."""
return self._last_error

Expand All @@ -143,7 +140,7 @@ class SqlDriver:
class RowResult:
"""Simple class to match the Griptape RowResult interface."""

cells: Dict[str, Any]
cells: dict[str, Any]

def __init__(
self,
Expand Down Expand Up @@ -184,7 +181,7 @@ async def execute_query(
query: LiteralString,
params: list[Any] | None = None,
force_readonly: bool = False,
) -> Optional[List[RowResult]]:
) -> list[RowResult] | None:
"""
Execute a query and return results.

Expand Down Expand Up @@ -221,7 +218,7 @@ async def execute_query(

raise e

async def _execute_with_connection(self, connection, query, params, force_readonly) -> Optional[List[RowResult]]:
async def _execute_with_connection(self, connection, query, params, force_readonly) -> list[RowResult] | None:
"""Execute query with the given connection."""
transaction_started = False
try:
Expand Down
3 changes: 1 addition & 2 deletions src/postgres_mcp/top_queries/top_queries_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from dataclasses import dataclass
from typing import Literal
from typing import LiteralString
from typing import Union
from typing import cast

from ..sql import SafeSqlDriver
Expand Down Expand Up @@ -82,7 +81,7 @@ def _get_pg_stat_statements_columns(pg_version: int) -> PgStatStatementsColumns:
class TopQueriesCalc:
"""Tool for retrieving the slowest SQL queries."""

def __init__(self, sql_driver: Union[SqlDriver, SafeSqlDriver]):
def __init__(self, sql_driver: SqlDriver | SafeSqlDriver):
self.sql_driver = sql_driver

async def get_top_queries_by_time(self, limit: int = 10, sort_by: Literal["total", "mean"] = "mean") -> str:
Expand Down