Skip to content
This repository was archived by the owner on Dec 1, 2025. It is now read-only.

Commit dcf4746

Browse files
committed
add wrapper for sort_values
1 parent 413eb33 commit dcf4746

File tree

2 files changed

+120
-0
lines changed

2 files changed

+120
-0
lines changed

src/nested_dask/core.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

33
import os
4+
from collections.abc import Callable, Mapping
5+
from typing import Any, Literal
46

57
import dask.dataframe as dd
68
import dask.dataframe.dask_expr as dx
79
import nested_pandas as npd
10+
import numpy as np
811
import pandas as pd
912
import pyarrow as pa
1013
from dask.dataframe.dask_expr._collection import new_collection
@@ -616,6 +619,103 @@ def dropna(
616619
meta=self._meta,
617620
)
618621

622+
def sort_values(
623+
self,
624+
by: str | list[str],
625+
npartitions: int | None = None,
626+
ascending: bool | list[bool] = True,
627+
na_position: Literal["first"] | Literal["last"] = "last",
628+
partition_size: float = 128e6,
629+
sort_function: Callable[[npd.NestedFrame], npd.NestedFrame] | None = None,
630+
sort_function_kwargs: Mapping[str, Any] | None = None,
631+
upsample: float = 1.0,
632+
ignore_index: bool | None = False,
633+
shuffle_method: str | None = None,
634+
**options,
635+
) -> Self: # type: ignore[name-defined] # noqa: F821:
636+
"""
637+
Sort the dataset by a single column.
638+
639+
Sorting a parallel dataset requires expensive shuffles and is generally
640+
not recommended. See ‘set_index‘ for implementation details.
641+
642+
Parameters:
643+
-----------
644+
by: str or list[str]
645+
Column(s) to sort by.
646+
npartitions: int, None, or ‘auto’
647+
The ideal number of output partitions. If None, use the same as the
648+
input. If ‘auto’ then decide by memory use. Not used when sorting
649+
nested layers.
650+
ascending: bool, optional
651+
Sort ascending vs. descending. Defaults to True.
652+
na_position: {‘last’, ‘first’}, optional
653+
Puts NaNs at the beginning if ‘first’, puts NaN at the end if
654+
‘last’. Defaults to ‘last’.
655+
sort_function: function, optional
656+
Sorting function to use when sorting underlying partitions. If
657+
None, defaults to M.sort_values (the partition library’s
658+
implementation of sort_values). Not used when sorting nested
659+
layers.
660+
sort_function_kwargs: dict, optional
661+
Additional keyword arguments to pass to the partition sorting
662+
function. By default, by, ascending, and na_position are provided.
663+
664+
Returns:
665+
--------
666+
DataFrame
667+
DataFrame with sorted values.
668+
669+
"""
670+
671+
# Resolve target layer
672+
target = []
673+
if isinstance(by, str):
674+
by = [by]
675+
# Check "by" columns for hierarchical references
676+
for col in by:
677+
if self._is_known_hierarchical_column(col):
678+
target.append(col.split(".")[0])
679+
else:
680+
target.append("base")
681+
682+
# Ensure one target layer, preventing multi-layer operations
683+
target = np.unique(target)
684+
if len(target) > 1:
685+
raise ValueError("Queries cannot target multiple structs/layers, write a separate query for each")
686+
target = str(target[0])
687+
688+
# Just use dask's sort_values if the target is the base layer
689+
# Drops divisions, but this is expected behavior of a sorting operation
690+
if target == "base":
691+
return super().sort_values(
692+
by=by,
693+
npartitions=npartitions,
694+
ascending=ascending,
695+
na_position=na_position,
696+
partition_size=partition_size,
697+
sort_function=sort_function,
698+
sort_function_kwargs=sort_function_kwargs,
699+
upsample=upsample,
700+
ignore_index=ignore_index,
701+
shuffle_method=shuffle_method,
702+
**options,
703+
)
704+
705+
# If nested target layer, go through nested-pandas API
706+
# apply via map_partitions, meta is propagated
707+
# does preserve divisions
708+
return self.map_partitions(
709+
lambda x: npd.NestedFrame(x).sort_values(
710+
by=by,
711+
ascending=ascending,
712+
na_position=na_position,
713+
ignore_index=ignore_index,
714+
**options,
715+
),
716+
meta=self._meta,
717+
)
718+
619719
def reduce(self, func, *args, meta=None, **kwargs) -> NestedFrame:
620720
"""
621721
Takes a function and applies it to each top-level row of the NestedFrame.

tests/nested_dask/test_nestedframe.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,26 @@ def test_dropna(test_dataset_with_nans):
299299
assert len(flat_nested_nan_free) == len(flat_nested) - 1
300300

301301

302+
def test_sort_values(test_dataset):
303+
"""test the sort_values function"""
304+
305+
# test sorting on base columns
306+
sorted_base = test_dataset.sort_values(by="a")
307+
assert sorted_base["a"].values.compute().tolist() == sorted(test_dataset["a"].values.compute().tolist())
308+
309+
# test sorting on nested columns
310+
sorted_nested = test_dataset.sort_values(by="nested.flux", ascending=False)
311+
assert sorted_nested.compute().loc[0]["nested"]["flux"].values.tolist() == sorted(
312+
test_dataset.loc[0]["nested.flux"].values.compute().tolist(),
313+
reverse=True,
314+
)
315+
assert sorted_nested.known_divisions # Divisions should be known
316+
317+
# Make sure we trigger multi-target exception
318+
with pytest.raises(ValueError):
319+
test_dataset.sort_values(by=["a", "nested.flux"])
320+
321+
302322
def test_reduce(test_dataset):
303323
"""test the reduce function"""
304324

0 commit comments

Comments
 (0)