@@ -208,10 +208,159 @@ async def test_discover_oauth_metadata_request(self, oauth_provider):
208
208
"""Test OAuth metadata discovery request building."""
209
209
request = await oauth_provider ._discover_oauth_metadata ()
210
210
211
+ assert request .method == "GET"
212
+ assert str (request .url ) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp"
213
+ assert "mcp-protocol-version" in request .headers
214
+
215
+ @pytest .mark .anyio
216
+ async def test_discover_oauth_metadata_request_no_path (self , client_metadata , mock_storage ):
217
+ """Test OAuth metadata discovery request building when server has no path."""
218
+
219
+ async def redirect_handler (url : str ) -> None :
220
+ pass
221
+
222
+ async def callback_handler () -> tuple [str , str | None ]:
223
+ return "test_auth_code" , "test_state"
224
+
225
+ provider = OAuthClientProvider (
226
+ server_url = "https://api.example.com" ,
227
+ client_metadata = client_metadata ,
228
+ storage = mock_storage ,
229
+ redirect_handler = redirect_handler ,
230
+ callback_handler = callback_handler ,
231
+ )
232
+
233
+ request = await provider ._discover_oauth_metadata ()
234
+
235
+ assert request .method == "GET"
236
+ assert str (request .url ) == "https://api.example.com/.well-known/oauth-authorization-server"
237
+ assert "mcp-protocol-version" in request .headers
238
+
239
+ @pytest .mark .anyio
240
+ async def test_discover_oauth_metadata_request_trailing_slash (self , client_metadata , mock_storage ):
241
+ """Test OAuth metadata discovery request building when server path has trailing slash."""
242
+
243
+ async def redirect_handler (url : str ) -> None :
244
+ pass
245
+
246
+ async def callback_handler () -> tuple [str , str | None ]:
247
+ return "test_auth_code" , "test_state"
248
+
249
+ provider = OAuthClientProvider (
250
+ server_url = "https://api.example.com/v1/mcp/" ,
251
+ client_metadata = client_metadata ,
252
+ storage = mock_storage ,
253
+ redirect_handler = redirect_handler ,
254
+ callback_handler = callback_handler ,
255
+ )
256
+
257
+ request = await provider ._discover_oauth_metadata ()
258
+
259
+ assert request .method == "GET"
260
+ assert str (request .url ) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp"
261
+ assert "mcp-protocol-version" in request .headers
262
+
263
+
264
+ class TestOAuthFallback :
265
+ """Test OAuth discovery fallback behavior for legacy (act as AS not RS) servers."""
266
+
267
+ @pytest .mark .anyio
268
+ async def test_fallback_discovery_request (self , client_metadata , mock_storage ):
269
+ """Test fallback discovery request building."""
270
+
271
+ async def redirect_handler (url : str ) -> None :
272
+ pass
273
+
274
+ async def callback_handler () -> tuple [str , str | None ]:
275
+ return "test_auth_code" , "test_state"
276
+
277
+ provider = OAuthClientProvider (
278
+ server_url = "https://api.example.com/v1/mcp" ,
279
+ client_metadata = client_metadata ,
280
+ storage = mock_storage ,
281
+ redirect_handler = redirect_handler ,
282
+ callback_handler = callback_handler ,
283
+ )
284
+
285
+ # Set up discovery state manually as if path-aware discovery was attempted
286
+ provider .context .discovery_base_url = "https://api.example.com"
287
+ provider .context .discovery_pathname = "/v1/mcp"
288
+
289
+ # Test fallback request building
290
+ request = await provider ._discover_oauth_metadata_fallback ()
291
+
211
292
assert request .method == "GET"
212
293
assert str (request .url ) == "https://api.example.com/.well-known/oauth-authorization-server"
213
294
assert "mcp-protocol-version" in request .headers
214
295
296
+ @pytest .mark .anyio
297
+ async def test_should_attempt_fallback (self , oauth_provider ):
298
+ """Test fallback decision logic."""
299
+ # Should attempt fallback on 404 with non-root path
300
+ assert oauth_provider ._should_attempt_fallback (404 , "/v1/mcp" )
301
+
302
+ # Should NOT attempt fallback on 404 with root path
303
+ assert not oauth_provider ._should_attempt_fallback (404 , "/" )
304
+
305
+ # Should NOT attempt fallback on other status codes
306
+ assert not oauth_provider ._should_attempt_fallback (200 , "/v1/mcp" )
307
+ assert not oauth_provider ._should_attempt_fallback (500 , "/v1/mcp" )
308
+
309
+ @pytest .mark .anyio
310
+ async def test_handle_metadata_response_success (self , oauth_provider ):
311
+ """Test successful metadata response handling."""
312
+ # Create minimal valid OAuth metadata
313
+ content = b"""{
314
+ "issuer": "https://auth.example.com",
315
+ "authorization_endpoint": "https://auth.example.com/authorize",
316
+ "token_endpoint": "https://auth.example.com/token"
317
+ }"""
318
+ response = httpx .Response (200 , content = content )
319
+
320
+ # Should return True (success) and set metadata
321
+ result = await oauth_provider ._handle_oauth_metadata_response (response , is_fallback = False )
322
+ assert result is True
323
+ assert oauth_provider .context .oauth_metadata is not None
324
+ assert str (oauth_provider .context .oauth_metadata .issuer ) == "https://auth.example.com/"
325
+
326
+ @pytest .mark .anyio
327
+ async def test_handle_metadata_response_404_needs_fallback (self , oauth_provider ):
328
+ """Test 404 response handling that should trigger fallback."""
329
+ # Set up discovery state for non-root path
330
+ oauth_provider .context .discovery_base_url = "https://api.example.com"
331
+ oauth_provider .context .discovery_pathname = "/v1/mcp"
332
+
333
+ # Mock 404 response
334
+ response = httpx .Response (404 )
335
+
336
+ # Should return False (needs fallback)
337
+ result = await oauth_provider ._handle_oauth_metadata_response (response , is_fallback = False )
338
+ assert result is False
339
+
340
+ @pytest .mark .anyio
341
+ async def test_handle_metadata_response_404_no_fallback_needed (self , oauth_provider ):
342
+ """Test 404 response handling when no fallback is needed."""
343
+ # Set up discovery state for root path
344
+ oauth_provider .context .discovery_base_url = "https://api.example.com"
345
+ oauth_provider .context .discovery_pathname = "/"
346
+
347
+ # Mock 404 response
348
+ response = httpx .Response (404 )
349
+
350
+ # Should return True (no fallback needed)
351
+ result = await oauth_provider ._handle_oauth_metadata_response (response , is_fallback = False )
352
+ assert result is True
353
+
354
+ @pytest .mark .anyio
355
+ async def test_handle_metadata_response_404_fallback_attempt (self , oauth_provider ):
356
+ """Test 404 response handling during fallback attempt."""
357
+ # Mock 404 response during fallback
358
+ response = httpx .Response (404 )
359
+
360
+ # Should return True (fallback attempt complete, no further action needed)
361
+ result = await oauth_provider ._handle_oauth_metadata_response (response , is_fallback = True )
362
+ assert result is True
363
+
215
364
@pytest .mark .anyio
216
365
async def test_register_client_request (self , oauth_provider ):
217
366
"""Test client registration request building."""
0 commit comments