diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..2155b30 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,18 @@ +cmake_minimum_required(VERSION 3.30) + +project(gdino.sam.lama) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_CXX_STANDARD 20) +set(EXECUTABLE_OUTPUT_PATH ${CMAKE_BINARY_DIR}/bin) + +include_directories(/opt/homebrew/include /usr/local/include) +link_directories(/opt/homebrew/lib /usr/local/lib) + +find_package(OpenCV REQUIRED) +include_directories(${OpenCV_INCLUDE_DIRS}) + +find_package(Boost COMPONENTS program_options REQUIRED) +include_directories(${Boost_INCLUDE_DIR}) + +add_subdirectory(src) \ No newline at end of file diff --git a/images/1.jpg b/images/1.jpg new file mode 100644 index 0000000..0df3133 Binary files /dev/null and b/images/1.jpg differ diff --git a/images/cat_dog.jpeg b/images/cat_dog.jpeg new file mode 100644 index 0000000..8b30a3c Binary files /dev/null and b/images/cat_dog.jpeg differ diff --git a/images/dog.jpg b/images/dog.jpg new file mode 100644 index 0000000..926638c Binary files /dev/null and b/images/dog.jpg differ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000..db58372 --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,4 @@ +set(FILES main.cpp gdino.cpp sam2.cpp lama.cpp) + +add_executable(main ${FILES}) +target_link_libraries(main ${OpenCV_LIBS} ${Boost_LIBRARIES} onnx onnxruntime) \ No newline at end of file diff --git a/src/gdino.cpp b/src/gdino.cpp new file mode 100644 index 0000000..bf2d76b --- /dev/null +++ b/src/gdino.cpp @@ -0,0 +1,402 @@ +#include "gdino.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +CGDINOModel::CGDINOModel(const char * encode, + const char * decode, + const char * vocab) { + static Ort::Env env(ORT_LOGGING_LEVEL_ERROR, "GDINO"); + + std::ifstream f(vocab); + if (f.is_open()) { + int index = 0; + for (std::string line; std::getline(f, line); ) { + token_id_table[line] = index; + id_token_table[index] = line; + index++; + } + } + + special_token_ids = convert_tokens_to_ids({"[CLS]", "[SEP]", ".", "?"}); + + encoder = std::make_unique(env, + encode, + Ort::SessionOptions(nullptr)); + + decoder = std::make_unique(env, + decode, + Ort::SessionOptions(nullptr)); +} + +CGDINOModel::~CGDINOModel() { + encoder->release(); + decoder->release(); +} + +std::vector CGDINOModel::convert_tokens_to_ids(const std::vector& tokens) { + std::vector input_ids; + for (auto& token: tokens) { + std::string sub_token = token, proc_token = sub_token; + while (sub_token.size() > 0) { + if (token_id_table.contains(sub_token)) { + input_ids.emplace_back(token_id_table[sub_token]); + if (sub_token == proc_token) break; + + sub_token = "##" + proc_token.substr(sub_token.size()); + proc_token = sub_token; + continue; + } + sub_token = proc_token.substr(0, sub_token.size() - 1); + } + } + + return input_ids; +} + +std::string CGDINOModel::convert_id_to_token(int id) { + if (id > id_token_table.size()) return ""; + return id_token_table[id]; +} + +std::map> CGDINOModel::tokenize(const std::string& text) { + std::vector tokens; + std::string token; + + for (auto c: text) { + if (c == ' ') { + if (token.size() > 0) { + tokens.emplace_back(token); + } + token = ""; + continue; + } + + //punctuation + if ((c >= 33 && c <= 47) || + (c >= 58 && c <= 64) || + (c >= 91 && c <= 96) || + (c >= 123 && c <= 126)) { + tokens.emplace_back(token); + + token = c; + tokens.emplace_back(token); + + token = ""; + continue; + } + + token += c; + } + +// for (auto& t: tokens) std::cout << t << std::endl; + tokens.insert(tokens.begin(), "[CLS]"); //101 + tokens.emplace_back("[SEP]"); //102 + + std::vector input_ids = convert_tokens_to_ids(tokens); + + int input_ids_len = input_ids.size(); + std::vector attention_mask(input_ids_len); + std::vector token_type_ids(input_ids_len); + std::vector position_ids(input_ids_len); + + std::vector special_token_indices; + for (int i=0, sep=0, pos=0; + i text_self_attention_masks(input_ids_len * input_ids_len, 0); + int prev_special_token_index = special_token_indices[0]; + for (auto& special_token_index: special_token_indices) { + if (special_token_index == 0 || special_token_index == (input_ids_len - 1)) { + text_self_attention_masks[special_token_index * input_ids_len + special_token_index] = 1; + continue; + } + + for (int i=(prev_special_token_index + 1); i<=special_token_index; ++i) { + for (int j=(prev_special_token_index + 1); j<=special_token_index; ++j) { + text_self_attention_masks[i * input_ids_len + j] = 1; + } + } + prev_special_token_index = special_token_index; + } + + std::map> tokenized; + + tokenized["input_ids"] = input_ids; + tokenized["attention_mask"] = attention_mask; + tokenized["token_type_ids"] = token_type_ids; + tokenized["position_ids"] = position_ids; + tokenized["text_self_attention_masks"] = text_self_attention_masks; + + return tokenized; +} + +std::string CGDINOModel::pre_process_caption(const std::string& caption) { + int begin = caption.find_first_not_of(' '); + int end = caption.find_last_not_of(' '); + std::string proc_caption = caption.substr(begin, end - begin + 1); + std::transform(proc_caption.cbegin(), + proc_caption.cend(), + proc_caption.begin(), + [](char c) { return std::tolower(c); }); + if (!proc_caption.ends_with('.')) proc_caption += '.'; + + return proc_caption; +} + +cv::Mat CGDINOModel::pre_process_image(const cv::Mat& image) { + cv::Mat proc_image; + cv::resize(image, proc_image, gdino_image_size); + cv::cvtColor(proc_image, proc_image, cv::COLOR_BGR2RGB); + proc_image.convertTo(proc_image, CV_32F, 1.0 / 255.0); + proc_image = (proc_image - cv::Scalar(0.485, 0.456, 0.406)) / + cv::Scalar(0.229, 0.224, 0.225); + return cv::dnn::blobFromImage(proc_image); +} + +int CGDINOModel::encode(const std::string& caption) { + std::string proc_caption = pre_process_caption(caption); + tokenized = tokenize(proc_caption); + + int input_ids_len = tokenized["input_ids"].size(); + +#if 0 + std::cout << "input_ids: " << std::endl; + for (auto& input_id: tokenized["input_ids"]) { + std::cout << input_id << " "; + } + std::cout << std::endl; + + std::cout << "attention_mask: " << std::endl; + for (auto& attention_mask: tokenized["attention_mask"]) { + std::cout << attention_mask << " "; + } + std::cout << std::endl; + + std::cout << "token_type_ids: " << std::endl; + for (auto& token_type_id: tokenized["token_type_ids"]) { + std::cout << token_type_id << " "; + } + std::cout << std::endl; + + std::cout << "position_ids: " << std::endl; + for (auto& position_id: tokenized["position_ids"]) { + std::cout << position_id << " "; + } + std::cout << std::endl; + + std::cout << "text_self_attention_masks: " << std::endl; + for (int i=0; i shape_input_ids = { 1, input_ids_len }; + Ort::Value tensor_input_ids = Ort::Value::CreateTensor(memory_info, + tokenized["input_ids"].data(), + tokenized["input_ids"].size(), + shape_input_ids.data(), + shape_input_ids.size()); + + const std::vector shape_token_type_ids = { 1, input_ids_len }; + Ort::Value tensor_token_type_ids = Ort::Value::CreateTensor(memory_info, + tokenized["token_type_ids"].data(), + tokenized["token_type_ids"].size(), + shape_token_type_ids.data(), + shape_token_type_ids.size()); + + const std::vector shape_text_self_attention_masks = { 1, input_ids_len, input_ids_len }; + std::vector text_self_attention_masks(tokenized["text_self_attention_masks"].size()); + std::transform(tokenized["text_self_attention_masks"].cbegin(), + tokenized["text_self_attention_masks"].cend(), + text_self_attention_masks.begin(), + [](int c) { return static_cast(c); }); + Ort::Value tensor_text_self_attention_masks = Ort::Value::CreateTensor(memory_info, + text_self_attention_masks.data(), + text_self_attention_masks.size(), + shape_text_self_attention_masks.data(), + shape_text_self_attention_masks.size()); + + const std::vector shape_position_ids = { 1, input_ids_len }; + Ort::Value tensor_position_ids = Ort::Value::CreateTensor(memory_info, + tokenized["position_ids"].data(), + tokenized["position_ids"].size(), + shape_position_ids.data(), + shape_position_ids.size()); + std::vector inputs; + inputs.emplace_back(std::move(tensor_input_ids)); + inputs.emplace_back(std::move(tensor_token_type_ids)); + inputs.emplace_back(std::move(tensor_text_self_attention_masks)); + inputs.emplace_back(std::move(tensor_position_ids)); + + auto outputs = encoder->Run(Ort::RunOptions(nullptr), + encode_input_names.data(), + inputs.data(), + inputs.size(), + encode_output_names.data(), + encode_output_names.size()); + + tensor_last_hidden_state = std::move(outputs[0]); +#if 0 + std::cout << "last_hidden_state shape:" << std::endl; + for (auto& e: tensor_last_hidden_state.GetTensorTypeAndShapeInfo().GetShape()) { + std::cout << e << " "; + } + std::cout << std::endl; +#endif + + return 0; +} + +int CGDINOModel::decode(const cv::Mat& image, std::vector>& results) { + cv::Mat blob = pre_process_image(image); + const std::vector shape_image = { 1, 3, gdino_image_size.height, gdino_image_size.width }; + Ort::Value tensor_image = Ort::Value::CreateTensor(memory_info, + reinterpret_cast(blob.data), + blob.total(), + shape_image.data(), + shape_image.size()); + + std::vector input_ids = tokenized["input_ids"]; + int input_ids_len = input_ids.size(); + + const std::vector shape_attention_mask = { 1, input_ids_len }; + std::vector attention_mask(tokenized["attention_mask"].size()); + std::transform(tokenized["attention_mask"].cbegin(), + tokenized["attention_mask"].cend(), + attention_mask.begin(), + [](int c) { return static_cast(c); }); + Ort::Value tensor_attention_mask = Ort::Value::CreateTensor(memory_info, + attention_mask.data(), + attention_mask.size(), + shape_attention_mask.data(), + shape_attention_mask.size()); + + const std::vector shape_position_ids = { 1, input_ids_len }; + Ort::Value tensor_position_ids = Ort::Value::CreateTensor(memory_info, + tokenized["position_ids"].data(), + tokenized["position_ids"].size(), + shape_position_ids.data(), + shape_position_ids.size()); + + const std::vector shape_text_self_attention_masks = { 1, input_ids_len, input_ids_len }; + std::vector text_self_attention_masks(tokenized["text_self_attention_masks"].size()); + std::transform(tokenized["text_self_attention_masks"].cbegin(), + tokenized["text_self_attention_masks"].cend(), + text_self_attention_masks.begin(), + [](int c) { return static_cast(c); }); + Ort::Value tensor_text_self_attention_masks = Ort::Value::CreateTensor(memory_info, + text_self_attention_masks.data(), + text_self_attention_masks.size(), + shape_text_self_attention_masks.data(), + shape_text_self_attention_masks.size()); + + const std::vector shape_box_threshold = { 1 }; + Ort::Value tensor_box_threshold = Ort::Value::CreateTensor(memory_info, + const_cast(&box_threshold), + 1, + shape_box_threshold.data(), + shape_box_threshold.size()); + const std::vector shape_text_threshold = { 1 }; + Ort::Value tensor_text_threshold = Ort::Value::CreateTensor(memory_info, + const_cast(&text_threshold), + 1, + shape_text_threshold.data(), + shape_text_threshold.size()); + + std::vector inputs; + inputs.emplace_back(std::move(tensor_image)); + inputs.emplace_back(std::move(tensor_last_hidden_state)); + inputs.emplace_back(std::move(tensor_attention_mask)); + inputs.emplace_back(std::move(tensor_position_ids)); + inputs.emplace_back(std::move(tensor_text_self_attention_masks)); + inputs.emplace_back(std::move(tensor_box_threshold)); + inputs.emplace_back(std::move(tensor_text_threshold)); + + auto outputs = decoder->Run(Ort::RunOptions(nullptr), + decode_input_names.data(), + inputs.data(), + inputs.size(), + decode_output_names.data(), + decode_output_names.size()); + Ort::Value predict_logits = std::move(outputs[0]); + Ort::Value predict_boxes = std::move(outputs[1]); + Ort::Value predict_masks = std::move(outputs[2]); + + int predict_num = predict_logits.GetTensorTypeAndShapeInfo().GetShape()[1]; + + const float * predict_logits_data = predict_logits.GetTensorData(); + const float * predict_boxes_data = predict_boxes.GetTensorData(); + const bool * predict_masks_data = predict_masks.GetTensorData(); + + for (int i=0; i>& results) { + encode(caption); + decode(image, results); + return 0; +} diff --git a/src/gdino.h b/src/gdino.h new file mode 100644 index 0000000..4a7a40f --- /dev/null +++ b/src/gdino.h @@ -0,0 +1,63 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +class CGDINOModel { + const cv::Size gdino_image_size = { 800, 800 }; + const float box_threshold = 0.35f; + const float text_threshold = 0.25f; + + public: + CGDINOModel(const char * encode, + const char * decode, + const char * vocab); + ~CGDINOModel(); + + private: + std::vector convert_tokens_to_ids(const std::vector& tokens); + std::string convert_id_to_token(int id); + std::map> tokenize(const std::string& text); + std::string pre_process_caption(const std::string& caption); + cv::Mat pre_process_image(const cv::Mat& image); + + int encode(const std::string& caption); + int decode(const cv::Mat& image, std::vector>& results); + + public: + int process(const cv::Mat& image, + const std::string& caption, + std::vector>& results); + + private: + std::map token_id_table; + std::map id_token_table; + + std::vector special_token_ids; + std::map> tokenized; + + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::unique_ptr encoder; + const std::vector encode_input_names = { + "input_ids", "token_type_ids", "text_self_attention_masks", "position_ids" + }; + const std::vector encode_output_names = { + "last_hidden_state" + }; + Ort::Value tensor_last_hidden_state = Ort::Value(nullptr); + + std::unique_ptr decoder; + const std::vector decode_input_names = { + "image", "last_hidden_state", "attention_mask", + "position_ids", "text_self_attention_masks", + "box_threshold", "text_threshold" + }; + const std::vector decode_output_names = { + "logits", "boxes", "masks" + }; +}; \ No newline at end of file diff --git a/src/lama.cpp b/src/lama.cpp new file mode 100644 index 0000000..ade701a --- /dev/null +++ b/src/lama.cpp @@ -0,0 +1,85 @@ +#include "lama.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +CLama::CLama(const char * model, int kernel /* = 9 */) : + dilate_kernel(kernel, kernel) { + static Ort::Env env(ORT_LOGGING_LEVEL_ERROR, "LAMA"); + + session = std::make_unique(env, + model, + Ort::SessionOptions(nullptr)); +} + +CLama::~CLama() { + session->release(); +} + +int CLama::inpainting(const cv::Mat& image, const cv::Mat& mask, cv::Mat& result) { + cv::Mat proc_image, proc_mask, blob_image, blob_mask; + cv::resize(image, proc_image, lama_image_size); + cv::resize(mask, proc_mask, lama_image_size); + + cv::cvtColor(proc_image, proc_image, cv::COLOR_BGR2RGB); + + cv::Mat kernel = cv::getStructuringElement(cv::MORPH_RECT, dilate_kernel); + cv::dilate(proc_mask, proc_mask, kernel); + cv::threshold(proc_mask, proc_mask, 127, 255, cv::THRESH_BINARY); + + proc_image.convertTo(proc_image, CV_32F, 1.0 / 255.0); + proc_mask.convertTo(proc_mask, CV_32FC1, 1.0 / 255.0); + + blob_image = cv::dnn::blobFromImage(proc_image); + blob_mask = cv::dnn::blobFromImage(proc_mask); + + std::vector image_shape = { 1, 3, lama_image_size.height, lama_image_size.width }; + Ort::Value tensor_image = Ort::Value::CreateTensor(memory_info, + reinterpret_cast(blob_image.data), + blob_image.total(), + image_shape.data(), + image_shape.size()); + std::vector mask_shape = { 1, 1, lama_image_size.height, lama_image_size.width }; + Ort::Value tensor_mask = Ort::Value::CreateTensor(memory_info, + reinterpret_cast(blob_mask.data), + blob_mask.total(), + mask_shape.data(), + mask_shape.size()); + + std::vector inputs; + inputs.emplace_back(std::move(tensor_image)); + inputs.emplace_back(std::move(tensor_mask)); + + auto outputs = session->Run(Ort::RunOptions(nullptr), + input_names.data(), + inputs.data(), + inputs.size(), + output_names.data(), + output_names.size()); + + auto output_shape = outputs[0].GetTensorTypeAndShapeInfo().GetShape(); + const uint8_t * output_data = outputs[0].GetTensorData(); + int output_data_area = output_shape[2] * output_shape[3]; + std::vector channels(output_shape[1]); + for (int i=0; i +#include +#include +#include + +class CLama { + const cv::Size lama_image_size = { 512, 512 }; + + public: + CLama(const char * model, int kernel = 9); + ~CLama(); + + public: + int inpainting(const cv::Mat& image, const cv::Mat& mask, cv::Mat& result); + + private: + std::unique_ptr session; + + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, + OrtMemTypeCPU); + + const std::vector input_names = { "image", "mask" }; + const std::vector output_names = { "output" }; + + cv::Size dilate_kernel; +}; \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp new file mode 100644 index 0000000..cd66fd3 --- /dev/null +++ b/src/main.cpp @@ -0,0 +1,87 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "gdino.h" +#include "sam2.h" +#include "lama.h" + +namespace po = boost::program_options; + +int main(int argc, char * argv[]) { + po::options_description opts("Allowed options"); + opts.add_options() + ("help", "help message") + ("image,i", po::value(), "input image") + ("prompt,p", po::value(), "prompt") + ("kernel,k", po::value()->default_value(9), "kernel size") + ; + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, opts), vm); + if (vm.count("help") > 0 || + vm.count("image") == 0 || + vm.count("prompt") == 0) { + std::cout << opts << std::endl; + return 0; + } + + cv::Mat image = cv::imread(vm["image"].as()); + std::string caption = vm["prompt"].as(); + + CGDINOModel gdino("models/gdino.encoder.onnx", + "models/gdino.decoder.onnx", + "models/vocab.txt"); + CSam2 sam2("models/sam2_1.encoder.onnx", + "models/sam2_1.decoder.box.onnx"); + CLama lama("models/lama.onnx", vm["kernel"].as()); + + std::vector> results; + gdino.process(image, caption, results); + + std::vector boxes; + cv::Mat result_image = image.clone(); + for (auto& result: results) { + float score; + std::string caption; + cv::Rect box; + + tie(score, caption, box) = result; + + boxes.emplace_back(box); + + cv::rectangle(result_image, box, cv::Scalar(0, 0, 255), 2); + } + + cv::imshow("image", result_image); + cv::waitKey(); + + cv::Mat mask; + sam2.process(image, boxes, mask); + + cv::cvtColor(result_image, result_image, cv::COLOR_BGR2BGRA); + + cv::Mat result_mask; + cv::cvtColor(mask, result_mask, cv::COLOR_GRAY2BGRA); + result_mask.setTo(cv::Scalar(0, 203, 255, static_cast(255 * 0.73)), + (result_mask > 128)); + cv::addWeighted(result_image, 1.0, result_mask, 0.3, 0.0, result_image); + cv::imshow("image", result_image); + cv::waitKey(); + + + lama.inpainting(image, mask, result_image); + + cv::imshow("image", result_image); + cv::waitKey(); + + cv::destroyAllWindows(); + + return 0; +} \ No newline at end of file diff --git a/src/sam2.cpp b/src/sam2.cpp new file mode 100644 index 0000000..90b0318 --- /dev/null +++ b/src/sam2.cpp @@ -0,0 +1,135 @@ +#include "sam2.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +CSam2::CSam2(const char * encode, const char * decode) { + static Ort::Env env(ORT_LOGGING_LEVEL_ERROR, "SAM2"); + + encoder = std::make_unique(env, + encode, + Ort::SessionOptions(nullptr)); + + decoder = std::make_unique(env, + decode, + Ort::SessionOptions(nullptr)); +} + +CSam2::~CSam2() { + encoder->release(); + decoder->release(); +} + +int CSam2::encode(const cv::Mat& image) { + cv::Mat proc_image; + cv::resize(image, proc_image, sam2_image_size); + proc_image.convertTo(proc_image, CV_32F, 1.0 / 255.0); + proc_image = (proc_image - cv::Scalar(0.485, 0.456, 0.406)) / + cv::Scalar(0.229, 0.224, 0.225); + cv::Mat blob = cv::dnn::blobFromImage(proc_image, + 1.0, + proc_image.size(), + cv::Scalar(), + true); + + const std::vector shape_input = { 1, + 3, + sam2_image_size.height, + sam2_image_size.width }; + Ort::Value input = Ort::Value::CreateTensor(memory_info, + reinterpret_cast(blob.data), + blob.total(), + shape_input.data(), + shape_input.size()); + + auto outputs = encoder->Run(Ort::RunOptions(nullptr), + encode_input_names.data(), + &input, + 1, + encode_output_names.data(), + encode_output_names.size()); + + tensor_image_embeddings = std::move(outputs[0]); + tensor_high_res_features1 = std::move(outputs[1]); + tensor_high_res_features2 = std::move(outputs[2]); + + return 0; +} + +int CSam2::decode(const cv::Mat& image, + const std::vector boxes, + cv::Mat& mask) { + std::vector ratios = { 1.0f * sam2_image_size.width / image.cols, + 1.0f * sam2_image_size.height / image.rows }; + + std::vector boxes_data; + for (auto& box: boxes) { + int x1 = box.x * ratios[0]; + int y1 = box.y * ratios[1]; + int x2 = (box.x + box.width) * ratios[0]; + int y2 = (box.y + box.height) * ratios[1]; + + boxes_data.emplace_back(x1); + boxes_data.emplace_back(y1); + boxes_data.emplace_back(x2); + boxes_data.emplace_back(y2); + } + const std::vector boxes_shape = { static_cast(boxes.size()), 2, 2 }; + Ort::Value tensor_boxes = Ort::Value::CreateTensor(memory_info, + boxes_data.data(), + boxes_data.size(), + boxes_shape.data(), + boxes_shape.size()); + + std::vector inputs; + inputs.emplace_back(std::move(tensor_image_embeddings)); + inputs.emplace_back(std::move(tensor_high_res_features1)); + inputs.emplace_back(std::move(tensor_high_res_features2)); + inputs.emplace_back(std::move(tensor_boxes)); + + auto outputs = decoder->Run(Ort::RunOptions(nullptr), + decode_input_names.data(), + inputs.data(), + inputs.size(), + decode_output_names.data(), + decode_output_names.size()); + + std::vector masks_shape = outputs[0].GetTensorTypeAndShapeInfo().GetShape(); + const uint8_t * masks_data = outputs[0].GetTensorData(); + int mask_data_len = masks_shape[1] * masks_shape[2] * masks_shape[3]; + cv::Mat predict_mask_combine = cv::Mat::zeros(masks_shape[2], + masks_shape[3], + CV_8UC1); + for (int i=0; i boxes, + cv::Mat& mask) { + encode(image); + decode(image, boxes, mask); + return 0; +} diff --git a/src/sam2.h b/src/sam2.h new file mode 100644 index 0000000..49cf97c --- /dev/null +++ b/src/sam2.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include +#include +#include +#include + +class CSam2 { + const cv::Size sam2_image_size = { 1024, 1024 }; + + public: + CSam2(const char * encode, const char * decode); + ~CSam2(); + + private: + int encode(const cv::Mat& image); + int decode(const cv::Mat& image, + const std::vector boxes, + cv::Mat& mask); + + public: + int process(const cv::Mat& image, + const std::vector boxes, + cv::Mat& mask); + + private: + std::unique_ptr encoder; + std::unique_ptr decoder; + + Ort::Value tensor_image_embeddings = Ort::Value(nullptr); + Ort::Value tensor_high_res_features1 = Ort::Value(nullptr); + Ort::Value tensor_high_res_features2 = Ort::Value(nullptr); + + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, + OrtMemTypeCPU); + const std::vector encode_input_names = { + "image" + }; + const std::vector encode_output_names = { + "image_embeddings", "high_res_feats1", "high_res_feats2" + }; + + const std::vector decode_input_names = { + "image_embeddings", "high_res_features1", "high_res_features2", + "boxes" + }; + const std::vector decode_output_names = { + "masks", "iou_predictions" + }; +}; \ No newline at end of file