diff --git a/include/cppflow/model.h b/include/cppflow/model.h index d3a8d3e..95dc9a7 100644 --- a/include/cppflow/model.h +++ b/include/cppflow/model.h @@ -28,6 +28,7 @@ namespace cppflow { explicit model(const std::string& filename, const TYPE type=TYPE::SAVED_MODEL); std::vector get_operations() const; + std::string get_tensor_mapping(const std::string& operation) const; std::vector get_operation_shape(const std::string& operation) const; std::vector operator()(std::vector> inputs, std::vector outputs); @@ -45,6 +46,8 @@ namespace cppflow { std::shared_ptr status; std::shared_ptr graph; std::shared_ptr session; + + std::vector tensor_mapping; }; } @@ -72,6 +75,23 @@ namespace cppflow { this->session = {TF_LoadSessionFromSavedModel(session_options.get(), run_options.get(), filename.c_str(), &tag, tag_len, this->graph.get(), meta_graph.get(), this->status.get()), session_deleter}; + + std::string_view sv((char*)meta_graph.get()->data, meta_graph.get()->length); + + std::string output_name = "StatefulPartitionedCall:0"; + std::size_t tensor_mapping_end = sv.npos; + while ((tensor_mapping_end = sv.find(output_name)) != sv.npos) + { + // found at position tensor_mapping_end, now reverse search for beginning + auto tensor_mapping_start = sv.rfind("\x0A", tensor_mapping_end - 4); + if (tensor_mapping_start != sv.npos) + { + tensor_mapping.emplace_back(sv.substr(tensor_mapping_start + 2 /* bytes used for identifying start */, + tensor_mapping_end - tensor_mapping_start - 4 /* byte spacing between the tensor name and output mapping */ - 2)); + + output_name.back()++; // Increment output index; + } + } } else if (type == TYPE::FROZEN_GRAPH) { this->session = {TF_NewSession(this->graph.get(), session_options.get(), this->status.get()), session_deleter}; @@ -106,11 +126,29 @@ namespace cppflow { return result; } + inline std::tuple parse_name(const std::string& name) { + auto idx = name.find(':'); + return (idx == std::string::npos ? std::make_tuple(name, 0) : std::make_tuple(name.substr(0, idx), std::stoi(name.substr(idx + 1)))); + } + + inline std::string model::get_tensor_mapping(const std::string& operation) const { + std::string output_name; + auto it = std::find(tensor_mapping.begin(), tensor_mapping.end(), operation); + if (it != tensor_mapping.end()) + { + output_name = "StatefulPartitionedCall:0"; + output_name.back() += std::distance(tensor_mapping.begin(), it); + } + + return output_name; + } + inline std::vector model::get_operation_shape(const std::string& operation) const { // Get operation by the name TF_Output out_op; - out_op.oper = TF_GraphOperationByName(this->graph.get(), operation.c_str()); - out_op.index = 0; + const auto [op_name, op_idx] = parse_name(operation); + out_op.oper = TF_GraphOperationByName(this->graph.get(), op_name.c_str()); + out_op.index = op_idx; std::vector shape; @@ -143,11 +181,6 @@ namespace cppflow { return shape; } - inline std::tuple parse_name(const std::string& name) { - auto idx = name.find(':'); - return (idx == std::string::npos ? std::make_tuple(name, 0) : std::make_tuple(name.substr(0, idx), std::stoi(name.substr(idx + 1)))); - } - inline std::vector model::operator()(std::vector> inputs, std::vector outputs) { std::vector inp_ops(inputs.size());