diff --git a/Dockerfile b/Dockerfile index b1dbff6..01c8f07 100644 --- a/Dockerfile +++ b/Dockerfile @@ -51,7 +51,6 @@ RUN SITE_PACKAGES_DIR=$(find /usr/local/lib -name "site-packages" -type d | head # Create startup script RUN echo '#!/bin/sh\n\ - CONFIG_FILE="config.json"\n\ if [ ! -f .env ]; then\n\ echo "Generating new token..."\n\ generate-new-token\n\ @@ -59,9 +58,9 @@ RUN echo '#!/bin/sh\n\ fi\n\ if [ ! -f certs/cert.pem ] || [ ! -f certs/key.pem ]; then\n\ echo "Generating self-signed certificates..."\n\ - generate-certificate --config-file="$CONFIG_FILE"\n\ + generate-certificate\n\ fi\n\ - exec python-template-server --config-file="$CONFIG_FILE"' > /app/start.sh && \ + exec python-template-server' > /app/start.sh && \ chmod +x /app/start.sh && \ chown template_server_user:template_server_user /app/start.sh diff --git a/python_template_server/authentication_handler.py b/python_template_server/authentication_handler.py index 50cd4fb..7dcc828 100644 --- a/python_template_server/authentication_handler.py +++ b/python_template_server/authentication_handler.py @@ -7,13 +7,12 @@ import dotenv -from python_template_server.config import ROOT_DIR -from python_template_server.constants import ENV_FILE_NAME, ENV_VAR_NAME, TOKEN_LENGTH +from python_template_server.constants import ENV_FILE_PATH, ENV_VAR_NAME, TOKEN_LENGTH +from python_template_server.logging_setup import setup_logging +setup_logging() logger = logging.getLogger(__name__) -ENV_FILE = ROOT_DIR / ENV_FILE_NAME - def generate_token() -> str: """Generate a secure random token. @@ -39,10 +38,10 @@ def save_hashed_token(token: str) -> None: """ hashed = hash_token(token) - if not ENV_FILE.exists(): - ENV_FILE.touch() + if not ENV_FILE_PATH.exists(): + ENV_FILE_PATH.touch() - dotenv.set_key(ENV_FILE, ENV_VAR_NAME, hashed) + dotenv.set_key(ENV_FILE_PATH, ENV_VAR_NAME, hashed) def load_hashed_token() -> str: @@ -50,7 +49,7 @@ def load_hashed_token() -> str: :return str: The hashed token string, or an empty string if not found """ - dotenv.load_dotenv(ENV_FILE) + dotenv.load_dotenv(ENV_FILE_PATH) return os.getenv(ENV_VAR_NAME, "") diff --git a/python_template_server/certificate_handler.py b/python_template_server/certificate_handler.py index e81ac3a..eb53821 100644 --- a/python_template_server/certificate_handler.py +++ b/python_template_server/certificate_handler.py @@ -11,9 +11,11 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID -from python_template_server.config import load_config, parse_args +from python_template_server.logging_setup import setup_logging +from python_template_server.main import ExampleServer from python_template_server.models import CertificateConfigModel +setup_logging() logger = logging.getLogger(__name__) @@ -136,8 +138,7 @@ def generate_self_signed_certificate() -> None: :raise SystemExit: If certificate generation fails """ try: - args = parse_args() - config = load_config(args.config_file) + config = ExampleServer().config handler = CertificateHandler(config.certificate) handler.generate_self_signed_cert() except (OSError, PermissionError): diff --git a/python_template_server/config.py b/python_template_server/config.py deleted file mode 100644 index f1b90af..0000000 --- a/python_template_server/config.py +++ /dev/null @@ -1,112 +0,0 @@ -"""Configuration handling for the server.""" - -import argparse -import json -import logging -import sys -from logging.handlers import RotatingFileHandler - -from pydantic import ValidationError -from pyhere import here - -from python_template_server.constants import ( - CONFIG_FILE_NAME, - LOG_BACKUP_COUNT, - LOG_DATE_FORMAT, - LOG_DIR_NAME, - LOG_FILE_NAME, - LOG_FORMAT, - LOG_LEVEL, - LOG_MAX_BYTES, -) -from python_template_server.models import TemplateServerConfig - -ROOT_DIR = here() -CONFIG_DIR = ROOT_DIR / "configuration" -LOG_DIR = ROOT_DIR / LOG_DIR_NAME -LOG_FILE_PATH = LOG_DIR / LOG_FILE_NAME - - -def setup_logging() -> None: - """Configure logging with both console and rotating file handlers. - - Creates a logs directory if it doesn't exist and sets up: - - Console handler for stdout - - Rotating file handler with size-based rotation - """ - # Create logs directory if it doesn't exist - LOG_DIR.mkdir(exist_ok=True) - - # Get the root logger - root_logger = logging.getLogger() - root_logger.setLevel(getattr(logging, LOG_LEVEL)) - - # Remove any existing handlers - root_logger.handlers.clear() - - # Create formatter - formatter = logging.Formatter(LOG_FORMAT, datefmt=LOG_DATE_FORMAT) - - # Console handler - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(getattr(logging, LOG_LEVEL)) - console_handler.setFormatter(formatter) - root_logger.addHandler(console_handler) - - # Rotating file handler - file_handler = RotatingFileHandler( - LOG_FILE_PATH, maxBytes=LOG_MAX_BYTES, backupCount=LOG_BACKUP_COUNT, encoding="utf-8" - ) - file_handler.setLevel(getattr(logging, LOG_LEVEL)) - file_handler.setFormatter(formatter) - root_logger.addHandler(file_handler) - - -# Setup logging on module import -setup_logging() -logger = logging.getLogger(__name__) - - -def load_config(config_file: str) -> TemplateServerConfig: - """Load configuration from the config.json file. - - :param str config_file: Name of the configuration file - :return TemplateServerConfig: The validated configuration model - :raise SystemExit: If configuration file is missing, invalid JSON, or fails validation - """ - config_path = CONFIG_DIR / config_file - if not config_path.exists(): - logger.error("Configuration file not found: %s", config_path) - sys.exit(1) - - config_data = {} - try: - with config_path.open() as f: - config_data = json.load(f) - except json.JSONDecodeError: - logger.exception("JSON parsing error: %s", config_path) - sys.exit(1) - except OSError: - logger.exception("JSON read error: %s", config_path) - sys.exit(1) - - try: - return TemplateServerConfig.model_validate(config_data) - except ValidationError: - logger.exception("Invalid configuration in: %s", config_path) - sys.exit(1) - - -def parse_args() -> argparse.Namespace: - """Parse command-line arguments. - - :return argparse.Namespace: Parsed arguments - """ - parser = argparse.ArgumentParser(description="Python Template Server") - parser.add_argument( - "--config-file", - type=str, - default=CONFIG_FILE_NAME, - help="Path to the configuration file (default: config.json)", - ) - return parser.parse_known_args()[0] diff --git a/python_template_server/constants.py b/python_template_server/constants.py index 58c7414..efaec85 100644 --- a/python_template_server/constants.py +++ b/python_template_server/constants.py @@ -1,22 +1,34 @@ """Constants used across the server.""" +from pyhere import here + # General constants +ROOT_DIR = here() +CONFIG_DIR_NAME = "configuration" +LOG_DIR_NAME = "logs" +CONFIG_DIR = ROOT_DIR / CONFIG_DIR_NAME +LOG_DIR = ROOT_DIR / LOG_DIR_NAME + +CONFIG_FILE_NAME = "config.json" +LOG_FILE_NAME = "server.log" +ENV_FILE_NAME = ".env" + +CONFIG_FILE_PATH = CONFIG_DIR / CONFIG_FILE_NAME +LOG_FILE_PATH = LOG_DIR / LOG_FILE_NAME +ENV_FILE_PATH = ROOT_DIR / ENV_FILE_NAME + BYTES_TO_MB = 1024 * 1024 # Main constants PACKAGE_NAME = "python-template-server" API_PREFIX = "/api" API_KEY_HEADER_NAME = "X-API-Key" -CONFIG_FILE_NAME = "config.json" # Authentication constants -ENV_FILE_NAME = ".env" ENV_VAR_NAME = "API_TOKEN_HASH" TOKEN_LENGTH = 32 # Logging constants -LOG_DIR_NAME = "logs" -LOG_FILE_NAME = "server.log" LOG_MAX_BYTES = 10 * BYTES_TO_MB # 10 MB LOG_BACKUP_COUNT = 5 LOG_FORMAT = "[%(asctime)s] (%(levelname)s) %(module)s: %(message)s" diff --git a/python_template_server/logging_setup.py b/python_template_server/logging_setup.py new file mode 100644 index 0000000..ec72d5b --- /dev/null +++ b/python_template_server/logging_setup.py @@ -0,0 +1,50 @@ +"""Logging setup for the server.""" + +import logging +import sys +from logging.handlers import RotatingFileHandler + +from python_template_server.constants import ( + LOG_BACKUP_COUNT, + LOG_DATE_FORMAT, + LOG_DIR, + LOG_FILE_PATH, + LOG_FORMAT, + LOG_LEVEL, + LOG_MAX_BYTES, +) + + +def setup_logging() -> None: + """Configure logging with both console and rotating file handlers. + + Creates a logs directory if it doesn't exist and sets up: + - Console handler for stdout + - Rotating file handler with size-based rotation + """ + # Create logs directory if it doesn't exist + LOG_DIR.mkdir(exist_ok=True) + + # Get the root logger + root_logger = logging.getLogger() + root_logger.setLevel(getattr(logging, LOG_LEVEL)) + + # Remove any existing handlers + root_logger.handlers.clear() + + # Create formatter + formatter = logging.Formatter(LOG_FORMAT, datefmt=LOG_DATE_FORMAT) + + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(getattr(logging, LOG_LEVEL)) + console_handler.setFormatter(formatter) + root_logger.addHandler(console_handler) + + # Rotating file handler + file_handler = RotatingFileHandler( + LOG_FILE_PATH, maxBytes=LOG_MAX_BYTES, backupCount=LOG_BACKUP_COUNT, encoding="utf-8" + ) + file_handler.setLevel(getattr(logging, LOG_LEVEL)) + file_handler.setFormatter(formatter) + root_logger.addHandler(file_handler) diff --git a/python_template_server/main.py b/python_template_server/main.py index 17927cf..bb71451 100644 --- a/python_template_server/main.py +++ b/python_template_server/main.py @@ -1,6 +1,7 @@ """FastAPI template server using uvicorn.""" -from python_template_server.config import load_config, parse_args +from typing import Any + from python_template_server.models import TemplateServerConfig from python_template_server.template_server import TemplateServer @@ -8,12 +9,19 @@ class ExampleServer(TemplateServer): """Example server inheriting from TemplateServer.""" - def __init__(self, config: TemplateServerConfig) -> None: + def __init__(self) -> None: """Initialize the ExampleServer by delegating to the template server. :param TemplateServerConfig config: Example server configuration """ - super().__init__(config) + super().__init__() + + def validate_config(self, config_data: dict[str, Any]) -> TemplateServerConfig: + """Validate configuration from the config.json file. + + :return TemplateServerConfig: Loaded configuration + """ + return super().validate_config(config_data) def setup_routes(self) -> None: """Set up API routes.""" @@ -25,7 +33,5 @@ def run() -> None: :raise SystemExit: If configuration fails to load or SSL certificate files are missing """ - args = parse_args() - config = load_config(args.config_file) - server = ExampleServer(config) + server = ExampleServer() server.run() diff --git a/python_template_server/models.py b/python_template_server/models.py index f3a3ddb..e0e056c 100644 --- a/python_template_server/models.py +++ b/python_template_server/models.py @@ -6,8 +6,6 @@ from pydantic import BaseModel, Field -from python_template_server.constants import API_PREFIX - # Template Server Configuration Models class ServerConfigModel(BaseModel): @@ -26,11 +24,6 @@ def url(self) -> str: """Get the server URL.""" return f"https://{self.address}" - @property - def full_url(self) -> str: - """Get the full server URL including API prefix.""" - return f"{self.url}{API_PREFIX}" - class SecurityConfigModel(BaseModel): """Security headers configuration model.""" diff --git a/python_template_server/template_server.py b/python_template_server/template_server.py index ccdf129..100ec97 100644 --- a/python_template_server/template_server.py +++ b/python_template_server/template_server.py @@ -1,5 +1,6 @@ """Template FastAPI server module.""" +import json import logging import sys from abc import ABC, abstractmethod @@ -15,15 +16,20 @@ from prometheus_client import Counter, Gauge from prometheus_fastapi_instrumentator import Instrumentator from pydantic import BaseModel +from pydantic_core import ValidationError from slowapi import Limiter from slowapi.errors import RateLimitExceeded from slowapi.util import get_remote_address from python_template_server.authentication_handler import load_hashed_token, verify_token -from python_template_server.constants import API_KEY_HEADER_NAME, API_PREFIX, PACKAGE_NAME +from python_template_server.constants import API_KEY_HEADER_NAME, API_PREFIX, CONFIG_DIR, CONFIG_FILE_NAME, PACKAGE_NAME +from python_template_server.logging_setup import setup_logging from python_template_server.middleware import RequestLoggingMiddleware, SecurityHeadersMiddleware from python_template_server.models import GetHealthResponse, ResponseCode, ServerHealthStatus, TemplateServerConfig +setup_logging() +logger = logging.getLogger(__name__) + class TemplateServer(ABC): """Template FastAPI server. @@ -31,29 +37,31 @@ class TemplateServer(ABC): This class provides a template for building FastAPI servers with common features such as request logging, security headers, rate limiting, and Prometheus metrics. - Ensure you implement the `setup_routes` method in subclasses to define API endpoints. + Ensure you implement the `setup_routes` and `validate_config` methods in subclasses. """ - def __init__(self, config: TemplateServerConfig) -> None: + def __init__( + self, package_name: str = PACKAGE_NAME, api_prefix: str = API_PREFIX, config: TemplateServerConfig | None = None + ) -> None: """Initialize the TemplateServer. - :param TemplateServerConfig config: Template server configuration + :param str package_name: The package name for metadata retrieval + :param str api_prefix: The API prefix for the server + :param TemplateServerConfig | None config: Optional pre-loaded configuration """ - self.config = config - self.logger = logging.getLogger(__name__) - self.hashed_token = load_hashed_token() - - package_metadata = metadata(PACKAGE_NAME) - + self.api_prefix = api_prefix + package_metadata = metadata(package_name) self.app = FastAPI( title=package_metadata["Name"], description=package_metadata["Summary"], version=package_metadata["Version"], - root_path=API_PREFIX, + root_path=self.api_prefix, lifespan=self.lifespan, ) self.api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False) + self.config = config or self.load_config() + self.hashed_token = load_hashed_token() self._setup_request_logging() self._setup_security_headers() self._setup_rate_limiting() @@ -66,6 +74,45 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """Handle application lifespan events.""" yield + @abstractmethod + def validate_config(self, config_data: dict[str, Any]) -> TemplateServerConfig: + """Validate configuration data against the TemplateServerConfig model. + + :param dict config_data: The configuration data to validate + :return TemplateServerConfig: The validated configuration model + :raise ValidationError: If the configuration data is invalid + """ + return TemplateServerConfig.model_validate(config_data) + + def load_config(self, config_file: str = CONFIG_FILE_NAME) -> TemplateServerConfig: + """Load configuration from the specified json file. + + :param str config_file: Name of the configuration file + :return TemplateServerConfig: The validated configuration model + :raise SystemExit: If configuration file is missing, invalid JSON, or fails validation + """ + config_path = CONFIG_DIR / config_file + if not config_path.exists(): + logger.error("Configuration file not found: %s", config_path) + sys.exit(1) + + config_data = {} + try: + with config_path.open() as f: + config_data = json.load(f) + except json.JSONDecodeError: + logger.exception("JSON parsing error: %s", config_path) + sys.exit(1) + except OSError: + logger.exception("JSON read error: %s", config_path) + sys.exit(1) + + try: + return self.validate_config(config_data) + except ValidationError: + logger.exception("Invalid configuration in: %s", config_path) + sys.exit(1) + async def _verify_api_key( self, api_key: str | None = Security(APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)) ) -> None: @@ -75,7 +122,7 @@ async def _verify_api_key( :raise HTTPException: If the API key is missing or invalid """ if api_key is None: - self.logger.warning("Missing API key in request!") + logger.warning("Missing API key in request!") self.auth_failure_counter.labels(reason="missing").inc() raise HTTPException( status_code=ResponseCode.UNAUTHORIZED, @@ -84,16 +131,16 @@ async def _verify_api_key( try: if not verify_token(api_key, self.hashed_token): - self.logger.warning("Invalid API key attempt!") + logger.warning("Invalid API key attempt!") self.auth_failure_counter.labels(reason="invalid").inc() raise HTTPException( status_code=ResponseCode.UNAUTHORIZED, detail="Invalid API key", ) - self.logger.debug("API key validated successfully.") + logger.debug("API key validated successfully.") self.auth_success_counter.inc() except ValueError as e: - self.logger.exception("Error verifying API key!") + logger.exception("Error verifying API key!") self.auth_failure_counter.labels(reason="error").inc() raise HTTPException( status_code=ResponseCode.UNAUTHORIZED, @@ -103,7 +150,7 @@ async def _verify_api_key( def _setup_request_logging(self) -> None: """Set up request logging middleware.""" self.app.add_middleware(RequestLoggingMiddleware) - self.logger.info("Request logging enabled") + logger.info("Request logging enabled") def _setup_security_headers(self) -> None: """Set up security headers middleware.""" @@ -113,7 +160,7 @@ def _setup_security_headers(self) -> None: csp=self.config.security.content_security_policy, ) - self.logger.info( + logger.info( "Security headers enabled: HSTS max-age=%s, CSP=%s", self.config.security.hsts_max_age, self.config.security.content_security_policy, @@ -138,7 +185,7 @@ async def _rate_limit_exception_handler(self, request: Request, exc: RateLimitEx def _setup_rate_limiting(self) -> None: """Set up rate limiting middleware.""" if not self.config.rate_limit.enabled: - self.logger.info("Rate limiting is disabled") + logger.info("Rate limiting is disabled") self.limiter = None return @@ -150,7 +197,7 @@ def _setup_rate_limiting(self) -> None: self.app.state.limiter = self.limiter self.app.add_exception_handler(RateLimitExceeded, self._rate_limit_exception_handler) # type: ignore[arg-type] - self.logger.info( + logger.info( "Rate limiting enabled: rate=%s, storage=%s", self.config.rate_limit.rate_limit, self.config.rate_limit.storage_uri or "in-memory", @@ -193,7 +240,7 @@ def _setup_metrics(self) -> None: ["endpoint"], ) - self.logger.info("Prometheus metrics enabled.") + logger.info("Prometheus metrics enabled.") def run(self) -> None: """Run the server using uvicorn. @@ -205,10 +252,10 @@ def run(self) -> None: key_file = self.config.certificate.ssl_key_file_path if not (cert_file.exists() and key_file.exists()): - self.logger.error("SSL certificate files are missing. Expected: '%s' and '%s'", cert_file, key_file) + logger.error("SSL certificate files are missing. Expected: '%s' and '%s'", cert_file, key_file) sys.exit(1) - self.logger.info("Starting server: %s", self.config.server.full_url) + logger.info("Starting server: %s%s", self.config.server.url, self.api_prefix) uvicorn.run( self.app, host=self.config.server.host, @@ -216,9 +263,9 @@ def run(self) -> None: ssl_keyfile=str(key_file), ssl_certfile=str(cert_file), ) - self.logger.info("Server stopped.") + logger.info("Server stopped.") except OSError: - self.logger.exception("Failed to start - ran into an OSError!") + logger.exception("Failed to start - ran into an OSError!") sys.exit(1) def add_unauthenticated_route( diff --git a/tests/test_authentication_handler.py b/tests/test_authentication_handler.py index 09d2843..5d41fa3 100644 --- a/tests/test_authentication_handler.py +++ b/tests/test_authentication_handler.py @@ -13,8 +13,7 @@ save_hashed_token, verify_token, ) -from python_template_server.config import ROOT_DIR -from python_template_server.constants import ENV_FILE_NAME, ENV_VAR_NAME, TOKEN_LENGTH +from python_template_server.constants import ENV_FILE_PATH, ENV_VAR_NAME, TOKEN_LENGTH @pytest.fixture @@ -60,7 +59,7 @@ def test_save_hashed_token( """Test the save_hashed_token function.""" mock_exists.return_value = True save_hashed_token("testtoken") - mock_set_key.assert_called_once_with(ROOT_DIR / ENV_FILE_NAME, ENV_VAR_NAME, mock_hash_token.return_value) + mock_set_key.assert_called_once_with(ENV_FILE_PATH, ENV_VAR_NAME, mock_hash_token.return_value) def test_save_hashed_token_file_creation( self, mock_hash_token: MagicMock, mock_exists: MagicMock, mock_touch: MagicMock, mock_set_key: MagicMock @@ -69,7 +68,7 @@ def test_save_hashed_token_file_creation( mock_exists.return_value = False save_hashed_token("testtoken") mock_touch.assert_called_once() - mock_set_key.assert_called_once_with(ROOT_DIR / ENV_FILE_NAME, ENV_VAR_NAME, mock_hash_token.return_value) + mock_set_key.assert_called_once_with(ENV_FILE_PATH, ENV_VAR_NAME, mock_hash_token.return_value) @pytest.mark.parametrize( ("token", "expected"), diff --git a/tests/test_certificate_handler.py b/tests/test_certificate_handler.py index a143395..e53b476 100644 --- a/tests/test_certificate_handler.py +++ b/tests/test_certificate_handler.py @@ -18,10 +18,15 @@ @pytest.fixture -def mock_load_config() -> Generator[MagicMock, None, None]: - """Mock the load_config function.""" - with patch("python_template_server.certificate_handler.load_config") as mock_config: - yield mock_config +def mock_example_server( + tmp_path: Path, mock_template_server_config: TemplateServerConfig +) -> Generator[MagicMock, None, None]: + """Mock the ExampleServer class.""" + with patch("python_template_server.certificate_handler.ExampleServer") as mock_server: + cert_dir = tmp_path / "certs" + mock_template_server_config.certificate.directory = str(cert_dir) + mock_server.return_value.config = mock_template_server_config + yield mock_server class TestCertificateHandler: @@ -163,23 +168,15 @@ class TestGenerateSelfSignedCertificate: """Unit tests for the generate_self_signed_certificate function.""" def test_generate_self_signed_certificate_success( - self, mock_load_config: MagicMock, mock_template_server_config: TemplateServerConfig, tmp_path: Path + self, mock_example_server: MagicMock, mock_template_server_config: TemplateServerConfig, tmp_path: Path ) -> None: """Test successful certificate generation via wrapper function.""" - # Use tmp_path for certificate directory to avoid permission issues - cert_dir = tmp_path / "certs" - mock_template_server_config.certificate.directory = str(cert_dir) - mock_load_config.return_value = mock_template_server_config - with ( patch.object(CertificateHandler, "write_to_key_file") as mock_write_key, patch.object(CertificateHandler, "write_to_cert_file") as mock_write_cert, ): generate_self_signed_certificate() - # Verify config was loaded - mock_load_config.assert_called_once() - # Verify write methods were called with PEM-encoded data mock_write_key.assert_called_once() key_data = mock_write_key.call_args[0][0] @@ -191,17 +188,12 @@ def test_generate_self_signed_certificate_success( def test_generate_self_signed_certificate_os_error( self, - mock_load_config: MagicMock, + mock_example_server: MagicMock, mock_template_server_config: TemplateServerConfig, mock_sys_exit: MagicMock, tmp_path: Path, ) -> None: """Test certificate generation wrapper handles OSError.""" - # Use tmp_path for certificate directory - cert_dir = tmp_path / "certs" - mock_template_server_config.certificate.directory = str(cert_dir) - mock_load_config.return_value = mock_template_server_config - with patch( "python_template_server.certificate_handler.CertificateHandler.generate_self_signed_cert", side_effect=OSError("Disk error"), @@ -213,17 +205,12 @@ def test_generate_self_signed_certificate_os_error( def test_generate_self_signed_certificate_permission_error( self, - mock_load_config: MagicMock, + mock_example_server: MagicMock, mock_template_server_config: TemplateServerConfig, mock_sys_exit: MagicMock, tmp_path: Path, ) -> None: """Test certificate generation wrapper handles PermissionError.""" - # Use tmp_path for certificate directory - cert_dir = tmp_path / "certs" - mock_template_server_config.certificate.directory = str(cert_dir) - mock_load_config.return_value = mock_template_server_config - with patch( "python_template_server.certificate_handler.CertificateHandler.generate_self_signed_cert", side_effect=PermissionError("No permission"), @@ -235,17 +222,12 @@ def test_generate_self_signed_certificate_permission_error( def test_generate_self_signed_certificate_unexpected_error( self, - mock_load_config: MagicMock, + mock_example_server: MagicMock, mock_template_server_config: TemplateServerConfig, mock_sys_exit: MagicMock, tmp_path: Path, ) -> None: """Test certificate generation wrapper handles unexpected exceptions.""" - # Use tmp_path for certificate directory - cert_dir = tmp_path / "certs" - mock_template_server_config.certificate.directory = str(cert_dir) - mock_load_config.return_value = mock_template_server_config - with patch( "python_template_server.certificate_handler.CertificateHandler.generate_self_signed_cert", side_effect=RuntimeError("Unexpected error"), diff --git a/tests/test_config.py b/tests/test_config.py deleted file mode 100644 index 12d84d5..0000000 --- a/tests/test_config.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Unit tests for the python_template_server.config module.""" - -import json -import logging -from unittest.mock import MagicMock - -import pytest - -from python_template_server.config import load_config, setup_logging -from python_template_server.constants import CONFIG_FILE_NAME -from python_template_server.models import TemplateServerConfig - - -class TestSetupLogging: - """Tests for the setup_logging function.""" - - def test_setup_logging_creates_log_directory(self, mock_mkdir: MagicMock) -> None: - """Test that setup_logging creates the log directory.""" - setup_logging() - mock_mkdir.assert_called_once_with(exist_ok=True) - - def test_setup_logging_configures_handlers(self) -> None: - """Test that setup_logging configures both console and file handlers.""" - expected_handlers = ["StreamHandler", "RotatingFileHandler"] - - setup_logging() - - root_logger = logging.getLogger() - assert len(root_logger.handlers) == len(expected_handlers) - - # Check handler types - handler_types = [type(handler).__name__ for handler in root_logger.handlers] - for expected_handler in expected_handlers: - assert expected_handler in handler_types - - def test_setup_logging_handlers_have_formatters(self) -> None: - """Test that all handlers have formatters configured.""" - setup_logging() - - root_logger = logging.getLogger() - for handler in root_logger.handlers: - assert handler.formatter is not None - assert handler.formatter._fmt is not None - assert "[%(asctime)s]" in handler.formatter._fmt - - -class TestLoadConfig: - """Tests for the load_config function.""" - - def test_load_config_success( - self, - mock_exists: MagicMock, - mock_open_file: MagicMock, - mock_sys_exit: MagicMock, - mock_template_server_config: TemplateServerConfig, - ) -> None: - """Test successful loading of config.""" - mock_exists.return_value = True - mock_open_file.return_value.read.return_value = json.dumps(mock_template_server_config.model_dump()) - - config = load_config(CONFIG_FILE_NAME) - - assert isinstance(config, TemplateServerConfig) - assert config == mock_template_server_config - mock_sys_exit.assert_not_called() - - def test_load_config_file_not_found( - self, - mock_exists: MagicMock, - mock_sys_exit: MagicMock, - ) -> None: - """Test loading config when the file does not exist.""" - mock_exists.return_value = False - - with pytest.raises(SystemExit): - load_config(CONFIG_FILE_NAME) - - mock_sys_exit.assert_called_once_with(1) - - def test_load_config_invalid_json( - self, - mock_exists: MagicMock, - mock_open_file: MagicMock, - mock_sys_exit: MagicMock, - ) -> None: - """Test loading config with invalid JSON content.""" - mock_exists.return_value = True - mock_open_file.return_value.read.return_value = "invalid json" - - with pytest.raises(SystemExit): - load_config(CONFIG_FILE_NAME) - - mock_sys_exit.assert_called_with(1) - - def test_load_config_os_error( - self, - mock_exists: MagicMock, - mock_open_file: MagicMock, - mock_sys_exit: MagicMock, - ) -> None: - """Test loading config that raises an OSError.""" - mock_exists.return_value = True - mock_open_file.side_effect = OSError("File read error") - - with pytest.raises(SystemExit): - load_config(CONFIG_FILE_NAME) - - mock_sys_exit.assert_called_with(1) - - def test_load_config_validation_error( - self, - mock_exists: MagicMock, - mock_open_file: MagicMock, - mock_sys_exit: MagicMock, - ) -> None: - """Test loading config that fails validation.""" - mock_exists.return_value = True - mock_open_file.return_value.read.return_value = json.dumps({"server": {"host": "localhost", "port": 999999}}) - - with pytest.raises(SystemExit): - load_config(CONFIG_FILE_NAME) - - mock_sys_exit.assert_called_once_with(1) diff --git a/tests/test_logging_setup.py b/tests/test_logging_setup.py new file mode 100644 index 0000000..39bca2a --- /dev/null +++ b/tests/test_logging_setup.py @@ -0,0 +1,39 @@ +"""Unit tests for the python_template_server.logging_setup module.""" + +import logging +from unittest.mock import MagicMock + +from python_template_server.logging_setup import setup_logging + + +class TestSetupLogging: + """Tests for the setup_logging function.""" + + def test_setup_logging_creates_log_directory(self, mock_mkdir: MagicMock) -> None: + """Test that setup_logging creates the log directory.""" + setup_logging() + mock_mkdir.assert_called_once_with(exist_ok=True) + + def test_setup_logging_configures_handlers(self) -> None: + """Test that setup_logging configures both console and file handlers.""" + expected_handlers = ["StreamHandler", "RotatingFileHandler"] + + setup_logging() + + root_logger = logging.getLogger() + assert len(root_logger.handlers) == len(expected_handlers) + + # Check handler types + handler_types = [type(handler).__name__ for handler in root_logger.handlers] + for expected_handler in expected_handlers: + assert expected_handler in handler_types + + def test_setup_logging_handlers_have_formatters(self) -> None: + """Test that all handlers have formatters configured.""" + setup_logging() + + root_logger = logging.getLogger() + for handler in root_logger.handlers: + assert handler.formatter is not None + assert handler.formatter._fmt is not None + assert "[%(asctime)s]" in handler.formatter._fmt diff --git a/tests/test_main.py b/tests/test_main.py index e6d3796..5fa8cc0 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -12,26 +12,18 @@ @pytest.fixture -def mock_load_config(mock_template_server_config: TemplateServerConfig) -> Generator[MagicMock, None, None]: - """Mock the load_config function.""" - with patch("python_template_server.main.load_config") as mock_config: - mock_config.return_value = mock_template_server_config - yield mock_config - - -@pytest.fixture -def mock_template_server_class() -> Generator[MagicMock, None, None]: +def mock_template_server_class(mock_template_server_config: TemplateServerConfig) -> Generator[MagicMock, None, None]: """Mock TemplateServer class.""" with patch("python_template_server.main.ExampleServer") as mock_server: + mock_server.load_config.return_value = mock_template_server_config yield mock_server class TestRun: """Unit tests for the run function.""" - def test_run(self, mock_load_config: MagicMock, mock_template_server_class: MagicMock) -> None: + def test_run(self, mock_template_server_class: MagicMock) -> None: """Test successful server run.""" run() - mock_load_config.assert_called_once() mock_template_server_class.return_value.run.assert_called_once() diff --git a/tests/test_models.py b/tests/test_models.py index 7b6a76b..c554cfc 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -34,10 +34,6 @@ def test_url_property(self, mock_server_config: ServerConfigModel) -> None: """Test the url property.""" assert mock_server_config.url == "https://localhost:8080" - def test_full_url_property(self, mock_server_config: ServerConfigModel) -> None: - """Test the full_url property.""" - assert mock_server_config.full_url == "https://localhost:8080/api" - @pytest.mark.parametrize("port", [0, 70000]) def test_port_field(self, mock_server_config_dict: dict, port: int) -> None: """Test the port field validation.""" diff --git a/tests/test_template_server.py b/tests/test_template_server.py index e8fe740..3236d72 100644 --- a/tests/test_template_server.py +++ b/tests/test_template_server.py @@ -5,6 +5,8 @@ import asyncio import json from collections.abc import Generator +from importlib.metadata import PackageMetadata +from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -26,6 +28,21 @@ from python_template_server.template_server import TemplateServer +@pytest.fixture(autouse=True) +def mock_package_metadata() -> Generator[MagicMock, None, None]: + """Mock importlib.metadata.metadata to return a mock PackageMetadata.""" + with patch("python_template_server.template_server.metadata") as mock_metadata: + mock_pkg_metadata = MagicMock(spec=PackageMetadata) + metadata_dict = { + "Name": "python-template-server", + "Version": "0.1.0", + "Summary": "A template FastAPI server with authentication, rate limiting and Prometheus metrics.", + } + mock_pkg_metadata.__getitem__.side_effect = lambda key: metadata_dict[key] + mock_metadata.return_value = mock_pkg_metadata + yield mock_metadata + + @pytest.fixture def mock_verify_token() -> Generator[MagicMock, None, None]: """Mock the verify_token function.""" @@ -52,7 +69,7 @@ def mock_timestamp() -> Generator[str, None, None]: @pytest.fixture def mock_template_server(mock_template_server_config: TemplateServerConfig) -> MockTemplateServer: """Provide a MockTemplateServer instance for testing.""" - return MockTemplateServer(mock_template_server_config) + return MockTemplateServer(config=mock_template_server_config) class MockTemplateServer(TemplateServer): @@ -70,6 +87,14 @@ def mock_protected_method(self, request: Request) -> BaseResponse: code=ResponseCode.OK, message="protected endpoint", timestamp=BaseResponse.current_timestamp() ) + def validate_config(self, config_data: dict[str, Any]) -> TemplateServerConfig: + """Validate configuration from the config.json file. + + :param dict config_data: Configuration data + :return TemplateServerConfig: Loaded configuration + """ + return super().validate_config(config_data) + def setup_routes(self) -> None: """Set up mock routes for testing.""" super().setup_routes() @@ -99,6 +124,85 @@ def test_request_middleware_added(self, mock_template_server: TemplateServer) -> assert SecurityHeadersMiddleware in middlewares +class TestLoadConfig: + """Tests for the load_config function.""" + + def test_load_config_success( + self, + mock_exists: MagicMock, + mock_open_file: MagicMock, + mock_sys_exit: MagicMock, + mock_template_server_config: TemplateServerConfig, + ) -> None: + """Test successful loading of config.""" + mock_exists.return_value = True + mock_open_file.return_value.read.return_value = json.dumps(mock_template_server_config.model_dump()) + + config = MockTemplateServer().config + + assert isinstance(config, TemplateServerConfig) + assert config == mock_template_server_config + mock_sys_exit.assert_not_called() + + def test_load_config_file_not_found( + self, + mock_exists: MagicMock, + mock_sys_exit: MagicMock, + ) -> None: + """Test loading config when the file does not exist.""" + mock_exists.return_value = False + + with pytest.raises(SystemExit): + MockTemplateServer() + + mock_sys_exit.assert_called_once_with(1) + + def test_load_config_invalid_json( + self, + mock_exists: MagicMock, + mock_open_file: MagicMock, + mock_sys_exit: MagicMock, + ) -> None: + """Test loading config with invalid JSON content.""" + mock_exists.return_value = True + mock_open_file.return_value.read.return_value = "invalid json" + + with pytest.raises(SystemExit): + MockTemplateServer() + + mock_sys_exit.assert_called_with(1) + + def test_load_config_os_error( + self, + mock_exists: MagicMock, + mock_open_file: MagicMock, + mock_sys_exit: MagicMock, + ) -> None: + """Test loading config that raises an OSError.""" + mock_exists.return_value = True + mock_open_file.side_effect = OSError("File read error") + + with pytest.raises(SystemExit): + MockTemplateServer() + + mock_sys_exit.assert_called_with(1) + + def test_load_config_validation_error( + self, + mock_exists: MagicMock, + mock_open_file: MagicMock, + mock_sys_exit: MagicMock, + ) -> None: + """Test loading config that fails validation.""" + mock_exists.return_value = True + mock_open_file.return_value.read.return_value = json.dumps({"server": {"host": "localhost", "port": 999999}}) + + with pytest.raises(SystemExit): + MockTemplateServer() + + mock_sys_exit.assert_called_once_with(1) + + class TestVerifyApiKey: """Unit tests for the _verify_api_key method.""" @@ -243,14 +347,14 @@ def test_setup_rate_limiting_enabled(self, mock_template_server_config: Template """Test rate limiting setup when enabled.""" mock_template_server_config.rate_limit.enabled = True - server = MockTemplateServer(mock_template_server_config) + server = MockTemplateServer(config=mock_template_server_config) assert server.limiter is not None assert server.app.state.limiter is not None def test_setup_rate_limiting_disabled(self, mock_template_server_config: TemplateServerConfig) -> None: """Test rate limiting setup when disabled.""" - server = MockTemplateServer(mock_template_server_config) + server = MockTemplateServer(config=mock_template_server_config) assert server.limiter is None @@ -258,7 +362,7 @@ def test_limit_route_with_limiter_enabled(self, mock_template_server_config: Tem """Test _limit_route when rate limiting is enabled.""" mock_template_server_config.rate_limit.enabled = True - server = MockTemplateServer(mock_template_server_config) + server = MockTemplateServer(config=mock_template_server_config) limited_route = server._limit_route(server.mock_unprotected_method) assert limited_route != server.mock_unprotected_method @@ -266,7 +370,7 @@ def test_limit_route_with_limiter_enabled(self, mock_template_server_config: Tem def test_limit_route_with_limiter_disabled(self, mock_template_server_config: TemplateServerConfig) -> None: """Test _limit_route when rate limiting is disabled.""" - server = MockTemplateServer(mock_template_server_config) + server = MockTemplateServer(config=mock_template_server_config) limited_route = server._limit_route(server.mock_unprotected_method) assert limited_route == server.mock_unprotected_method