Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions python_template_server/template_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,35 +268,47 @@ def run(self) -> None:
sys.exit(1)

def add_unauthenticated_route(
self, endpoint: str, handler_function: Callable, response_model: type[BaseModel], methods: list[str]
self,
endpoint: str,
handler_function: Callable,
response_model: type[BaseModel],
methods: list[str],
limited: bool = True, # noqa: FBT001, FBT002
) -> None:
"""Add an unauthenticated API route.

:param str endpoint: The API endpoint path
:param Callable handler_function: The handler function for the endpoint
:param BaseModel response_model: The Pydantic model for the response
:param list[str] methods: The HTTP methods for the endpoint
:param bool limited: Whether to apply rate limiting to this route
"""
self.app.add_api_route(
endpoint,
self._limit_route(handler_function),
self._limit_route(handler_function) if limited else handler_function,
methods=methods,
response_model=response_model,
)

def add_authenticated_route(
self, endpoint: str, handler_function: Callable, response_model: type[BaseModel], methods: list[str]
self,
endpoint: str,
handler_function: Callable,
response_model: type[BaseModel],
methods: list[str],
limited: bool = True, # noqa: FBT001, FBT002
) -> None:
"""Add an authenticated API route.

:param str endpoint: The API endpoint path
:param Callable handler_function: The handler function for the endpoint
:param BaseModel response_model: The Pydantic model for the response
:param list[str] methods: The HTTP methods for the endpoint
:param bool limited: Whether to apply rate limiting to this route
"""
self.app.add_api_route(
endpoint,
self._limit_route(handler_function),
self._limit_route(handler_function) if limited else handler_function,
methods=methods,
response_model=response_model,
dependencies=[Security(self._verify_api_key)],
Expand All @@ -315,7 +327,7 @@ def setup_routes(self) -> None:
```

"""
self.add_unauthenticated_route("/health", self.get_health, GetHealthResponse, ["GET"])
self.add_unauthenticated_route("/health", self.get_health, GetHealthResponse, ["GET"], limited=False)

async def get_health(self, request: Request) -> GetHealthResponse:
"""Get server health.
Expand Down
81 changes: 80 additions & 1 deletion tests/test_template_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,18 @@ def mock_protected_method(self, request: Request) -> BaseResponse:
code=ResponseCode.OK, message="protected endpoint", timestamp=BaseResponse.current_timestamp()
)

def mock_unlimited_unprotected_method(self, request: Request) -> BaseResponse:
"""Mock unlimited unprotected method."""
return BaseResponse(
code=ResponseCode.OK, message="unlimited unprotected endpoint", timestamp=BaseResponse.current_timestamp()
)

def mock_unlimited_protected_method(self, request: Request) -> BaseResponse:
"""Mock unlimited protected method."""
return BaseResponse(
code=ResponseCode.OK, message="unlimited protected endpoint", timestamp=BaseResponse.current_timestamp()
)

def validate_config(self, config_data: dict[str, Any]) -> TemplateServerConfig:
"""Validate configuration from the config.json file.

Expand All @@ -104,6 +116,20 @@ def setup_routes(self) -> None:
super().setup_routes()
self.add_unauthenticated_route("/unauthenticated-endpoint", self.mock_unprotected_method, BaseResponse, ["GET"])
self.add_authenticated_route("/authenticated-endpoint", self.mock_protected_method, BaseResponse, ["POST"])
self.add_unauthenticated_route(
"/unlimited-unauthenticated-endpoint",
self.mock_unlimited_unprotected_method,
BaseResponse,
["GET"],
limited=False,
)
self.add_authenticated_route(
"/unlimited-authenticated-endpoint",
self.mock_unlimited_protected_method,
BaseResponse,
["POST"],
limited=False,
)


class TestTemplateServer:
Expand Down Expand Up @@ -485,11 +511,64 @@ def test_add_authenticated_route(self, mock_template_server: MockTemplateServer)
assert "POST" in test_route.methods
assert test_route.response_model == BaseResponse

def test_limited_parameter_with_rate_limiting_enabled(
self, mock_template_server_config: TemplateServerConfig
) -> None:
"""Test that limited=True applies rate limiting when limiter is enabled."""
mock_template_server_config.rate_limit.enabled = True
server = MockTemplateServer(config=mock_template_server_config)

# Get the limited routes
api_routes = [route for route in server.app.routes if isinstance(route, APIRoute)]
limited_route = next((route for route in api_routes if route.path == "/unauthenticated-endpoint"), None)
unlimited_route = next(
(route for route in api_routes if route.path == "/unlimited-unauthenticated-endpoint"), None
)

assert limited_route is not None
assert unlimited_route is not None

# Limited route should have the limiter wrapper
assert hasattr(limited_route.endpoint, "__wrapped__")
# Unlimited route should not have the limiter wrapper
assert not hasattr(unlimited_route.endpoint, "__wrapped__")

def test_authenticated_route_limited_parameter(self, mock_template_server_config: TemplateServerConfig) -> None:
"""Test that limited parameter works correctly for authenticated routes."""
mock_template_server_config.rate_limit.enabled = True
server = MockTemplateServer(config=mock_template_server_config)

# Get the authenticated routes
api_routes = [route for route in server.app.routes if isinstance(route, APIRoute)]
limited_route = next((route for route in api_routes if route.path == "/authenticated-endpoint"), None)
unlimited_route = next(
(route for route in api_routes if route.path == "/unlimited-authenticated-endpoint"), None
)

assert limited_route is not None
assert unlimited_route is not None

# Both routes should have authentication dependencies
assert len(limited_route.dependencies) > 0
assert len(unlimited_route.dependencies) > 0

# Limited route should have the limiter wrapper
assert hasattr(limited_route.endpoint, "__wrapped__")
# Unlimited route should not have the limiter wrapper
assert not hasattr(unlimited_route.endpoint, "__wrapped__")

def test_setup_routes(self, mock_template_server: MockTemplateServer) -> None:
"""Test that routes are set up correctly."""
api_routes = [route for route in mock_template_server.app.routes if isinstance(route, APIRoute)]
routes = [route.path for route in api_routes]
expected_endpoints = ["/health", "/metrics", "/unauthenticated-endpoint", "/authenticated-endpoint"]
expected_endpoints = [
"/health",
"/metrics",
"/unauthenticated-endpoint",
"/authenticated-endpoint",
"/unlimited-unauthenticated-endpoint",
"/unlimited-authenticated-endpoint",
]
for endpoint in expected_endpoints:
assert endpoint in routes

Expand Down