Skip to content

Commit e8abb85

Browse files
authored
Add serialization of data types (#72)
* Add serialization of data types * Add AofRewrite (it might go in the future) and replication * Saving/loading tensors to/from RDB now working * Fix loading scripts * Handle encver properly * Use zero initialization for the error struct * Fix for script serialization. Add tests. * Fix length of input and output arrays * Fix arguments to calloc * Use getkeys-api * Use array_new instead of plain array * Disable AOF for now * Improve testing of unhappy paths * More unhappy path tests * Temporarily stop when AOF if activated
1 parent c2e9f10 commit e8abb85

File tree

19 files changed

+901
-124
lines changed

19 files changed

+901
-124
lines changed

docs/commands.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,17 @@ AI.MODELSET resnet18 TORCH GPU < foo.pt
6565
AI.MODELSET resnet18 TF CPU INPUTS in1 OUTPUTS linear4 < foo.pt
6666
```
6767

68+
## AI.MODELGET - get a model
69+
70+
```sql
71+
AI.MODELGET model_key
72+
```
73+
74+
* model_key - key for the model
75+
76+
> The command returns the model as serialized by the backend (i.e. a string containing a protobuf)
77+
78+
6879
## AI.MODELRUN - run a model
6980
```sql
7081
AI.MODELRUN model_key INPUTS input_key1 ... OUTPUTS output_key1 ...

examples/models/load_model.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ REDIS_CLI=../../deps/redis/src/redis-cli
22

33
echo "SET MODEL"
44
$REDIS_CLI -x AI.MODELSET foo TF GPU INPUTS a b OUTPUTS mul < graph.pb
5+
$REDIS_CLI AI.MODELGET foo
56

67
echo "SET TENSORS"
78
$REDIS_CLI AI.TENSORSET a FLOAT 2 VALUES 2 3

examples/models/load_torch_model.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ REDIS_CLI=../../deps/redis/src/redis-cli
22

33
echo "SET MODEL"
44
$REDIS_CLI -x AI.MODELSET foo TORCH CPU < pt-minimal.pt
5+
$REDIS_CLI AI.MODELGET foo
56

67
echo "SET TENSORS"
78
$REDIS_CLI AI.TENSORSET a FLOAT 2 VALUES 2 3

src/backends/tensorflow.c

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@ RAI_Tensor* RAI_TensorCreateFromTFTensor(TF_Tensor *tensor) {
7777
size_t ndims = TF_NumDims(tensor);
7878

7979
int64_t* shape = RedisModule_Calloc(ndims, sizeof(*shape));
80+
int64_t* strides = RedisModule_Calloc(ndims, sizeof(*strides));
8081
for (long i = 0 ; i < ndims ; ++i){
8182
shape[i] = TF_Dim(tensor, i);
83+
strides[i] = 1;
8284
}
8385

8486
// FIXME: In TF, RunSession allocates memory for output tensors
@@ -106,7 +108,7 @@ RAI_Tensor* RAI_TensorCreateFromTFTensor(TF_Tensor *tensor) {
106108
.ndim = ndims,
107109
.dtype = RAI_GetDLDataTypeFromTF(TF_TensorType(tensor)),
108110
.shape = shape,
109-
.strides = NULL,
111+
.strides = strides,
110112
.byte_offset = 0
111113
},
112114
.manager_ctx = NULL,
@@ -222,6 +224,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, RAI_Device device,
222224
ret->model = model;
223225
ret->session = session;
224226
ret->backend = backend;
227+
ret->device = device;
225228
ret->inputs = inputs_;
226229
ret->outputs = outputs_;
227230
ret->refCount = 1;
@@ -282,7 +285,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx* mctx, RAI_Error *error) {
282285
port.oper = TF_GraphOperationByName(mctx->model->model, mctx->inputs[i].name);
283286
port.index = 0;
284287
if(port.oper == NULL){
285-
return 0;
288+
return 1;
286289
}
287290
inputs[i] = port;
288291
}
@@ -292,7 +295,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx* mctx, RAI_Error *error) {
292295
port.oper = TF_GraphOperationByName(mctx->model->model, mctx->outputs[i].name);
293296
port.index = 0;
294297
if(port.oper == NULL){
295-
return 0;
298+
return 1;
296299
}
297300
outputs[i] = port;
298301
}
@@ -307,7 +310,7 @@ int RAI_ModelRunTF(RAI_ModelRunCtx* mctx, RAI_Error *error) {
307310
if (TF_GetCode(status) != TF_OK) {
308311
RAI_SetError(error, RAI_EMODELRUN, RedisModule_Strdup(TF_Message(status)));
309312
TF_DeleteStatus(status);
310-
return 0;
313+
return 1;
311314
}
312315

313316
for(size_t i = 0 ; i < array_len(mctx->outputs) ; ++i) {
@@ -317,5 +320,28 @@ int RAI_ModelRunTF(RAI_ModelRunCtx* mctx, RAI_Error *error) {
317320

318321
TF_DeleteStatus(status);
319322

320-
return 1;
323+
return 0;
324+
}
325+
326+
int RAI_ModelSerializeTF(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error) {
327+
TF_Buffer *tf_buffer = TF_NewBuffer();
328+
TF_Status *status = TF_NewStatus();
329+
330+
TF_GraphToGraphDef(model->model, tf_buffer, status);
331+
332+
if (TF_GetCode(status) != TF_OK) {
333+
RAI_SetError(error, RAI_EMODELSERIALIZE, "Error serializing TF model");
334+
TF_DeleteBuffer(tf_buffer);
335+
TF_DeleteStatus(status);
336+
return 1;
337+
}
338+
339+
*buffer = RedisModule_Alloc(tf_buffer->length);
340+
memcpy(*buffer, tf_buffer->data, tf_buffer->length);
341+
*len = tf_buffer->length;
342+
343+
TF_DeleteBuffer(tf_buffer);
344+
TF_DeleteStatus(status);
345+
346+
return 0;
321347
}

src/backends/tensorflow.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, RAI_Device device,
1818
const char *modeldef, size_t modellen,
1919
RAI_Error *error);
2020

21-
void RAI_ModelFreeTF(RAI_Model* model, RAI_Error *error);
21+
void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error);
2222

23-
int RAI_ModelRunTF(RAI_ModelRunCtx* mctx, RAI_Error *error);
23+
int RAI_ModelRunTF(RAI_ModelRunCtx *mctx, RAI_Error *error);
24+
25+
int RAI_ModelSerializeTF(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error);
2426

2527
#endif /* SRC_BACKENDS_TENSORFLOW_H_ */

src/backends/torch.c

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, RAI_Device device,
99
const char *modeldef, size_t modellen,
10-
RAI_Error *err) {
10+
RAI_Error *error) {
1111
DLDeviceType dl_device;
1212
switch (device) {
1313
case RAI_DEVICE_CPU:
@@ -17,55 +17,59 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, RAI_Device device,
1717
dl_device = kDLGPU;
1818
break;
1919
default:
20-
RAI_SetError(err, RAI_EMODELCONFIGURE, "Error configuring model: unsupported device\n");
20+
RAI_SetError(error, RAI_EMODELCONFIGURE, "Error configuring model: unsupported device\n");
2121
return NULL;
2222
}
2323

24-
char* err_descr = NULL;
25-
void* model = torchLoadModel(modeldef, modellen, dl_device, &err_descr);
24+
char* error_descr = NULL;
25+
void* model = torchLoadModel(modeldef, modellen, dl_device, &error_descr);
2626

2727
if (model == NULL) {
28-
RAI_SetError(err, RAI_EMODELCREATE, err_descr);
29-
free(err_descr);
28+
RAI_SetError(error, RAI_EMODELCREATE, error_descr);
29+
free(error_descr);
3030
return NULL;
3131
}
3232

3333
RAI_Model* ret = RedisModule_Calloc(1, sizeof(*ret));
3434
ret->model = model;
3535
ret->session = NULL;
3636
ret->backend = backend;
37+
ret->device = device;
3738
ret->inputs = NULL;
3839
ret->outputs = NULL;
3940
ret->refCount = 1;
4041

4142
return ret;
4243
}
4344

44-
void RAI_ModelFreeTorch(RAI_Model* model, RAI_Error *err) {
45+
void RAI_ModelFreeTorch(RAI_Model* model, RAI_Error *error) {
4546
torchDeallocContext(model->model);
4647
}
4748

48-
int RAI_ModelRunTorch(RAI_ModelRunCtx* mctx, RAI_Error *err) {
49+
int RAI_ModelRunTorch(RAI_ModelRunCtx* mctx, RAI_Error *error) {
50+
51+
size_t ninputs = array_len(mctx->inputs);
52+
size_t noutputs = array_len(mctx->outputs);
4953

50-
DLManagedTensor** inputs = RedisModule_Calloc(1, sizeof(*inputs));
51-
DLManagedTensor** outputs = RedisModule_Calloc(1, sizeof(*outputs));
54+
DLManagedTensor** inputs = RedisModule_Calloc(ninputs, sizeof(*inputs));
55+
DLManagedTensor** outputs = RedisModule_Calloc(noutputs, sizeof(*outputs));
5256

53-
for (size_t i=0 ; i<array_len(mctx->inputs); ++i) {
57+
for (size_t i=0 ; i<ninputs; ++i) {
5458
inputs[i] = &mctx->inputs[i].tensor->tensor;
5559
}
5660

57-
for (size_t i=0 ; i<array_len(mctx->outputs); ++i) {
61+
for (size_t i=0 ; i<noutputs; ++i) {
5862
outputs[i] = &mctx->outputs[i].tensor->tensor;
5963
}
6064

61-
char* err_descr = NULL;
65+
char* error_descr = NULL;
6266
torchRunModel(mctx->model->model,
63-
array_len(mctx->inputs), inputs,
64-
array_len(mctx->outputs), outputs, &err_descr);
67+
ninputs, inputs,
68+
noutputs, outputs, &error_descr);
6569

66-
if (err_descr != NULL) {
67-
RAI_SetError(err, RAI_EMODELRUN, err_descr);
68-
free(err_descr);
70+
if (error_descr != NULL) {
71+
RAI_SetError(error, RAI_EMODELRUN, error_descr);
72+
free(error_descr);
6973
return 1;
7074
}
7175

@@ -77,11 +81,20 @@ int RAI_ModelRunTorch(RAI_ModelRunCtx* mctx, RAI_Error *err) {
7781
return 0;
7882
}
7983

80-
RAI_Script *RAI_ScriptCreateTorch(RAI_Device device, const char *scriptdef, RAI_Error *err) {
81-
size_t scriptlen = strlen(scriptdef);
82-
char* scriptdef_ = RedisModule_Calloc(scriptlen, sizeof(char));
83-
memcpy(scriptdef_, scriptdef, scriptlen);
84+
int RAI_ModelSerializeTorch(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error) {
85+
char* error_descr = NULL;
86+
torchSerializeModel(model->model, buffer, len, &error_descr);
87+
88+
if (*buffer == NULL) {
89+
RAI_SetError(error, RAI_EMODELSERIALIZE, error_descr);
90+
free(error_descr);
91+
return 1;
92+
}
93+
94+
return 0;
95+
}
8496

97+
RAI_Script *RAI_ScriptCreateTorch(RAI_Device device, const char *scriptdef, RAI_Error *error) {
8598
DLDeviceType dl_device;
8699
switch (device) {
87100
case RAI_DEVICE_CPU:
@@ -91,36 +104,36 @@ RAI_Script *RAI_ScriptCreateTorch(RAI_Device device, const char *scriptdef, RAI_
91104
dl_device = kDLGPU;
92105
break;
93106
default:
94-
RAI_SetError(err, RAI_ESCRIPTCONFIGURE, "Error configuring script: unsupported device\n");
107+
RAI_SetError(error, RAI_ESCRIPTCONFIGURE, "Error configuring script: unsupported device\n");
95108
break;
96109
}
97110

98-
char* err_descr = NULL;
99-
void* script = torchCompileScript(scriptdef, dl_device, &err_descr);
111+
char* error_descr = NULL;
112+
void* script = torchCompileScript(scriptdef, dl_device, &error_descr);
100113

101114
if (script == NULL) {
102-
RAI_SetError(err, RAI_ESCRIPTCREATE, err_descr);
103-
free(err_descr);
115+
RAI_SetError(error, RAI_ESCRIPTCREATE, error_descr);
116+
free(error_descr);
104117
return NULL;
105118
}
106119

107120
RAI_Script* ret = RedisModule_Calloc(1, sizeof(*ret));
108121
ret->script = script;
109-
ret->scriptdef = scriptdef_;
122+
ret->scriptdef = RedisModule_Strdup(scriptdef);
110123
ret->device = device;
111124
ret->refCount = 1;
112125

113126
return ret;
114127
}
115128

116-
void RAI_ScriptFreeTorch(RAI_Script* script, RAI_Error* err) {
129+
void RAI_ScriptFreeTorch(RAI_Script* script, RAI_Error* error) {
117130

118131
torchDeallocContext(script->script);
119132
RedisModule_Free(script->scriptdef);
120133
RedisModule_Free(script);
121134
}
122135

123-
int RAI_ScriptRunTorch(RAI_ScriptRunCtx* sctx, RAI_Error* err) {
136+
int RAI_ScriptRunTorch(RAI_ScriptRunCtx* sctx, RAI_Error* error) {
124137

125138
long nInputs = array_len(sctx->inputs);
126139
long nOutputs = array_len(sctx->outputs);
@@ -136,13 +149,13 @@ int RAI_ScriptRunTorch(RAI_ScriptRunCtx* sctx, RAI_Error* err) {
136149
outputs[i] = &sctx->outputs[i].tensor->tensor;
137150
}
138151

139-
char* err_descr = NULL;
140-
torchRunScript(sctx->script->script, sctx->fnname, nInputs, inputs, nOutputs, outputs, &err_descr);
152+
char* error_descr = NULL;
153+
torchRunScript(sctx->script->script, sctx->fnname, nInputs, inputs, nOutputs, outputs, &error_descr);
141154

142-
if (err_descr) {
155+
if (error_descr) {
143156
printf("F\n");
144-
RAI_SetError(err, RAI_ESCRIPTRUN, err_descr);
145-
free(err_descr);
157+
RAI_SetError(error, RAI_ESCRIPTRUN, error_descr);
158+
free(error_descr);
146159
return 1;
147160
}
148161

src/backends/torch.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@ RAI_Model *RAI_ModelCreateTorch(RAI_Backend backend, RAI_Device device,
1313
const char *modeldef, size_t modellen,
1414
RAI_Error *err);
1515

16-
void RAI_ModelFreeTorch(RAI_Model* model, RAI_Error* err);
16+
void RAI_ModelFreeTorch(RAI_Model *model, RAI_Error *error);
1717

18-
int RAI_ModelRunTorch(RAI_ModelRunCtx* mctx, RAI_Error* err);
18+
int RAI_ModelRunTorch(RAI_ModelRunCtx *mctx, RAI_Error *error);
19+
20+
int RAI_ModelSerializeTorch(RAI_Model *model, char **buffer, size_t *len, RAI_Error *error);
1921

2022
RAI_Script *RAI_ScriptCreateTorch(RAI_Device device, const char *scriptdef,
21-
RAI_Error *err);
23+
RAI_Error *error);
2224

23-
void RAI_ScriptFreeTorch(RAI_Script* script, RAI_Error* err);
25+
void RAI_ScriptFreeTorch(RAI_Script *script, RAI_Error *error);
2426

25-
int RAI_ScriptRunTorch(RAI_ScriptRunCtx* sctx, RAI_Error* err);
27+
int RAI_ScriptRunTorch(RAI_ScriptRunCtx *sctx, RAI_Error *error);
2628

2729
#endif /* SRC_BACKENDS_TORCH_H_ */

src/config.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@ typedef enum {
1717
RAI_DEVICE_GPU,
1818
} RAI_Device;
1919

20+
#define RAI_ENC_VER 100
21+
2022
//#define RAI_COPY_RUN_INPUT
2123
#define RAI_COPY_RUN_OUTPUT
2224
#define RAI_PRINT_BACKEND_ERRORS
2325

26+
// #define RAI_OVERRIDE_AOF_CHECK
27+
2428
#endif /* SRC_CONFIG_H_ */

src/err.c

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,6 @@ void RAI_SetError(RAI_Error *err, RAI_ErrorCode code, const char *detail) {
3232
err->detail_oneline = RAI_Chomp(err->detail);
3333
}
3434

35-
RAI_Error RAI_InitError() {
36-
RAI_Error err = {.code = RAI_OK, .detail = NULL, .detail_oneline = NULL};
37-
return err;
38-
}
39-
4035
void RAI_ClearError(RAI_Error *err) {
4136
if (err->detail) {
4237
RedisModule_Free(err->detail);

src/err.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ typedef enum {
77
RAI_EMODELCONFIGURE,
88
RAI_EMODELCREATE,
99
RAI_EMODELRUN,
10+
RAI_EMODELSERIALIZE,
1011
RAI_EMODELFREE,
1112
RAI_ESCRIPTIMPORT,
1213
RAI_ESCRIPTCONFIGURE,
@@ -22,8 +23,6 @@ typedef struct RAI_Error {
2223
char* detail_oneline;
2324
} RAI_Error;
2425

25-
RAI_Error RAI_InitError();
26-
2726
void RAI_SetError(RAI_Error *err, RAI_ErrorCode code, const char *detail);
2827

2928
void RAI_ClearError(RAI_Error *err);

0 commit comments

Comments
 (0)