Skip to content

Commit c2c3c54

Browse files
committed
Add back authorization to the /revoke endpoint, simplify revoke
1 parent 66fb120 commit c2c3c54

File tree

3 files changed

+51
-33
lines changed

3 files changed

+51
-33
lines changed

src/mcp/server/auth/handlers/revoke.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
from dataclasses import dataclass
2+
from functools import partial
23
from typing import Literal
34

45
from pydantic import BaseModel, ValidationError
56
from starlette.requests import Request
67
from starlette.responses import Response
78

89
from mcp.server.auth.errors import (
10+
InvalidClientError,
911
stringify_pydantic_error,
1012
)
1113
from mcp.server.auth.json_response import PydanticJSONResponse
1214
from mcp.server.auth.middleware.client_auth import (
1315
ClientAuthenticator,
1416
)
15-
from mcp.server.auth.provider import OAuthServerProvider
17+
from mcp.server.auth.provider import AuthInfo, OAuthServerProvider, RefreshToken
1618

1719

1820
class RevocationRequest(BaseModel):
@@ -22,6 +24,8 @@ class RevocationRequest(BaseModel):
2224

2325
token: str
2426
token_type_hint: Literal["access_token", "refresh_token"] | None = None
27+
client_id: str
28+
client_secret: str | None
2529

2630

2731
class RevocationErrorResponse(BaseModel):
@@ -50,10 +54,30 @@ async def handle(self, request: Request) -> Response:
5054
),
5155
)
5256

53-
# Revoke token
54-
await self.provider.revoke_token(
55-
revocation_request.token, revocation_request.token_type_hint
56-
)
57+
# Authenticate client
58+
try:
59+
client = await self.client_authenticator.authenticate(
60+
revocation_request.client_id, revocation_request.client_secret
61+
)
62+
except InvalidClientError as e:
63+
return PydanticJSONResponse(status_code=401, content=e.error_response())
64+
65+
loaders = [
66+
self.provider.load_access_token,
67+
partial(self.provider.load_refresh_token, client),
68+
]
69+
if revocation_request.token_type_hint == "refresh_token":
70+
loaders = reversed(loaders)
71+
72+
token: None | AuthInfo | RefreshToken = None
73+
for loader in loaders:
74+
token = await loader(revocation_request.token)
75+
if token is not None:
76+
break
77+
78+
if token and token.client_id == client.client_id:
79+
# Revoke token
80+
await self.provider.revoke_token(token)
5781

5882
# Return successful empty response
5983
return Response(

src/mcp/server/auth/provider.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Literal, Protocol
1+
from typing import Protocol
22
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
33

44
from pydantic import AnyHttpUrl, BaseModel
@@ -172,8 +172,7 @@ async def load_access_token(self, token: str) -> AuthInfo | None:
172172

173173
async def revoke_token(
174174
self,
175-
token: str,
176-
token_type_hint: Literal["access_token", "refresh_token"] | None = None,
175+
token: AuthInfo | RefreshToken,
177176
) -> None:
178177
"""
179178
Revokes an access or refresh token.

tests/server/fastmcp/auth/test_auth_integration.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import secrets
99
import time
1010
import unittest.mock
11-
from typing import Literal
1211
from urllib.parse import parse_qs, urlparse
1312

1413
import anyio
@@ -164,11 +163,12 @@ async def exchange_refresh_token(
164163
new_refresh_token = f"refresh_{secrets.token_hex(32)}"
165164

166165
# Store the new tokens
167-
self.tokens[new_access_token] = {
168-
"client_id": client.client_id,
169-
"scopes": scopes or token_info.scopes,
170-
"expires_at": int(time.time()) + 3600,
171-
}
166+
self.tokens[new_access_token] = AuthInfo(
167+
token=new_access_token,
168+
client_id=client.client_id,
169+
scopes=scopes or token_info.scopes,
170+
expires_at=int(time.time()) + 3600,
171+
)
172172

173173
self.refresh_tokens[new_refresh_token] = new_access_token
174174

@@ -198,25 +198,20 @@ async def load_access_token(self, token: str) -> AuthInfo | None:
198198
expires_at=token_info.expires_at,
199199
)
200200

201-
async def revoke_token(
202-
self,
203-
token: str,
204-
token_type_hint: Literal["access_token", "refresh_token"] | None = None,
205-
) -> None:
206-
# Check if it's a refresh token
207-
if token in self.refresh_tokens:
208-
# Remove the refresh token
209-
del self.refresh_tokens[token]
210-
211-
# Check if it's an access token
212-
elif token in self.tokens:
213-
# Remove the access token
214-
del self.tokens[token]
215-
216-
# Also remove any refresh tokens that point to this access token
217-
for refresh_token, access_token in list(self.refresh_tokens.items()):
218-
if access_token == token:
219-
del self.refresh_tokens[refresh_token]
201+
async def revoke_token(self, token: OAuthToken | RefreshToken) -> None:
202+
match token:
203+
case RefreshToken():
204+
# Remove the refresh token
205+
del self.refresh_tokens[token.token]
206+
207+
case AuthInfo():
208+
# Remove the access token
209+
del self.tokens[token.token]
210+
211+
# Also remove any refresh tokens that point to this access token
212+
for refresh_token, access_token in list(self.refresh_tokens.items()):
213+
if access_token == token.token:
214+
del self.refresh_tokens[refresh_token]
220215

221216

222217
@pytest.fixture

0 commit comments

Comments
 (0)