Skip to content

Commit 8db5b32

Browse files
committed
Wrap model attributes in smart ptrs to avoid memory leaks
Remove memory leaks from #57 and #61 by using shared ptrs on: model::graph model::session And unique ptrs on vars from model constructor: session_options run_options meta_graph
1 parent 306a5c0 commit 8db5b32

File tree

1 file changed

+25
-14
lines changed

1 file changed

+25
-14
lines changed

include/cppflow/model.h

+25-14
Original file line numberDiff line numberDiff line change
@@ -25,29 +25,40 @@ namespace cppflow {
2525
std::vector<tensor> operator()(std::vector<std::tuple<std::string, tensor>> inputs, std::vector<std::string> outputs);
2626
tensor operator()(const tensor& input);
2727

28+
~model() = default;
29+
model(const model &model) = default;
30+
model(model &&model) = default;
31+
model &operator=(const model &other) = default;
32+
model &operator=(model &&other) = default;
33+
2834
private:
2935

30-
TF_Graph* graph;
31-
TF_Session* session;
36+
std::shared_ptr<TF_Graph> graph;
37+
std::shared_ptr<TF_Session> session;
3238
};
3339
}
3440

41+
3542
namespace cppflow {
3643

3744
model::model(const std::string &filename) {
38-
this->graph = TF_NewGraph();
45+
this->graph = {TF_NewGraph(), TF_DeleteGraph};
3946

4047
// Create the session.
41-
TF_SessionOptions* session_options = TF_NewSessionOptions();
42-
TF_Buffer* run_options = TF_NewBufferFromString("", 0);
43-
TF_Buffer* meta_graph = TF_NewBuffer();
48+
std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)> session_options = {TF_NewSessionOptions(), TF_DeleteSessionOptions};
49+
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> run_options = {TF_NewBufferFromString("", 0), TF_DeleteBuffer};
50+
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> meta_graph = {TF_NewBuffer(), TF_DeleteBuffer};
51+
52+
auto session_deleter = [](TF_Session* sess) {
53+
TF_DeleteSession(sess, context::get_status());
54+
status_check(context::get_status());
55+
};
4456

4557
int tag_len = 1;
4658
const char* tag = "serve";
47-
this->session = TF_LoadSessionFromSavedModel(session_options, run_options, filename.c_str(), &tag, tag_len, graph, meta_graph, context::get_status());
48-
TF_DeleteSessionOptions(session_options);
49-
TF_DeleteBuffer(run_options);
50-
//TF_DeleteBuffer(meta_graph);
59+
this->session = {TF_LoadSessionFromSavedModel(session_options.get(), run_options.get(), filename.c_str(),
60+
&tag, tag_len, this->graph.get(), meta_graph.get(), context::get_status()),
61+
session_deleter};
5162

5263
status_check(context::get_status());
5364
}
@@ -58,7 +69,7 @@ namespace cppflow {
5869
TF_Operation* oper;
5970

6071
// Iterate through the operations of a graph
61-
while ((oper = TF_GraphNextOperation(this->graph, &pos)) != nullptr) {
72+
while ((oper = TF_GraphNextOperation(this->graph.get(), &pos)) != nullptr) {
6273
result.emplace_back(TF_OperationName(oper));
6374
}
6475
return result;
@@ -77,7 +88,7 @@ namespace cppflow {
7788

7889
// Operations
7990
const auto[op_name, op_idx] = parse_name(std::get<0>(inputs[i]));
80-
inp_ops[i].oper = TF_GraphOperationByName(this->graph, op_name.c_str());
91+
inp_ops[i].oper = TF_GraphOperationByName(this->graph.get(), op_name.c_str());
8192
inp_ops[i].index = op_idx;
8293

8394
if (!inp_ops[i].oper)
@@ -94,15 +105,15 @@ namespace cppflow {
94105
for (int i=0; i<outputs.size(); i++) {
95106

96107
const auto[op_name, op_idx] = parse_name(outputs[i]);
97-
out_ops[i].oper = TF_GraphOperationByName(this->graph, op_name.c_str());
108+
out_ops[i].oper = TF_GraphOperationByName(this->graph.get(), op_name.c_str());
98109
out_ops[i].index = op_idx;
99110

100111
if (!out_ops[i].oper)
101112
throw std::runtime_error("No operation named \"" + op_name + "\" exists");
102113

103114
}
104115

105-
TF_SessionRun(this->session, NULL,
116+
TF_SessionRun(this->session.get(), NULL,
106117
inp_ops.data(), inp_val.data(), inputs.size(),
107118
out_ops.data(), out_val.get(), outputs.size(),
108119
NULL, 0,NULL , context::get_status());

0 commit comments

Comments
 (0)