Skip to content

Commit 239fc42

Browse files
danielmlowsatra
andauthored
cleaner v03: empirical p-value, performance table, README, feature_importance and permutation_importance (#43)
* empircal p-value, performance table, improved README, feature_importance, permutation_importance * added example images, removed example output directories * added example images, removed example output directories * Update README.md Co-authored-by: Satrajit Ghosh <[email protected]> * empirical p-value, performance table, README, feature_importance and permutation_importance * edited with pre-commit * removed clear_locks * added new feature importance arguments to tests/test_classifier.py. * Delete .Rhistory Co-authored-by: danielmlow <danielmlow@> Co-authored-by: Satrajit Ghosh <[email protected]>
1 parent 149b819 commit 239fc42

12 files changed

+536
-87
lines changed

README.md

+158-52
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
{"filename": "breast_cancer.csv",
2+
"x_indices": ["radius_mean", "texture_mean","perimeter_mean", "area_mean", "smoothness_mean",
3+
"compactness_mean", "concavity_mean", "concave points_mean",
4+
"symmetry_mean", "fractal_dimension_mean", "radius_se",
5+
"texture_se", "perimeter_se", "area_se", "smoothness_se",
6+
"compactness_se", "concavity_se", "concave points_se",
7+
"symmetry_se", "fractal_dimension_se", "radius_worst",
8+
"texture_worst", "perimeter_worst", "area_worst",
9+
"smoothness_worst", "compactness_worst", "concavity_worst",
10+
"concave points_worst", "symmetry_worst", "fractal_dimension_worst"],
11+
"target_vars": ["target"],
12+
"group_var": null,
13+
"n_splits": 100,
14+
"test_size": 0.2,
15+
"clf_info": [
16+
["sklearn.ensemble", "AdaBoostClassifier"],
17+
["sklearn.naive_bayes", "GaussianNB"],
18+
[ ["sklearn.impute", "SimpleImputer"],
19+
["sklearn.preprocessing", "StandardScaler"],
20+
["sklearn.tree", "DecisionTreeClassifier", {"max_depth": 5}]],
21+
["sklearn.ensemble", "RandomForestClassifier", {"n_estimators": 100}],
22+
["sklearn.ensemble", "ExtraTreesClassifier", {"n_estimators": 100, "class_weight": "balanced"}],
23+
["sklearn.linear_model", "LogisticRegressionCV", {"solver": "liblinear", "penalty": "l1"}],
24+
["sklearn.neural_network", "MLPClassifier", {"alpha": 1, "max_iter": 1000}],
25+
["sklearn.svm", "SVC", {"probability": true},
26+
[{"kernel": ["rbf", "linear"], "C": [1, 10, 100, 1000]}]],
27+
["sklearn.neighbors", "KNeighborsClassifier", {},
28+
[{"n_neighbors": [3, 5, 7, 9, 11, 13, 15, 17, 19],
29+
"weights": ["uniform", "distance"]}]]
30+
],
31+
"permute": [true, false],
32+
"gen_feature_importance": false,
33+
"gen_permutation_importance": false,
34+
"permutation_importance_n_repeats": 5,
35+
"permutation_importance_scoring": "accuracy",
36+
"gen_shap": true,
37+
"nsamples": "auto",
38+
"l1_reg": "aic",
39+
"plot_top_n_shap": 16,
40+
"metrics": ["roc_auc_score", "f1_score", "precision_score", "recall_score"]
41+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{"filename": "breast_cancer.csv",
2+
"x_indices": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
3+
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
4+
"target_vars": ["target"],
5+
"group_var": null,
6+
"n_splits": 3,
7+
"test_size": 0.2,
8+
"clf_info": [
9+
["sklearn.neural_network", "MLPClassifier", {"alpha": 1, "max_iter": 1000}],
10+
[ ["sklearn.impute", "SimpleImputer"],
11+
["sklearn.preprocessing", "StandardScaler"],
12+
["sklearn.tree", "DecisionTreeClassifier", {"max_depth": 5}]
13+
]
14+
],
15+
"permute": [false, true],
16+
"gen_feature_importance": false,
17+
"gen_permutation_importance": false,
18+
"permutation_importance_n_repeats": 5,
19+
"permutation_importance_scoring": "accuracy",
20+
"gen_shap": true,
21+
"nsamples": 100,
22+
"l1_reg": "aic",
23+
"plot_top_n_shap": 16,
24+
"metrics": ["roc_auc_score", "accuracy_score"]
25+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{"filename": "breast_cancer.csv",
2+
"x_indices": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
3+
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
4+
"target_vars": ["target"],
5+
"group_var": null,
6+
"n_splits": 3,
7+
"test_size": 0.2,
8+
"clf_info": [
9+
["sklearn.svm", "SVC", {"kernel": "linear"}],
10+
["sklearn.linear_model", "LogisticRegression", {"penalty": "l1", "solver":"liblinear"}]
11+
],
12+
"permute": [false, true],
13+
"gen_feature_importance": true,
14+
"gen_permutation_importance": true,
15+
"permutation_importance_n_repeats": 5,
16+
"permutation_importance_scoring": "accuracy",
17+
"gen_shap": true,
18+
"nsamples": 100,
19+
"l1_reg": "aic",
20+
"plot_top_n_shap": 16,
21+
"metrics": ["roc_auc_score", "precision_score", "recall_score"]
22+
}
+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{"filename": "diabetes_table.csv",
2+
"x_indices": [0,1,2,3,4,5,6,7,8,9],
3+
"target_vars": ["target"],
4+
"group_var": null,
5+
"n_splits": 4,
6+
"test_size": 0.2,
7+
"clf_info": [
8+
["sklearn.linear_model","RidgeCV",{"fit_intercept": true,"normalize": true}],
9+
["sklearn.linear_model","LassoCV",{"fit_intercept": true,"normalize": true}],
10+
["sklearn.linear_model","ElasticNetCV",{"fit_intercept": true,"normalize": true}]
11+
],
12+
"permute": [true,false],
13+
"gen_feature_importance": false,
14+
"gen_permutation_importance": false,
15+
"permutation_importance_n_repeats": 5,
16+
"permutation_importance_scoring": null,
17+
"gen_shap": true,
18+
"nsamples": 100,
19+
"l1_reg": "aic",
20+
"plot_top_n_shap": 10,
21+
"metrics":["explained_variance_score","mean_squared_error","mean_absolute_error"]
22+
}

examples/shap_example.png

192 KB
Loading
56.6 KB
Loading
Loading

pydra_ml/classifier.py

+37
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
calc_metric,
1313
create_model,
1414
gen_splits,
15+
get_feature_importance,
16+
get_permutation_importance,
1517
get_shap,
1618
read_file,
1719
train_test_kernel,
@@ -43,6 +45,14 @@
4345
annotate({"return": {"score": ty.Any, "output": ty.Any}})(calc_metric)
4446
)
4547

48+
get_feature_importance_pdt = task(
49+
annotate({"return": {"feature_importance": ty.Any}})(get_feature_importance)
50+
)
51+
52+
get_permutation_importance_pdt = task(
53+
annotate({"return": {"permutation_importance": ty.Any}})(get_permutation_importance)
54+
)
55+
4656
get_shap_pdt = task(annotate({"return": {"shaps": ty.Any}})(get_shap))
4757

4858
create_model_pdt = task(
@@ -99,6 +109,28 @@ def gen_workflow(inputs, cache_dir=None, cache_locations=None):
99109
)
100110
)
101111
wf.metric.combine("fit_clf.split_index")
112+
wf.add(
113+
get_feature_importance_pdt(
114+
name="feature_importance",
115+
permute=wf.lzin.permute,
116+
model=wf.fit_clf.lzout.model,
117+
gen_feature_importance=wf.lzin.gen_feature_importance,
118+
)
119+
)
120+
wf.feature_importance.combine("fit_clf.split_index")
121+
wf.add(
122+
get_permutation_importance_pdt(
123+
name="permutation_importance",
124+
X=wf.readcsv.lzout.X,
125+
y=wf.readcsv.lzout.Y,
126+
permute=wf.lzin.permute,
127+
model=wf.fit_clf.lzout.model,
128+
permutation_importance_n_repeats=wf.lzin.permutation_importance_n_repeats,
129+
permutation_importance_scoring=wf.lzin.permutation_importance_scoring,
130+
gen_permutation_importance=wf.lzin.gen_permutation_importance,
131+
)
132+
)
133+
wf.permutation_importance.combine("fit_clf.split_index")
102134
wf.add(
103135
get_shap_pdt(
104136
name="shap",
@@ -124,6 +156,11 @@ def gen_workflow(inputs, cache_dir=None, cache_locations=None):
124156
[
125157
("output", wf.metric.lzout.output),
126158
("score", wf.metric.lzout.score),
159+
("feature_importance", wf.feature_importance.lzout.feature_importance),
160+
(
161+
"permutation_importance",
162+
wf.permutation_importance.lzout.permutation_importance,
163+
),
127164
("shaps", wf.shap.lzout.shaps),
128165
("feature_names", wf.readcsv.lzout.feature_names),
129166
("model", wf.create_model.lzout.model),

0 commit comments

Comments
 (0)