Skip to content

Commit 5833902

Browse files
authored
Merge pull request #33 from satra/master
enh: add trained model saving
2 parents 4078bf7 + 68b2953 commit 5833902

File tree

4 files changed

+87
-2
lines changed

4 files changed

+87
-2
lines changed

README.md

+14-1
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,19 @@ Each model contains:
171171
amount of predictions and F the different SHAP values for each feature.
172172
`shaps` is empty if `gen_shap` is set to `false` or if `permute` is set
173173
to true.
174+
- `model`: A pickled version of the model trained on all the input data.
175+
One can use this model to test on new data that has the exact same input
176+
shape and features as the trained model. For example:
177+
```python
178+
import pickle as pk
179+
import numpy as np
180+
with open("results-20201208T010313.229190.pkl", "rb") as fp:
181+
data = pk.load(fp)
182+
trained_model = data[0][1].output.model
183+
trained_model.predict(np.random.rand(1, 30))
184+
```
185+
Please check the value of `data[N][0]` to ensure that you are not using
186+
a permuted model.
174187
- One figure per metric with performance distribution across splits (with or
175188
without null distribution trained on permuted labels)
176189
- One figure per any metric with the word `score` in it reporting the results of
@@ -202,7 +215,7 @@ The actual numeric values are stored in a correspondingly named pkl file.
202215
## Debugging
203216

204217
You will need to understand a bit of pydra to know how to debug this application for
205-
now. If the process crashes, the easiest way to restart is to remove the `cache-wf`
218+
now. If the process crashes, the easiest way to restart is to remove the `cache-wf`
206219
folder first. However, if you are rerunning, you could also remove any `.lock` file
207220
in the `cache-wf`directory.
208221

pydra_ml/classifier.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@
55
from pydra.utils.messenger import AuditFlag, FileMessenger
66
import typing as ty
77
import os
8-
from .tasks import read_file, gen_splits, train_test_kernel, calc_metric, get_shap
8+
from .tasks import (
9+
read_file,
10+
gen_splits,
11+
train_test_kernel,
12+
calc_metric,
13+
get_shap,
14+
create_model,
15+
)
916
from .report import gen_report
1017

1118
# Create pydra tasks
@@ -36,6 +43,10 @@
3643

3744
get_shap_pdt = task(annotate({"return": {"shaps": ty.Any}})(get_shap))
3845

46+
create_model_pdt = task(
47+
annotate({"return": {"output": ty.Any, "model": ty.Any}})(create_model)
48+
)
49+
3950

4051
def gen_workflow(inputs, cache_dir=None, cache_locations=None):
4152
wf = pydra.Workflow(
@@ -98,12 +109,22 @@ def gen_workflow(inputs, cache_dir=None, cache_locations=None):
98109
)
99110
)
100111
wf.shap.combine("fit_clf.split_index")
112+
wf.add(
113+
create_model_pdt(
114+
name="create_model",
115+
X=wf.readcsv.lzout.X,
116+
y=wf.readcsv.lzout.Y,
117+
clf_info=wf.lzin.clf_info,
118+
permute=wf.lzin.permute,
119+
)
120+
)
101121
wf.set_output(
102122
[
103123
("output", wf.metric.lzout.output),
104124
("score", wf.metric.lzout.score),
105125
("shaps", wf.shap.lzout.shaps),
106126
("feature_names", wf.readcsv.lzout.feature_names),
127+
("model", wf.create_model.lzout.model),
107128
]
108129
)
109130
return wf

pydra_ml/tasks.py

+46
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,49 @@ def get_shap(X, permute, model, gen_shap=False, nsamples="auto", l1_reg="aic"):
134134
explainer = shap.KernelExplainer(pipe.predict, shap.kmeans(X[train_index], 5))
135135
shaps = explainer.shap_values(X[test_index], nsamples=nsamples, l1_reg=l1_reg)
136136
return shaps
137+
138+
139+
def create_model(X, y, clf_info, permute):
140+
"""Train a model with all the data
141+
142+
:param X: Input features
143+
:param y: Target variables
144+
:param clf_info: how to construct the classifier
145+
:param permute: whether to run it in permuted mode or not
146+
:return: training error, classifier
147+
"""
148+
from sklearn.pipeline import Pipeline
149+
import numpy as np
150+
151+
def to_instance(clf_info):
152+
mod = __import__(clf_info[0], fromlist=[clf_info[1]])
153+
params = {}
154+
if len(clf_info) > 2:
155+
params = clf_info[2]
156+
clf = getattr(mod, clf_info[1])(**params)
157+
if len(clf_info) == 4:
158+
from sklearn.model_selection import GridSearchCV
159+
160+
clf = GridSearchCV(clf, param_grid=clf_info[3])
161+
return clf
162+
163+
if isinstance(clf_info[0], list):
164+
# Process as a pipeline constructor
165+
steps = []
166+
for val in clf_info:
167+
step = to_instance(val)
168+
steps.append((val[1], step))
169+
pipe = Pipeline(steps)
170+
else:
171+
clf = to_instance(clf_info)
172+
from sklearn.preprocessing import StandardScaler
173+
174+
pipe = Pipeline([("std", StandardScaler()), (clf_info[1], clf)])
175+
176+
y = y.ravel()
177+
if permute:
178+
pipe.fit(X, y[np.random.permutation(range(len(y)))])
179+
else:
180+
pipe.fit(X, y)
181+
predicted = pipe.predict(X)
182+
return (y, predicted), pipe

pydra_ml/tests/test_classifier.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from ..classifier import gen_workflow, run_workflow
3+
import numpy as np
34

45

56
def test_classifier(tmpdir):
@@ -32,6 +33,8 @@ def test_classifier(tmpdir):
3233
assert results[0][0]["ml_wf.clf_info"][1] == "MLPClassifier"
3334
assert results[0][0]["ml_wf.permute"]
3435
assert results[0][1].output.score[0][0] < results[1][1].output.score[0][0]
36+
assert hasattr(results[2][1].output.model, "predict")
37+
assert isinstance(results[2][1].output.model.predict(np.ones((1, 30))), np.ndarray)
3538

3639

3740
def test_regressor(tmpdir):
@@ -69,3 +72,5 @@ def test_regressor(tmpdir):
6972
assert results[0][0]["ml_wf.clf_info"][-1][1] == "MLPRegressor"
7073
assert results[0][0]["ml_wf.permute"]
7174
assert results[0][1].output.score[0][0] < results[1][1].output.score[0][0]
75+
assert hasattr(results[2][1].output.model, "predict")
76+
assert isinstance(results[2][1].output.model.predict(np.ones((1, 10))), np.ndarray)

0 commit comments

Comments
 (0)