Added InferBlobShape to data_transformer.
authorSergio Guadarrama <sguada@google.com>
Wed, 8 Apr 2015 19:08:56 +0000 (12:08 -0700)
committerSergio Guadarrama <sguada@google.com>
Thu, 9 Apr 2015 00:32:03 +0000 (17:32 -0700)
include/caffe/data_transformer.hpp
src/caffe/data_transformer.cpp

index 8803566..0ad68c8 100644 (file)
@@ -62,6 +62,7 @@ class DataTransformer {
    */
   void Transform(const vector<cv::Mat> & mat_vector,
                 Blob<Dtype>* transformed_blob);
+
   /**
    * @brief Applies the transformation defined in the data layer's
    * transform_param block to a cv::Mat
@@ -87,6 +88,41 @@ class DataTransformer {
    */
   void Transform(Blob<Dtype>* input_blob, Blob<Dtype>* transformed_blob);
 
+  /**
+   * @brief Infers the shape of transformed_blob will have when
+   *    the transformation is applied to the data.
+   *
+   * @param datum
+   *    Datum containing the data to be transformed.
+   */
+  vector<int> InferBlobShape(const Datum& datum);
+  /**
+   * @brief Infers the shape of transformed_blob will have when
+   *    the transformation is applied to the data.
+   *    It uses the first element to infer the shape of the blob.
+   *
+   * @param datum_vector
+   *    A vector of Datum containing the data to be transformed.
+   */
+  vector<int> InferBlobShape(const vector<Datum> & datum_vector);
+  /**
+   * @brief Infers the shape of transformed_blob will have when
+   *    the transformation is applied to the data.
+   *    It uses the first element to infer the shape of the blob.
+   *
+   * @param mat_vector
+   *    A vector of Mat containing the data to be transformed.
+   */
+  vector<int> InferBlobShape(const vector<cv::Mat> & mat_vector);
+  /**
+   * @brief Infers the shape of transformed_blob will have when
+   *    the transformation is applied to the data.
+   *
+   * @param cv_img
+   *    cv::Mat containing the data to be transformed.
+   */
+  vector<int> InferBlobShape(const cv::Mat& cv_img);
+
  protected:
    /**
    * @brief Generates a random integer from Uniform({0, 1, ..., n-1}).
index 454dabb..6f75bdb 100644 (file)
@@ -144,20 +144,11 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,
     }
   }
 
+  const int crop_size = param_.crop_size();
   const int datum_channels = datum.channels();
   const int datum_height = datum.height();
   const int datum_width = datum.width();
 
-  const int crop_size = param_.crop_size();
-
-  if (transformed_blob->count() == 0) {
-    // Initialize it.
-    if (crop_size) {
-      transformed_blob->Reshape(1, datum_channels, crop_size, crop_size);
-    } else {
-      transformed_blob->Reshape(1, datum_channels, datum_height, datum_width);
-    }
-  }
   // Check dimensions.
   const int channels = transformed_blob->channels();
   const int height = transformed_blob->height();
@@ -224,21 +215,12 @@ void DataTransformer<Dtype>::Transform(const vector<cv::Mat> & mat_vector,
 template<typename Dtype>
 void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
                                        Blob<Dtype>* transformed_blob) {
+  const int crop_size = param_.crop_size();
   const int img_channels = cv_img.channels();
   const int img_height = cv_img.rows;
   const int img_width = cv_img.cols;
 
-  const int crop_size = param_.crop_size();
-
-  if (transformed_blob->count() == 0) {
-    // Initialize it.
-    if (crop_size) {
-      transformed_blob->Reshape(1, img_channels, crop_size, crop_size);
-    } else {
-      transformed_blob->Reshape(1, img_channels, img_height, img_width);
-    }
-  }
-
+  // Check dimensions.
   const int channels = transformed_blob->channels();
   const int height = transformed_blob->height();
   const int width = transformed_blob->width();
@@ -335,15 +317,14 @@ void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
 template<typename Dtype>
 void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,
                                        Blob<Dtype>* transformed_blob) {
+  const int crop_size = param_.crop_size();
   const int input_num = input_blob->num();
   const int input_channels = input_blob->channels();
   const int input_height = input_blob->height();
   const int input_width = input_blob->width();
 
-  const int crop_size = param_.crop_size();
-
   if (transformed_blob->count() == 0) {
-    // Initialize it.
+    // Initialize transformed_blob with the right shape.
     if (crop_size) {
       transformed_blob->Reshape(input_num, input_channels,
                                 crop_size, crop_size);
@@ -446,6 +427,82 @@ void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,
   }
 }
 
+template<typename Dtype>
+vector<int> DataTransformer<Dtype>::InferBlobShape(const Datum& datum) {
+  if (datum.encoded()) {
+    CHECK(!param_.force_color() && !param_.force_gray())
+        << "cannot set both force_color and force_gray";
+    cv::Mat cv_img;
+    if (param_.force_color() || param_.force_gray()) {
+    // If force_color then decode in color otherwise decode in gray.
+      cv_img = DecodeDatumToCVMat(datum, param_.force_color());
+    } else {
+      cv_img = DecodeDatumToCVMatNative(datum);
+    }
+    // InferBlobShape using the cv::image.
+    return InferBlobShape(cv_img);
+  }
+
+  const int crop_size = param_.crop_size();
+  const int datum_channels = datum.channels();
+  const int datum_height = datum.height();
+  const int datum_width = datum.width();
+  // Check dimensions.
+  CHECK_GT(datum_channels, 0);
+  CHECK_GE(datum_height, crop_size);
+  CHECK_GE(datum_width, crop_size);
+  // Build BlobShape.
+  vector<int> shape(4);
+  shape[0] = 1;
+  shape[1] = datum_channels;
+  shape[2] = (crop_size)? crop_size: datum_height;
+  shape[3] = (crop_size)? crop_size: datum_width;
+  return shape;
+}
+
+template<typename Dtype>
+vector<int> DataTransformer<Dtype>::InferBlobShape(
+    const vector<Datum> & datum_vector) {
+  const int num = datum_vector.size();
+  CHECK_GT(num, 0) << "There is no datum to in the vector";
+  // Use first datum in the vector to InferBlobShape.
+  vector<int> shape = InferBlobShape(datum_vector[0]);
+  // Adjust num to the size of the vector.
+  shape[0] = num;
+  return shape;
+}
+
+template<typename Dtype>
+vector<int> DataTransformer<Dtype>::InferBlobShape(const cv::Mat& cv_img) {
+  const int crop_size = param_.crop_size();
+  const int img_channels = cv_img.channels();
+  const int img_height = cv_img.rows;
+  const int img_width = cv_img.cols;
+  // Check dimensions.
+  CHECK_GT(img_channels, 0);
+  CHECK_GE(img_height, crop_size);
+  CHECK_GE(img_width, crop_size);
+  // Build BlobShape.
+  vector<int> shape(4);
+  shape[0] = 1;
+  shape[1] = img_channels;
+  shape[2] = (crop_size)? crop_size: img_height;
+  shape[3] = (crop_size)? crop_size: img_width;
+  return shape;
+}
+
+template<typename Dtype>
+vector<int> DataTransformer<Dtype>::InferBlobShape(
+    const vector<cv::Mat> & mat_vector) {
+  const int num = mat_vector.size();
+  CHECK_GT(num, 0) << "There is no cv_img to in the vector";
+  // Use first cv_img in the vector to InferBlobShape.
+  vector<int> shape = InferBlobShape(mat_vector[0]);
+  // Adjust num to the size of the vector.
+  shape[0] = num;
+  return shape;
+}
+
 template <typename Dtype>
 void DataTransformer<Dtype>::InitRand() {
   const bool needs_rand = param_.mirror() ||