diff --git a/django-stubs/db/models/fields/__init__.pyi b/django-stubs/db/models/fields/__init__.pyi index 8d45f1c358..530fab5b53 100644 --- a/django-stubs/db/models/fields/__init__.pyi +++ b/django-stubs/db/models/fields/__init__.pyi @@ -9,12 +9,13 @@ from django import forms from django.core import validators # due to weird mypy.stubtest error from django.core.checks import CheckMessage from django.db.backends.base.base import BaseDatabaseWrapper -from django.db.models import Model +from django.db.models import Choices, Model from django.db.models.expressions import Col, Combinable, Expression, Func from django.db.models.fields.reverse_related import ForeignObjectRel from django.db.models.query_utils import Q, RegisterLookupMixin from django.forms import Widget -from django.utils.choices import BlankChoiceIterator, _Choice, _ChoiceNamedGroup, _Choices, _ChoicesCallable +from django.utils.choices import BlankChoiceIterator, _Choice, _ChoiceNamedGroup, _ChoicesCallable, _ChoicesMapping +from django.utils.choices import _Choices as _ChoicesSequence from django.utils.datastructures import DictWrapper from django.utils.functional import _Getter, _StrOrPromise, cached_property from typing_extensions import Self, TypeAlias @@ -27,6 +28,9 @@ BLANK_CHOICE_DASH: list[tuple[str, str]] _ChoicesList: TypeAlias = Sequence[_Choice] | Sequence[_ChoiceNamedGroup] _LimitChoicesTo: TypeAlias = Q | dict[str, Any] _LimitChoicesToCallable: TypeAlias = Callable[[], _LimitChoicesTo] +_Choices: TypeAlias = ( + _ChoicesSequence | _ChoicesMapping | type[Choices] | Callable[[], _ChoicesSequence | _ChoicesMapping] +) _F = TypeVar("_F", bound=Field, covariant=True) diff --git a/django-stubs/utils/choices.pyi b/django-stubs/utils/choices.pyi index dbbca54903..d2f995939f 100644 --- a/django-stubs/utils/choices.pyi +++ b/django-stubs/utils/choices.pyi @@ -1,4 +1,4 @@ -from collections.abc import Iterable, Iterator +from collections.abc import Iterable, Iterator, Mapping from typing import Any, Protocol, TypeVar, type_check_only from typing_extensions import TypeAlias @@ -6,6 +6,7 @@ from typing_extensions import TypeAlias _Choice: TypeAlias = tuple[Any, Any] _ChoiceNamedGroup: TypeAlias = tuple[str, Iterable[_Choice]] _Choices: TypeAlias = Iterable[_Choice | _ChoiceNamedGroup] +_ChoicesMapping: TypeAlias = Mapping[Any, Any] | Mapping[str, Mapping[Any, Any]] # noqa: PYI047 @type_check_only class _ChoicesCallable(Protocol): diff --git a/tests/assert_type/db/models/fields/test_choices.py b/tests/assert_type/db/models/fields/test_choices.py new file mode 100644 index 0000000000..1f9f773c8a --- /dev/null +++ b/tests/assert_type/db/models/fields/test_choices.py @@ -0,0 +1,89 @@ +from collections.abc import Callable, Mapping, Sequence +from typing import TypeVar + +from django.db import models +from typing_extensions import assert_type + +_T = TypeVar("_T") + + +def to_named_seq(func: Callable[[], _T]) -> Callable[[], Sequence[tuple[str, _T]]]: + def inner() -> Sequence[tuple[str, _T]]: + return [("title", func())] + + return inner + + +def to_named_mapping(func: Callable[[], _T]) -> Callable[[], Mapping[str, _T]]: + def inner() -> Mapping[str, _T]: + return {"title": func()} + + return inner + + +def str_tuple() -> Sequence[tuple[str, str]]: + return (("foo", "bar"), ("fuzz", "bazz")) + + +def str_mapping() -> Mapping[str, str]: + return {"foo": "bar", "fuzz": "bazz"} + + +def int_tuple() -> Sequence[tuple[int, str]]: + return ((1, "bar"), (2, "bazz")) + + +def int_mapping() -> Mapping[int, str]: + return {3: "bar", 4: "bazz"} + + +class TestModel(models.Model): + class TextChoices(models.TextChoices): + FIRST = "foo", "bar" + SECOND = "foo2", "bar" + + class IntegerChoices(models.IntegerChoices): + FIRST = 1, "bar" + SECOND = 2, "bar" + + char1 = models.CharField[str, str](max_length=5, choices=TextChoices, default="foo") + char2 = models.CharField[str, str](max_length=5, choices=str_tuple, default="foo") + char3 = models.CharField[str, str](max_length=5, choices=str_mapping, default="foo") + char4 = models.CharField[str, str](max_length=5, choices=str_tuple(), default="foo") + char5 = models.CharField[str, str](max_length=5, choices=str_mapping(), default="foo") + char6 = models.CharField[str, str](max_length=5, choices=to_named_seq(str_tuple), default="foo") + char7 = models.CharField[str, str](max_length=5, choices=to_named_mapping(str_mapping), default="foo") + char8 = models.CharField[str, str](max_length=5, choices=to_named_seq(str_tuple)(), default="foo") + char9 = models.CharField[str, str](max_length=5, choices=to_named_mapping(str_mapping)(), default="foo") + + int1 = models.IntegerField[int, int](choices=IntegerChoices, default=1) + int2 = models.IntegerField[int, int](choices=int_tuple, default=1) + int3 = models.IntegerField[int, int](choices=int_mapping, default=1) + int4 = models.IntegerField[int, int](choices=int_tuple(), default=1) + int5 = models.IntegerField[int, int](choices=int_mapping(), default=1) + int6 = models.IntegerField[int, int](choices=to_named_seq(int_tuple), default=1) + int7 = models.IntegerField[int, int](choices=to_named_seq(int_mapping), default=1) + int8 = models.IntegerField[int, int](choices=to_named_seq(int_tuple)(), default=1) + int9 = models.IntegerField[int, int](choices=to_named_seq(int_mapping)(), default=1) + + +instance = TestModel() +assert_type(instance.char1, str) +assert_type(instance.char2, str) +assert_type(instance.char3, str) +assert_type(instance.char4, str) +assert_type(instance.char5, str) +assert_type(instance.char6, str) +assert_type(instance.char7, str) +assert_type(instance.char8, str) +assert_type(instance.char9, str) + +assert_type(instance.int1, int) +assert_type(instance.int2, int) +assert_type(instance.int3, int) +assert_type(instance.int4, int) +assert_type(instance.int5, int) +assert_type(instance.int6, int) +assert_type(instance.int7, int) +assert_type(instance.int8, int) +assert_type(instance.int9, int) diff --git a/tests/typecheck/db/models/test_fields.yml b/tests/typecheck/db/models/test_fields.yml new file mode 100644 index 0000000000..56c1b2920a --- /dev/null +++ b/tests/typecheck/db/models/test_fields.yml @@ -0,0 +1,159 @@ +- case: db_models_fields_choices + main: | + from collections.abc import Callable, Mapping, Sequence + from datetime import date, time + from decimal import Decimal + from typing import TypeVar, Tuple + from uuid import UUID + + from django.db import models + + _T = TypeVar("_T") + + + def to_named_seq(func: Callable[[], _T]) -> Callable[[], Sequence[Tuple[str, _T]]]: + def inner() -> Sequence[Tuple[str, _T]]: + return [("title", func())] + + return inner + + + def to_named_mapping(func: Callable[[], _T]) -> Callable[[], Mapping[str, _T]]: + def inner() -> Mapping[str, _T]: + return {"title": func()} + + return inner + + + def str_tuple() -> Sequence[Tuple[str, str]]: + return (("foo", "bar"), ("fuzz", "bazz")) + + + def str_mapping() -> Mapping[str, str]: + return {"foo": "bar", "fuzz": "bazz"} + + + def int_tuple() -> Sequence[Tuple[int, str]]: + return ((1, "bar"), (2, "bazz")) + + + def int_mapping() -> Mapping[int, str]: + return {3: "bar", 4: "bazz"} + + + def dec_tuple() -> Sequence[Tuple[Decimal, str]]: + return ((Decimal(1), "bar"), (Decimal(2), "bazz")) + + + def dec_mapping() -> Mapping[Decimal, str]: + return {Decimal(3): "bar", Decimal(4): "bazz"} + + + def url_tuple() -> Sequence[Tuple[str, str]]: + return (("https://python.org", "bar"), ("https://mypy-lang.org", "bazz")) + + + def url_mapping() -> Mapping[str, str]: + return {"https://python.org": "bar", "https://mypy-lang.org": "bazz"} + + + def date_tuple() -> Sequence[Tuple[date, str]]: + return ((date.today(), "bar"), (date(2024, 1, 1), "bazz")) + + + def date_mapping() -> Mapping[date, str]: + return {date.today(): "bar", date(2024, 1, 1): "bazz"} + + + def time_tuple() -> Sequence[Tuple[time, str]]: + return ((time(0, 0, 2), "bar"), (time(0, 0, 1), "bazz")) + + + def time_mapping() -> Mapping[time, str]: + return {time(0, 0, 2): "bar", time(0, 0, 1): "bazz"} + + + def uuid_tuple() -> Sequence[Tuple[UUID, str]]: + return ((UUID(), "bar"), (UUID(), "bazz")) + + + def uuid_mapping() -> Mapping[UUID, str]: + return {UUID(): "bar", UUID(): "bazz"} + + + class NewModel(models.Model): + class TextChoices(models.TextChoices): + FIRST = "foo", "bar" + SECOND = "foo2", "bar" + + class IntegerChoices(models.IntegerChoices): + FIRST = 1, "bar" + SECOND = 2, "bar" + + char1 = models.CharField[str, str](max_length=200, choices=TextChoices) + char2 = models.CharField[str, str](max_length=200, choices=str_tuple) + char3 = models.CharField[str, str](max_length=200, choices=str_mapping) + char4 = models.CharField[str, str](max_length=200, choices=str_tuple()) + char5 = models.CharField[str, str](max_length=200, choices=str_mapping()) + char6 = models.CharField[str, str](max_length=200, choices=to_named_seq(str_tuple)) + char7 = models.CharField[str, str](max_length=200, choices=to_named_seq(str_tuple)()) + char8 = models.CharField[str, str](max_length=200, choices=to_named_mapping(str_mapping)) + char9 = models.CharField[str, str](max_length=200, choices=to_named_mapping(str_mapping)()) + + int1 = models.IntegerField[int, int](choices=IntegerChoices) + int2 = models.IntegerField[int, int](choices=int_tuple) + int3 = models.IntegerField[int, int](choices=int_mapping) + int4 = models.IntegerField[int, int](choices=int_tuple()) + int5 = models.IntegerField[int, int](choices=int_mapping()) + int6 = models.IntegerField[int, int](choices=to_named_seq(str_tuple)) + int7 = models.IntegerField[int, int](choices=to_named_seq(str_tuple)()) + int8 = models.IntegerField[int, int](choices=to_named_mapping(str_mapping)) + int9 = models.IntegerField[int, int](choices=to_named_mapping(str_mapping)()) + + dec1 = models.DecimalField[Decimal, Decimal](choices=dec_tuple) + dec2 = models.DecimalField[Decimal, Decimal](choices=dec_mapping) + dec3 = models.DecimalField[Decimal, Decimal](choices=dec_tuple()) + dec4 = models.DecimalField[Decimal, Decimal](choices=dec_mapping()) + + slug1 = models.SlugField[str, str](choices=TextChoices) + slug4 = models.SlugField[str, str](choices=str_tuple) + slug5 = models.SlugField[str, str](choices=str_mapping) + slug2 = models.SlugField[str, str](choices=str_tuple()) + slug3 = models.SlugField[str, str](choices=str_mapping()) + + url1 = models.URLField[str, str](choices=str_tuple) + url2 = models.URLField[str, str](choices=str_mapping) + url3 = models.URLField[str, str](choices=str_tuple()) + url4 = models.URLField[str, str](choices=str_mapping()) + + text1 = models.TextField[str, str](choices=TextChoices) + text2 = models.TextField[str, str](choices=str_tuple) + text3 = models.TextField[str, str](choices=str_mapping) + text4 = models.TextField[str, str](choices=str_tuple()) + text5 = models.TextField[str, str](choices=str_mapping()) + + ip1 = models.GenericIPAddressField[int, int](choices=int_tuple) + ip2 = models.GenericIPAddressField[int, int](choices=int_mapping) + ip3 = models.GenericIPAddressField[int, int](choices=int_tuple()) + ip4 = models.GenericIPAddressField[int, int](choices=int_mapping()) + + date1 = models.DateField[date, date](choices=date_tuple) + date2 = models.DateField[date, date](choices=date_mapping) + date3 = models.DateField[date, date](choices=date_tuple()) + date4 = models.DateField[date, date](choices=date_mapping()) + + time1 = models.TimeField[time, time](choices=time_tuple) + time2 = models.TimeField[time, time](choices=time_mapping) + time3 = models.TimeField[time, time](choices=time_tuple()) + time4 = models.TimeField[time, time](choices=time_mapping()) + + uuid1 = models.UUIDField[UUID, UUID](choices=uuid_tuple) + uuid2 = models.UUIDField[UUID, UUID](choices=uuid_mapping) + uuid3 = models.UUIDField[UUID, UUID](choices=uuid_tuple()) + uuid4 = models.UUIDField[UUID, UUID](choices=uuid_mapping()) + + path1 = models.FilePathField[str, str](choices=TextChoices) + path2 = models.FilePathField[str, str](choices=str_tuple) + path3 = models.FilePathField[str, str](choices=str_mapping) + path4 = models.FilePathField[str, str](choices=str_tuple()) + path5 = models.FilePathField[str, str](choices=str_mapping())