@@ -25,29 +25,40 @@ namespace cppflow {
25
25
std::vector<tensor> operator ()(std::vector<std::tuple<std::string, tensor>> inputs, std::vector<std::string> outputs);
26
26
tensor operator ()(const tensor& input);
27
27
28
+ ~model () = default ;
29
+ model (const model &model) = default ;
30
+ model (model &&model) = default ;
31
+ model &operator =(const model &other) = default ;
32
+ model &operator =(model &&other) = default ;
33
+
28
34
private:
29
35
30
- TF_Graph* graph;
31
- TF_Session* session;
36
+ std::shared_ptr< TF_Graph> graph;
37
+ std::shared_ptr< TF_Session> session;
32
38
};
33
39
}
34
40
41
+
35
42
namespace cppflow {
36
43
37
44
model::model (const std::string &filename) {
38
- this ->graph = TF_NewGraph ();
45
+ this ->graph = { TF_NewGraph (), TF_DeleteGraph} ;
39
46
40
47
// Create the session.
41
- TF_SessionOptions* session_options = TF_NewSessionOptions ();
42
- TF_Buffer* run_options = TF_NewBufferFromString (" " , 0 );
43
- TF_Buffer* meta_graph = TF_NewBuffer ();
48
+ std::unique_ptr<TF_SessionOptions, decltype (&TF_DeleteSessionOptions)> session_options = {TF_NewSessionOptions (), TF_DeleteSessionOptions};
49
+ std::unique_ptr<TF_Buffer, decltype (&TF_DeleteBuffer)> run_options = {TF_NewBufferFromString (" " , 0 ), TF_DeleteBuffer};
50
+ std::unique_ptr<TF_Buffer, decltype (&TF_DeleteBuffer)> meta_graph = {TF_NewBuffer (), TF_DeleteBuffer};
51
+
52
+ auto session_deleter = [](TF_Session* sess) {
53
+ TF_DeleteSession (sess, context::get_status ());
54
+ status_check (context::get_status ());
55
+ };
44
56
45
57
int tag_len = 1 ;
46
58
const char * tag = " serve" ;
47
- this ->session = TF_LoadSessionFromSavedModel (session_options, run_options, filename.c_str (), &tag, tag_len, graph, meta_graph, context::get_status ());
48
- TF_DeleteSessionOptions (session_options);
49
- TF_DeleteBuffer (run_options);
50
- // TF_DeleteBuffer(meta_graph);
59
+ this ->session = {TF_LoadSessionFromSavedModel (session_options.get (), run_options.get (), filename.c_str (),
60
+ &tag, tag_len, this ->graph .get (), meta_graph.get (), context::get_status ()),
61
+ session_deleter};
51
62
52
63
status_check (context::get_status ());
53
64
}
@@ -58,7 +69,7 @@ namespace cppflow {
58
69
TF_Operation* oper;
59
70
60
71
// Iterate through the operations of a graph
61
- while ((oper = TF_GraphNextOperation (this ->graph , &pos)) != nullptr ) {
72
+ while ((oper = TF_GraphNextOperation (this ->graph . get () , &pos)) != nullptr ) {
62
73
result.emplace_back (TF_OperationName (oper));
63
74
}
64
75
return result;
@@ -77,7 +88,7 @@ namespace cppflow {
77
88
78
89
// Operations
79
90
const auto [op_name, op_idx] = parse_name (std::get<0 >(inputs[i]));
80
- inp_ops[i].oper = TF_GraphOperationByName (this ->graph , op_name.c_str ());
91
+ inp_ops[i].oper = TF_GraphOperationByName (this ->graph . get () , op_name.c_str ());
81
92
inp_ops[i].index = op_idx;
82
93
83
94
if (!inp_ops[i].oper )
@@ -94,15 +105,15 @@ namespace cppflow {
94
105
for (int i=0 ; i<outputs.size (); i++) {
95
106
96
107
const auto [op_name, op_idx] = parse_name (outputs[i]);
97
- out_ops[i].oper = TF_GraphOperationByName (this ->graph , op_name.c_str ());
108
+ out_ops[i].oper = TF_GraphOperationByName (this ->graph . get () , op_name.c_str ());
98
109
out_ops[i].index = op_idx;
99
110
100
111
if (!out_ops[i].oper )
101
112
throw std::runtime_error (" No operation named \" " + op_name + " \" exists" );
102
113
103
114
}
104
115
105
- TF_SessionRun (this ->session , NULL ,
116
+ TF_SessionRun (this ->session . get () , NULL ,
106
117
inp_ops.data (), inp_val.data (), inputs.size (),
107
118
out_ops.data (), out_val.get (), outputs.size (),
108
119
NULL , 0 ,NULL , context::get_status ());
0 commit comments