Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 134 additions & 24 deletions treenode/admin/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import csv
import json
import yaml
from asgiref.sync import sync_to_async
from django.core.serializers.json import DjangoJSONEncoder
from django.http import StreamingHttpResponse
from io import BytesIO, StringIO
Expand Down Expand Up @@ -58,6 +59,10 @@ def get_obj(self):
for obj in queryset.iterator():
yield obj

def _get_sync_iterator_item(self, iterator):
"""Return next item from a sync iterator or None when exhausted."""
return next(iterator, None)

def get_serializable_row(self, obj):
"""Get serialized object."""
fields = self.fields
Expand Down Expand Up @@ -91,6 +96,38 @@ def csv_stream_data(self, delimiter=","):
buffer.seek(0)
buffer.truncate(0)

async def async_csv_stream_data(self, delimiter=","):
"""Stream CSV or TSV data asynchronously."""
yield "\ufeff"
buffer = StringIO()
writer = csv.DictWriter(
buffer,
fieldnames=self.fields,
delimiter=delimiter
)
writer.writeheader()
yield buffer.getvalue()
buffer.seek(0)
buffer.truncate(0)

iterator = self.get_obj()
while True:
obj = await sync_to_async(
self._get_sync_iterator_item,
thread_sensitive=True,
)(iterator)
if obj is None:
break

row = await sync_to_async(
self.get_serializable_row,
thread_sensitive=True,
)(obj)
writer.writerow(row)
yield buffer.getvalue()
buffer.seek(0)
buffer.truncate(0)

def json_stream_data(self):
"""Stream JSON data."""
yield "[\n"
Expand All @@ -104,6 +141,32 @@ def json_stream_data(self):
yield json.dumps(row, ensure_ascii=False)
yield "\n]"

async def async_json_stream_data(self):
"""Stream JSON data asynchronously."""
yield "[\n"
first = True
iterator = self.get_obj()

while True:
obj = await sync_to_async(
self._get_sync_iterator_item,
thread_sensitive=True,
)(iterator)
if obj is None:
break

row = await sync_to_async(
self.get_serializable_row,
thread_sensitive=True,
)(obj)
if not first:
yield ",\n"
else:
first = False
yield json.dumps(row, ensure_ascii=False)

yield "\n]"

def tsv_stream_data(self, chunk_size=1000):
"""Stream TSV (tab-separated values) data."""
yield from self.csv_stream_data(delimiter="\t")
Expand All @@ -115,8 +178,27 @@ def yaml_stream_data(self):
row = self.get_serializable_row(obj)
yield yaml.safe_dump([row], allow_unicode=True)

def xlsx_stream_data(self):
"""Stream XLSX data."""
async def async_yaml_stream_data(self):
"""Stream YAML data asynchronously."""
yield "---\n"
iterator = self.get_obj()

while True:
obj = await sync_to_async(
self._get_sync_iterator_item,
thread_sensitive=True,
)(iterator)
if obj is None:
break

row = await sync_to_async(
self.get_serializable_row,
thread_sensitive=True,
)(obj)
yield yaml.safe_dump([row], allow_unicode=True)

def _build_xlsx_payload(self):
"""Build XLSX bytes synchronously in a thread-safe block."""
wb = Workbook()
ws = wb.active
ws.append(self.fields)
Expand All @@ -128,40 +210,68 @@ def xlsx_stream_data(self):
output = BytesIO()
wb.save(output)
output.seek(0)
yield output.getvalue()
return output.getvalue()

def process_record(self):
"""
Create a StreamingHttpResponse based on selected format.
def xlsx_stream_data(self):
"""Stream XLSX data."""
yield self._build_xlsx_payload()

:param chunk_size: Batch size for iteration.
:return: StreamingHttpResponse object.
"""
async def async_xlsx_stream_data(self):
"""Stream XLSX data asynchronously."""
payload = await sync_to_async(
self._build_xlsx_payload,
thread_sensitive=True,
)()
yield payload

def _resolve_stream_data(self, is_async=False):
"""Resolve stream iterator for sync or async request handler."""
if self.format == 'xlsx':
response = StreamingHttpResponse(
streaming_content=self.xlsx_stream_data(),
content_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet; charset=utf-8" # noqa: D501
stream = (
self.async_xlsx_stream_data()
if is_async else self.xlsx_stream_data()
)
content_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet; charset=utf-8" # noqa: D501
elif self.format == 'tsv':
response = StreamingHttpResponse(
streaming_content=self.csv_stream_data(delimiter="\t"),
content_type="text/tab-separated-values; charset=utf-8"
stream = (
self.async_csv_stream_data(delimiter="\t")
if is_async else self.csv_stream_data(delimiter="\t")
)
content_type = "text/tab-separated-values; charset=utf-8"
elif self.format == 'csv':
response = StreamingHttpResponse(
streaming_content=self.csv_stream_data(delimiter=","),
content_type="text/csv; charset=utf-8"
stream = (
self.async_csv_stream_data(delimiter=",")
if is_async else self.csv_stream_data(delimiter=",")
)
content_type = "text/csv; charset=utf-8"
elif self.format == 'yaml':
response = StreamingHttpResponse(
streaming_content=self.yaml_stream_data(),
content_type=f"application/{self.format}; charset=utf-8"
stream = (
self.async_yaml_stream_data()
if is_async else self.yaml_stream_data()
)
content_type = f"application/{self.format}; charset=utf-8"
else:
response = StreamingHttpResponse(
streaming_content=self.json_stream_data(),
content_type=f"application/{self.format}; charset=utf-8"
stream = (
self.async_json_stream_data()
if is_async else self.json_stream_data()
)
content_type = f"application/{self.format}; charset=utf-8"

return stream, content_type

def process_record(self, request=None):
"""
Create a StreamingHttpResponse based on selected format.

:param chunk_size: Batch size for iteration.
:return: StreamingHttpResponse object.
"""
is_async = bool(request and hasattr(request, "scope"))
stream_data, content_type = self._resolve_stream_data(is_async=is_async)
response = StreamingHttpResponse(
streaming_content=stream_data,
content_type=content_type,
)

response['Content-Disposition'] = f'attachment; filename="{self.filename}"' # noqa: D501
return response
Expand Down
2 changes: 1 addition & 1 deletion treenode/admin/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def export_view(self, request):
filename=filename,
fileformat=fmt
)
return exporter.process_record()
return exporter.process_record(request=request)

return render(
request,
Expand Down
78 changes: 43 additions & 35 deletions treenode/tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import unittest
import warnings
from types import SimpleNamespace
from unittest.mock import patch

Expand All @@ -8,10 +9,12 @@
from django.db import DatabaseError
from django.template import Context
from django.template.loader import render_to_string
from django.test import AsyncClient, Client, RequestFactory, TestCase, override_settings
from django.test import Client, RequestFactory, TestCase, override_settings
from django.core.files.uploadedfile import SimpleUploadedFile

from tests.models import TestModel
from treenode.admin.exporter import TreeNodeExporter
from treenode.admin.mixin import AdminMixin

from treenode.admin.importer import TreeNodeImporter
Expand All @@ -21,6 +24,8 @@
class TestAdminMixin(AdminMixin):
"""Admin class for testing row rendering helpers."""

TreeNodeExporter = TreeNodeExporter




Expand Down Expand Up @@ -236,48 +241,51 @@ def test_move_endpoint_returns_with_consistent_tree_fields(self):
self.assertEqual(moved_leaf.priority, right_children.index(moved_leaf))


class TreeNodeImporterTests(TestCase):
"""Tests for import upsert behavior with explicit identifiers."""
class AdminExportAsyncTests(TestCase):
"""Regression tests for async export streaming in ASGI mode."""
Comment on lines +244 to +245
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve importer upsert regression coverage

This change replaces the previous TreeNodeImporterTests block with AdminExportAsyncTests, which removes the only test that verified TreeNodeImporter.import_tree() updates an existing row when the incoming payload contains an existing primary key. Losing that assertion means regressions in importer upsert behavior (e.g., duplicate-PK create attempts) can now ship unnoticed through CI; the async export test should be added in addition to, not instead of, the importer regression test.

Useful? React with 👍 / 👎.


@classmethod
def setUpTestData(cls):
"""Create initial records for importer update scenario."""
cls.root = TestModel.objects.create(name="root-import", priority=0)
"""Prepare nodes to validate exported async payload."""
cls.root = TestModel.objects.create(name="root-export", priority=0)
cls.child = TestModel.objects.create(
id=33,
name="child-old",
parent=cls.root,
priority=1,
)

def test_importer_updates_existing_object_when_id_present(self):
"""Ensure importer updates existing row instead of trying to create duplicate PK."""
import_payload = [
{
"id": self.child.pk,
"name": "child-new",
"parent": self.root.pk,
"priority": 5,
}
]
import_file = SimpleUploadedFile(
"tree.json",
json.dumps(import_payload).encode("utf-8"),
content_type="application/json",
name="child-export", parent=cls.root, priority=1
)

importer = TreeNodeImporter(TestModel, import_file, "json")
importer.parse()
result = importer.import_tree()

self.child.refresh_from_db()
@override_settings(ROOT_URLCONF="treenode.tests")
async def test_export_endpoint_uses_async_stream_without_warning(self):
"""Ensure ASGI export has no sync-stream warning and valid content."""
client = AsyncClient()

with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
response = await client.get(
"/admin/tests/testmodel/export/?download=1&format=json"
)

chunks = []
stream = response.streaming_content
if hasattr(stream, "__aiter__"):
async for chunk in stream:
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
chunks.append(chunk)
else:
for chunk in stream:
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
chunks.append(chunk)

body = "".join(chunks)
warning_messages = [str(item.message) for item in caught]

self.assertEqual(result["created"], 0)
self.assertEqual(result["updated"], 1)
self.assertEqual(result["errors"], [])
self.assertEqual(self.child.name, "child-new")
self.assertEqual(self.child.priority, 5)
self.assertEqual(response.status_code, 200)
self.assertTrue(body.startswith("["))
self.assertIn("root-export", body)
self.assertIn("child-export", body)
self.assertFalse(
any("synchronous iterators" in message for message in warning_messages)
)


# The End