Skip to content

Commit ff04d4e

Browse files
Make route limiting optional for unauthenticated and authenticated routes (#14)
1 parent f15aac6 commit ff04d4e

File tree

2 files changed

+97
-6
lines changed

2 files changed

+97
-6
lines changed

python_template_server/template_server.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -268,35 +268,47 @@ def run(self) -> None:
268268
sys.exit(1)
269269

270270
def add_unauthenticated_route(
271-
self, endpoint: str, handler_function: Callable, response_model: type[BaseModel], methods: list[str]
271+
self,
272+
endpoint: str,
273+
handler_function: Callable,
274+
response_model: type[BaseModel],
275+
methods: list[str],
276+
limited: bool = True, # noqa: FBT001, FBT002
272277
) -> None:
273278
"""Add an unauthenticated API route.
274279
275280
:param str endpoint: The API endpoint path
276281
:param Callable handler_function: The handler function for the endpoint
277282
:param BaseModel response_model: The Pydantic model for the response
278283
:param list[str] methods: The HTTP methods for the endpoint
284+
:param bool limited: Whether to apply rate limiting to this route
279285
"""
280286
self.app.add_api_route(
281287
endpoint,
282-
self._limit_route(handler_function),
288+
self._limit_route(handler_function) if limited else handler_function,
283289
methods=methods,
284290
response_model=response_model,
285291
)
286292

287293
def add_authenticated_route(
288-
self, endpoint: str, handler_function: Callable, response_model: type[BaseModel], methods: list[str]
294+
self,
295+
endpoint: str,
296+
handler_function: Callable,
297+
response_model: type[BaseModel],
298+
methods: list[str],
299+
limited: bool = True, # noqa: FBT001, FBT002
289300
) -> None:
290301
"""Add an authenticated API route.
291302
292303
:param str endpoint: The API endpoint path
293304
:param Callable handler_function: The handler function for the endpoint
294305
:param BaseModel response_model: The Pydantic model for the response
295306
:param list[str] methods: The HTTP methods for the endpoint
307+
:param bool limited: Whether to apply rate limiting to this route
296308
"""
297309
self.app.add_api_route(
298310
endpoint,
299-
self._limit_route(handler_function),
311+
self._limit_route(handler_function) if limited else handler_function,
300312
methods=methods,
301313
response_model=response_model,
302314
dependencies=[Security(self._verify_api_key)],
@@ -315,7 +327,7 @@ def setup_routes(self) -> None:
315327
```
316328
317329
"""
318-
self.add_unauthenticated_route("/health", self.get_health, GetHealthResponse, ["GET"])
330+
self.add_unauthenticated_route("/health", self.get_health, GetHealthResponse, ["GET"], limited=False)
319331

320332
async def get_health(self, request: Request) -> GetHealthResponse:
321333
"""Get server health.

tests/test_template_server.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,18 @@ def mock_protected_method(self, request: Request) -> BaseResponse:
9191
code=ResponseCode.OK, message="protected endpoint", timestamp=BaseResponse.current_timestamp()
9292
)
9393

94+
def mock_unlimited_unprotected_method(self, request: Request) -> BaseResponse:
95+
"""Mock unlimited unprotected method."""
96+
return BaseResponse(
97+
code=ResponseCode.OK, message="unlimited unprotected endpoint", timestamp=BaseResponse.current_timestamp()
98+
)
99+
100+
def mock_unlimited_protected_method(self, request: Request) -> BaseResponse:
101+
"""Mock unlimited protected method."""
102+
return BaseResponse(
103+
code=ResponseCode.OK, message="unlimited protected endpoint", timestamp=BaseResponse.current_timestamp()
104+
)
105+
94106
def validate_config(self, config_data: dict[str, Any]) -> TemplateServerConfig:
95107
"""Validate configuration from the config.json file.
96108
@@ -104,6 +116,20 @@ def setup_routes(self) -> None:
104116
super().setup_routes()
105117
self.add_unauthenticated_route("/unauthenticated-endpoint", self.mock_unprotected_method, BaseResponse, ["GET"])
106118
self.add_authenticated_route("/authenticated-endpoint", self.mock_protected_method, BaseResponse, ["POST"])
119+
self.add_unauthenticated_route(
120+
"/unlimited-unauthenticated-endpoint",
121+
self.mock_unlimited_unprotected_method,
122+
BaseResponse,
123+
["GET"],
124+
limited=False,
125+
)
126+
self.add_authenticated_route(
127+
"/unlimited-authenticated-endpoint",
128+
self.mock_unlimited_protected_method,
129+
BaseResponse,
130+
["POST"],
131+
limited=False,
132+
)
107133

108134

109135
class TestTemplateServer:
@@ -485,11 +511,64 @@ def test_add_authenticated_route(self, mock_template_server: MockTemplateServer)
485511
assert "POST" in test_route.methods
486512
assert test_route.response_model == BaseResponse
487513

514+
def test_limited_parameter_with_rate_limiting_enabled(
515+
self, mock_template_server_config: TemplateServerConfig
516+
) -> None:
517+
"""Test that limited=True applies rate limiting when limiter is enabled."""
518+
mock_template_server_config.rate_limit.enabled = True
519+
server = MockTemplateServer(config=mock_template_server_config)
520+
521+
# Get the limited routes
522+
api_routes = [route for route in server.app.routes if isinstance(route, APIRoute)]
523+
limited_route = next((route for route in api_routes if route.path == "/unauthenticated-endpoint"), None)
524+
unlimited_route = next(
525+
(route for route in api_routes if route.path == "/unlimited-unauthenticated-endpoint"), None
526+
)
527+
528+
assert limited_route is not None
529+
assert unlimited_route is not None
530+
531+
# Limited route should have the limiter wrapper
532+
assert hasattr(limited_route.endpoint, "__wrapped__")
533+
# Unlimited route should not have the limiter wrapper
534+
assert not hasattr(unlimited_route.endpoint, "__wrapped__")
535+
536+
def test_authenticated_route_limited_parameter(self, mock_template_server_config: TemplateServerConfig) -> None:
537+
"""Test that limited parameter works correctly for authenticated routes."""
538+
mock_template_server_config.rate_limit.enabled = True
539+
server = MockTemplateServer(config=mock_template_server_config)
540+
541+
# Get the authenticated routes
542+
api_routes = [route for route in server.app.routes if isinstance(route, APIRoute)]
543+
limited_route = next((route for route in api_routes if route.path == "/authenticated-endpoint"), None)
544+
unlimited_route = next(
545+
(route for route in api_routes if route.path == "/unlimited-authenticated-endpoint"), None
546+
)
547+
548+
assert limited_route is not None
549+
assert unlimited_route is not None
550+
551+
# Both routes should have authentication dependencies
552+
assert len(limited_route.dependencies) > 0
553+
assert len(unlimited_route.dependencies) > 0
554+
555+
# Limited route should have the limiter wrapper
556+
assert hasattr(limited_route.endpoint, "__wrapped__")
557+
# Unlimited route should not have the limiter wrapper
558+
assert not hasattr(unlimited_route.endpoint, "__wrapped__")
559+
488560
def test_setup_routes(self, mock_template_server: MockTemplateServer) -> None:
489561
"""Test that routes are set up correctly."""
490562
api_routes = [route for route in mock_template_server.app.routes if isinstance(route, APIRoute)]
491563
routes = [route.path for route in api_routes]
492-
expected_endpoints = ["/health", "/metrics", "/unauthenticated-endpoint", "/authenticated-endpoint"]
564+
expected_endpoints = [
565+
"/health",
566+
"/metrics",
567+
"/unauthenticated-endpoint",
568+
"/authenticated-endpoint",
569+
"/unlimited-unauthenticated-endpoint",
570+
"/unlimited-authenticated-endpoint",
571+
]
493572
for endpoint in expected_endpoints:
494573
assert endpoint in routes
495574

0 commit comments

Comments
 (0)