Skip to content

Commit 6a06579

Browse files
authored
Update request state usage to context variable (#853)
* Update request state usage to context variable * Update context to custom ctx * Fix the exception interception in opera log * restore elapsed
1 parent dd08775 commit 6a06579

File tree

17 files changed

+135
-91
lines changed

17 files changed

+135
-91
lines changed

backend/app/admin/service/login_log_service.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from backend.app.admin.crud.crud_login_log import login_log_dao
88
from backend.app.admin.schema.login_log import CreateLoginLogParam, DeleteLoginLogParam
9+
from backend.common.context import ctx
910
from backend.common.log import log
1011
from backend.database.db import async_db_session
1112

@@ -53,14 +54,14 @@ async def create(
5354
user_uuid=user_uuid,
5455
username=username,
5556
status=status,
56-
ip=request.state.ip,
57-
country=request.state.country,
58-
region=request.state.region,
59-
city=request.state.city,
60-
user_agent=request.state.user_agent,
61-
browser=request.state.browser,
62-
os=request.state.os,
63-
device=request.state.device,
57+
ip=ctx.ip,
58+
country=ctx.country,
59+
region=ctx.region,
60+
city=ctx.city,
61+
user_agent=ctx.user_agent,
62+
browser=ctx.browser,
63+
os=ctx.os,
64+
device=ctx.device,
6465
msg=msg,
6566
login_time=login_time,
6667
)

backend/app/admin/service/user_service.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ResetPasswordParam,
1515
UpdateUserParam,
1616
)
17+
from backend.common.context import ctx
1718
from backend.common.enums import UserPermissionType
1819
from backend.common.exception import errors
1920
from backend.common.response.response_code import CustomErrorCode
@@ -249,12 +250,12 @@ async def update_email(*, request: Request, captcha: str, email: str) -> int:
249250
user = await user_dao.get(db, token_payload.id)
250251
if not user:
251252
raise errors.NotFoundError(msg='用户不存在')
252-
captcha_code = await redis_client.get(f'{settings.EMAIL_CAPTCHA_REDIS_PREFIX}:{request.state.ip}')
253+
captcha_code = await redis_client.get(f'{settings.EMAIL_CAPTCHA_REDIS_PREFIX}:{ctx.ip}')
253254
if not captcha_code:
254255
raise errors.RequestError(msg='验证码已失效,请重新获取')
255256
if captcha != captcha_code:
256257
raise errors.CustomError(error=CustomErrorCode.CAPTCHA_ERROR)
257-
await redis_client.delete(f'{settings.EMAIL_CAPTCHA_REDIS_PREFIX}:{request.state.ip}')
258+
await redis_client.delete(f'{settings.EMAIL_CAPTCHA_REDIS_PREFIX}:{ctx.ip}')
258259
count = await user_dao.update_email(db, token_payload.id, email)
259260
await redis_client.delete(f'{settings.JWT_USER_REDIS_PREFIX}:{user.id}')
260261
return count

backend/common/context.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from datetime import datetime
2+
from typing import Any, Protocol
3+
4+
from starlette_context.ctx import _Context, context
5+
6+
7+
class TypedContextProtocol(Protocol):
8+
perf_time: float
9+
start_time: datetime
10+
11+
ip: str
12+
country: str | None
13+
region: str | None
14+
city: str | None
15+
16+
user_agent: str
17+
os: str | None
18+
browser: str | None
19+
device: str | None
20+
21+
permission: str | None
22+
23+
24+
class TypedContext(TypedContextProtocol, _Context):
25+
def __getattr__(self, name: str) -> Any:
26+
return context.get(name)
27+
28+
def __setattr__(self, name: str, value: Any) -> None:
29+
context[name] = value
30+
31+
32+
ctx = TypedContext()

backend/common/exception/exception_handler.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from starlette.exceptions import HTTPException
55
from uvicorn.protocols.http.h11_impl import STATUS_PHRASES
66

7+
from backend.common.context import ctx
78
from backend.common.exception.errors import BaseExceptionError
89
from backend.common.i18n import i18n, t
910
from backend.common.response.response_code import CustomResponseCode, StandardResponseCode
@@ -72,8 +73,8 @@ async def _validation_exception_handler(request: Request, exc: RequestValidation
7273
'msg': msg,
7374
'data': data,
7475
}
75-
request.state.__request_validation_exception__ = content # 用于在中间件中获取异常信息
76-
content.update(trace_id=get_request_trace_id(request))
76+
ctx.__request_validation_exception__ = content # 用于在中间件中获取异常信息
77+
content.update(trace_id=get_request_trace_id())
7778
return MsgSpecJSONResponse(status_code=StandardResponseCode.HTTP_422, content=content)
7879

7980

@@ -96,8 +97,8 @@ async def http_exception_handler(request: Request, exc: HTTPException):
9697
else:
9798
res = response_base.fail(res=CustomResponseCode.HTTP_400)
9899
content = res.model_dump()
99-
request.state.__request_http_exception__ = content
100-
content.update(trace_id=get_request_trace_id(request))
100+
ctx.__request_http_exception__ = content
101+
content.update(trace_id=get_request_trace_id())
101102
return MsgSpecJSONResponse(
102103
status_code=_get_exception_code(exc.status_code),
103104
content=content,
@@ -144,8 +145,8 @@ async def assertion_error_handler(request: Request, exc: AssertionError):
144145
else:
145146
res = response_base.fail(res=CustomResponseCode.HTTP_500)
146147
content = res.model_dump()
147-
request.state.__request_assertion_error__ = content
148-
content.update(trace_id=get_request_trace_id(request))
148+
ctx.__request_assertion_error__ = content
149+
content.update(trace_id=get_request_trace_id())
149150
return MsgSpecJSONResponse(
150151
status_code=StandardResponseCode.HTTP_500,
151152
content=content,
@@ -165,8 +166,8 @@ async def custom_exception_handler(request: Request, exc: BaseExceptionError):
165166
'msg': str(exc.msg),
166167
'data': exc.data or None,
167168
}
168-
request.state.__request_custom_exception__ = content
169-
content.update(trace_id=get_request_trace_id(request))
169+
ctx.__request_custom_exception__ = content
170+
content.update(trace_id=get_request_trace_id())
170171
return MsgSpecJSONResponse(
171172
status_code=_get_exception_code(exc.code),
172173
content=content,
@@ -191,7 +192,7 @@ async def all_unknown_exception_handler(request: Request, exc: Exception):
191192
else:
192193
res = response_base.fail(res=CustomResponseCode.HTTP_500)
193194
content = res.model_dump()
194-
content.update(trace_id=get_request_trace_id(request))
195+
content.update(trace_id=get_request_trace_id())
195196
return MsgSpecJSONResponse(
196197
status_code=StandardResponseCode.HTTP_500,
197198
content=content,

backend/common/log.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44
import re
55
import sys
66

7-
from asgi_correlation_id import correlation_id
87
from loguru import logger
98

9+
from backend.common.context import ctx
1010
from backend.core.conf import settings
1111
from backend.core.path_conf import LOG_DIR
1212
from backend.utils.timezone import timezone
13+
from backend.utils.trace_id import get_request_trace_id
1314

1415

1516
class InterceptHandler(logging.Handler):
@@ -75,11 +76,14 @@ def setup_logging() -> None:
7576
# 移除 loguru 默认处理器
7677
logger.remove()
7778

78-
# correlation_id 过滤器
79-
# https://github.com/snok/asgi-correlation-id/issues/7
80-
def correlation_id_filter(record: logging.LogRecord) -> logging.LogRecord:
81-
cid = correlation_id.get(settings.TRACE_ID_LOG_DEFAULT_VALUE)
82-
record['correlation_id'] = cid[: settings.TRACE_ID_LOG_LENGTH]
79+
# request_id 过滤器
80+
def request_id_filter(record: logging.LogRecord) -> logging.LogRecord:
81+
if ctx.exists():
82+
rid = get_request_trace_id()
83+
record['request_id'] = rid[: settings.TRACE_ID_LOG_LENGTH]
84+
else:
85+
record['request_id'] = settings.TRACE_ID_LOG_DEFAULT_VALUE
86+
8387
return record
8488

8589
# 配置 loguru 处理器
@@ -89,7 +93,7 @@ def correlation_id_filter(record: logging.LogRecord) -> logging.LogRecord:
8993
'sink': sys.stdout,
9094
'level': settings.LOG_STD_LEVEL,
9195
'format': default_formatter,
92-
'filter': lambda record: correlation_id_filter(record),
96+
'filter': lambda record: request_id_filter(record),
9397
},
9498
],
9599
)

backend/common/security/permission.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from sqlalchemy.ext.asyncio import AsyncSession
44

55
from backend.app.admin.crud.crud_data_scope import data_scope_dao
6+
from backend.common.context import ctx
67
from backend.common.enums import RoleDataRuleExpressionType, RoleDataRuleOperatorType
78
from backend.common.exception import errors
89
from backend.core.conf import settings
@@ -38,7 +39,7 @@ async def __call__(self, request: Request) -> None:
3839
if not isinstance(self.value, str):
3940
raise errors.ServerError
4041
# 附加权限标识到请求状态
41-
request.state.permission = self.value
42+
ctx.permission = self.value
4243

4344

4445
async def filter_data_permission(db: AsyncSession, request: Request) -> ColumnElement[bool]: # noqa: C901

backend/common/security/rbac.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from fastapi import Depends, Request
22

3+
from backend.common.context import ctx
34
from backend.common.enums import MethodType, StatusType
45
from backend.common.exception import errors
56
from backend.common.log import log
@@ -49,7 +50,7 @@ async def rbac_verify(request: Request, _token: str = DependsJwtAuth) -> None:
4950

5051
# RBAC 鉴权
5152
if settings.RBAC_ROLE_MENU_MODE:
52-
path_auth_perm = getattr(request.state, 'permission', None)
53+
path_auth_perm = ctx.permission
5354

5455
# 没有菜单操作权限标识不校验
5556
if not path_auth_perm:

backend/core/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class Settings(BaseSettings):
150150

151151
# 日志
152152
LOG_FORMAT: str = (
153-
'<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</> | <lvl>{level: <8}</> | <cyan>{correlation_id}</> | <lvl>{message}</>'
153+
'<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</> | <lvl>{level: <8}</> | <cyan>{request_id}</> | <lvl>{message}</>'
154154
)
155155

156156
# 日志(控制台)

backend/core/registrar.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,20 @@
66

77
import socketio
88

9-
from asgi_correlation_id import CorrelationIdMiddleware
109
from fastapi import Depends, FastAPI
1110
from fastapi_limiter import FastAPILimiter
1211
from fastapi_pagination import add_pagination
1312
from starlette.middleware.authentication import AuthenticationMiddleware
1413
from starlette.middleware.cors import CORSMiddleware
1514
from starlette.staticfiles import StaticFiles
1615
from starlette.types import ASGIApp
16+
from starlette_context.middleware import ContextMiddleware
17+
from starlette_context.plugins import RequestIdPlugin
1718

1819
from backend import __version__
1920
from backend.common.exception.exception_handler import register_exception
2021
from backend.common.log import set_custom_logfile, setup_logging
22+
from backend.common.response.response_code import StandardResponseCode
2123
from backend.core.conf import settings
2224
from backend.core.path_conf import STATIC_DIR, UPLOAD_DIR
2325
from backend.database.db import create_tables
@@ -154,8 +156,15 @@ def register_middleware(app: FastAPI) -> None:
154156
# Access log
155157
app.add_middleware(AccessMiddleware)
156158

157-
# Trace ID
158-
app.add_middleware(CorrelationIdMiddleware, validator=False)
159+
# ContextVar
160+
app.add_middleware(
161+
ContextMiddleware,
162+
plugins=[RequestIdPlugin(validate=True)],
163+
default_error_response=MsgSpecJSONResponse(
164+
content={'code': StandardResponseCode.HTTP_400, 'msg': 'BAD_REQUEST', 'data': None},
165+
status_code=StandardResponseCode.HTTP_400,
166+
),
167+
)
159168

160169

161170
def register_router(app: FastAPI) -> None:

backend/middleware/access_middleware.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from fastapi import Request, Response
44
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
55

6+
from backend.common.context import ctx
67
from backend.common.log import log
78
from backend.utils.timezone import timezone
89

@@ -24,21 +25,19 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -
2425
log.debug(f'--> 请求开始[{path}]')
2526

2627
perf_time = time.perf_counter()
27-
request.state.perf_time = perf_time
28+
ctx.perf_time = perf_time
2829

2930
start_time = timezone.now()
30-
request.state.start_time = start_time
31+
ctx.start_time = start_time
3132

3233
response = await call_next(request)
3334

34-
elapsed = (time.perf_counter() - perf_time) * 1000
35-
3635
if request.method != 'OPTIONS':
3736
log.debug('<-- 请求结束')
3837

3938
log.info(
4039
f'{request.client.host: <15} | {request.method: <8} | {response.status_code: <6} | '
41-
f'{path} | {elapsed:.3f}ms',
40+
f'{path} | {(time.perf_counter() - perf_time) * 1000:.3f}ms',
4241
)
4342

4443
return response

0 commit comments

Comments
 (0)