Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions litestar/plugins/pydantic/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,7 @@ def generate_field_definitions(
stacklevel=2,
)

if not is_pydantic_undefined(field_info.default):
default = field_info.default
elif field_definition.is_optional:
default = None
else:
default = Empty
default = field_info.default if not is_pydantic_undefined(field_info.default) else Empty

default_factory = (
field_info.default_factory
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_plugins/test_pydantic/test_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Model(base_model): # type: ignore[misc, valid-type]
dto_type = PydanticDTO[Model]
field_defs = list(dto_type.generate_field_definitions(Model))
assert len(field_defs) == 1
assert field_defs[0].default is None
assert field_defs[0].default is Empty


def test_detect_nested_field_pydantic_v1(monkeypatch: pytest.MonkeyPatch) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import replace
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable, List, Optional, Union
from unittest.mock import ANY

import pydantic as pydantic_v2
Expand All @@ -11,6 +11,7 @@

from litestar.dto import DTOField, DTOFieldDefinition, Mark, dto_field
from litestar.plugins.pydantic import PydanticDTO
from litestar.types.builtin_types import NoneType
from litestar.typing import FieldDefinition

from . import PydanticVersion
Expand Down Expand Up @@ -91,13 +92,41 @@ def expected_field_defs(int_factory: Callable[[], int]) -> list[DTOFieldDefiniti
default_factory=int_factory,
dto_field=DTOField(),
),
default=None,
metadata=ANY,
type_wrappers=ANY,
raw=ANY,
kwarg_definition=ANY,
passthrough_constraints=False,
),
DTOFieldDefinition.from_field_definition(
field_definition=FieldDefinition.from_kwarg(
annotation=Optional[int],
name="f",
),
model_name=ANY,
default_factory=None,
dto_field=DTOField(),
passthrough_constraints=False,
),
replace(
DTOFieldDefinition.from_field_definition(
field_definition=FieldDefinition.from_kwarg(
annotation=List[str],
name="g",
),
model_name=ANY,
default_factory=list,
dto_field=DTOField(),
passthrough_constraints=False,
),
annotation=Optional[List[str]],
origin=Union,
instantiable_origin=Union,
safe_generic_origin=Union,
inner_types=ANY,
args=(List[str], NoneType),
raw=Optional[List[str]],
),
]


Expand All @@ -111,6 +140,8 @@ class TestModel(pydantic_v1.BaseModel):
c: Annotated[int, pydantic_v1.Field(gt=1)]
d: int = pydantic_v1.Field(default=1)
e: int = pydantic_v1.Field(default_factory=int_factory)
f: Optional[int] # noqa: UP007
g: List[str] = pydantic_v1.Field(default_factory=list) # noqa: UP006

field_defs = list(PydanticDTO.generate_field_definitions(TestModel))
assert field_defs[0].model_name == "TestModel"
Expand All @@ -128,6 +159,8 @@ class TestModel(pydantic_v2.BaseModel):
c: Annotated[int, pydantic_v2.Field(gt=1)]
d: int = pydantic_v2.Field(default=1)
e: int = pydantic_v2.Field(default_factory=int_factory)
f: Optional[int] # noqa: UP007
g: List[str] = pydantic_v2.Field(default_factory=list) # noqa: UP006

field_defs = list(PydanticDTO.generate_field_definitions(TestModel))
assert field_defs[0].model_name == "TestModel"
Expand Down
Loading