Skip to content

Commit

Permalink
Get rid of cast() & simplify Generator
Browse files Browse the repository at this point in the history
  • Loading branch information
intgr committed May 9, 2024
1 parent 71b99ed commit a6e1a5f
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions rest_framework_nested/viewsets.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
from __future__ import annotations

import contextlib
from typing import Any, Generator, Generic, TypeVar, cast
from typing import Any, Generic, TypeVar, Iterator

from django.core.exceptions import ImproperlyConfigured
from django.db.models import Model, QuerySet
from django.http import HttpRequest, QueryDict
from rest_framework.generics import GenericAPIView
from rest_framework.request import Request
from rest_framework.viewsets import ViewSetMixin
from rest_framework.serializers import BaseSerializer

T_Model = TypeVar('T_Model', bound=Model)


@contextlib.contextmanager
def _force_mutable(querydict: QueryDict | dict[str, Any]) -> Generator[QueryDict | dict[str, Any], None, None]:
def _force_mutable(querydict: QueryDict | dict[str, Any]) -> Iterator[QueryDict | dict[str, Any]]:
"""
Takes a HttpRequest querydict from Django and forces it to be mutable.
Reverts the initial state back on exit, if any.
Expand All @@ -39,7 +38,7 @@ def _get_parent_lookup_kwargs(self) -> dict[str, str]:
parent_lookup_kwargs = getattr(self, 'parent_lookup_kwargs', None)

if not parent_lookup_kwargs:
serializer_class = cast(GenericAPIView, self).get_serializer_class()
serializer_class: type[BaseSerializer[T_Model]] = self.get_serializer_class() # type: ignore[attr-defined]
parent_lookup_kwargs = getattr(serializer_class, 'parent_lookup_kwargs', None)

if not parent_lookup_kwargs:
Expand All @@ -59,25 +58,25 @@ def get_queryset(self) -> QuerySet[T_Model]:
if getattr(self, 'swagger_fake_view', False):
return queryset

orm_filters = {}
orm_filters: dict[str, Any] = {}
parent_lookup_kwargs = self._get_parent_lookup_kwargs()
for query_param, field_name in parent_lookup_kwargs.items():
orm_filters[field_name] = cast(ViewSetMixin, self).kwargs[query_param]
orm_filters[field_name] = self.kwargs[query_param] # type: ignore[attr-defined]
return queryset.filter(**orm_filters)

def initialize_request(self, request: HttpRequest, *args: Any, **kwargs: Any) -> Request:
"""
Adds the parent params from URL inside the children data available
"""
request = cast(ViewSetMixin, super()).initialize_request(request, *args, **kwargs)
drf_request: Request = super().initialize_request(request, *args, **kwargs) # type: ignore[misc]

if getattr(self, 'swagger_fake_view', False):
return request
return drf_request

for url_kwarg, fk_filter in self._get_parent_lookup_kwargs().items():
# fk_filter is alike 'grandparent__parent__pk'
parent_arg = fk_filter.partition('__')[0]
for querydict in [request.data, request.query_params]:
for querydict in [drf_request.data, drf_request.query_params]:
with _force_mutable(querydict):
querydict[parent_arg] = kwargs[url_kwarg]
return request
return drf_request

0 comments on commit a6e1a5f

Please sign in to comment.