1919
2020if TYPE_CHECKING :
2121 from collections .abc import Callable
22+ from typing import Literal
23+
24+ from numpy .typing import NDArray
2225
2326 from scanpy ._compat import CSRBase
2427
@@ -42,73 +45,14 @@ def xfail_dask_median(
4245 adata : ad .AnnData ,
4346 metric : AggType ,
4447 request : pytest .FixtureRequest ,
45- ):
48+ ) -> None :
4649 if isinstance (adata .X , DaskArray ) and metric == "median" :
4750 reason = "Median calculation not implemented for Dask"
4851 request .applymarker (pytest .mark .xfail (reason = reason ))
4952
5053
51- @pytest .fixture
52- def df_base ():
53- ax_base = ["A" , "B" ]
54- return pd .DataFrame (index = ax_base )
55-
56-
57- @pytest .fixture
58- def df_groupby ():
59- ax_groupby = [
60- * ["v0" , "v1" , "v2" ],
61- * ["w0" , "w1" ],
62- * ["a1" , "a2" , "a3" ],
63- * ["b1" , "b2" ],
64- * ["c1" , "c2" ],
65- "d0" ,
66- ]
67-
68- df_groupby = pd .DataFrame (index = pd .Index (ax_groupby , name = "cell" ))
69- df_groupby ["key" ] = pd .Categorical ([c [0 ] for c in ax_groupby ])
70- df_groupby ["key_superset" ] = pd .Categorical ([c [0 ] for c in ax_groupby ]).map ({
71- ** {"v" : "v" , "w" : "v" }, # noqa: PIE800
72- ** {"a" : "a" , "b" : "a" , "c" : "a" , "d" : "a" }, # noqa: PIE800
73- })
74- df_groupby ["key_subset" ] = pd .Categorical ([c [1 ] for c in ax_groupby ])
75- df_groupby ["weight" ] = 2.0
76- return df_groupby
77-
78-
79- @pytest .fixture
80- def x ():
81- data = [
82- * [[0 , - 2 ], [1 , 13 ], [2 , 1 ]], # v
83- * [[3 , 12 ], [4 , 2 ]], # w
84- * [[5 , 11 ], [6 , 3 ], [7 , 10 ]], # a
85- * [[8 , 4 ], [9 , 9 ]], # b
86- * [[10 , 5 ], [11 , 8 ]], # c
87- [12 , 6 ], # d
88- ]
89- return np .array (data , dtype = np .float32 )
90-
91-
92- def gen_adata (data_key , dim , df_base , df_groupby , x ):
93- if (data_key == "varm" and dim == "obs" ) or (data_key == "obsm" and dim == "var" ):
94- pytest .skip ("invalid parameter combination" )
95-
96- obs_df , var_df = (df_groupby , df_base ) if dim == "obs" else (df_base , df_groupby )
97- data = x .T if dim == "var" and data_key != "varm" else x
98- if data_key != "X" :
99- data_dict_sparse = {data_key : {"test" : sparse .csr_matrix (data )}} # noqa: TID251
100- data_dict_dense = {data_key : {"test" : data }}
101- else :
102- data_dict_sparse = {data_key : sparse .csr_matrix (data )} # noqa: TID251
103- data_dict_dense = {data_key : data }
104-
105- adata_sparse = ad .AnnData (obs = obs_df , var = var_df , ** data_dict_sparse )
106- adata_dense = ad .AnnData (obs = obs_df , var = var_df , ** data_dict_dense )
107- return adata_sparse , adata_dense
108-
109-
11054@pytest .mark .parametrize ("axis" , [0 , 1 ])
111- def test_mask (axis ) :
55+ def test_mask (axis : Literal [ 0 , 1 ]) -> None :
11256 blobs = sc .datasets .blobs ()
11357 mask = blobs .obs ["blobs" ] == 0
11458 blobs .obs ["mask_col" ] = mask
@@ -125,7 +69,7 @@ def test_mask(axis):
12569@pytest .mark .parametrize ("array_type" , VALID_ARRAY_TYPES )
12670def test_aggregate_vs_pandas (
12771 metric : AggType , array_type , request : pytest .FixtureRequest
128- ):
72+ ) -> None :
12973 adata = pbmc3k_processed ().raw .to_adata ()
13074 adata = adata [
13175 adata .obs ["louvain" ].isin (adata .obs ["louvain" ].cat .categories [:5 ]), :1_000
@@ -165,7 +109,9 @@ def test_aggregate_vs_pandas(
165109
166110
167111@pytest .mark .parametrize ("array_type" , VALID_ARRAY_TYPES )
168- def test_aggregate_axis (array_type , metric , request : pytest .FixtureRequest ):
112+ def test_aggregate_axis (
113+ array_type , metric : AggType , request : pytest .FixtureRequest
114+ ) -> None :
169115 adata = pbmc3k_processed ().raw .to_adata ()
170116 adata = adata [
171117 adata .obs ["louvain" ].isin (adata .obs ["louvain" ].cat .categories [:5 ]), :1_000
@@ -178,7 +124,7 @@ def test_aggregate_axis(array_type, metric, request: pytest.FixtureRequest):
178124 assert_equal (expected , actual )
179125
180126
181- def test_aggregate_entry ():
127+ def test_aggregate_entry () -> None :
182128 args = ("blobs" , ["mean" , "var" , "count_nonzero" ])
183129
184130 adata = sc .datasets .blobs ()
@@ -213,14 +159,14 @@ def test_aggregate_entry():
213159 assert_equal (x_result .layers , varm_result .T .layers )
214160
215161
216- def test_aggregate_incorrect_dim ():
162+ def test_aggregate_incorrect_dim () -> None :
217163 adata = pbmc3k_processed ().raw .to_adata ()
218164
219165 with pytest .raises (ValueError , match = "was 'foo'" ):
220166 sc .get .aggregate (adata , ["louvain" ], "sum" , axis = "foo" )
221167
222168
223- def to_bad_chunking (x : CSRBase ):
169+ def to_bad_chunking (x : CSRBase ) -> DaskArray :
224170 import dask .array as da
225171
226172 return da .from_array (
@@ -230,7 +176,7 @@ def to_bad_chunking(x: CSRBase):
230176 )
231177
232178
233- def to_csc (x : CSRBase ):
179+ def to_csc (x : CSRBase ) -> DaskArray :
234180 import dask .array as da
235181
236182 return da .from_array (
@@ -249,15 +195,17 @@ def to_csc(x: CSRBase):
249195 ),
250196 ],
251197)
252- def test_aggregate_bad_dask_array (func : Callable [[CSRBase ], DaskArray ], error_msg : str ):
198+ def test_aggregate_bad_dask_array (
199+ func : Callable [[CSRBase ], DaskArray ], error_msg : str
200+ ) -> None :
253201 adata = pbmc3k_processed ().raw .to_adata ()
254202 adata .X = func (adata .X )
255203 with pytest .raises (ValueError , match = error_msg ):
256204 sc .get .aggregate (adata , ["louvain" ], "sum" )
257205
258206
259207@pytest .mark .parametrize ("axis_name" , ["obs" , "var" ])
260- def test_aggregate_axis_specification (axis_name ) :
208+ def test_aggregate_axis_specification (axis_name : Literal [ "obs" , "var" ]) -> None :
261209 axis , axis_name = _resolve_axis (axis_name )
262210 by = "blobs" if axis == 0 else "labels"
263211
@@ -381,17 +329,20 @@ def test_aggregate_axis_specification(axis_name):
381329 ),
382330 ],
383331)
384- def test_aggregate_examples (matrix , df , keys , metrics , expected ):
332+ def test_aggregate_examples (
333+ matrix : NDArray [np .int64 ],
334+ df : pd .DataFrame ,
335+ keys : list [str ],
336+ metrics : list [AggType ],
337+ expected : ad .AnnData ,
338+ ) -> None :
385339 adata = ad .AnnData (
386340 X = matrix ,
387341 obs = df ,
388342 var = pd .DataFrame (index = [f"gene_{ i } " for i in range (matrix .shape [1 ])]),
389343 )
390344 result = sc .get .aggregate (adata , by = keys , func = metrics )
391345
392- print (result )
393- print (expected )
394-
395346 assert_equal (expected , result )
396347
397348
@@ -439,7 +390,9 @@ def test_aggregate_examples(matrix, df, keys, metrics, expected):
439390 ),
440391 ],
441392)
442- def test_combine_categories (label_cols , cols , expected ):
393+ def test_combine_categories (
394+ label_cols : dict [str , pd .Categorical ], cols : list [str ], expected : pd .Categorical
395+ ) -> None :
443396 from scanpy .get ._aggregated import _combine_categories
444397
445398 label_df = pd .DataFrame (label_cols )
@@ -462,7 +415,7 @@ def test_combine_categories(label_cols, cols, expected):
462415@pytest .mark .parametrize ("array_type" , VALID_ARRAY_TYPES )
463416def test_aggregate_arraytype (
464417 array_type , metric : AggType , request : pytest .FixtureRequest
465- ):
418+ ) -> None :
466419 adata = pbmc3k_processed ().raw .to_adata ()
467420 adata = adata [
468421 adata .obs ["louvain" ].isin (adata .obs ["louvain" ].cat .categories [:5 ]), :1_000
@@ -476,7 +429,7 @@ def test_aggregate_arraytype(
476429 )
477430
478431
479- def test_aggregate_obsm_varm ():
432+ def test_aggregate_obsm_varm () -> None :
480433 adata_obsm = sc .datasets .blobs ()
481434 adata_obsm .obs ["blobs" ] = adata_obsm .obs ["blobs" ].astype (str )
482435 adata_obsm .obsm ["test" ] = adata_obsm .X [:, ::2 ].copy ()
@@ -502,7 +455,7 @@ def test_aggregate_obsm_varm():
502455 assert_equal (expected_mean .values , result_obsm .layers ["mean" ])
503456
504457
505- def test_aggregate_obsm_labels ():
458+ def test_aggregate_obsm_labels () -> None :
506459 from itertools import chain , repeat
507460
508461 label_counts = [("a" , 5 ), ("b" , 3 ), ("c" , 4 )]
@@ -546,13 +499,13 @@ def test_aggregate_obsm_labels():
546499 assert_equal (expected , result )
547500
548501
549- def test_dispatch_not_implemented ():
502+ def test_dispatch_not_implemented () -> None :
550503 adata = sc .datasets .blobs ()
551504 with pytest .raises (NotImplementedError ):
552505 sc .get .aggregate (adata .X , adata .obs ["blobs" ], "sum" )
553506
554507
555- def test_factors ():
508+ def test_factors () -> None :
556509 from itertools import product
557510
558511 obs = pd .DataFrame (product (range (5 ), repeat = 4 ), columns = list ("abcd" ))
@@ -564,3 +517,27 @@ def test_factors():
564517
565518 res = sc .get .aggregate (adata , by = ["a" , "b" , "c" , "d" ], func = "sum" )
566519 np .testing .assert_equal (res .layers ["sum" ], adata .X )
520+
521+
522+ def test_nan () -> None :
523+ x = np .arange (6 ) + np .arange (6 )[:, None ]
524+ obs = pd .DataFrame (
525+ dict (
526+ cell_type = [np .nan , np .nan , "B" , "C" , "B" , "B" ],
527+ sample_id = ["s1" , "s1" , "s1" , "s2" , "s2" , "s2" ],
528+ patient_type = [* (["responder" ] * 3 ), * (["control" ] * 3 )],
529+ ),
530+ index = [f"cell{ i } " for i in range (x .shape [0 ])],
531+ )
532+ adata = ad .AnnData (x , obs = obs )
533+
534+ adata_agg = sc .get .aggregate (
535+ adata , by = ["sample_id" , "patient_type" , "cell_type" ], func = "sum" , layer = None
536+ )
537+
538+ assert adata_agg .obs .index .tolist () == [
539+ "s1_responder_B" ,
540+ "s2_control_B" ,
541+ "s2_control_C" ,
542+ ]
543+ assert adata_agg .obs ["n_obs_aggregated" ].tolist () == [1 , 2 , 1 ]
0 commit comments