diff --git a/litestar/plugins/pydantic/dto.py b/litestar/plugins/pydantic/dto.py index b75cdc5a2e..02921d6c55 100644 --- a/litestar/plugins/pydantic/dto.py +++ b/litestar/plugins/pydantic/dto.py @@ -137,12 +137,7 @@ def generate_field_definitions( stacklevel=2, ) - if not is_pydantic_undefined(field_info.default): - default = field_info.default - elif field_definition.is_optional: - default = None - else: - default = Empty + default = field_info.default if not is_pydantic_undefined(field_info.default) else Empty default_factory = ( field_info.default_factory diff --git a/tests/unit/test_plugins/test_pydantic/test_dto.py b/tests/unit/test_plugins/test_pydantic/test_dto.py index fe7cbab991..a013f419c0 100644 --- a/tests/unit/test_plugins/test_pydantic/test_dto.py +++ b/tests/unit/test_plugins/test_pydantic/test_dto.py @@ -57,7 +57,7 @@ class Model(base_model): # type: ignore[misc, valid-type] dto_type = PydanticDTO[Model] field_defs = list(dto_type.generate_field_definitions(Model)) assert len(field_defs) == 1 - assert field_defs[0].default is None + assert field_defs[0].default is Empty def test_detect_nested_field_pydantic_v1(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/tests/unit/test_plugins/test_pydantic/test_pydantic_dto_factory.py b/tests/unit/test_plugins/test_pydantic/test_pydantic_dto_factory.py index e609bbaddb..15cfceaad3 100644 --- a/tests/unit/test_plugins/test_pydantic/test_pydantic_dto_factory.py +++ b/tests/unit/test_plugins/test_pydantic/test_pydantic_dto_factory.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import replace -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, List, Optional, Union from unittest.mock import ANY import pydantic as pydantic_v2 @@ -11,6 +11,7 @@ from litestar.dto import DTOField, DTOFieldDefinition, Mark, dto_field from litestar.plugins.pydantic import PydanticDTO +from litestar.types.builtin_types import NoneType from litestar.typing import FieldDefinition from . import PydanticVersion @@ -91,13 +92,41 @@ def expected_field_defs(int_factory: Callable[[], int]) -> list[DTOFieldDefiniti default_factory=int_factory, dto_field=DTOField(), ), - default=None, metadata=ANY, type_wrappers=ANY, raw=ANY, kwarg_definition=ANY, passthrough_constraints=False, ), + DTOFieldDefinition.from_field_definition( + field_definition=FieldDefinition.from_kwarg( + annotation=Optional[int], + name="f", + ), + model_name=ANY, + default_factory=None, + dto_field=DTOField(), + passthrough_constraints=False, + ), + replace( + DTOFieldDefinition.from_field_definition( + field_definition=FieldDefinition.from_kwarg( + annotation=List[str], + name="g", + ), + model_name=ANY, + default_factory=list, + dto_field=DTOField(), + passthrough_constraints=False, + ), + annotation=Optional[List[str]], + origin=Union, + instantiable_origin=Union, + safe_generic_origin=Union, + inner_types=ANY, + args=(List[str], NoneType), + raw=Optional[List[str]], + ), ] @@ -111,6 +140,8 @@ class TestModel(pydantic_v1.BaseModel): c: Annotated[int, pydantic_v1.Field(gt=1)] d: int = pydantic_v1.Field(default=1) e: int = pydantic_v1.Field(default_factory=int_factory) + f: Optional[int] # noqa: UP007 + g: List[str] = pydantic_v1.Field(default_factory=list) # noqa: UP006 field_defs = list(PydanticDTO.generate_field_definitions(TestModel)) assert field_defs[0].model_name == "TestModel" @@ -128,6 +159,8 @@ class TestModel(pydantic_v2.BaseModel): c: Annotated[int, pydantic_v2.Field(gt=1)] d: int = pydantic_v2.Field(default=1) e: int = pydantic_v2.Field(default_factory=int_factory) + f: Optional[int] # noqa: UP007 + g: List[str] = pydantic_v2.Field(default_factory=list) # noqa: UP006 field_defs = list(PydanticDTO.generate_field_definitions(TestModel)) assert field_defs[0].model_name == "TestModel"