Skip to content

Commit 9d9bd05

Browse files
authored
Merge pull request #1690 from vitalik/uniont-type
Optional Union fix
2 parents 7d549d5 + a6b0d58 commit 9d9bd05

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

ninja/signature/details.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]]
208208
return flatten_map
209209

210210
def _model_flatten_map(self, model: TModel, prefix: str) -> Generator:
211+
model = _unwrap_union_model(model)
211212
field: FieldInfo
212213
for attr, field in model.model_fields.items():
213214
field_name = field.alias or attr
@@ -306,6 +307,15 @@ def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
306307
)
307308

308309

310+
def _unwrap_union_model(annotation: Any) -> Any:
311+
"""If annotation is a Union containing a pydantic model, return that model class."""
312+
if get_origin(annotation) in UNION_TYPES:
313+
for arg in get_args(annotation):
314+
if arg is not type(None) and is_pydantic_model(arg):
315+
return arg
316+
return annotation
317+
318+
309319
def is_pydantic_model(cls: Any) -> bool:
310320
try:
311321
origin = get_origin(cls)
@@ -360,6 +370,7 @@ def detect_collection_fields(
360370
for attr in path[1:]:
361371
if hasattr(annotation_or_field, "annotation"):
362372
annotation_or_field = annotation_or_field.annotation
373+
annotation_or_field = _unwrap_union_model(annotation_or_field)
363374
annotation_or_field = next(
364375
(
365376
a

tests/test_query_schema.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
import sys
12
from datetime import datetime
23
from enum import IntEnum
4+
from typing import Optional
35

6+
import pytest
47
from pydantic import BaseModel, Field
58

69
from ninja import NinjaAPI, Query, Schema
710
from ninja.testing.client import TestClient
811

12+
PY_310 = sys.version_info >= (3, 10)
13+
914

1015
class Range(IntEnum):
1116
TWENTY = 20
@@ -163,3 +168,87 @@ def test_schema_all_of_no_ref():
163168
"default": 1,
164169
"allOf": [{"title": "Best Type Ever!"}, {"no-ref-here": "xyzzy"}],
165170
}
171+
172+
173+
@pytest.mark.skipif(not PY_310, reason="requires Python 3.10+ pipe syntax")
174+
def test_unwrap_union_model_no_model():
175+
"""_unwrap_union_model returns input as-is when union has no pydantic model."""
176+
from ninja.signature.details import _unwrap_union_model
177+
178+
annotation = int | None
179+
assert _unwrap_union_model(annotation) is annotation
180+
assert _unwrap_union_model(str) is str
181+
182+
183+
def test_optional_query_schema():
184+
"""Optional[Model] in Query should not crash (issue #1634)."""
185+
186+
class MyFilter(Schema):
187+
name: str = "default"
188+
189+
temp_api = NinjaAPI()
190+
191+
@temp_api.get("/opt")
192+
def view(request, f: Optional[MyFilter] = Query(None)):
193+
if f:
194+
return f.model_dump()
195+
return {}
196+
197+
client = TestClient(temp_api)
198+
199+
resp = client.get("/opt?name=hello")
200+
assert resp.status_code == 200
201+
assert resp.json() == {"name": "hello"}
202+
203+
resp = client.get("/opt")
204+
assert resp.status_code == 200
205+
206+
207+
@pytest.mark.skipif(not PY_310, reason="requires Python 3.10+ pipe syntax")
208+
def test_union_pipe_syntax_query_schema():
209+
"""Model | None in Query should not crash (issue #1634, Python 3.10+ pipe syntax)."""
210+
211+
class MyFilter(Schema):
212+
name: str = "default"
213+
214+
temp_api = NinjaAPI()
215+
216+
@temp_api.get("/pipe")
217+
def view(request, f: MyFilter | None = Query(None)):
218+
if f:
219+
return f.model_dump()
220+
return {}
221+
222+
client = TestClient(temp_api)
223+
224+
resp = client.get("/pipe?name=world")
225+
assert resp.status_code == 200
226+
assert resp.json() == {"name": "world"}
227+
228+
resp = client.get("/pipe")
229+
assert resp.status_code == 200
230+
231+
232+
@pytest.mark.skipif(not PY_310, reason="requires Python 3.10+ pipe syntax")
233+
def test_nested_optional_query_schema():
234+
"""Nested optional model fields should also work."""
235+
236+
class Inner(Schema):
237+
value: Optional[int] = 0
238+
items: list[str] = []
239+
240+
class Outer(Schema):
241+
inner: Inner | None = None
242+
label: str = "x"
243+
244+
temp_api = NinjaAPI()
245+
246+
@temp_api.get("/nested")
247+
def view(request, f: Outer = Query(...)):
248+
return f.model_dump()
249+
250+
client = TestClient(temp_api)
251+
252+
resp = client.get("/nested?label=test")
253+
assert resp.status_code == 200
254+
assert resp.json()["label"] == "test"

0 commit comments

Comments
 (0)