@@ -39,7 +39,7 @@ def _fit(
39
39
quick_scale : bool = None ,
40
40
close_session = True ,
41
41
dtype = "float64"
42
- ) -> glm .typing .InputDataBaseTyping :
42
+ ) -> glm .typing .InputDataBase :
43
43
"""
44
44
:param noise_model: str, noise model to use in model-based unit_test. Possible options:
45
45
@@ -186,7 +186,7 @@ def _fit(
186
186
187
187
188
188
def lrt (
189
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
189
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
190
190
full_formula_loc : str ,
191
191
reduced_formula_loc : str ,
192
192
full_formula_scale : str = "~1" ,
@@ -370,7 +370,7 @@ def lrt(
370
370
371
371
372
372
def wald (
373
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
373
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
374
374
factor_loc_totest : Union [str , List [str ]] = None ,
375
375
coef_to_test : Union [str , List [str ]] = None ,
376
376
formula_loc : Union [None , str ] = None ,
@@ -547,7 +547,7 @@ def wald(
547
547
if isinstance (as_numeric , str ):
548
548
as_numeric = [as_numeric ]
549
549
550
- # # Parse input data formats:
550
+ # Parse input data formats:
551
551
gene_names = parse_gene_names (data , gene_names )
552
552
if dmat_loc is None and dmat_scale is None :
553
553
sample_description = parse_sample_description (data , sample_description )
@@ -644,7 +644,7 @@ def wald(
644
644
645
645
646
646
def t_test (
647
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
647
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
648
648
grouping ,
649
649
gene_names : Union [np .ndarray , list ] = None ,
650
650
sample_description : pd .DataFrame = None ,
@@ -686,7 +686,7 @@ def t_test(
686
686
687
687
688
688
def rank_test (
689
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
689
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
690
690
grouping : Union [str , np .ndarray , list ],
691
691
gene_names : Union [np .ndarray , list ] = None ,
692
692
sample_description : pd .DataFrame = None ,
@@ -728,7 +728,7 @@ def rank_test(
728
728
729
729
730
730
def two_sample (
731
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
731
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
732
732
grouping : Union [str , np .ndarray , list ],
733
733
as_numeric : Union [List [str ], Tuple [str ], str ] = (),
734
734
test : str = "t-test" ,
@@ -819,8 +819,8 @@ def two_sample(
819
819
:param kwargs: [Debugging] Additional arguments will be passed to the _fit method.
820
820
"""
821
821
if test in ['t-test' , 'rank' ] and noise_model is not None :
822
- raise ValueError ( 'base. two_sample(): Do not specify `noise_model` if using test t-test or rank_test: ' +
823
- 'The t-test is based on a gaussian noise model and wilcoxon is model free.' )
822
+ raise Warning ( ' two_sample(): Do not specify `noise_model` if using test t-test or rank_test: ' +
823
+ 'The t-test is based on a gaussian noise model and the rank sum test is model free.' )
824
824
825
825
gene_names = parse_gene_names (data , gene_names )
826
826
grouping = parse_grouping (data , sample_description , grouping )
@@ -848,6 +848,8 @@ def two_sample(
848
848
sample_description = sample_description ,
849
849
noise_model = noise_model ,
850
850
size_factors = size_factors ,
851
+ init_a = "closed_form" ,
852
+ init_b = "closed_form" ,
851
853
batch_size = batch_size ,
852
854
training_strategy = training_strategy ,
853
855
quick_scale = quick_scale ,
@@ -872,6 +874,8 @@ def two_sample(
872
874
sample_description = sample_description ,
873
875
noise_model = noise_model ,
874
876
size_factors = size_factors ,
877
+ init_a = "closed_form" ,
878
+ init_b = "closed_form" ,
875
879
batch_size = batch_size ,
876
880
training_strategy = training_strategy ,
877
881
quick_scale = quick_scale ,
@@ -883,16 +887,14 @@ def two_sample(
883
887
data = data ,
884
888
gene_names = gene_names ,
885
889
grouping = grouping ,
886
- is_sig_zerovar = is_sig_zerovar ,
887
- dtype = dtype
890
+ is_sig_zerovar = is_sig_zerovar
888
891
)
889
892
elif test .lower () == 'rank' :
890
893
de_test = rank_test (
891
894
data = data ,
892
895
gene_names = gene_names ,
893
896
grouping = grouping ,
894
- is_sig_zerovar = is_sig_zerovar ,
895
- dtype = dtype
897
+ is_sig_zerovar = is_sig_zerovar
896
898
)
897
899
else :
898
900
raise ValueError ('two_sample(): Parameter `test="%s"` not recognized.' % test )
@@ -901,19 +903,19 @@ def two_sample(
901
903
902
904
903
905
def pairwise (
904
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
906
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
905
907
grouping : Union [str , np .ndarray , list ],
906
908
as_numeric : Union [List [str ], Tuple [str ], str ] = (),
907
- test : str = ' z-test' ,
908
- lazy : bool = False ,
909
+ test : str = " z-test" ,
910
+ lazy : bool = True ,
909
911
gene_names : Union [np .ndarray , list ] = None ,
910
912
sample_description : pd .DataFrame = None ,
911
- noise_model : str = None ,
913
+ noise_model : str = "nb" ,
912
914
size_factors : np .ndarray = None ,
913
915
batch_size : int = None ,
914
916
training_strategy : Union [str , List [Dict [str , object ]], Callable ] = "AUTO" ,
915
917
is_sig_zerovar : bool = True ,
916
- quick_scale : bool = None ,
918
+ quick_scale : bool = False ,
917
919
dtype = "float64" ,
918
920
pval_correction : str = "global" ,
919
921
keep_full_test_objs : bool = False ,
@@ -1036,6 +1038,8 @@ def pairwise(
1036
1038
design_scale = dmat ,
1037
1039
gene_names = gene_names ,
1038
1040
size_factors = size_factors ,
1041
+ init_a = "closed_form" ,
1042
+ init_b = "closed_form" ,
1039
1043
batch_size = batch_size ,
1040
1044
training_strategy = training_strategy ,
1041
1045
quick_scale = quick_scale ,
@@ -1058,6 +1062,10 @@ def pairwise(
1058
1062
correction_type = pval_correction
1059
1063
)
1060
1064
else :
1065
+ if isinstance (data , anndata .AnnData ) or isinstance (data , anndata .Raw ):
1066
+ data = data .X
1067
+ elif isinstance (data , glm .typing .InputDataBase ):
1068
+ data = data .x
1061
1069
groups = np .unique (grouping )
1062
1070
pvals = np .tile (np .NaN , [len (groups ), len (groups ), data .shape [1 ]])
1063
1071
pvals [np .eye (pvals .shape [0 ]).astype (bool )] = 0
@@ -1073,16 +1081,19 @@ def pairwise(
1073
1081
for j , g2 in enumerate (groups [(i + 1 ):]):
1074
1082
j = j + i + 1
1075
1083
1076
- sel = (grouping == g1 ) | (grouping == g2 )
1084
+ idx = np .where (np .logical_or (
1085
+ grouping == g1 ,
1086
+ grouping == g2
1087
+ ))[0 ]
1077
1088
de_test_temp = two_sample (
1078
- data = data [sel ],
1079
- grouping = grouping [sel ],
1089
+ data = data [idx , : ],
1090
+ grouping = grouping [idx ],
1080
1091
as_numeric = as_numeric ,
1081
1092
test = test ,
1082
1093
gene_names = gene_names ,
1083
- sample_description = sample_description .iloc [sel ],
1094
+ sample_description = sample_description .iloc [idx , : ],
1084
1095
noise_model = noise_model ,
1085
- size_factors = size_factors [sel ] if size_factors is not None else None ,
1096
+ size_factors = size_factors [idx ] if size_factors is not None else None ,
1086
1097
batch_size = batch_size ,
1087
1098
training_strategy = training_strategy ,
1088
1099
quick_scale = quick_scale ,
@@ -1093,7 +1104,7 @@ def pairwise(
1093
1104
pvals [i , j ] = de_test_temp .pval
1094
1105
pvals [j , i ] = pvals [i , j ]
1095
1106
logfc [i , j ] = de_test_temp .log_fold_change ()
1096
- logfc [j , i ] = - logfc [i , j ]
1107
+ logfc [j , i ] = - logfc [i , j ]
1097
1108
if keep_full_test_objs :
1098
1109
tests [i , j ] = de_test_temp
1099
1110
tests [j , i ] = de_test_temp
@@ -1112,7 +1123,7 @@ def pairwise(
1112
1123
1113
1124
1114
1125
def versus_rest (
1115
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
1126
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
1116
1127
grouping : Union [str , np .ndarray , list ],
1117
1128
as_numeric : Union [List [str ], Tuple [str ], str ] = (),
1118
1129
test : str = 'wald' ,
@@ -1274,7 +1285,7 @@ def versus_rest(
1274
1285
1275
1286
1276
1287
def partition (
1277
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
1288
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
1278
1289
parts : Union [str , np .ndarray , list ],
1279
1290
gene_names : Union [np .ndarray , list ] = None ,
1280
1291
sample_description : pd .DataFrame = None
@@ -1317,7 +1328,7 @@ class _Partition:
1317
1328
1318
1329
def __init__ (
1319
1330
self ,
1320
- data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBaseTyping ],
1331
+ data : Union [anndata .AnnData , Raw , np .ndarray , scipy .sparse .csr_matrix , glm .typing .InputDataBase ],
1321
1332
parts : Union [str , np .ndarray , list ],
1322
1333
gene_names : Union [np .ndarray , list ] = None ,
1323
1334
sample_description : pd .DataFrame = None
@@ -1332,7 +1343,7 @@ def __init__(
1332
1343
:param gene_names: optional list/array of gene names which will be used if `data` does not implicitly store these
1333
1344
:param sample_description: optional pandas.DataFrame containing sample annotations
1334
1345
"""
1335
- if isinstance (data , glm .typing .InputDataBaseTyping ):
1346
+ if isinstance (data , glm .typing .InputDataBase ):
1336
1347
self .x = data .x
1337
1348
elif isinstance (data , anndata .AnnData ) or isinstance (data , Raw ):
1338
1349
self .x = data .X
0 commit comments