Skip to content

Commit 1adcc0f

Browse files
committed
Fix: Use correct field_meta for constrained union types when building field values for coverage
1 parent 7d67749 commit 1adcc0f

File tree

2 files changed

+41
-6
lines changed

2 files changed

+41
-6
lines changed

polyfactory/factories/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,9 @@ def get_field_value_coverage( # noqa: C901,PLR0912
902902
for unwrapped_annotation in flatten_annotation(field_meta.annotation):
903903
unwrapped_annotation = cls._resolve_forward_references(unwrapped_annotation) # noqa: PLW2901
904904

905+
unwrapped_annotation_meta = next(
906+
(meta for meta in (field_meta.children or []) if meta.annotation == unwrapped_annotation), field_meta
907+
)
905908
if unwrapped_annotation in (None, NoneType):
906909
yield None
907910

@@ -911,11 +914,11 @@ def get_field_value_coverage( # noqa: C901,PLR0912
911914
elif isinstance(unwrapped_annotation, EnumMeta):
912915
yield CoverageContainer(list(unwrapped_annotation))
913916

914-
elif field_meta.constraints:
917+
elif unwrapped_annotation_meta.constraints:
915918
yield CoverageContainerCallable(
916919
cls.get_constrained_field_value,
917920
annotation=unwrapped_annotation,
918-
field_meta=field_meta,
921+
field_meta=unwrapped_annotation_meta,
919922
)
920923

921924
elif BaseFactory.is_factory_type(annotation=unwrapped_annotation):

tests/test_pydantic_factory.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from uuid import UUID
1313

1414
import pytest
15-
from annotated_types import Ge, Gt, Le, LowerCase, MinLen, UpperCase
15+
from annotated_types import Ge, Gt, Le, LowerCase, MaxLen, MinLen, UpperCase
1616
from typing_extensions import Annotated, TypeAlias
1717

1818
import pydantic
@@ -717,18 +717,50 @@ class C(BaseModel):
717717
assert CFactory.build()
718718

719719

720-
def test_constrained_union_types() -> None:
720+
@pytest.mark.skipif(IS_PYDANTIC_V2, reason="pydantic 1 only test")
721+
def test_constrained_union_types_pydantic_v1() -> None:
721722
class A(BaseModel):
722723
a: Union[Annotated[List[str], MinLen(100)], Annotated[int, Ge(1000)]]
723724
b: Union[List[Annotated[str, MinLen(100)]], int]
724725
c: Union[Annotated[List[int], MinLen(100)], None]
725-
d: Union[Annotated[List[int], MinLen(100)], Annotated[List[str], MinLen(100)]]
726-
e: Optional[Union[Annotated[List[int], MinLen(10)], Annotated[List[str], MinLen(10)]]]
726+
d: Union[Annotated[List[int], MinLen(100)], Annotated[List[str], MaxLen(99)]]
727+
e: Optional[Union[Annotated[List[int], MinLen(10)], Annotated[List[str], MaxLen(9)]]]
727728
f: Optional[Union[Annotated[List[int], MinLen(10)], List[str]]]
729+
g: Optional[
730+
Union[
731+
Annotated[List[int], MinLen(10)],
732+
Union[Annotated[List[str], MaxLen(9)], Annotated[Decimal, Field(max_digits=4, decimal_places=2)]],
733+
]
734+
]
735+
736+
AFactory = ModelFactory.create_factory(A, __allow_none_optionals__=False)
737+
738+
assert AFactory.build()
739+
assert list(AFactory.coverage())
740+
741+
742+
@pytest.mark.skipif(IS_PYDANTIC_V1, reason="pydantic 2 only test")
743+
def test_constrained_union_types_pydantic_v2() -> None:
744+
class A(BaseModel):
745+
a: Union[Annotated[List[str], MinLen(100)], Annotated[int, Ge(1000)]]
746+
b: Union[List[Annotated[str, MinLen(100)]], int]
747+
c: Union[Annotated[List[int], MinLen(100)], None]
748+
d: Union[Annotated[List[int], MinLen(100)], Annotated[List[str], MaxLen(99)]]
749+
e: Optional[Union[Annotated[List[int], MinLen(10)], Annotated[List[str], MaxLen(9)]]]
750+
f: Optional[Union[Annotated[List[int], MinLen(10)], List[str]]]
751+
g: Optional[
752+
Union[
753+
Annotated[List[int], MinLen(10)],
754+
Union[Annotated[List[str], MaxLen(9)], Annotated[Decimal, Field(max_digits=4, decimal_places=2)]],
755+
]
756+
]
757+
# This annotation is not allowed in pydantic 1
758+
h: Annotated[Union[List[int], List[str]], MinLen(10)]
728759

729760
AFactory = ModelFactory.create_factory(A, __allow_none_optionals__=False)
730761

731762
assert AFactory.build()
763+
assert list(AFactory.coverage())
732764

733765

734766
@pytest.mark.parametrize("allow_none", (True, False))

0 commit comments

Comments
 (0)