9
9
from mlflow .protos .databricks_pb2 import INVALID_PARAMETER_VALUE , RESOURCE_DOES_NOT_EXIST
10
10
11
11
from . import torchscript
12
+ import mlflow .tensorflow
12
13
13
14
14
15
_logger = logging .getLogger (__name__ )
15
- SUPPORTED_DEPLOYMENT_FLAVORS = [torchscript .FLAVOR_NAME ]
16
+ SUPPORTED_DEPLOYMENT_FLAVORS = [torchscript .FLAVOR_NAME , mlflow . tensorflow . FLAVOR_NAME ]
16
17
17
18
18
- _flavor2backend = {torchscript .FLAVOR_NAME : 'torch' }
19
+ _flavor2backend = {
20
+ torchscript .FLAVOR_NAME : 'torch' ,
21
+ mlflow .tensorflow .FLAVOR_NAME : 'tf' }
19
22
20
23
21
24
def _get_preferred_deployment_flavor (model_config ):
22
25
"""
23
- Obtains the flavor that MLflow would prefer to use when deploying the model.
26
+ Obtains the flavor that MLflow would prefer to use when deploying the model on RedisAI .
24
27
If the model does not contain any supported flavors for deployment, an exception
25
28
will be thrown.
26
29
27
30
:param model_config: An MLflow model object
28
31
:return: The name of the preferred deployment flavor for the specified model
29
32
"""
33
+ # TODO: add onnx & TFlite
30
34
if torchscript .FLAVOR_NAME in model_config .flavors :
31
35
return torchscript .FLAVOR_NAME
36
+ elif mlflow .tensorflow .FLAVOR_NAME in model_config .flavors :
37
+ return mlflow .tensorflow .FLAVOR_NAME
32
38
else :
33
39
raise MlflowException (
34
40
message = (
@@ -70,9 +76,14 @@ def deploy(model_key, model_uri, flavor=None, device='cpu', **kwargs):
70
76
"""
71
77
Deploy an MLFlow model to RedisAI. User needs to pass the URL and credentials
72
78
to connect to RedisAI server. Currently it accepts only TorchScript model, freezed
73
- Tensorflow model, Tensorflow lite model, ONNX model (any models like scikit-learn,
74
- spark which is converted to ONNX). Note: ml2rt is one of the package we have
75
- developed which can do the conversion from different frameworks to ONNX
79
+ Tensorflow model and SavedModel from tensorflow through MLFlow although RedisAI
80
+ can takes Tensorflow lite model, ONNX model (any models like scikit-learn, spark
81
+ which is converted to ONNX).
82
+
83
+ Note: ml2rt is a package we have developed which can
84
+ - do the conversion from different frameworks to ONNX
85
+ - load SavedModel, freezed tensorflow, torchscript or ONNX models from disk
86
+ - load script
76
87
77
88
:param model_key: Redis Key on which we deploy the model
78
89
:param model_uri: The location, in URI format, of the MLflow model to deploy to RedisAI.
@@ -94,10 +105,13 @@ def deploy(model_key, model_uri, flavor=None, device='cpu', **kwargs):
94
105
flavors. If the specified flavor is not present or not supported for deployment,
95
106
an exception will be thrown.
96
107
:param device: GPU or CPU
108
+ :param kwargs: Parameters for RedisAI connection
109
+
97
110
"""
98
111
model_path = _download_artifact_from_uri (model_uri )
112
+ # TODO: use os.path for python2.x compatiblity
99
113
path = Path (model_path )
100
- model_config = path . joinpath ( 'MLmodel' )
114
+ model_config = path / 'MLmodel'
101
115
if not model_config .exists ():
102
116
raise MlflowException (
103
117
message = (
@@ -112,13 +126,19 @@ def deploy(model_key, model_uri, flavor=None, device='cpu', **kwargs):
112
126
_validate_deployment_flavor (model_config , flavor )
113
127
_logger .info ("Using the %s flavor for deployment!" , flavor )
114
128
115
- # TODO: Add mode (similar to sagemaker)
116
-
117
129
con = redisai .Client (** kwargs )
118
- model_path = list (path .joinpath ('data' ).iterdir ())[0 ]
119
- if model_path .suffix != '.pt' :
120
- raise RuntimeError ("Model file does not have a valid suffix. Expected .pt" )
121
- model = ml2rt .load_model (model_path )
130
+ if flavor == mlflow .tensorflow .FLAVOR_NAME :
131
+ tags = model_config .flavors [flavor ]['meta_graph_tags' ]
132
+ signaturedef = model_config .flavors [flavor ]['signature_def_key' ]
133
+ model_dir = path / model_config .flavors [flavor ]['saved_model_dir' ]
134
+ model , inputs , outputs = ml2rt .load_model (model_dir , tags , signaturedef )
135
+ else :
136
+ # TODO: this assumes the torchscript is saved using mlflow-redisai
137
+ model_path = list (path .joinpath ('data' ).iterdir ())[0 ]
138
+ if model_path .suffix != '.pt' :
139
+ raise RuntimeError ("Model file does not have a valid suffix. Expected .pt" )
140
+ model = ml2rt .load_model (model_path )
141
+ inputs = outputs = None
122
142
try :
123
143
device = redisai .Device .__members__ [device ]
124
144
except KeyError :
@@ -134,7 +154,7 @@ def deploy(model_key, model_uri, flavor=None, device='cpu', **kwargs):
134
154
),
135
155
error_code = INVALID_PARAMETER_VALUE )
136
156
backend = redisai .Backend .__members__ [backend ]
137
- con .modelset (model_key , backend , device , model )
157
+ con .modelset (model_key , backend , device , model , inputs = inputs , outputs = outputs )
138
158
139
159
140
160
def delete (model_key , ** kwargs ):
0 commit comments