Skip to content

Commit aeba1c1

Browse files
committed
[Container] Add more constructors for MemRef and Img container.
1 parent d155769 commit aeba1c1

File tree

4 files changed

+52
-5
lines changed

4 files changed

+52
-5
lines changed

include/Interface/buddy/core/Container.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@ template <typename T, size_t N> class MemRef {
3333
public:
3434
// Constructor from shape.
3535
MemRef(intptr_t sizes[N], T init = T(0));
36+
MemRef(std::vector<size_t> sizes, T init = T(0));
3637
// Constructor from data.
3738
MemRef(const T *data, intptr_t sizes[N], intptr_t offset = 0);
3839
// Constructor from a unique_ptr, taking over.
39-
MemRef(std::unique_ptr<T>& uptr, intptr_t sizes[N], intptr_t offset = 0);
40+
MemRef(std::unique_ptr<T> &uptr, intptr_t sizes[N], intptr_t offset = 0);
4041
// Copy constructor.
4142
MemRef(const MemRef<T, N> &other);
4243
// Copy assignment operator.

include/Interface/buddy/core/ImageContainer.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,12 @@
2626
// Image container.
2727
// - T represents the type of the elements.
2828
// - N represents the number of dimensions.
29+
// - image represents the OpenCV Mat object.
30+
// - norm indicates whether to perform normalization, and the normalization is
31+
// disabled by default.
2932
template <typename T, size_t N> class Img : public MemRef<T, N> {
3033
public:
31-
Img(cv::Mat image);
34+
Img(cv::Mat image, bool norm = false);
3235
};
3336

3437
#include "Interface/core/ImageContainer.cpp"

lib/Interface/core/Container.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,21 @@ MemRef<T, N>::MemRef(intptr_t sizes[N], T init) {
4444
std::fill(aligned, aligned + size, init);
4545
}
4646

47+
template <typename T, std::size_t N>
48+
MemRef<T, N>::MemRef(std::vector<size_t> sizes, T init) {
49+
if (sizes.size() != N) {
50+
throw std::runtime_error("Invalid number of dimensions.");
51+
}
52+
for (size_t i = 0; i < N; i++) {
53+
this->sizes[i] = sizes[i];
54+
}
55+
setStrides();
56+
size = product(this->sizes);
57+
allocated = new T[size];
58+
aligned = allocated;
59+
std::fill(aligned, aligned + size, init);
60+
}
61+
4762
// MemRef Array Constructor.
4863
// Construct a MemRef object from the data pointer, sizes, and offset.
4964
// The default offset is 0.
@@ -210,7 +225,7 @@ size_t MemRef<T, N>::product(intptr_t sizes[N]) const {
210225
return size;
211226
}
212227
template <typename T, size_t N>
213-
MemRef<T, N>::MemRef(std::unique_ptr<T>& uptr, intptr_t *sizes,
228+
MemRef<T, N>::MemRef(std::unique_ptr<T> &uptr, intptr_t *sizes,
214229
intptr_t offset) {
215230
if (!uptr)
216231
assert(0 && "Taking over an empty unique pointer.");

lib/Interface/core/ImageContainer.cpp

+30-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
#include <cassert>
2525

2626
// Image Constructor from OpenCV Mat.
27-
template <typename T, size_t N> Img<T, N>::Img(cv::Mat image) : MemRef<T, N>() {
27+
template <typename T, size_t N>
28+
Img<T, N>::Img(cv::Mat image, bool norm) : MemRef<T, N>() {
2829
if (image.channels() == 1) {
2930
assert((N == 2) &&
3031
"Input image type does not match the selected dimension.");
@@ -41,9 +42,36 @@ template <typename T, size_t N> Img<T, N>::Img(cv::Mat image) : MemRef<T, N>() {
4142
}
4243
}
4344
this->setStrides();
45+
}
46+
// Use NHWC layout by default.
47+
else if (image.channels() == 3) {
48+
assert((N == 4) &&
49+
"Input image type does not match the selected dimension.");
50+
this->sizes[0] = 1;
51+
this->sizes[1] = image.rows;
52+
this->sizes[2] = image.cols;
53+
this->sizes[3] = 3;
54+
this->size = image.rows * image.cols * 3;
55+
this->allocated = new T[this->size];
56+
this->aligned = this->allocated;
57+
int k = 0;
58+
for (int i = 0; i < image.rows; i++) {
59+
for (int j = 0; j < image.cols; j++) {
60+
for (int color = 0; color < 3; color++) {
61+
// Reorder to RGB layout.
62+
if (norm) {
63+
this->aligned[k] = (T)image.at<cv::Vec3b>(i, j)[2 - color] / 255;
64+
} else {
65+
this->aligned[k] = (T)image.at<cv::Vec3b>(i, j)[2 - color];
66+
}
67+
k++;
68+
}
69+
}
70+
}
71+
this->setStrides();
4472
} else {
4573
// TODO: Add more image channels in this constructor.
46-
assert((N != 2) && "This image channels is not supported.");
74+
std::cerr << "This image channels is not supported." << std::endl;
4775
}
4876
}
4977

0 commit comments

Comments
 (0)