Skip to content

Commit 5c1f387

Browse files
author
hhsecond
committed
SavedModel support
1 parent e319220 commit 5c1f387

File tree

2 files changed

+36
-17
lines changed

2 files changed

+36
-17
lines changed

mlflow_redisai/__init__.py

+34-14
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,32 @@
99
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST
1010

1111
from . import torchscript
12+
import mlflow.tensorflow
1213

1314

1415
_logger = logging.getLogger(__name__)
15-
SUPPORTED_DEPLOYMENT_FLAVORS = [torchscript.FLAVOR_NAME]
16+
SUPPORTED_DEPLOYMENT_FLAVORS = [torchscript.FLAVOR_NAME, mlflow.tensorflow.FLAVOR_NAME]
1617

1718

18-
_flavor2backend = {torchscript.FLAVOR_NAME: 'torch'}
19+
_flavor2backend = {
20+
torchscript.FLAVOR_NAME: 'torch',
21+
mlflow.tensorflow.FLAVOR_NAME: 'tf'}
1922

2023

2124
def _get_preferred_deployment_flavor(model_config):
2225
"""
23-
Obtains the flavor that MLflow would prefer to use when deploying the model.
26+
Obtains the flavor that MLflow would prefer to use when deploying the model on RedisAI.
2427
If the model does not contain any supported flavors for deployment, an exception
2528
will be thrown.
2629
2730
:param model_config: An MLflow model object
2831
:return: The name of the preferred deployment flavor for the specified model
2932
"""
33+
# TODO: add onnx & TFlite
3034
if torchscript.FLAVOR_NAME in model_config.flavors:
3135
return torchscript.FLAVOR_NAME
36+
elif mlflow.tensorflow.FLAVOR_NAME in model_config.flavors:
37+
return mlflow.tensorflow.FLAVOR_NAME
3238
else:
3339
raise MlflowException(
3440
message=(
@@ -70,9 +76,14 @@ def deploy(model_key, model_uri, flavor=None, device='cpu', **kwargs):
7076
"""
7177
Deploy an MLFlow model to RedisAI. User needs to pass the URL and credentials
7278
to connect to RedisAI server. Currently it accepts only TorchScript model, freezed
73-
Tensorflow model, Tensorflow lite model, ONNX model (any models like scikit-learn,
74-
spark which is converted to ONNX). Note: ml2rt is one of the package we have
75-
developed which can do the conversion from different frameworks to ONNX
79+
Tensorflow model and SavedModel from tensorflow through MLFlow although RedisAI
80+
can takes Tensorflow lite model, ONNX model (any models like scikit-learn, spark
81+
which is converted to ONNX).
82+
83+
Note: ml2rt is a package we have developed which can
84+
- do the conversion from different frameworks to ONNX
85+
- load SavedModel, freezed tensorflow, torchscript or ONNX models from disk
86+
- load script
7687
7788
:param model_key: Redis Key on which we deploy the model
7889
:param model_uri: The location, in URI format, of the MLflow model to deploy to RedisAI.
@@ -94,10 +105,13 @@ def deploy(model_key, model_uri, flavor=None, device='cpu', **kwargs):
94105
flavors. If the specified flavor is not present or not supported for deployment,
95106
an exception will be thrown.
96107
:param device: GPU or CPU
108+
:param kwargs: Parameters for RedisAI connection
109+
97110
"""
98111
model_path = _download_artifact_from_uri(model_uri)
112+
# TODO: use os.path for python2.x compatiblity
99113
path = Path(model_path)
100-
model_config = path.joinpath('MLmodel')
114+
model_config = path/'MLmodel'
101115
if not model_config.exists():
102116
raise MlflowException(
103117
message=(
@@ -112,13 +126,19 @@ def deploy(model_key, model_uri, flavor=None, device='cpu', **kwargs):
112126
_validate_deployment_flavor(model_config, flavor)
113127
_logger.info("Using the %s flavor for deployment!", flavor)
114128

115-
# TODO: Add mode (similar to sagemaker)
116-
117129
con = redisai.Client(**kwargs)
118-
model_path = list(path.joinpath('data').iterdir())[0]
119-
if model_path.suffix != '.pt':
120-
raise RuntimeError("Model file does not have a valid suffix. Expected .pt")
121-
model = ml2rt.load_model(model_path)
130+
if flavor == mlflow.tensorflow.FLAVOR_NAME:
131+
tags = model_config.flavors[flavor]['meta_graph_tags']
132+
signaturedef = model_config.flavors[flavor]['signature_def_key']
133+
model_dir = path/model_config.flavors[flavor]['saved_model_dir']
134+
model, inputs, outputs = ml2rt.load_model(model_dir, tags, signaturedef)
135+
else:
136+
# TODO: this assumes the torchscript is saved using mlflow-redisai
137+
model_path = list(path.joinpath('data').iterdir())[0]
138+
if model_path.suffix != '.pt':
139+
raise RuntimeError("Model file does not have a valid suffix. Expected .pt")
140+
model = ml2rt.load_model(model_path)
141+
inputs = outputs = None
122142
try:
123143
device = redisai.Device.__members__[device]
124144
except KeyError:
@@ -134,7 +154,7 @@ def deploy(model_key, model_uri, flavor=None, device='cpu', **kwargs):
134154
),
135155
error_code=INVALID_PARAMETER_VALUE)
136156
backend = redisai.Backend.__members__[backend]
137-
con.modelset(model_key, backend, device, model)
157+
con.modelset(model_key, backend, device, model, inputs=inputs, outputs=outputs)
138158

139159

140160
def delete(model_key, **kwargs):

mlflow_redisai/cli.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ def delete(model_key, host, port):
3737
" from the model's available flavors.".format(
3838
supported_flavors=mlflow_redisai.SUPPORTED_DEPLOYMENT_FLAVORS)))
3939
def deploy(model_key, model_uri, host, port, device, flavor):
40-
# TODO: add note about how to save the model because it doesn't accept
41-
# all MLFlow models
40+
# TODO: add note about how to save the model because it doesn't accept all MLFlow models
4241
"""
4342
Deploy MLFlow models on RedisAI
4443
"""
@@ -48,4 +47,4 @@ def deploy(model_key, model_uri, host, port, device, flavor):
4847
host=host,
4948
port=port,
5049
device=device,
51-
flavor=flavor)
50+
flavor=flavor)

0 commit comments

Comments
 (0)