Skip to content

Commit 35e245f

Browse files
Merge pull request #128 from saritasa-nest/feature/interal-refactor
Minor internal refactoring
2 parents aedc48c + 6d7de17 commit 35e245f

File tree

3 files changed

+113
-12
lines changed

3 files changed

+113
-12
lines changed

HISTORY.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ Unreleased
66
------------------
77

88
* Add support using admin page filters for export
9+
* Minor refactor of CeleryResourceMixin for easier overriding of export/import methods
10+
* Add ability to pass additional args to `BaseFormat.export_data` on export
911

1012
1.6.0 (2025-04-29)
1113
------------------

import_export_extensions/models/export_job.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,14 +237,19 @@ def cancel_export(self) -> None:
237237
self.export_status = self.ExportStatus.CANCELLED
238238
self.save(update_fields=["export_status"])
239239

240-
def _export_data_inner(self):
240+
def _export_data_inner(self) -> None:
241241
"""Run export process with saving to file."""
242242
self.result = self.resource.export()
243243
self.save(update_fields=["result"])
244244

245245
# `export_data` may be bytes (base formats such as xlsx, csv, etc.) or
246246
# file object (formats inherited from `BaseZipExport`)
247-
export_data = self.file_format.export_data(dataset=self.result)
247+
export_data = self.file_format.export_data(
248+
dataset=self.result,
249+
**self.resource.get_export_data_format_kwargs(
250+
file_format=self.file_format,
251+
),
252+
)
248253
# create file if `export_data` is not file
249254
if not hasattr(export_data, "read"):
250255
export_data = django_files.base.ContentFile(

import_export_extensions/resources.py

Lines changed: 104 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ def __init__(
4848
# _admin_filter differences from _filter_kwargs
4949
# because it isn't used in filterset_class
5050
# and it always comes from admin panel export page
51-
self._admin_filters: dict[str, str] = kwargs.pop("admin_filters", {})
51+
self._admin_filters: dict[str, typing.Any] = kwargs.pop(
52+
"admin_filters",
53+
{},
54+
)
5255
self._ordering = ordering
5356
self._created_by = created_by
5457
self.resource_init_kwargs: dict[str, typing.Any] = kwargs
@@ -77,23 +80,48 @@ def get_model_queryset(cls) -> QuerySet:
7780
"""
7881
return cls._meta.model.objects.all()
7982

80-
def get_queryset(self):
81-
"""Filter export queryset via filterset class."""
82-
queryset = self.get_model_queryset().filter(
83+
def get_queryset(self) -> QuerySet:
84+
"""Filter export queryset via filterset class and order it."""
85+
return self.filter_queryset(
86+
self.order_queryset(
87+
self.filter_queryset_via_admin(
88+
self.get_model_queryset(),
89+
),
90+
),
91+
)
92+
93+
def filter_queryset_via_admin(
94+
self,
95+
queryset: QuerySet,
96+
) -> QuerySet:
97+
"""Filter queryset via admin filters."""
98+
return queryset.filter(
8399
self._get_admin_search_filter(
84100
self._admin_filters.pop("search", {}),
85101
),
86102
**self._admin_filters,
87103
)
104+
105+
def order_queryset(
106+
self,
107+
queryset: QuerySet,
108+
) -> QuerySet:
109+
"""Order queryset for export."""
88110
try:
89-
queryset = queryset.order_by(*(self._ordering or ()))
111+
return queryset.order_by(*(self._ordering or ()))
90112
except FieldError as error:
91113
raise ValidationError(
92114
{
93115
# Split error text not to expose all fields to api clients.
94116
"ordering": str(error).split("Choices are:")[0].strip(),
95117
},
96118
) from error
119+
120+
def filter_queryset(
121+
self,
122+
queryset: QuerySet,
123+
) -> QuerySet:
124+
"""Filter queryset for export."""
97125
if not self._filter_kwargs:
98126
return queryset
99127
filter_instance = self.filterset_class(
@@ -143,7 +171,7 @@ def import_data(
143171
rollback_on_validation_errors: bool = False,
144172
force_import: bool = False,
145173
**kwargs,
146-
):
174+
) -> typing.Any:
147175
"""Init task state before importing.
148176
149177
If `force_import=True`, then rows with errors will be skipped.
@@ -157,6 +185,29 @@ def import_data(
157185
),
158186
queryset=dataset,
159187
)
188+
return self._import_data(
189+
dataset=dataset,
190+
dry_run=dry_run,
191+
raise_errors=raise_errors,
192+
use_transactions=use_transactions,
193+
collect_failed_rows=collect_failed_rows,
194+
rollback_on_validation_errors=rollback_on_validation_errors,
195+
force_import=force_import,
196+
**kwargs,
197+
)
198+
199+
def _import_data(
200+
self,
201+
dataset: tablib.Dataset,
202+
dry_run: bool = False,
203+
raise_errors: bool = False,
204+
use_transactions: bool | None = None,
205+
collect_failed_rows: bool = False,
206+
rollback_on_validation_errors: bool = False,
207+
force_import: bool = False,
208+
**kwargs,
209+
) -> typing.Any:
210+
"""Override if you need custom import logic."""
160211
return super().import_data( # type: ignore
161212
dataset=dataset,
162213
dry_run=dry_run,
@@ -177,14 +228,14 @@ def import_row(
177228
raise_errors=False,
178229
force_import=False,
179230
**kwargs,
180-
):
231+
) -> RowResult:
181232
"""Update task status as we import rows.
182233
183234
If `force_import=True`, then row errors will be stored in
184235
`field_skipped_errors` or `non_field_skipped_errors`.
185236
186237
"""
187-
imported_row: RowResult = super().import_row(
238+
imported_row = self._import_row(
188239
row=row,
189240
instance_loader=instance_loader,
190241
using_transactions=using_transactions,
@@ -203,6 +254,25 @@ def import_row(
203254
imported_row = self._skip_row_with_errors(imported_row, row)
204255
return imported_row
205256

257+
def _import_row(
258+
self,
259+
row,
260+
instance_loader,
261+
using_transactions=True,
262+
dry_run=False,
263+
raise_errors=False,
264+
**kwargs,
265+
) -> RowResult:
266+
"""Override if you need custom import row logic."""
267+
return super().import_row( # type: ignore
268+
row=row,
269+
instance_loader=instance_loader,
270+
using_transactions=using_transactions,
271+
dry_run=dry_run,
272+
raise_errors=raise_errors,
273+
**kwargs,
274+
)
275+
206276
def _skip_row_with_errors(
207277
self,
208278
row_result: RowResult,
@@ -258,6 +328,14 @@ def export(
258328
state=TaskState.EXPORTING.name,
259329
queryset=queryset,
260330
)
331+
return self._export(queryset=queryset, **kwargs)
332+
333+
def _export(
334+
self,
335+
queryset: QuerySet,
336+
**kwargs,
337+
) -> tablib.Dataset:
338+
"""Override if you need custom export logic."""
261339
return super().export( # type: ignore
262340
queryset=queryset,
263341
**kwargs,
@@ -268,12 +346,28 @@ def export_resource(
268346
obj,
269347
selected_fields: list[fields.Field] | None = None,
270348
**kwargs,
271-
):
349+
) -> typing.Any:
272350
"""Update task status as we export rows."""
273-
resource = super().export_resource(obj, selected_fields, **kwargs) # type: ignore
351+
resource = self._export_resource(obj, selected_fields, **kwargs)
274352
self.update_task_state(state=TaskState.EXPORTING.name)
275353
return resource
276354

355+
def _export_resource(
356+
self,
357+
obj,
358+
selected_fields: list[fields.Field] | None = None,
359+
**kwargs,
360+
) -> typing.Any:
361+
"""Override if you need custom export resource logic."""
362+
return super().export_resource(obj, selected_fields, **kwargs) # type: ignore
363+
364+
def get_export_data_format_kwargs(
365+
self,
366+
file_format: base_formats.Format,
367+
) -> dict[str, typing.Any]:
368+
"""Get additional params for export format."""
369+
return {}
370+
277371
def initialize_task_state(
278372
self,
279373
state: str,

0 commit comments

Comments
 (0)