@@ -272,8 +272,6 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
272
272
options = NULL ;
273
273
TF_DeleteBuffer (tfbuffer );
274
274
tfbuffer = NULL ;
275
- TF_DeleteStatus (status );
276
- status = NULL ;
277
275
278
276
TF_Output tf_inputs [ninputs ];
279
277
TF_Output tf_outputs [noutputs ];
@@ -306,37 +304,37 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
306
304
noutputs , tf_outputs , // noutputs, outputs
307
305
outputs , // output_names,
308
306
NULL , // opts
309
- "" , // description
307
+ NULL , // description
310
308
status // status
311
309
);
312
- // TODO EAGER
313
- // check status and return error
310
+
311
+ if (TF_GetCode (status ) != TF_OK ) {
312
+ RAI_SetError (error , RAI_EMODELCONFIGURE , RedisModule_Strdup (TF_Message (status )));
313
+ goto cleanup ;
314
+ }
314
315
315
316
TFE_ContextOptions * context_opts = TFE_NewContextOptions ();
316
317
// TFE_ContextOptionsSetConfig(context_opts, proto, proto_len, status);
317
318
// TFE_ContextOptionsSetAsync(context_opts, 0);
318
- TFE_ContextOptionsSetDevicePlacementPolicy (context_opts , TFE_DEVICE_PLACEMENT_EXPLICIT );
319
+ // TFE_ContextOptionsSetDevicePlacementPolicy(context_opts, TFE_DEVICE_PLACEMENT_EXPLICIT);
319
320
320
321
TFE_Context * context = TFE_NewContext (context_opts , status );
321
- // TODO EAGER
322
- // check status and return error
322
+ if (TF_GetCode (status ) != TF_OK ) {
323
+ RAI_SetError (error , RAI_EMODELCONFIGURE , RedisModule_Strdup (TF_Message (status )));
324
+ goto cleanup ;
325
+ }
323
326
324
327
TFE_ContextAddFunction (context , function , status );
325
- // TODO EAGER
326
- // check status and return error
328
+ if (TF_GetCode (status ) != TF_OK ) {
329
+ RAI_SetError (error , RAI_EMODELCONFIGURE , RedisModule_Strdup (TF_Message (status )));
330
+ goto cleanup ;
331
+ }
327
332
328
333
TFE_DeleteContextOptions (context_opts );
329
- TFE_DeleteContext (context );
330
334
331
- #if 0
332
- TF_Status * optionsStatus = NULL ;
333
- TF_SessionOptions * sessionOptions = NULL ;
334
- TF_Status * sessionStatus = NULL ;
335
- TF_Session * session = NULL ;
336
-
337
- optionsStatus = TF_NewStatus ();
338
- sessionOptions = TF_NewSessionOptions ();
335
+ TF_DeleteStatus (status );
339
336
337
+ #if 0
340
338
// For setting config options in session from the C API see:
341
339
// https://github.com/tensorflow/tensorflow/issues/13853
342
340
// import tensorflow as tf
@@ -391,16 +389,6 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
391
389
}
392
390
}
393
391
394
- if (TF_GetCode (optionsStatus ) != TF_OK ) {
395
- RAI_SetError (error , RAI_EMODELCONFIGURE , RedisModule_Strdup (TF_Message (optionsStatus )));
396
- goto cleanup ;
397
- }
398
- TF_DeleteStatus (optionsStatus );
399
- optionsStatus = NULL ;
400
-
401
- sessionStatus = TF_NewStatus ();
402
- session = TF_NewSession (graph , sessionOptions , sessionStatus );
403
-
404
392
TF_Status * deviceListStatus = TF_NewStatus ();
405
393
TF_DeviceList * deviceList = TF_SessionListDevices (session , deviceListStatus );
406
394
const int num_devices = TF_DeviceListCount (deviceList );
@@ -426,9 +414,6 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
426
414
RAI_SetError (error , RAI_EMODELCREATE , RedisModule_Strdup (TF_Message (status )));
427
415
goto cleanup ;
428
416
}
429
-
430
- TF_DeleteSessionOptions (sessionOptions );
431
- TF_DeleteStatus (sessionStatus );
432
417
#endif
433
418
434
419
char * * inputs_ = array_new (char * , ninputs );
@@ -468,33 +453,13 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char *devicestr, RAI_Mod
468
453
TF_DeleteBuffer (tfbuffer );
469
454
if (status )
470
455
TF_DeleteStatus (status );
471
- // if (sessionOptions)
472
- // TF_DeleteSessionOptions(sessionOptions);
473
- // if (sessionStatus)
474
- // TF_DeleteStatus(sessionStatus);
475
456
return NULL ;
476
457
}
477
458
478
459
void RAI_ModelFreeTF (RAI_Model * model , RAI_Error * error ) {
479
- TF_Status * status = TF_NewStatus ();
480
- #if 0
481
- TF_CloseSession (model -> session , status );
482
-
483
- if (TF_GetCode (status ) != TF_OK ) {
484
- RAI_SetError (error , RAI_EMODELFREE , RedisModule_Strdup (TF_Message (status )));
485
- return ;
486
- }
487
-
488
- TF_DeleteSession (model -> session , status );
489
- #endif
490
460
TFE_DeleteContext (model -> session );
491
461
model -> session = NULL ;
492
462
493
- // if (TF_GetCode(status) != TF_OK) {
494
- // RAI_SetError(error, RAI_EMODELFREE, RedisModule_Strdup(TF_Message(status)));
495
- // return;
496
- // }
497
-
498
463
TF_DeleteGraph (model -> model );
499
464
model -> model = NULL ;
500
465
@@ -519,10 +484,6 @@ void RAI_ModelFreeTF(RAI_Model *model, RAI_Error *error) {
519
484
if (model -> data ) {
520
485
RedisModule_Free (model -> data );
521
486
}
522
-
523
- #if 0
524
- TF_DeleteStatus (status );
525
- #endif
526
487
}
527
488
528
489
int RAI_ModelRunTF (RAI_ModelRunCtx * * mctxs , RAI_Error * error ) {
@@ -563,24 +524,44 @@ int RAI_ModelRunTF(RAI_ModelRunCtx **mctxs, RAI_Error *error) {
563
524
}
564
525
inputTensorsValues [i ] = RAI_TFTensorFromTensors (batched_input_tensors , nbatches );
565
526
inputTensorsHandles [i ] = TFE_NewTensorHandle (inputTensorsValues [i ], status );
566
- // TODO EAGER
567
- // check status and return error
527
+ if (TF_GetCode (status ) != TF_OK ) {
528
+ char * errorMessage = RedisModule_Strdup (TF_Message (status ));
529
+ RAI_SetError (error , RAI_EMODELRUN , errorMessage );
530
+ TF_DeleteStatus (status );
531
+ RedisModule_Free (errorMessage );
532
+ return 1 ;
533
+ }
568
534
}
569
535
570
536
TFE_Op * fn_op = TFE_NewOp (mctxs [0 ]-> model -> session , RAI_TF_FN_NAME , status );
571
- // TODO EAGER
572
- // check status and return error
537
+ if (TF_GetCode (status ) != TF_OK ) {
538
+ char * errorMessage = RedisModule_Strdup (TF_Message (status ));
539
+ RAI_SetError (error , RAI_EMODELRUN , errorMessage );
540
+ TF_DeleteStatus (status );
541
+ RedisModule_Free (errorMessage );
542
+ return 1 ;
543
+ }
573
544
574
545
TFE_OpAddInputList (fn_op , inputTensorsHandles , ninputs , status );
575
- // TODO EAGER
576
- // check status and return error
546
+ if (TF_GetCode (status ) != TF_OK ) {
547
+ char * errorMessage = RedisModule_Strdup (TF_Message (status ));
548
+ RAI_SetError (error , RAI_EMODELRUN , errorMessage );
549
+ TF_DeleteStatus (status );
550
+ RedisModule_Free (errorMessage );
551
+ return 1 ;
552
+ }
577
553
578
554
// TODO EAGER: send tensors to device (as long as we keep device allocation EXPLICIT)
579
555
580
556
int noutputs_ = noutputs ;
581
557
TFE_Execute (fn_op , outputTensorsHandles , & noutputs_ , status );
582
- // TODO EAGER
583
- // check status and return error
558
+ if (TF_GetCode (status ) != TF_OK ) {
559
+ char * errorMessage = RedisModule_Strdup (TF_Message (status ));
560
+ RAI_SetError (error , RAI_EMODELRUN , errorMessage );
561
+ TF_DeleteStatus (status );
562
+ RedisModule_Free (errorMessage );
563
+ return 1 ;
564
+ }
584
565
585
566
for (size_t i = 0 ; i < ninputs ; ++ i ) {
586
567
TFE_DeleteTensorHandle (inputTensorsHandles [i ]);
0 commit comments