diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index b8290361b..088176cd8 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -172,6 +172,7 @@ def manager_and_queryset_method_hooks(self) -> dict[str, Callable[[MethodContext "alias": partial(querysets.extract_proper_type_queryset_annotate, django_context=self.django_context), "annotate": partial(querysets.extract_proper_type_queryset_annotate, django_context=self.django_context), "create": partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context), + "acreate": partial(init_create.redefine_and_typecheck_model_acreate, django_context=self.django_context), "filter": typecheck_filtering_method, "get": typecheck_filtering_method, "exclude": typecheck_filtering_method, diff --git a/mypy_django_plugin/transformers/init_create.py b/mypy_django_plugin/transformers/init_create.py index f26baaf6d..3c863afa9 100644 --- a/mypy_django_plugin/transformers/init_create.py +++ b/mypy_django_plugin/transformers/init_create.py @@ -76,3 +76,23 @@ def redefine_and_typecheck_model_create(ctx: MethodContext, django_context: Djan return ctx.default_return_type return typecheck_model_method(ctx, django_context, model_cls, "create") + + +def redefine_and_typecheck_model_acreate(ctx: MethodContext, django_context: DjangoContext) -> MypyType: + default_return_type = get_proper_type(ctx.default_return_type) + + if not isinstance(default_return_type, Instance): + # only work with ctx.default_return_type = model Instance + return ctx.default_return_type + + # default_return_type at this point should be of type Coroutine[Any, Any, ] + model = default_return_type.args[-1] + if not isinstance(model, Instance): + return ctx.default_return_type + + model_fullname = model.type.fullname + model_cls = django_context.get_model_class_by_fullname(model_fullname) + if model_cls is None: + return ctx.default_return_type + + return typecheck_model_method(ctx, django_context, model_cls, "acreate") diff --git a/tests/typecheck/models/test_create.yml b/tests/typecheck/models/test_create.yml index 91d28b1ab..c7b1fe0f7 100644 --- a/tests/typecheck/models/test_create.yml +++ b/tests/typecheck/models/test_create.yml @@ -155,3 +155,21 @@ id = models.IntegerField(primary_key=True) class MyModel3(models.Model): default = models.IntegerField(default=return_int) + +- case: default_manager_acreate_is_typechecked + main: | + import asyncio + from myapp.models import User + async def amain() -> None: + await User.objects.acreate(pk=1, name='Max', age=10) + await User.objects.acreate(age=[]) # E: Incompatible type for "age" of "User" (got "List[Any]", expected "Union[float, int, str, Combinable]") [misc] + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class User(models.Model): + name = models.CharField(max_length=100) + age = models.IntegerField()