diff --git a/pyproject.toml b/pyproject.toml index 3160a03a..6bb526e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,11 +65,7 @@ lint.select = [ "RUF" # ruff-specific rules ] -# TODO: Remove these ignores when fixing #129 (code modernization) lint.ignore = [ - "UP006", # Use `list` instead of `List` for type annotations - "UP035", # Import from `collections.abc` instead of `typing` - "UP045", # Use `X | None` instead of `Optional[X]` "RUF059", # Unused unpacked variable "RUF100", # Unused noqa directive ] diff --git a/src/postgres_mcp/database_health/database_health.py b/src/postgres_mcp/database_health/database_health.py index ecf94924..4b1e6bf0 100644 --- a/src/postgres_mcp/database_health/database_health.py +++ b/src/postgres_mcp/database_health/database_health.py @@ -2,7 +2,6 @@ import logging from enum import Enum -from typing import List import mcp.types as types @@ -14,7 +13,7 @@ from .sequence_health_calc import SequenceHealthCalc from .vacuum_health_calc import VacuumHealthCalc -ResponseType = List[types.TextContent | types.ImageContent | types.EmbeddedResource] +ResponseType = list[types.TextContent | types.ImageContent | types.EmbeddedResource] logger = logging.getLogger(__name__) diff --git a/src/postgres_mcp/database_health/replication_calc.py b/src/postgres_mcp/database_health/replication_calc.py index 42b7882c..93c2fa20 100644 --- a/src/postgres_mcp/database_health/replication_calc.py +++ b/src/postgres_mcp/database_health/replication_calc.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional from ..sql import SqlDriver @@ -14,7 +13,7 @@ class ReplicationSlot: @dataclass class ReplicationMetrics: is_replica: bool - replication_lag_seconds: Optional[float] + replication_lag_seconds: float | None is_replicating: bool replication_slots: list[ReplicationSlot] @@ -22,7 +21,7 @@ class ReplicationMetrics: class ReplicationCalc: def __init__(self, sql_driver: SqlDriver): self.sql_driver = sql_driver - self._server_version: Optional[int] = None + self._server_version: int | None = None self._feature_support: dict[str, bool] = {} async def replication_health_check(self) -> str: @@ -85,7 +84,7 @@ async def _is_replica(self) -> bool: result_list = [dict(x.cells) for x in result] if result is not None else [] return bool(result_list[0]["pg_is_in_recovery"]) if result_list else False - async def _get_replication_lag(self) -> Optional[float]: + async def _get_replication_lag(self) -> float | None: """Get replication lag in seconds.""" if not self._feature_supported("replication_lag"): return None diff --git a/src/postgres_mcp/index/presentation.py b/src/postgres_mcp/index/presentation.py index 5e806c2a..32abe76e 100644 --- a/src/postgres_mcp/index/presentation.py +++ b/src/postgres_mcp/index/presentation.py @@ -3,8 +3,6 @@ import logging import os from typing import Any -from typing import Dict -from typing import List import humanize @@ -169,7 +167,7 @@ async def _execute_analysis( logger.error(f"Error analyzing queries: {e}", exc_info=True) return {"error": f"Error analyzing queries: {e}"} - def _build_recommendations_list(self, session: IndexTuningResult) -> List[Dict[str, Any]]: + def _build_recommendations_list(self, session: IndexTuningResult) -> list[dict[str, Any]]: recommendations = [] for index_apply_order, rec in enumerate(session.recommendations): rec_dict = { @@ -200,7 +198,7 @@ def _build_recommendations_list(self, session: IndexTuningResult) -> List[Dict[s recommendations.append(rec_dict) return recommendations - async def _generate_query_impact(self, session: IndexTuningResult) -> List[Dict[str, Any]]: + async def _generate_query_impact(self, session: IndexTuningResult) -> list[dict[str, Any]]: """ Generate the query impact section showing before/after explain plans. diff --git a/src/postgres_mcp/server.py b/src/postgres_mcp/server.py index f3ba8f8b..2f567dec 100644 --- a/src/postgres_mcp/server.py +++ b/src/postgres_mcp/server.py @@ -7,9 +7,7 @@ import sys from enum import Enum from typing import Any -from typing import List from typing import Literal -from typing import Union import mcp.types as types from mcp.server.fastmcp import FastMCP @@ -41,7 +39,7 @@ PG_STAT_STATEMENTS = "pg_stat_statements" HYPOPG_EXTENSION = "hypopg" -ResponseType = List[types.TextContent | types.ImageContent | types.EmbeddedResource] +ResponseType = list[types.TextContent | types.ImageContent | types.EmbeddedResource] logger = logging.getLogger(__name__) @@ -59,7 +57,7 @@ class AccessMode(str, Enum): shutdown_in_progress = False -async def get_sql_driver() -> Union[SqlDriver, SafeSqlDriver]: +async def get_sql_driver() -> SqlDriver | SafeSqlDriver: """Get the appropriate SQL driver based on the current access mode.""" base_driver = SqlDriver(conn=db_connection) diff --git a/src/postgres_mcp/sql/safe_sql.py b/src/postgres_mcp/sql/safe_sql.py index 37382f0b..92acc3ff 100644 --- a/src/postgres_mcp/sql/safe_sql.py +++ b/src/postgres_mcp/sql/safe_sql.py @@ -5,7 +5,6 @@ import re from typing import Any from typing import ClassVar -from typing import Optional import pglast from pglast.ast import A_ArrayExpr @@ -982,7 +981,7 @@ async def execute_query( query: LiteralString, params: list[Any] | None = None, force_readonly: bool = True, # do not use value passed in - ) -> Optional[list[SqlDriver.RowResult]]: # noqa: UP007 + ) -> list[SqlDriver.RowResult] | None: """Execute a query after validating it is safe""" self._validate(query) diff --git a/src/postgres_mcp/sql/sql_driver.py b/src/postgres_mcp/sql/sql_driver.py index 5beacb03..1fa08631 100644 --- a/src/postgres_mcp/sql/sql_driver.py +++ b/src/postgres_mcp/sql/sql_driver.py @@ -4,9 +4,6 @@ import re from dataclasses import dataclass from typing import Any -from typing import Dict -from typing import List -from typing import Optional from urllib.parse import urlparse from urllib.parse import urlunparse @@ -62,13 +59,13 @@ def obfuscate_password(text: str | None) -> str | None: class DbConnPool: """Database connection manager using psycopg's connection pool.""" - def __init__(self, connection_url: Optional[str] = None): + def __init__(self, connection_url: str | None = None): self.connection_url = connection_url self.pool: AsyncConnectionPool | None = None self._is_valid = False self._last_error = None - async def pool_connect(self, connection_url: Optional[str] = None) -> AsyncConnectionPool: + async def pool_connect(self, connection_url: str | None = None) -> AsyncConnectionPool: """Initialize connection pool with retry logic.""" # If we already have a valid pool, return it if self.pool and self._is_valid: @@ -131,7 +128,7 @@ def is_valid(self) -> bool: return self._is_valid @property - def last_error(self) -> Optional[str]: + def last_error(self) -> str | None: """Get the last error message.""" return self._last_error @@ -143,7 +140,7 @@ class SqlDriver: class RowResult: """Simple class to match the Griptape RowResult interface.""" - cells: Dict[str, Any] + cells: dict[str, Any] def __init__( self, @@ -184,7 +181,7 @@ async def execute_query( query: LiteralString, params: list[Any] | None = None, force_readonly: bool = False, - ) -> Optional[List[RowResult]]: + ) -> list[RowResult] | None: """ Execute a query and return results. @@ -221,7 +218,7 @@ async def execute_query( raise e - async def _execute_with_connection(self, connection, query, params, force_readonly) -> Optional[List[RowResult]]: + async def _execute_with_connection(self, connection, query, params, force_readonly) -> list[RowResult] | None: """Execute query with the given connection.""" transaction_started = False try: diff --git a/src/postgres_mcp/top_queries/top_queries_calc.py b/src/postgres_mcp/top_queries/top_queries_calc.py index f3fcba1c..58198cff 100644 --- a/src/postgres_mcp/top_queries/top_queries_calc.py +++ b/src/postgres_mcp/top_queries/top_queries_calc.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import Literal from typing import LiteralString -from typing import Union from typing import cast from ..sql import SafeSqlDriver @@ -82,7 +81,7 @@ def _get_pg_stat_statements_columns(pg_version: int) -> PgStatStatementsColumns: class TopQueriesCalc: """Tool for retrieving the slowest SQL queries.""" - def __init__(self, sql_driver: Union[SqlDriver, SafeSqlDriver]): + def __init__(self, sql_driver: SqlDriver | SafeSqlDriver): self.sql_driver = sql_driver async def get_top_queries_by_time(self, limit: int = 10, sort_by: Literal["total", "mean"] = "mean") -> str: