Skip to content

Commit 6747688

Browse files
authored
Fix /.well-known/oauth-authorization-server dropping path (#1014)
1 parent ce007de commit 6747688

File tree

2 files changed

+219
-10
lines changed

2 files changed

+219
-10
lines changed

src/mcp/client/auth.py

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ class OAuthContext:
106106
# State
107107
lock: anyio.Lock = field(default_factory=anyio.Lock)
108108

109+
# Discovery state for fallback support
110+
discovery_base_url: str | None = None
111+
discovery_pathname: str | None = None
112+
109113
def get_authorization_base_url(self, server_url: str) -> str:
110114
"""Extract base URL by removing path component."""
111115
parsed = urlparse(server_url)
@@ -197,18 +201,53 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
197201
except ValidationError:
198202
pass
199203

204+
def _build_well_known_path(self, pathname: str) -> str:
205+
"""Construct well-known path for OAuth metadata discovery."""
206+
well_known_path = f"/.well-known/oauth-authorization-server{pathname}"
207+
if pathname.endswith("/"):
208+
# Strip trailing slash from pathname to avoid double slashes
209+
well_known_path = well_known_path[:-1]
210+
return well_known_path
211+
212+
def _should_attempt_fallback(self, response_status: int, pathname: str) -> bool:
213+
"""Determine if fallback to root discovery should be attempted."""
214+
return response_status == 404 and pathname != "/"
215+
216+
async def _try_metadata_discovery(self, url: str) -> httpx.Request:
217+
"""Build metadata discovery request for a specific URL."""
218+
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
219+
200220
async def _discover_oauth_metadata(self) -> httpx.Request:
201-
"""Build OAuth metadata discovery request."""
221+
"""Build OAuth metadata discovery request with fallback support."""
202222
if self.context.auth_server_url:
203-
base_url = self.context.get_authorization_base_url(self.context.auth_server_url)
223+
auth_server_url = self.context.auth_server_url
204224
else:
205-
base_url = self.context.get_authorization_base_url(self.context.server_url)
225+
auth_server_url = self.context.server_url
226+
227+
# Per RFC 8414, try path-aware discovery first
228+
parsed = urlparse(auth_server_url)
229+
well_known_path = self._build_well_known_path(parsed.path)
230+
base_url = f"{parsed.scheme}://{parsed.netloc}"
231+
url = urljoin(base_url, well_known_path)
232+
233+
# Store fallback info for use in response handler
234+
self.context.discovery_base_url = base_url
235+
self.context.discovery_pathname = parsed.path
206236

237+
return await self._try_metadata_discovery(url)
238+
239+
async def _discover_oauth_metadata_fallback(self) -> httpx.Request:
240+
"""Build fallback OAuth metadata discovery request for legacy servers."""
241+
base_url = getattr(self.context, "discovery_base_url", "")
242+
if not base_url:
243+
raise OAuthFlowError("No base URL available for fallback discovery")
244+
245+
# Fallback to root discovery for legacy servers
207246
url = urljoin(base_url, "/.well-known/oauth-authorization-server")
208-
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
247+
return await self._try_metadata_discovery(url)
209248

210-
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
211-
"""Handle OAuth metadata response."""
249+
async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fallback: bool = False) -> bool:
250+
"""Handle OAuth metadata response. Returns True if handled successfully."""
212251
if response.status_code == 200:
213252
try:
214253
content = await response.aread()
@@ -217,9 +256,18 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response) -> Non
217256
# Apply default scope if none specified
218257
if self.context.client_metadata.scope is None and metadata.scopes_supported is not None:
219258
self.context.client_metadata.scope = " ".join(metadata.scopes_supported)
259+
return True
220260
except ValidationError:
221261
pass
222262

263+
# Check if we should attempt fallback (404 on path-aware discovery)
264+
if not is_fallback and self._should_attempt_fallback(
265+
response.status_code, getattr(self.context, "discovery_pathname", "/")
266+
):
267+
return False # Signal that fallback should be attempted
268+
269+
return True # Signal no fallback needed (either success or non-404 error)
270+
223271
async def _register_client(self) -> httpx.Request | None:
224272
"""Build registration request or skip if already registered."""
225273
if self.context.client_info:
@@ -418,10 +466,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
418466
discovery_response = yield discovery_request
419467
await self._handle_protected_resource_response(discovery_response)
420468

421-
# Step 2: Discover OAuth metadata
469+
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
422470
oauth_request = await self._discover_oauth_metadata()
423471
oauth_response = yield oauth_request
424-
await self._handle_oauth_metadata_response(oauth_response)
472+
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)
473+
474+
# If path-aware discovery failed with 404, try fallback to root
475+
if not handled:
476+
fallback_request = await self._discover_oauth_metadata_fallback()
477+
fallback_response = yield fallback_request
478+
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)
425479

426480
# Step 3: Register client if needed
427481
registration_request = await self._register_client()
@@ -464,10 +518,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
464518
discovery_response = yield discovery_request
465519
await self._handle_protected_resource_response(discovery_response)
466520

467-
# Step 2: Discover OAuth metadata
521+
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
468522
oauth_request = await self._discover_oauth_metadata()
469523
oauth_response = yield oauth_request
470-
await self._handle_oauth_metadata_response(oauth_response)
524+
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)
525+
526+
# If path-aware discovery failed with 404, try fallback to root
527+
if not handled:
528+
fallback_request = await self._discover_oauth_metadata_fallback()
529+
fallback_response = yield fallback_request
530+
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)
471531

472532
# Step 3: Register client if needed
473533
registration_request = await self._register_client()

tests/client/test_auth.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,159 @@ async def test_discover_oauth_metadata_request(self, oauth_provider):
208208
"""Test OAuth metadata discovery request building."""
209209
request = await oauth_provider._discover_oauth_metadata()
210210

211+
assert request.method == "GET"
212+
assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp"
213+
assert "mcp-protocol-version" in request.headers
214+
215+
@pytest.mark.anyio
216+
async def test_discover_oauth_metadata_request_no_path(self, client_metadata, mock_storage):
217+
"""Test OAuth metadata discovery request building when server has no path."""
218+
219+
async def redirect_handler(url: str) -> None:
220+
pass
221+
222+
async def callback_handler() -> tuple[str, str | None]:
223+
return "test_auth_code", "test_state"
224+
225+
provider = OAuthClientProvider(
226+
server_url="https://api.example.com",
227+
client_metadata=client_metadata,
228+
storage=mock_storage,
229+
redirect_handler=redirect_handler,
230+
callback_handler=callback_handler,
231+
)
232+
233+
request = await provider._discover_oauth_metadata()
234+
235+
assert request.method == "GET"
236+
assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server"
237+
assert "mcp-protocol-version" in request.headers
238+
239+
@pytest.mark.anyio
240+
async def test_discover_oauth_metadata_request_trailing_slash(self, client_metadata, mock_storage):
241+
"""Test OAuth metadata discovery request building when server path has trailing slash."""
242+
243+
async def redirect_handler(url: str) -> None:
244+
pass
245+
246+
async def callback_handler() -> tuple[str, str | None]:
247+
return "test_auth_code", "test_state"
248+
249+
provider = OAuthClientProvider(
250+
server_url="https://api.example.com/v1/mcp/",
251+
client_metadata=client_metadata,
252+
storage=mock_storage,
253+
redirect_handler=redirect_handler,
254+
callback_handler=callback_handler,
255+
)
256+
257+
request = await provider._discover_oauth_metadata()
258+
259+
assert request.method == "GET"
260+
assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp"
261+
assert "mcp-protocol-version" in request.headers
262+
263+
264+
class TestOAuthFallback:
265+
"""Test OAuth discovery fallback behavior for legacy (act as AS not RS) servers."""
266+
267+
@pytest.mark.anyio
268+
async def test_fallback_discovery_request(self, client_metadata, mock_storage):
269+
"""Test fallback discovery request building."""
270+
271+
async def redirect_handler(url: str) -> None:
272+
pass
273+
274+
async def callback_handler() -> tuple[str, str | None]:
275+
return "test_auth_code", "test_state"
276+
277+
provider = OAuthClientProvider(
278+
server_url="https://api.example.com/v1/mcp",
279+
client_metadata=client_metadata,
280+
storage=mock_storage,
281+
redirect_handler=redirect_handler,
282+
callback_handler=callback_handler,
283+
)
284+
285+
# Set up discovery state manually as if path-aware discovery was attempted
286+
provider.context.discovery_base_url = "https://api.example.com"
287+
provider.context.discovery_pathname = "/v1/mcp"
288+
289+
# Test fallback request building
290+
request = await provider._discover_oauth_metadata_fallback()
291+
211292
assert request.method == "GET"
212293
assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server"
213294
assert "mcp-protocol-version" in request.headers
214295

296+
@pytest.mark.anyio
297+
async def test_should_attempt_fallback(self, oauth_provider):
298+
"""Test fallback decision logic."""
299+
# Should attempt fallback on 404 with non-root path
300+
assert oauth_provider._should_attempt_fallback(404, "/v1/mcp")
301+
302+
# Should NOT attempt fallback on 404 with root path
303+
assert not oauth_provider._should_attempt_fallback(404, "/")
304+
305+
# Should NOT attempt fallback on other status codes
306+
assert not oauth_provider._should_attempt_fallback(200, "/v1/mcp")
307+
assert not oauth_provider._should_attempt_fallback(500, "/v1/mcp")
308+
309+
@pytest.mark.anyio
310+
async def test_handle_metadata_response_success(self, oauth_provider):
311+
"""Test successful metadata response handling."""
312+
# Create minimal valid OAuth metadata
313+
content = b"""{
314+
"issuer": "https://auth.example.com",
315+
"authorization_endpoint": "https://auth.example.com/authorize",
316+
"token_endpoint": "https://auth.example.com/token"
317+
}"""
318+
response = httpx.Response(200, content=content)
319+
320+
# Should return True (success) and set metadata
321+
result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False)
322+
assert result is True
323+
assert oauth_provider.context.oauth_metadata is not None
324+
assert str(oauth_provider.context.oauth_metadata.issuer) == "https://auth.example.com/"
325+
326+
@pytest.mark.anyio
327+
async def test_handle_metadata_response_404_needs_fallback(self, oauth_provider):
328+
"""Test 404 response handling that should trigger fallback."""
329+
# Set up discovery state for non-root path
330+
oauth_provider.context.discovery_base_url = "https://api.example.com"
331+
oauth_provider.context.discovery_pathname = "/v1/mcp"
332+
333+
# Mock 404 response
334+
response = httpx.Response(404)
335+
336+
# Should return False (needs fallback)
337+
result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False)
338+
assert result is False
339+
340+
@pytest.mark.anyio
341+
async def test_handle_metadata_response_404_no_fallback_needed(self, oauth_provider):
342+
"""Test 404 response handling when no fallback is needed."""
343+
# Set up discovery state for root path
344+
oauth_provider.context.discovery_base_url = "https://api.example.com"
345+
oauth_provider.context.discovery_pathname = "/"
346+
347+
# Mock 404 response
348+
response = httpx.Response(404)
349+
350+
# Should return True (no fallback needed)
351+
result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False)
352+
assert result is True
353+
354+
@pytest.mark.anyio
355+
async def test_handle_metadata_response_404_fallback_attempt(self, oauth_provider):
356+
"""Test 404 response handling during fallback attempt."""
357+
# Mock 404 response during fallback
358+
response = httpx.Response(404)
359+
360+
# Should return True (fallback attempt complete, no further action needed)
361+
result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=True)
362+
assert result is True
363+
215364
@pytest.mark.anyio
216365
async def test_register_client_request(self, oauth_provider):
217366
"""Test client registration request building."""

0 commit comments

Comments
 (0)