From b963008a6591600e60ed6746d208e82e107f6a89 Mon Sep 17 00:00:00 2001 From: Sergio Guadarrama Date: Tue, 7 Apr 2015 17:51:22 -0700 Subject: [PATCH] Allow Transform of encoded datum. Allow initialize transformed_blob from datum or transform params. Allow force_color and force_gray as transform params. --- src/caffe/data_transformer.cpp | 59 +++++++++++++++++++++++++++++++++++++++--- src/caffe/proto/caffe.proto | 4 +++ 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/src/caffe/data_transformer.cpp b/src/caffe/data_transformer.cpp index b0b98e4..454dabb 100644 --- a/src/caffe/data_transformer.cpp +++ b/src/caffe/data_transformer.cpp @@ -125,10 +125,40 @@ void DataTransformer::Transform(const Datum& datum, template void DataTransformer::Transform(const Datum& datum, Blob* transformed_blob) { + // If datum is encoded, decoded and transform the cv::image. + if (datum.encoded()) { + CHECK(!param_.force_color() && !param_.force_gray()) + << "cannot set both force_color and force_gray"; + cv::Mat cv_img; + if (param_.force_color() || param_.force_gray()) { + // If force_color then decode in color otherwise decode in gray. + cv_img = DecodeDatumToCVMat(datum, param_.force_color()); + } else { + cv_img = DecodeDatumToCVMatNative(datum); + } + // Transform the cv::image into blob. + return Transform(cv_img, transformed_blob); + } else { + if (param_.force_color() || param_.force_gray()) { + LOG(ERROR) << "force_color and force_gray only for encoded datum"; + } + } + const int datum_channels = datum.channels(); const int datum_height = datum.height(); const int datum_width = datum.width(); + const int crop_size = param_.crop_size(); + + if (transformed_blob->count() == 0) { + // Initialize it. + if (crop_size) { + transformed_blob->Reshape(1, datum_channels, crop_size, crop_size); + } else { + transformed_blob->Reshape(1, datum_channels, datum_height, datum_width); + } + } + // Check dimensions. const int channels = transformed_blob->channels(); const int height = transformed_blob->height(); const int width = transformed_blob->width(); @@ -139,8 +169,6 @@ void DataTransformer::Transform(const Datum& datum, CHECK_LE(width, datum_width); CHECK_GE(num, 1); - const int crop_size = param_.crop_size(); - if (crop_size) { CHECK_EQ(crop_size, height); CHECK_EQ(crop_size, width); @@ -200,6 +228,17 @@ void DataTransformer::Transform(const cv::Mat& cv_img, const int img_height = cv_img.rows; const int img_width = cv_img.cols; + const int crop_size = param_.crop_size(); + + if (transformed_blob->count() == 0) { + // Initialize it. + if (crop_size) { + transformed_blob->Reshape(1, img_channels, crop_size, crop_size); + } else { + transformed_blob->Reshape(1, img_channels, img_height, img_width); + } + } + const int channels = transformed_blob->channels(); const int height = transformed_blob->height(); const int width = transformed_blob->width(); @@ -212,7 +251,6 @@ void DataTransformer::Transform(const cv::Mat& cv_img, CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte"; - const int crop_size = param_.crop_size(); const Dtype scale = param_.scale(); const bool do_mirror = param_.mirror() && Rand(2); const bool has_mean_file = param_.has_mean_file(); @@ -302,6 +340,19 @@ void DataTransformer::Transform(Blob* input_blob, const int input_height = input_blob->height(); const int input_width = input_blob->width(); + const int crop_size = param_.crop_size(); + + if (transformed_blob->count() == 0) { + // Initialize it. + if (crop_size) { + transformed_blob->Reshape(input_num, input_channels, + crop_size, crop_size); + } else { + transformed_blob->Reshape(input_num, input_channels, + input_height, input_width); + } + } + const int num = transformed_blob->num(); const int channels = transformed_blob->channels(); const int height = transformed_blob->height(); @@ -313,7 +364,7 @@ void DataTransformer::Transform(Blob* input_blob, CHECK_GE(input_height, height); CHECK_GE(input_width, width); - const int crop_size = param_.crop_size(); + const Dtype scale = param_.scale(); const bool do_mirror = param_.mirror() && Rand(2); const bool has_mean_file = param_.has_mean_file(); diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 5b21cf2..d66167e 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -351,6 +351,10 @@ message TransformationParameter { // or can be repeated the same number of times as channels // (would subtract them from the corresponding channel) repeated float mean_value = 5; + // Force the decoded image to have 3 color channels. + optional bool force_color = 6 [default = false]; + // Force the decoded image to have 1 color channels. + optional bool force_gray = 7 [default = false]; } // Message that stores parameters shared by loss layers -- 2.7.4