Skip to content

Commit 5ac7548

Browse files
committed
added some code to test reported issues
Signed-off-by: Jules Damji <[email protected]>
1 parent 76cd84d commit 5ac7548

File tree

6 files changed

+95
-25
lines changed

6 files changed

+95
-25
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import numpy as np
2+
import mlflow
3+
4+
QADict = {'What is your full name?': 'My name is King George.',
5+
'Have you taken your medication?': "I don't think so."}
6+
7+
8+
class convEngine(mlflow.pyfunc.PythonModel):
9+
10+
def __init__(self, QADict, model):
11+
self.model = model
12+
self.Qlist = list(QADict.keys())
13+
self.Qanswers = self.embed(self.Qlist)
14+
self.QADict = QADict
15+
16+
def embed(self, input):
17+
return self.model(input)
18+
19+
def predict(self, model_input_string):
20+
qasked = self.embed([model_input_string])
21+
corr = np.inner(qasked, self.Qanswers)
22+
index_max = np.argmax(corr)
23+
24+
return self.QADict[self.Qlist[index_max]]
25+
26+
27+
if __name__ == '__main__':
28+
29+
import tensorflow_hub as hub
30+
import cloudpickle
31+
32+
module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
33+
model = hub.load(module_url)
34+
model_path = "/tmp/tf_model"
35+
36+
with open(model_path, "w+") as f:
37+
cloudpickle.dump(model, f)
38+
"""
39+
convModel = convEngine(QADict, model)
40+
mlflow.pyfunc.save_model(path=model_path, python_model=convModel)
41+
"""
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import mlflow.sklearn
2+
from sklearn.datasets import load_iris
3+
from sklearn import tree
4+
5+
if __name__ == '__main__':
6+
7+
iris = load_iris()
8+
sk_model = tree.DecisionTreeClassifier()
9+
sk_model = sk_model.fit(iris.data, iris.target)
10+
11+
sk_path = 'saved_models'
12+
mlflow.set_tracking_uri('sqlite:///mlruns.db')
13+
with mlflow.start_run() as run:
14+
mlflow.log_metric('m', 1.5)
15+
mlflow.sklearn.save_model(sk_model, sk_path)
16+
17+
# Register model
18+
model_uri = "runs:/{}/model".format(run.info.run_id)
19+
registered_model_name = "RegisterSavedModel"
20+
mv = mlflow.register_model(model_uri, registered_model_name)
21+
print("Name: {}".format(mv.name))
22+
print("Version: {}".format(mv.version))

pytorch/mlflow/pytorch_autolog_mlflow_example.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,36 +53,43 @@ def forward(self, x):
5353

5454
return x
5555

56+
@staticmethod
57+
def add_model_specific_args(parent_parser):
5658

57-
if __name__ == "__main__":
58-
parser = ArgumentParser(description="PyTorch Autolog Mnist Example")
59+
parser = ArgumentParser(parents=[parent_parser], add_help=False)
5960

60-
# Add trainer specific arguments
61+
# Add trainer specific arguments
62+
parser.add_argument(
63+
"--tracking_uri", type=str, default="http://localhost:5000/", help="mlflow tracking uri"
64+
)
65+
parser.add_argument(
66+
"--max_epochs", type=int, default=20, help="number of epochs to run (default: 20)"
67+
)
68+
parser.add_argument(
69+
"--gpus", type=int, default=0, help="Number of gpus - by default runs on CPU"
70+
)
71+
parser.add_argument(
72+
"--accelerator",
73+
type=str,
74+
default=None,
75+
help="accelerator - (default: None)",
76+
)
77+
return parser
6178

62-
parser.add_argument(
63-
"--tracking_uri", type=str, default="http://localhost:5000/", help="mlflow tracking uri"
64-
)
65-
parser.add_argument(
66-
"--max_epochs", type=int, default=20, help="number of epochs to run (default: 20)"
67-
)
68-
parser.add_argument(
69-
"--gpus", type=int, default=0, help="Number of gpus - by default runs on CPU"
70-
)
71-
parser.add_argument(
72-
"--accelerator",
73-
type=str,
74-
default=None,
75-
help="accelerator - (default: None)",
76-
)
77-
parser = LightningMNISTClassifier.add_model_specific_args(parent_parser=parser)
79+
80+
if __name__ == "__main__":
81+
parent_parser = ArgumentParser(description="PyTorch Autolog Mnist Example")
82+
83+
parser = LightningMNISTClassifier.add_model_specific_args(parent_parser=parent_parser)
7884

7985
mlflow.pytorch.autolog() # just add this line and your Autologging should work!
8086

8187
args = parser.parse_args()
8288

8389
args = parser.parse_args()
8490
dict_args = vars(args)
85-
#mlflow.set_tracking_uri(dict_args['tracking_uri'])
91+
92+
# mlflow.set_tracking_uri(dict_args['tracking_uri'])
8693

8794
model = LightningMNISTClassifier(**dict_args)
8895
early_stopping = EarlyStopping(monitor="val_loss", mode="min", verbose=True)
@@ -93,15 +100,15 @@ def forward(self, x):
93100
verbose=True,
94101
monitor="val_loss",
95102
mode="min",
96-
prefix="",
103+
prefix=""
97104
)
98105
lr_logger = LearningRateMonitor()
99106

100107
trainer = pl.Trainer.from_argparse_args(
101108
args,
102109
callbacks=[lr_logger, early_stopping],
103110
checkpoint_callback=checkpoint_callback,
104-
train_percent_check=0.1,
111+
# train_percent_check=0.1,
105112
)
106113
trainer.fit(model)
107114
trainer.test()

pytorch/tutorials/13_feed_forward.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@
277277
"name": "python",
278278
"nbconvert_exporter": "python",
279279
"pygments_lexer": "ipython3",
280-
"version": "3.7.5"
280+
"version": "3.7.4"
281281
}
282282
},
283283
"nbformat": 4,

pytorch/tutorials/14_cnn.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@
245245
"name": "python",
246246
"nbconvert_exporter": "python",
247247
"pygments_lexer": "ipython3",
248-
"version": "3.7.5"
248+
"version": "3.7.4"
249249
}
250250
},
251251
"nbformat": 4,

pytorch/tutorials/18_custom_dataset.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,7 @@
904904
"name": "python",
905905
"nbconvert_exporter": "python",
906906
"pygments_lexer": "ipython3",
907-
"version": "3.7.5"
907+
"version": "3.7.4"
908908
}
909909
},
910910
"nbformat": 4,

0 commit comments

Comments
 (0)