Allow Transform of encoded datum.
authorSergio Guadarrama <sguada@google.com>
Wed, 8 Apr 2015 00:51:22 +0000 (17:51 -0700)
committerSergio Guadarrama <sguada@google.com>
Thu, 9 Apr 2015 00:19:49 +0000 (17:19 -0700)
Allow initialize transformed_blob from datum or transform params.
Allow force_color and force_gray as transform params.

src/caffe/data_transformer.cpp
src/caffe/proto/caffe.proto

index b0b98e4..454dabb 100644 (file)
@@ -125,10 +125,40 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,
 template<typename Dtype>
 void DataTransformer<Dtype>::Transform(const Datum& datum,
                                        Blob<Dtype>* transformed_blob) {
+  // If datum is encoded, decoded and transform the cv::image.
+  if (datum.encoded()) {
+    CHECK(!param_.force_color() && !param_.force_gray())
+        << "cannot set both force_color and force_gray";
+    cv::Mat cv_img;
+    if (param_.force_color() || param_.force_gray()) {
+    // If force_color then decode in color otherwise decode in gray.
+      cv_img = DecodeDatumToCVMat(datum, param_.force_color());
+    } else {
+      cv_img = DecodeDatumToCVMatNative(datum);
+    }
+    // Transform the cv::image into blob.
+    return Transform(cv_img, transformed_blob);
+  } else {
+    if (param_.force_color() || param_.force_gray()) {
+      LOG(ERROR) << "force_color and force_gray only for encoded datum";
+    }
+  }
+
   const int datum_channels = datum.channels();
   const int datum_height = datum.height();
   const int datum_width = datum.width();
 
+  const int crop_size = param_.crop_size();
+
+  if (transformed_blob->count() == 0) {
+    // Initialize it.
+    if (crop_size) {
+      transformed_blob->Reshape(1, datum_channels, crop_size, crop_size);
+    } else {
+      transformed_blob->Reshape(1, datum_channels, datum_height, datum_width);
+    }
+  }
+  // Check dimensions.
   const int channels = transformed_blob->channels();
   const int height = transformed_blob->height();
   const int width = transformed_blob->width();
@@ -139,8 +169,6 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,
   CHECK_LE(width, datum_width);
   CHECK_GE(num, 1);
 
-  const int crop_size = param_.crop_size();
-
   if (crop_size) {
     CHECK_EQ(crop_size, height);
     CHECK_EQ(crop_size, width);
@@ -200,6 +228,17 @@ void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
   const int img_height = cv_img.rows;
   const int img_width = cv_img.cols;
 
+  const int crop_size = param_.crop_size();
+
+  if (transformed_blob->count() == 0) {
+    // Initialize it.
+    if (crop_size) {
+      transformed_blob->Reshape(1, img_channels, crop_size, crop_size);
+    } else {
+      transformed_blob->Reshape(1, img_channels, img_height, img_width);
+    }
+  }
+
   const int channels = transformed_blob->channels();
   const int height = transformed_blob->height();
   const int width = transformed_blob->width();
@@ -212,7 +251,6 @@ void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
 
   CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";
 
-  const int crop_size = param_.crop_size();
   const Dtype scale = param_.scale();
   const bool do_mirror = param_.mirror() && Rand(2);
   const bool has_mean_file = param_.has_mean_file();
@@ -302,6 +340,19 @@ void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,
   const int input_height = input_blob->height();
   const int input_width = input_blob->width();
 
+  const int crop_size = param_.crop_size();
+
+  if (transformed_blob->count() == 0) {
+    // Initialize it.
+    if (crop_size) {
+      transformed_blob->Reshape(input_num, input_channels,
+                                crop_size, crop_size);
+    } else {
+      transformed_blob->Reshape(input_num, input_channels,
+                                input_height, input_width);
+    }
+  }
+
   const int num = transformed_blob->num();
   const int channels = transformed_blob->channels();
   const int height = transformed_blob->height();
@@ -313,7 +364,7 @@ void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,
   CHECK_GE(input_height, height);
   CHECK_GE(input_width, width);
 
-  const int crop_size = param_.crop_size();
+
   const Dtype scale = param_.scale();
   const bool do_mirror = param_.mirror() && Rand(2);
   const bool has_mean_file = param_.has_mean_file();
index 5b21cf2..d66167e 100644 (file)
@@ -351,6 +351,10 @@ message TransformationParameter {
   // or can be repeated the same number of times as channels
   // (would subtract them from the corresponding channel)
   repeated float mean_value = 5;
+  // Force the decoded image to have 3 color channels.
+  optional bool force_color = 6 [default = false];
+  // Force the decoded image to have 1 color channels.
+  optional bool force_gray = 7 [default = false];
 }
 
 // Message that stores parameters shared by loss layers