Skip to content
This repository was archived by the owner on Dec 1, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 117 additions & 2 deletions src/nested_dask/core.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from __future__ import annotations

import os
from collections.abc import Callable, Mapping
from typing import Any, Literal

import dask.dataframe as dd
import dask.dataframe.dask_expr as dx
import nested_pandas as npd
import numpy as np
import pandas as pd
import pyarrow as pa
from dask.dataframe.dask_expr._collection import new_collection
from nested_pandas.series.dtype import NestedDtype
from nested_pandas.series.packer import pack, pack_flat, pack_lists
from pandas._libs import lib
from pandas._typing import AnyAll, Axis, IndexLabel
from pandas._typing import Axis, IndexLabel
from pandas.api.extensions import no_default

# need this for the base _Frame class
Expand Down Expand Up @@ -540,7 +543,7 @@ def dropna(
self,
*,
axis: Axis = 0,
how: AnyAll | lib.NoDefault = no_default,
how: str | lib.NoDefault = no_default,
thresh: int | lib.NoDefault = no_default,
on_nested: bool = False,
subset: IndexLabel | None = None,
Expand Down Expand Up @@ -616,6 +619,118 @@ def dropna(
meta=self._meta,
)

def sort_values(
self,
by: str | list[str],
npartitions: int | None = None,
ascending: bool | list[bool] = True,
na_position: Literal["first"] | Literal["last"] = "last",
partition_size: float = 128e6,
sort_function: Callable[[pd.DataFrame], pd.DataFrame] | None = None,
sort_function_kwargs: Mapping[str, Any] | None = None,
upsample: float = 1.0,
ignore_index: bool | None = False,
shuffle_method: str | None = None,
**options,
) -> Self: # type: ignore[name-defined] # noqa: F821:
"""
Sort the dataset by a single column.

Sorting a parallel dataset requires expensive shuffles and is generally
not recommended. See ‘set_index‘ for implementation details.

Parameters:
-----------
by: str or list[str]
Column(s) to sort by.
npartitions: int, None, or ‘auto’
The ideal number of output partitions. If None, use the same as the
input. If ‘auto’ then decide by memory use. Not used when sorting
nested layers.
ascending: bool or list[bool], optional
Sort ascending vs. descending. Defaults to True. Specify list for
multiple sort orders. If this is a list of bools, must match the
length of the by.
na_position: {‘last’, ‘first’}, optional
Puts NaNs at the beginning if ‘first’, puts NaN at the end if
‘last’. Defaults to ‘last’.
partition_size: float, optional
The desired size of each partition in bytes. Defaults to 128e6
(128 MB). Not used in nested sorting.
sort_function: function, optional
Sorting function to use when sorting underlying partitions. If
None, defaults to M.sort_values (the partition library’s
implementation of sort_values). Not used when sorting nested
layers.
sort_function_kwargs: dict, optional
Additional keyword arguments to pass to the partition sorting
function. By default, by, ascending, and na_position are provided.
upsample: float, optional
Used to increase the number of samples for quantiles. Not used
in nested sorting
ignore_index: bool, optional
If True, the resulting axis will be labeled 0, 1, …, n - 1.
Defaults to False.
shuffle_method: str, optional
The method to use for shuffling data. Defaults to None. Not used
in nested sorting
**options: keyword arguments, optional
Additional options to pass to the sorting function.
Returns:
--------
DataFrame
DataFrame with sorted values.

"""

# Resolve target layer
targets = []
if isinstance(by, str):
by = [by]
# Check "by" columns for hierarchical references
for col in by:
if self._is_known_hierarchical_column(col):
targets.append(col.split(".")[0])
else:
targets.append("base")

# Ensure one target layer, preventing multi-layer operations
unq_targets = np.unique(targets).tolist()
if len(unq_targets) > 1:
raise ValueError("Queries cannot target multiple structs/layers, write a separate query for each")
target_layer = unq_targets[0]

# Just use dask's sort_values if the target is the base layer
# Drops divisions, but this is expected behavior of a sorting operation
if target_layer == "base":
return super().sort_values(
by=by,
npartitions=npartitions,
ascending=ascending,
na_position=na_position,
partition_size=partition_size,
sort_function=sort_function,
sort_function_kwargs=sort_function_kwargs,
upsample=upsample,
ignore_index=ignore_index,
shuffle_method=shuffle_method,
**options,
)

# If nested target layer, go through nested-pandas API
# apply via map_partitions, meta is propagated
# does preserve divisions
return self.map_partitions(
lambda x: npd.NestedFrame(x).sort_values(
by=by,
ascending=ascending,
na_position=na_position,
ignore_index=ignore_index,
**options,
),
meta=self._meta,
)

def reduce(self, func, *args, meta=None, **kwargs) -> NestedFrame:
"""
Takes a function and applies it to each top-level row of the NestedFrame.
Expand Down
20 changes: 20 additions & 0 deletions tests/nested_dask/test_nestedframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,26 @@ def test_dropna(test_dataset_with_nans):
assert len(flat_nested_nan_free) == len(flat_nested) - 1


def test_sort_values(test_dataset):
"""test the sort_values function"""

# test sorting on base columns
sorted_base = test_dataset.sort_values(by="a")
assert sorted_base["a"].values.compute().tolist() == sorted(test_dataset["a"].values.compute().tolist())

# test sorting on nested columns
sorted_nested = test_dataset.sort_values(by="nested.flux", ascending=False)
assert sorted_nested.compute().iloc[0]["nested"]["flux"].values.tolist() == sorted(
test_dataset.compute().iloc[0]["nested"]["flux"].values.tolist(),
reverse=True,
)
assert sorted_nested.known_divisions # Divisions should be known

# Make sure we trigger multi-target exception
with pytest.raises(ValueError):
test_dataset.sort_values(by=["a", "nested.flux"])


def test_reduce(test_dataset):
"""test the reduce function"""

Expand Down