Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions python_template_server/template_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import AsyncGenerator, Callable
from contextlib import asynccontextmanager
from importlib.metadata import metadata
from pathlib import Path
from typing import Any

import uvicorn
Expand Down Expand Up @@ -41,12 +42,17 @@ class TemplateServer(ABC):
"""

def __init__(
self, package_name: str = PACKAGE_NAME, api_prefix: str = API_PREFIX, config: TemplateServerConfig | None = None
self,
package_name: str = PACKAGE_NAME,
api_prefix: str = API_PREFIX,
config_filepath: Path = CONFIG_DIR / CONFIG_FILE_NAME,
config: TemplateServerConfig | None = None,
) -> None:
"""Initialize the TemplateServer.

:param str package_name: The package name for metadata retrieval
:param str api_prefix: The API prefix for the server
:param Path config_filepath: Path to the configuration file
:param TemplateServerConfig | None config: Optional pre-loaded configuration
"""
self.api_prefix = api_prefix
Expand All @@ -60,7 +66,7 @@ def __init__(
)
self.api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)

self.config = config or self.load_config()
self.config = config or self.load_config(config_filepath)
self.hashed_token = load_hashed_token()
self._setup_request_logging()
self._setup_security_headers()
Expand All @@ -84,33 +90,32 @@ def validate_config(self, config_data: dict[str, Any]) -> TemplateServerConfig:
"""
return TemplateServerConfig.model_validate(config_data)

def load_config(self, config_file: str = CONFIG_FILE_NAME) -> TemplateServerConfig:
def load_config(self, config_filepath: Path) -> TemplateServerConfig:
"""Load configuration from the specified json file.

:param str config_file: Name of the configuration file
:param Path config_filepath: Path to the configuration file
:return TemplateServerConfig: The validated configuration model
:raise SystemExit: If configuration file is missing, invalid JSON, or fails validation
"""
config_path = CONFIG_DIR / config_file
if not config_path.exists():
logger.error("Configuration file not found: %s", config_path)
if not config_filepath.exists():
logger.error("Configuration file not found: %s", config_filepath)
sys.exit(1)

config_data = {}
try:
with config_path.open() as f:
with config_filepath.open() as f:
config_data = json.load(f)
except json.JSONDecodeError:
logger.exception("JSON parsing error: %s", config_path)
logger.exception("JSON parsing error: %s", config_filepath)
sys.exit(1)
except OSError:
logger.exception("JSON read error: %s", config_path)
logger.exception("JSON read error: %s", config_filepath)
sys.exit(1)

try:
return self.validate_config(config_data)
except ValidationError:
logger.exception("Invalid configuration in: %s", config_path)
logger.exception("Invalid configuration in: %s", config_filepath)
sys.exit(1)

async def _verify_api_key(
Expand Down
14 changes: 13 additions & 1 deletion tests/test_template_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
from collections.abc import Generator
from importlib.metadata import PackageMetadata
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch

Expand Down Expand Up @@ -127,7 +128,18 @@ def test_request_middleware_added(self, mock_template_server: TemplateServer) ->
class TestLoadConfig:
"""Tests for the load_config function."""

def test_load_config_success(
def test_load_config_with_filepath_success(self, mock_template_server_config: TemplateServerConfig) -> None:
"""Test that load_config is called with the specified filepath when config is None."""
with patch.object(
MockTemplateServer, "load_config", return_value=mock_template_server_config
) as mock_load_config:
custom_filepath = Path("/custom/config.json")
server = MockTemplateServer(config_filepath=custom_filepath)

mock_load_config.assert_called_once_with(custom_filepath)
assert server.config == mock_template_server_config

def test_load_config_with_no_filepath_success(
self,
mock_exists: MagicMock,
mock_open_file: MagicMock,
Expand Down