diff --git a/python_template_server/template_server.py b/python_template_server/template_server.py index 100ec97..eb3520c 100644 --- a/python_template_server/template_server.py +++ b/python_template_server/template_server.py @@ -7,6 +7,7 @@ from collections.abc import AsyncGenerator, Callable from contextlib import asynccontextmanager from importlib.metadata import metadata +from pathlib import Path from typing import Any import uvicorn @@ -41,12 +42,17 @@ class TemplateServer(ABC): """ def __init__( - self, package_name: str = PACKAGE_NAME, api_prefix: str = API_PREFIX, config: TemplateServerConfig | None = None + self, + package_name: str = PACKAGE_NAME, + api_prefix: str = API_PREFIX, + config_filepath: Path = CONFIG_DIR / CONFIG_FILE_NAME, + config: TemplateServerConfig | None = None, ) -> None: """Initialize the TemplateServer. :param str package_name: The package name for metadata retrieval :param str api_prefix: The API prefix for the server + :param Path config_filepath: Path to the configuration file :param TemplateServerConfig | None config: Optional pre-loaded configuration """ self.api_prefix = api_prefix @@ -60,7 +66,7 @@ def __init__( ) self.api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False) - self.config = config or self.load_config() + self.config = config or self.load_config(config_filepath) self.hashed_token = load_hashed_token() self._setup_request_logging() self._setup_security_headers() @@ -84,33 +90,32 @@ def validate_config(self, config_data: dict[str, Any]) -> TemplateServerConfig: """ return TemplateServerConfig.model_validate(config_data) - def load_config(self, config_file: str = CONFIG_FILE_NAME) -> TemplateServerConfig: + def load_config(self, config_filepath: Path) -> TemplateServerConfig: """Load configuration from the specified json file. - :param str config_file: Name of the configuration file + :param Path config_filepath: Path to the configuration file :return TemplateServerConfig: The validated configuration model :raise SystemExit: If configuration file is missing, invalid JSON, or fails validation """ - config_path = CONFIG_DIR / config_file - if not config_path.exists(): - logger.error("Configuration file not found: %s", config_path) + if not config_filepath.exists(): + logger.error("Configuration file not found: %s", config_filepath) sys.exit(1) config_data = {} try: - with config_path.open() as f: + with config_filepath.open() as f: config_data = json.load(f) except json.JSONDecodeError: - logger.exception("JSON parsing error: %s", config_path) + logger.exception("JSON parsing error: %s", config_filepath) sys.exit(1) except OSError: - logger.exception("JSON read error: %s", config_path) + logger.exception("JSON read error: %s", config_filepath) sys.exit(1) try: return self.validate_config(config_data) except ValidationError: - logger.exception("Invalid configuration in: %s", config_path) + logger.exception("Invalid configuration in: %s", config_filepath) sys.exit(1) async def _verify_api_key( diff --git a/tests/test_template_server.py b/tests/test_template_server.py index 3236d72..a074881 100644 --- a/tests/test_template_server.py +++ b/tests/test_template_server.py @@ -6,6 +6,7 @@ import json from collections.abc import Generator from importlib.metadata import PackageMetadata +from pathlib import Path from typing import Any from unittest.mock import MagicMock, patch @@ -127,7 +128,18 @@ def test_request_middleware_added(self, mock_template_server: TemplateServer) -> class TestLoadConfig: """Tests for the load_config function.""" - def test_load_config_success( + def test_load_config_with_filepath_success(self, mock_template_server_config: TemplateServerConfig) -> None: + """Test that load_config is called with the specified filepath when config is None.""" + with patch.object( + MockTemplateServer, "load_config", return_value=mock_template_server_config + ) as mock_load_config: + custom_filepath = Path("/custom/config.json") + server = MockTemplateServer(config_filepath=custom_filepath) + + mock_load_config.assert_called_once_with(custom_filepath) + assert server.config == mock_template_server_config + + def test_load_config_with_no_filepath_success( self, mock_exists: MagicMock, mock_open_file: MagicMock,