Skip to content

Commit ed1acb9

Browse files
authored
fix: skip nan in sc.get.aggregate (#3906)
1 parent f75ac11 commit ed1acb9

File tree

3 files changed

+65
-83
lines changed

3 files changed

+65
-83
lines changed

docs/release-notes/3906.fix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix {func}`scanpy.get.aggregate` when a `by` column misses data {smaller}`P Angerer`

src/scanpy/get/_aggregated.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def __init__(
5858
mask: NDArray[np.bool_] | None = None,
5959
) -> None:
6060
self.groupby = groupby
61+
if (missing := groupby.isna()).any():
62+
mask = mask & ~missing if mask is not None else ~missing
6163
self.indicator_matrix = sparse_indicator(groupby, mask=mask)
6264
self.data = data
6365

@@ -531,9 +533,9 @@ def _combine_categories(
531533
code_array[i] = df[c].cat.codes
532534
code_array *= factors[:, None]
533535

534-
result_categorical = pd.Categorical.from_codes(
535-
code_array.sum(axis=0), categories=result_categories
536-
)
536+
codes = code_array.sum(axis=0)
537+
codes = np.where(np.any(code_array < 0, axis=0), -1, codes)
538+
result_categorical = pd.Categorical.from_codes(codes, categories=result_categories)
537539

538540
# Filter unused categories
539541
result_categorical = result_categorical.remove_unused_categories()
@@ -554,8 +556,10 @@ def sparse_indicator(
554556
weight = mask * weight
555557
elif mask is None and weight is None:
556558
weight = np.broadcast_to(1.0, len(categorical))
559+
# can’t have -1s in the codes, but (as long as it’s valid), the value is ignored, so set to 0 where masked
560+
codes = categorical.codes if mask is None else np.where(mask, categorical.codes, 0)
557561
a = sparse.coo_matrix(
558-
(weight, (categorical.codes, np.arange(len(categorical)))),
562+
(weight, (codes, np.arange(len(categorical)))),
559563
shape=(len(categorical.categories), len(categorical)),
560564
)
561565
return a

tests/test_aggregated.py

Lines changed: 56 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
if TYPE_CHECKING:
2121
from collections.abc import Callable
22+
from typing import Literal
23+
24+
from numpy.typing import NDArray
2225

2326
from scanpy._compat import CSRBase
2427

@@ -42,73 +45,14 @@ def xfail_dask_median(
4245
adata: ad.AnnData,
4346
metric: AggType,
4447
request: pytest.FixtureRequest,
45-
):
48+
) -> None:
4649
if isinstance(adata.X, DaskArray) and metric == "median":
4750
reason = "Median calculation not implemented for Dask"
4851
request.applymarker(pytest.mark.xfail(reason=reason))
4952

5053

51-
@pytest.fixture
52-
def df_base():
53-
ax_base = ["A", "B"]
54-
return pd.DataFrame(index=ax_base)
55-
56-
57-
@pytest.fixture
58-
def df_groupby():
59-
ax_groupby = [
60-
*["v0", "v1", "v2"],
61-
*["w0", "w1"],
62-
*["a1", "a2", "a3"],
63-
*["b1", "b2"],
64-
*["c1", "c2"],
65-
"d0",
66-
]
67-
68-
df_groupby = pd.DataFrame(index=pd.Index(ax_groupby, name="cell"))
69-
df_groupby["key"] = pd.Categorical([c[0] for c in ax_groupby])
70-
df_groupby["key_superset"] = pd.Categorical([c[0] for c in ax_groupby]).map({
71-
**{"v": "v", "w": "v"}, # noqa: PIE800
72-
**{"a": "a", "b": "a", "c": "a", "d": "a"}, # noqa: PIE800
73-
})
74-
df_groupby["key_subset"] = pd.Categorical([c[1] for c in ax_groupby])
75-
df_groupby["weight"] = 2.0
76-
return df_groupby
77-
78-
79-
@pytest.fixture
80-
def x():
81-
data = [
82-
*[[0, -2], [1, 13], [2, 1]], # v
83-
*[[3, 12], [4, 2]], # w
84-
*[[5, 11], [6, 3], [7, 10]], # a
85-
*[[8, 4], [9, 9]], # b
86-
*[[10, 5], [11, 8]], # c
87-
[12, 6], # d
88-
]
89-
return np.array(data, dtype=np.float32)
90-
91-
92-
def gen_adata(data_key, dim, df_base, df_groupby, x):
93-
if (data_key == "varm" and dim == "obs") or (data_key == "obsm" and dim == "var"):
94-
pytest.skip("invalid parameter combination")
95-
96-
obs_df, var_df = (df_groupby, df_base) if dim == "obs" else (df_base, df_groupby)
97-
data = x.T if dim == "var" and data_key != "varm" else x
98-
if data_key != "X":
99-
data_dict_sparse = {data_key: {"test": sparse.csr_matrix(data)}} # noqa: TID251
100-
data_dict_dense = {data_key: {"test": data}}
101-
else:
102-
data_dict_sparse = {data_key: sparse.csr_matrix(data)} # noqa: TID251
103-
data_dict_dense = {data_key: data}
104-
105-
adata_sparse = ad.AnnData(obs=obs_df, var=var_df, **data_dict_sparse)
106-
adata_dense = ad.AnnData(obs=obs_df, var=var_df, **data_dict_dense)
107-
return adata_sparse, adata_dense
108-
109-
11054
@pytest.mark.parametrize("axis", [0, 1])
111-
def test_mask(axis):
55+
def test_mask(axis: Literal[0, 1]) -> None:
11256
blobs = sc.datasets.blobs()
11357
mask = blobs.obs["blobs"] == 0
11458
blobs.obs["mask_col"] = mask
@@ -125,7 +69,7 @@ def test_mask(axis):
12569
@pytest.mark.parametrize("array_type", VALID_ARRAY_TYPES)
12670
def test_aggregate_vs_pandas(
12771
metric: AggType, array_type, request: pytest.FixtureRequest
128-
):
72+
) -> None:
12973
adata = pbmc3k_processed().raw.to_adata()
13074
adata = adata[
13175
adata.obs["louvain"].isin(adata.obs["louvain"].cat.categories[:5]), :1_000
@@ -165,7 +109,9 @@ def test_aggregate_vs_pandas(
165109

166110

167111
@pytest.mark.parametrize("array_type", VALID_ARRAY_TYPES)
168-
def test_aggregate_axis(array_type, metric, request: pytest.FixtureRequest):
112+
def test_aggregate_axis(
113+
array_type, metric: AggType, request: pytest.FixtureRequest
114+
) -> None:
169115
adata = pbmc3k_processed().raw.to_adata()
170116
adata = adata[
171117
adata.obs["louvain"].isin(adata.obs["louvain"].cat.categories[:5]), :1_000
@@ -178,7 +124,7 @@ def test_aggregate_axis(array_type, metric, request: pytest.FixtureRequest):
178124
assert_equal(expected, actual)
179125

180126

181-
def test_aggregate_entry():
127+
def test_aggregate_entry() -> None:
182128
args = ("blobs", ["mean", "var", "count_nonzero"])
183129

184130
adata = sc.datasets.blobs()
@@ -213,14 +159,14 @@ def test_aggregate_entry():
213159
assert_equal(x_result.layers, varm_result.T.layers)
214160

215161

216-
def test_aggregate_incorrect_dim():
162+
def test_aggregate_incorrect_dim() -> None:
217163
adata = pbmc3k_processed().raw.to_adata()
218164

219165
with pytest.raises(ValueError, match="was 'foo'"):
220166
sc.get.aggregate(adata, ["louvain"], "sum", axis="foo")
221167

222168

223-
def to_bad_chunking(x: CSRBase):
169+
def to_bad_chunking(x: CSRBase) -> DaskArray:
224170
import dask.array as da
225171

226172
return da.from_array(
@@ -230,7 +176,7 @@ def to_bad_chunking(x: CSRBase):
230176
)
231177

232178

233-
def to_csc(x: CSRBase):
179+
def to_csc(x: CSRBase) -> DaskArray:
234180
import dask.array as da
235181

236182
return da.from_array(
@@ -249,15 +195,17 @@ def to_csc(x: CSRBase):
249195
),
250196
],
251197
)
252-
def test_aggregate_bad_dask_array(func: Callable[[CSRBase], DaskArray], error_msg: str):
198+
def test_aggregate_bad_dask_array(
199+
func: Callable[[CSRBase], DaskArray], error_msg: str
200+
) -> None:
253201
adata = pbmc3k_processed().raw.to_adata()
254202
adata.X = func(adata.X)
255203
with pytest.raises(ValueError, match=error_msg):
256204
sc.get.aggregate(adata, ["louvain"], "sum")
257205

258206

259207
@pytest.mark.parametrize("axis_name", ["obs", "var"])
260-
def test_aggregate_axis_specification(axis_name):
208+
def test_aggregate_axis_specification(axis_name: Literal["obs", "var"]) -> None:
261209
axis, axis_name = _resolve_axis(axis_name)
262210
by = "blobs" if axis == 0 else "labels"
263211

@@ -381,17 +329,20 @@ def test_aggregate_axis_specification(axis_name):
381329
),
382330
],
383331
)
384-
def test_aggregate_examples(matrix, df, keys, metrics, expected):
332+
def test_aggregate_examples(
333+
matrix: NDArray[np.int64],
334+
df: pd.DataFrame,
335+
keys: list[str],
336+
metrics: list[AggType],
337+
expected: ad.AnnData,
338+
) -> None:
385339
adata = ad.AnnData(
386340
X=matrix,
387341
obs=df,
388342
var=pd.DataFrame(index=[f"gene_{i}" for i in range(matrix.shape[1])]),
389343
)
390344
result = sc.get.aggregate(adata, by=keys, func=metrics)
391345

392-
print(result)
393-
print(expected)
394-
395346
assert_equal(expected, result)
396347

397348

@@ -439,7 +390,9 @@ def test_aggregate_examples(matrix, df, keys, metrics, expected):
439390
),
440391
],
441392
)
442-
def test_combine_categories(label_cols, cols, expected):
393+
def test_combine_categories(
394+
label_cols: dict[str, pd.Categorical], cols: list[str], expected: pd.Categorical
395+
) -> None:
443396
from scanpy.get._aggregated import _combine_categories
444397

445398
label_df = pd.DataFrame(label_cols)
@@ -462,7 +415,7 @@ def test_combine_categories(label_cols, cols, expected):
462415
@pytest.mark.parametrize("array_type", VALID_ARRAY_TYPES)
463416
def test_aggregate_arraytype(
464417
array_type, metric: AggType, request: pytest.FixtureRequest
465-
):
418+
) -> None:
466419
adata = pbmc3k_processed().raw.to_adata()
467420
adata = adata[
468421
adata.obs["louvain"].isin(adata.obs["louvain"].cat.categories[:5]), :1_000
@@ -476,7 +429,7 @@ def test_aggregate_arraytype(
476429
)
477430

478431

479-
def test_aggregate_obsm_varm():
432+
def test_aggregate_obsm_varm() -> None:
480433
adata_obsm = sc.datasets.blobs()
481434
adata_obsm.obs["blobs"] = adata_obsm.obs["blobs"].astype(str)
482435
adata_obsm.obsm["test"] = adata_obsm.X[:, ::2].copy()
@@ -502,7 +455,7 @@ def test_aggregate_obsm_varm():
502455
assert_equal(expected_mean.values, result_obsm.layers["mean"])
503456

504457

505-
def test_aggregate_obsm_labels():
458+
def test_aggregate_obsm_labels() -> None:
506459
from itertools import chain, repeat
507460

508461
label_counts = [("a", 5), ("b", 3), ("c", 4)]
@@ -546,13 +499,13 @@ def test_aggregate_obsm_labels():
546499
assert_equal(expected, result)
547500

548501

549-
def test_dispatch_not_implemented():
502+
def test_dispatch_not_implemented() -> None:
550503
adata = sc.datasets.blobs()
551504
with pytest.raises(NotImplementedError):
552505
sc.get.aggregate(adata.X, adata.obs["blobs"], "sum")
553506

554507

555-
def test_factors():
508+
def test_factors() -> None:
556509
from itertools import product
557510

558511
obs = pd.DataFrame(product(range(5), repeat=4), columns=list("abcd"))
@@ -564,3 +517,27 @@ def test_factors():
564517

565518
res = sc.get.aggregate(adata, by=["a", "b", "c", "d"], func="sum")
566519
np.testing.assert_equal(res.layers["sum"], adata.X)
520+
521+
522+
def test_nan() -> None:
523+
x = np.arange(6) + np.arange(6)[:, None]
524+
obs = pd.DataFrame(
525+
dict(
526+
cell_type=[np.nan, np.nan, "B", "C", "B", "B"],
527+
sample_id=["s1", "s1", "s1", "s2", "s2", "s2"],
528+
patient_type=[*(["responder"] * 3), *(["control"] * 3)],
529+
),
530+
index=[f"cell{i}" for i in range(x.shape[0])],
531+
)
532+
adata = ad.AnnData(x, obs=obs)
533+
534+
adata_agg = sc.get.aggregate(
535+
adata, by=["sample_id", "patient_type", "cell_type"], func="sum", layer=None
536+
)
537+
538+
assert adata_agg.obs.index.tolist() == [
539+
"s1_responder_B",
540+
"s2_control_B",
541+
"s2_control_C",
542+
]
543+
assert adata_agg.obs["n_obs_aggregated"].tolist() == [1, 2, 1]

0 commit comments

Comments
 (0)