Skip to content

Commit 1fbecaa

Browse files
committed
Draft for the async test client (encode#440)
1 parent 60a71ef commit 1fbecaa

File tree

5 files changed

+185
-2
lines changed

5 files changed

+185
-2
lines changed

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ databases[sqlite]
1515
isort
1616
mypy
1717
pytest
18+
pytest-asyncio
1819
pytest-cov
20+
requests_async
1921

2022
# Documentation
2123
mkdocs

starlette/testclient.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from urllib.parse import unquote, urljoin, urlsplit
1111

1212
import requests
13+
import requests_async
1314

1415
from starlette.types import Message, Receive, Scope, Send
1516
from starlette.websockets import WebSocketDisconnect
@@ -472,3 +473,38 @@ async def wait_shutdown(self) -> None:
472473
self.task.result()
473474
assert message["type"] == "lifespan.shutdown.complete"
474475
await self.task
476+
477+
478+
class AsyncTestClient(requests_async.ASGISession):
479+
async def lifespan(self) -> None:
480+
scope = {"type": "lifespan"}
481+
try:
482+
await self.app(scope, self.receive_queue.get, self.send_queue.put)
483+
finally:
484+
await self.send_queue.put(None)
485+
486+
async def wait_startup(self) -> None:
487+
await self.receive_queue.put({"type": "lifespan.startup"})
488+
message = await self.send_queue.get()
489+
if message is None:
490+
self.task.result()
491+
assert message["type"] == "lifespan.startup.complete"
492+
493+
async def wait_shutdown(self) -> None:
494+
await self.receive_queue.put({"type": "lifespan.shutdown"})
495+
message = await self.send_queue.get()
496+
if message is None:
497+
self.task.result()
498+
assert message["type"] == "lifespan.shutdown.complete"
499+
await self.task
500+
501+
async def __aenter__(self) -> requests_async.ASGISession:
502+
loop = asyncio.get_event_loop()
503+
self.send_queue = asyncio.Queue() # type: asyncio.Queue
504+
self.receive_queue = asyncio.Queue() # type: asyncio.Queue
505+
self.task = loop.create_task(self.lifespan())
506+
await self.wait_startup()
507+
return self
508+
509+
async def __aexit__(self, *args: typing.Any) -> None:
510+
await self.wait_shutdown()

tests/middleware/test_lifespan.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from starlette.applications import Starlette
44
from starlette.responses import PlainTextResponse
55
from starlette.routing import Lifespan, Route, Router
6-
from starlette.testclient import TestClient
6+
from starlette.testclient import AsyncTestClient, TestClient
77

88

99
def test_routed_lifespan():
@@ -106,3 +106,52 @@ async def run_shutdown():
106106
assert not shutdown_complete
107107
assert startup_complete
108108
assert shutdown_complete
109+
110+
111+
@pytest.mark.asyncio
112+
async def test_raise_on_startup_async_test_client():
113+
def run_startup():
114+
raise RuntimeError()
115+
116+
app = Router(routes=[Lifespan(on_startup=run_startup)])
117+
118+
with pytest.raises(RuntimeError):
119+
async with AsyncTestClient(app):
120+
pass # pragma: nocover
121+
122+
123+
@pytest.mark.asyncio
124+
async def test_raise_on_shutdown_async_test_client():
125+
def run_shutdown():
126+
raise RuntimeError()
127+
128+
app = Router(routes=[Lifespan(on_shutdown=run_shutdown)])
129+
130+
with pytest.raises(RuntimeError):
131+
async with AsyncTestClient(app):
132+
pass # pragma: nocover
133+
134+
135+
@pytest.mark.asyncio
136+
async def test_app_async_test_client_lifespan():
137+
startup_complete = False
138+
shutdown_complete = False
139+
app = Starlette()
140+
141+
@app.on_event("startup")
142+
async def run_startup():
143+
nonlocal startup_complete
144+
startup_complete = True
145+
146+
@app.on_event("shutdown")
147+
async def run_shutdown():
148+
nonlocal shutdown_complete
149+
shutdown_complete = True
150+
151+
assert not startup_complete
152+
assert not shutdown_complete
153+
async with AsyncTestClient(app):
154+
assert startup_complete
155+
assert not shutdown_complete
156+
assert startup_complete
157+
assert shutdown_complete

tests/test_asynctestclient.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import pytest
2+
3+
from starlette.applications import Starlette
4+
from starlette.responses import JSONResponse
5+
from starlette.testclient import AsyncTestClient
6+
7+
mock_service = Starlette()
8+
9+
10+
@mock_service.route("/")
11+
def mock_service_endpoint(request):
12+
return JSONResponse({"mock": "example"})
13+
14+
15+
app = Starlette()
16+
17+
18+
@pytest.mark.asyncio
19+
@app.route("/")
20+
async def homepage(request):
21+
client = AsyncTestClient(mock_service)
22+
response = await client.get("/")
23+
return JSONResponse(response.json())
24+
25+
26+
startup_error_app = Starlette()
27+
28+
29+
@startup_error_app.on_event("startup")
30+
def startup():
31+
raise RuntimeError()
32+
33+
34+
@pytest.mark.asyncio
35+
async def test_use_asynctestclient_in_endpoint():
36+
"""
37+
We should be able to use the test client within applications.
38+
39+
This is useful if we need to mock out other services,
40+
during tests or in development.
41+
"""
42+
client = AsyncTestClient(app)
43+
response = await client.get("/")
44+
assert response.json() == {"mock": "example"}
45+
46+
47+
@pytest.mark.asyncio
48+
async def test_asynctestclient_as_contextmanager():
49+
async with AsyncTestClient(app):
50+
pass
51+
52+
53+
@pytest.mark.asyncio
54+
async def test_error_on_startup():
55+
with pytest.raises(RuntimeError):
56+
async with AsyncTestClient(startup_error_app):
57+
pass # pragma: no cover
58+
59+
60+
# TODO test_asynctestclient_asgi2
61+
# `requests_async` is ASGI 3 only as of now
62+
63+
64+
@pytest.mark.asyncio
65+
async def test_asynctestclient_asgi3():
66+
async def app(scope, receive, send):
67+
await send(
68+
{
69+
"type": "http.response.start",
70+
"status": 200,
71+
"headers": [[b"content-type", b"text/plain"]],
72+
}
73+
)
74+
await send({"type": "http.response.body", "body": b"Hello, world!"})
75+
76+
client = AsyncTestClient(app)
77+
response = await client.get("/")
78+
assert response.text == "Hello, world!"

tests/test_database.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from starlette.applications import Starlette
66
from starlette.datastructures import CommaSeparatedStrings
77
from starlette.responses import JSONResponse
8-
from starlette.testclient import TestClient
8+
from starlette.testclient import AsyncTestClient, TestClient
99

1010
DATABASE_URL = "sqlite:///test.db"
1111

@@ -165,3 +165,21 @@ def test_database_isolated_during_test_cases():
165165
response = client.get("/notes")
166166
assert response.status_code == 200
167167
assert response.json() == [{"text": "just one note", "completed": True}]
168+
169+
170+
@pytest.mark.asyncio
171+
async def test_database_with_async_test_client():
172+
"""
173+
Using AsyncTestClient
174+
"""
175+
async with AsyncTestClient(app) as client:
176+
# Post a row to the DB
177+
response = await client.post(
178+
"/notes", json={"text": "just one note", "completed": True}
179+
)
180+
assert response.status_code == 200
181+
182+
# Call the DB explicitly
183+
query = notes.select()
184+
results = await database.fetch_all(query)
185+
assert results == [(1, "just one note", True)]

0 commit comments

Comments
 (0)