Skip to content

Commit

Permalink
added a force_encoded_color flag to the data layer. Printing a warnin…
Browse files Browse the repository at this point in the history
…g if images of different channel dimensions are encoded together
  • Loading branch information
philkr committed Feb 20, 2015
1 parent 5246587 commit a2f7f47
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 6 deletions.
Binary file added examples/images/cat_gray.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions include/caffe/util/io.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ inline bool ReadImageToDatum(const string& filename, const int label,
}

bool DecodeDatumNative(Datum* datum);
bool DecodeDatum(Datum* datum, bool is_color);

cv::Mat ReadImageToCVMat(const string& filename,
const int height, const int width, const bool is_color);
Expand All @@ -135,6 +136,7 @@ cv::Mat ReadImageToCVMat(const string& filename,
cv::Mat ReadImageToCVMat(const string& filename);

cv::Mat DecodeDatumToCVMatNative(const Datum& datum);
cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color);

void CVMatToDatum(const cv::Mat& cv_img, Datum* datum);

Expand Down
15 changes: 13 additions & 2 deletions src/caffe/layers/data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
Datum datum;
datum.ParseFromString(cursor_->value());

if (DecodeDatumNative(&datum)) {
bool force_color = this->layer_param_.data_param().force_encoded_color();
if ((force_color && DecodeDatum(&datum, true)) ||
DecodeDatumNative(&datum)) {
LOG(INFO) << "Decoding Datum";
}
// image
Expand Down Expand Up @@ -90,6 +92,7 @@ void DataLayer<Dtype>::InternalThreadEntry() {
top_label = this->prefetch_label_.mutable_cpu_data();
}
const int batch_size = this->layer_param_.data_param().batch_size();
bool force_color = this->layer_param_.data_param().force_encoded_color();
for (int item_id = 0; item_id < batch_size; ++item_id) {
timer.Start();
// get a blob
Expand All @@ -98,7 +101,15 @@ void DataLayer<Dtype>::InternalThreadEntry() {

cv::Mat cv_img;
if (datum.encoded()) {
cv_img = DecodeDatumToCVMatNative(datum);
if (force_color)
cv_img = DecodeDatumToCVMat(datum, true);
else
cv_img = DecodeDatumToCVMatNative(datum);
if (cv_img.channels() != this->transformed_data_.channels())
LOG(WARNING) << "Your dataset contains encoded images with mixed "
<< "channel sizes. Consider adding a 'force_color' flag to the "
<< "model definition, or rebuild your dataset using "
<< "convert_imageset.";
}
read_time += timer.MicroSeconds();
timer.Start();
Expand Down
2 changes: 1 addition & 1 deletion src/caffe/layers/window_data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ void WindowDataLayer<Dtype>::InternalThreadEntry() {
if (this->cache_images_) {
pair<std::string, Datum> image_cached =
image_database_cache_[window[WindowDataLayer<Dtype>::IMAGE_INDEX]];
cv_img = DecodeDatumToCVMatNative(image_cached.second);
cv_img = DecodeDatumToCVMat(image_cached.second, true);
} else {
cv_img = cv::imread(image.first, CV_LOAD_IMAGE_COLOR);
if (!cv_img.data) {
Expand Down
2 changes: 2 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@ message DataParameter {
// DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror
// data.
optional bool mirror = 6 [default = false];
// Force the encoded image to have 3 color channels
optional bool force_encoded_color = 9 [default = false];
}

// Message that stores parameters used by DropoutLayer
Expand Down
88 changes: 86 additions & 2 deletions src/caffe/test/test_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,60 @@ TEST_F(IOTest, TestReadFileToDatum) {
}

TEST_F(IOTest, TestDecodeDatum) {
string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg";
Datum datum;
EXPECT_TRUE(ReadFileToDatum(filename, &datum));
EXPECT_TRUE(DecodeDatum(&datum, true));
EXPECT_FALSE(DecodeDatum(&datum, true));
Datum datum_ref;
ReadImageToDatumReference(filename, 0, 0, 0, true, &datum_ref);
EXPECT_EQ(datum.channels(), datum_ref.channels());
EXPECT_EQ(datum.height(), datum_ref.height());
EXPECT_EQ(datum.width(), datum_ref.width());
EXPECT_EQ(datum.data().size(), datum_ref.data().size());

const string& data = datum.data();
const string& data_ref = datum_ref.data();
for (int i = 0; i < datum.data().size(); ++i) {
EXPECT_TRUE(data[i] == data_ref[i]);
}
}

TEST_F(IOTest, TestDecodeDatumToCVMat) {
string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg";
Datum datum;
EXPECT_TRUE(ReadFileToDatum(filename, &datum));
cv::Mat cv_img = DecodeDatumToCVMat(datum, true);
EXPECT_EQ(cv_img.channels(), 3);
EXPECT_EQ(cv_img.rows, 360);
EXPECT_EQ(cv_img.cols, 480);
cv_img = DecodeDatumToCVMat(datum, false);
EXPECT_EQ(cv_img.channels(), 1);
EXPECT_EQ(cv_img.rows, 360);
EXPECT_EQ(cv_img.cols, 480);
}

TEST_F(IOTest, TestDecodeDatumToCVMatContent) {
string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg";
Datum datum;
EXPECT_TRUE(ReadImageToDatum(filename, 0, std::string("jpg"), &datum));
cv::Mat cv_img = DecodeDatumToCVMat(datum, true);
cv::Mat cv_img_ref = ReadImageToCVMat(filename);
EXPECT_EQ(cv_img_ref.channels(), cv_img.channels());
EXPECT_EQ(cv_img_ref.rows, cv_img.rows);
EXPECT_EQ(cv_img_ref.cols, cv_img.cols);

for (int c = 0; c < datum.channels(); ++c) {
for (int h = 0; h < datum.height(); ++h) {
for (int w = 0; w < datum.width(); ++w) {
EXPECT_TRUE(cv_img.at<cv::Vec3b>(h, w)[c]==
cv_img_ref.at<cv::Vec3b>(h, w)[c]);
}
}
}
}

TEST_F(IOTest, TestDecodeDatumNative) {
string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg";
Datum datum;
EXPECT_TRUE(ReadFileToDatum(filename, &datum));
Expand All @@ -305,7 +359,7 @@ TEST_F(IOTest, TestDecodeDatum) {
}
}

TEST_F(IOTest, TestDecodeDatumToCVMat) {
TEST_F(IOTest, TestDecodeDatumToCVMatNative) {
string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg";
Datum datum;
EXPECT_TRUE(ReadFileToDatum(filename, &datum));
Expand All @@ -315,7 +369,37 @@ TEST_F(IOTest, TestDecodeDatumToCVMat) {
EXPECT_EQ(cv_img.cols, 480);
}

TEST_F(IOTest, TestDecodeDatumToCVMatContent) {
TEST_F(IOTest, TestDecodeDatumNativeGray) {
string filename = EXAMPLES_SOURCE_DIR "images/cat_gray.jpg";
Datum datum;
EXPECT_TRUE(ReadFileToDatum(filename, &datum));
EXPECT_TRUE(DecodeDatumNative(&datum));
EXPECT_FALSE(DecodeDatumNative(&datum));
Datum datum_ref;
ReadImageToDatumReference(filename, 0, 0, 0, false, &datum_ref);
EXPECT_EQ(datum.channels(), datum_ref.channels());
EXPECT_EQ(datum.height(), datum_ref.height());
EXPECT_EQ(datum.width(), datum_ref.width());
EXPECT_EQ(datum.data().size(), datum_ref.data().size());

const string& data = datum.data();
const string& data_ref = datum_ref.data();
for (int i = 0; i < datum.data().size(); ++i) {
EXPECT_TRUE(data[i] == data_ref[i]);
}
}

TEST_F(IOTest, TestDecodeDatumToCVMatNativeGray) {
string filename = EXAMPLES_SOURCE_DIR "images/cat_gray.jpg";
Datum datum;
EXPECT_TRUE(ReadFileToDatum(filename, &datum));
cv::Mat cv_img = DecodeDatumToCVMatNative(datum);
EXPECT_EQ(cv_img.channels(), 1);
EXPECT_EQ(cv_img.rows, 360);
EXPECT_EQ(cv_img.cols, 480);
}

TEST_F(IOTest, TestDecodeDatumToCVMatContentNative) {
string filename = EXAMPLES_SOURCE_DIR "images/cat.jpg";
Datum datum;
EXPECT_TRUE(ReadImageToDatum(filename, 0, std::string("jpg"), &datum));
Expand Down
23 changes: 22 additions & 1 deletion src/caffe/util/io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,21 @@ cv::Mat DecodeDatumToCVMatNative(const Datum& datum) {
}
return cv_img;
}
cv::Mat DecodeDatumToCVMat(const Datum& datum, bool is_color) {
cv::Mat cv_img;
CHECK(datum.encoded()) << "Datum not encoded";
const string& data = datum.data();
std::vector<char> vec_data(data.c_str(), data.c_str() + data.size());
int cv_read_flag = (is_color ? CV_LOAD_IMAGE_COLOR :
CV_LOAD_IMAGE_GRAYSCALE);
cv_img = cv::imdecode(vec_data, cv_read_flag);
if (!cv_img.data) {
LOG(ERROR) << "Could not decode datum ";
}
return cv_img;
}

// If Datum is encoded will decoded using DecodeDatumToCVMat and CVMatToDatum
// if height and width are set it will resize it
// If Datum is not encoded will do nothing
bool DecodeDatumNative(Datum* datum) {
if (datum->encoded()) {
Expand All @@ -180,6 +192,15 @@ bool DecodeDatumNative(Datum* datum) {
return false;
}
}
bool DecodeDatum(Datum* datum, bool is_color) {
if (datum->encoded()) {
cv::Mat cv_img = DecodeDatumToCVMat((*datum), is_color);
CVMatToDatum(cv_img, datum);
return true;
} else {
return false;
}
}

void CVMatToDatum(const cv::Mat& cv_img, Datum* datum) {
CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";
Expand Down

0 comments on commit a2f7f47

Please sign in to comment.