-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
12 changed files
with
873 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
cmake_minimum_required(VERSION 3.30) | ||
|
||
project(gdino.sam.lama) | ||
|
||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) | ||
set(CMAKE_CXX_STANDARD 20) | ||
set(EXECUTABLE_OUTPUT_PATH ${CMAKE_BINARY_DIR}/bin) | ||
|
||
include_directories(/opt/homebrew/include /usr/local/include) | ||
link_directories(/opt/homebrew/lib /usr/local/lib) | ||
|
||
find_package(OpenCV REQUIRED) | ||
include_directories(${OpenCV_INCLUDE_DIRS}) | ||
|
||
find_package(Boost COMPONENTS program_options REQUIRED) | ||
include_directories(${Boost_INCLUDE_DIR}) | ||
|
||
add_subdirectory(src) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
set(FILES main.cpp gdino.cpp sam2.cpp lama.cpp) | ||
|
||
add_executable(main ${FILES}) | ||
target_link_libraries(main ${OpenCV_LIBS} ${Boost_LIBRARIES} onnx onnxruntime) |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#pragma once | ||
|
||
#include <map> | ||
#include <memory> | ||
#include <opencv2/opencv.hpp> | ||
#include <string> | ||
#include <onnxruntime/onnxruntime_cxx_api.h> | ||
#include <tuple> | ||
#include <vector> | ||
|
||
class CGDINOModel { | ||
const cv::Size gdino_image_size = { 800, 800 }; | ||
const float box_threshold = 0.35f; | ||
const float text_threshold = 0.25f; | ||
|
||
public: | ||
CGDINOModel(const char * encode, | ||
const char * decode, | ||
const char * vocab); | ||
~CGDINOModel(); | ||
|
||
private: | ||
std::vector<int> convert_tokens_to_ids(const std::vector<std::string>& tokens); | ||
std::string convert_id_to_token(int id); | ||
std::map<std::string, std::vector<int>> tokenize(const std::string& text); | ||
std::string pre_process_caption(const std::string& caption); | ||
cv::Mat pre_process_image(const cv::Mat& image); | ||
|
||
int encode(const std::string& caption); | ||
int decode(const cv::Mat& image, std::vector<std::tuple<float, std::string, cv::Rect>>& results); | ||
|
||
public: | ||
int process(const cv::Mat& image, | ||
const std::string& caption, | ||
std::vector<std::tuple<float, std::string, cv::Rect>>& results); | ||
|
||
private: | ||
std::map<std::string, int> token_id_table; | ||
std::map<int, std::string> id_token_table; | ||
|
||
std::vector<int> special_token_ids; | ||
std::map<std::string, std::vector<int>> tokenized; | ||
|
||
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); | ||
std::unique_ptr<Ort::Session> encoder; | ||
const std::vector<const char *> encode_input_names = { | ||
"input_ids", "token_type_ids", "text_self_attention_masks", "position_ids" | ||
}; | ||
const std::vector<const char *> encode_output_names = { | ||
"last_hidden_state" | ||
}; | ||
Ort::Value tensor_last_hidden_state = Ort::Value(nullptr); | ||
|
||
std::unique_ptr<Ort::Session> decoder; | ||
const std::vector<const char *> decode_input_names = { | ||
"image", "last_hidden_state", "attention_mask", | ||
"position_ids", "text_self_attention_masks", | ||
"box_threshold", "text_threshold" | ||
}; | ||
const std::vector<const char *> decode_output_names = { | ||
"logits", "boxes", "masks" | ||
}; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
#include "lama.h" | ||
#include <cstdint> | ||
#include <memory> | ||
#include <opencv2/core.hpp> | ||
#include <opencv2/core/base.hpp> | ||
#include <opencv2/core/hal/interface.h> | ||
#include <opencv2/core/types.hpp> | ||
#include <opencv2/dnn/dnn.hpp> | ||
#include <opencv2/highgui.hpp> | ||
#include <opencv2/imgproc.hpp> | ||
#include <vector> | ||
|
||
CLama::CLama(const char * model, int kernel /* = 9 */) : | ||
dilate_kernel(kernel, kernel) { | ||
static Ort::Env env(ORT_LOGGING_LEVEL_ERROR, "LAMA"); | ||
|
||
session = std::make_unique<Ort::Session>(env, | ||
model, | ||
Ort::SessionOptions(nullptr)); | ||
} | ||
|
||
CLama::~CLama() { | ||
session->release(); | ||
} | ||
|
||
int CLama::inpainting(const cv::Mat& image, const cv::Mat& mask, cv::Mat& result) { | ||
cv::Mat proc_image, proc_mask, blob_image, blob_mask; | ||
cv::resize(image, proc_image, lama_image_size); | ||
cv::resize(mask, proc_mask, lama_image_size); | ||
|
||
cv::cvtColor(proc_image, proc_image, cv::COLOR_BGR2RGB); | ||
|
||
cv::Mat kernel = cv::getStructuringElement(cv::MORPH_RECT, dilate_kernel); | ||
cv::dilate(proc_mask, proc_mask, kernel); | ||
cv::threshold(proc_mask, proc_mask, 127, 255, cv::THRESH_BINARY); | ||
|
||
proc_image.convertTo(proc_image, CV_32F, 1.0 / 255.0); | ||
proc_mask.convertTo(proc_mask, CV_32FC1, 1.0 / 255.0); | ||
|
||
blob_image = cv::dnn::blobFromImage(proc_image); | ||
blob_mask = cv::dnn::blobFromImage(proc_mask); | ||
|
||
std::vector<int64_t> image_shape = { 1, 3, lama_image_size.height, lama_image_size.width }; | ||
Ort::Value tensor_image = Ort::Value::CreateTensor<float>(memory_info, | ||
reinterpret_cast<float *>(blob_image.data), | ||
blob_image.total(), | ||
image_shape.data(), | ||
image_shape.size()); | ||
std::vector<int64_t> mask_shape = { 1, 1, lama_image_size.height, lama_image_size.width }; | ||
Ort::Value tensor_mask = Ort::Value::CreateTensor<float>(memory_info, | ||
reinterpret_cast<float *>(blob_mask.data), | ||
blob_mask.total(), | ||
mask_shape.data(), | ||
mask_shape.size()); | ||
|
||
std::vector<Ort::Value> inputs; | ||
inputs.emplace_back(std::move(tensor_image)); | ||
inputs.emplace_back(std::move(tensor_mask)); | ||
|
||
auto outputs = session->Run(Ort::RunOptions(nullptr), | ||
input_names.data(), | ||
inputs.data(), | ||
inputs.size(), | ||
output_names.data(), | ||
output_names.size()); | ||
|
||
auto output_shape = outputs[0].GetTensorTypeAndShapeInfo().GetShape(); | ||
const uint8_t * output_data = outputs[0].GetTensorData<uint8_t>(); | ||
int output_data_area = output_shape[2] * output_shape[3]; | ||
std::vector<cv::Mat> channels(output_shape[1]); | ||
for (int i=0; i<channels.size(); ++i) { | ||
channels[i] = cv::Mat(output_shape[2], | ||
output_shape[3], | ||
CV_8UC1, | ||
(void *)(output_data + i * output_data_area)); | ||
} | ||
|
||
cv::Mat output_image; | ||
cv::merge(channels, output_image); | ||
|
||
cv::cvtColor(output_image, output_image, cv::COLOR_RGB2BGR); | ||
cv::resize(output_image, result, image.size()); | ||
|
||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#pragma once | ||
|
||
#include <memory> | ||
#include <opencv2/opencv.hpp> | ||
#include <onnxruntime/onnxruntime_cxx_api.h> | ||
#include <vector> | ||
|
||
class CLama { | ||
const cv::Size lama_image_size = { 512, 512 }; | ||
|
||
public: | ||
CLama(const char * model, int kernel = 9); | ||
~CLama(); | ||
|
||
public: | ||
int inpainting(const cv::Mat& image, const cv::Mat& mask, cv::Mat& result); | ||
|
||
private: | ||
std::unique_ptr<Ort::Session> session; | ||
|
||
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, | ||
OrtMemTypeCPU); | ||
|
||
const std::vector<const char *> input_names = { "image", "mask" }; | ||
const std::vector<const char *> output_names = { "output" }; | ||
|
||
cv::Size dilate_kernel; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
#include <opencv2/core.hpp> | ||
#include <opencv2/core/types.hpp> | ||
#include <opencv2/highgui.hpp> | ||
#include <opencv2/imgcodecs.hpp> | ||
#include <opencv2/imgproc.hpp> | ||
#include <opencv2/opencv.hpp> | ||
#include <string> | ||
#include <vector> | ||
#include <boost/program_options.hpp> | ||
|
||
#include "gdino.h" | ||
#include "sam2.h" | ||
#include "lama.h" | ||
|
||
namespace po = boost::program_options; | ||
|
||
int main(int argc, char * argv[]) { | ||
po::options_description opts("Allowed options"); | ||
opts.add_options() | ||
("help", "help message") | ||
("image,i", po::value<std::string>(), "input image") | ||
("prompt,p", po::value<std::string>(), "prompt") | ||
("kernel,k", po::value<int>()->default_value(9), "kernel size") | ||
; | ||
|
||
po::variables_map vm; | ||
po::store(po::parse_command_line(argc, argv, opts), vm); | ||
if (vm.count("help") > 0 || | ||
vm.count("image") == 0 || | ||
vm.count("prompt") == 0) { | ||
std::cout << opts << std::endl; | ||
return 0; | ||
} | ||
|
||
cv::Mat image = cv::imread(vm["image"].as<std::string>()); | ||
std::string caption = vm["prompt"].as<std::string>(); | ||
|
||
CGDINOModel gdino("models/gdino.encoder.onnx", | ||
"models/gdino.decoder.onnx", | ||
"models/vocab.txt"); | ||
CSam2 sam2("models/sam2_1.encoder.onnx", | ||
"models/sam2_1.decoder.box.onnx"); | ||
CLama lama("models/lama.onnx", vm["kernel"].as<int>()); | ||
|
||
std::vector<std::tuple<float, std::string, cv::Rect>> results; | ||
gdino.process(image, caption, results); | ||
|
||
std::vector<cv::Rect> boxes; | ||
cv::Mat result_image = image.clone(); | ||
for (auto& result: results) { | ||
float score; | ||
std::string caption; | ||
cv::Rect box; | ||
|
||
tie(score, caption, box) = result; | ||
|
||
boxes.emplace_back(box); | ||
|
||
cv::rectangle(result_image, box, cv::Scalar(0, 0, 255), 2); | ||
} | ||
|
||
cv::imshow("image", result_image); | ||
cv::waitKey(); | ||
|
||
cv::Mat mask; | ||
sam2.process(image, boxes, mask); | ||
|
||
cv::cvtColor(result_image, result_image, cv::COLOR_BGR2BGRA); | ||
|
||
cv::Mat result_mask; | ||
cv::cvtColor(mask, result_mask, cv::COLOR_GRAY2BGRA); | ||
result_mask.setTo(cv::Scalar(0, 203, 255, static_cast<uchar>(255 * 0.73)), | ||
(result_mask > 128)); | ||
cv::addWeighted(result_image, 1.0, result_mask, 0.3, 0.0, result_image); | ||
cv::imshow("image", result_image); | ||
cv::waitKey(); | ||
|
||
|
||
lama.inpainting(image, mask, result_image); | ||
|
||
cv::imshow("image", result_image); | ||
cv::waitKey(); | ||
|
||
cv::destroyAllWindows(); | ||
|
||
return 0; | ||
} |
Oops, something went wrong.