Refactor common code
authorSergio <sguada@gmail.com>
Fri, 26 Sep 2014 02:58:52 +0000 (19:58 -0700)
committerSergio <sguada@gmail.com>
Fri, 3 Oct 2014 18:45:59 +0000 (11:45 -0700)
Make lint happy

Conflicts:
src/caffe/data_transformer.cpp

include/caffe/data_transformer.hpp
src/caffe/data_transformer.cpp
src/caffe/layers/base_data_layer.cpp

index f2cbbd0..4a2afda 100644 (file)
@@ -44,8 +44,6 @@ class DataTransformer {
    *    within the blob's data.
    */
 
-  void Transform(const Datum& datum, Dtype* transformed_data);
-
   void Transform(const Datum& datum, Blob<Dtype>* transformed_blob);
 
   void Transform(const vector<Datum> & datum_vector,
@@ -66,6 +64,7 @@ class DataTransformer {
    */
   virtual int Rand(int n);
 
+  void Transform(const Datum& datum, Dtype* transformed_data);
   // Tranformation parameters
   TransformationParameter param_;
 
index af1db34..cdb6e10 100644 (file)
@@ -31,7 +31,6 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,
   const int datum_channels = datum.channels();
   const int datum_height = datum.height();
   const int datum_width = datum.width();
-  const int size = datum.channels() * datum.height() * datum.width();
 
   const int crop_size = param_.crop_size();
   const Dtype scale = param_.scale();
@@ -51,10 +50,14 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,
     mean = data_mean_.mutable_cpu_data();
   }
 
+  int height = datum_height;
+  int width = datum_width;
+
   int h_off = 0;
   int w_off = 0;
-  Dtype datum_element;
   if (crop_size) {
+    height = crop_size;
+    width = crop_size;
     // We only do random crop when we do training.
     if (phase_ == Caffe::TRAIN) {
       h_off = Rand(height - crop_size + 1);
@@ -63,49 +66,31 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,
       h_off = (datum_height - crop_size) / 2;
       w_off = (datum_width - crop_size) / 2;
     }
+  }
 
-    int top_index, data_index;
-    for (int c = 0; c < datum_channels; ++c) {
-      int top_index_c = c * crop_size;
-      int data_index_c = c * datum_height + h_off;
-      for (int h = 0; h < crop_size; ++h) {
-        int top_index_h = (top_index_c + h) * crop_size;
-        int data_index_h = (data_index_c + h) * datum_width + w_off;
-        for (int w = 0; w < crop_size; ++w) {
-          data_index = data_index_h + w;
-          if (do_mirror) {
-            top_index = top_index_h + (crop_size - 1 - w);
-          } else {
-            top_index = top_index_h + w;
-          }
-          if (has_unit8) {
-            datum_element =
-              static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
-          } else {
-            datum_element = datum.float_data(data_index);
-          }
-          if (has_mean_file) {
-            transformed_data[top_index] =
-              (datum_element - mean[data_index]) * scale;
-          } else {
-            transformed_data[top_index] = datum_element * scale;
-          }
+  Dtype datum_element;
+  int top_index, data_index;
+  for (int c = 0; c < datum_channels; ++c) {
+    for (int h = 0; h < height; ++h) {
+      for (int w = 0; w < width; ++w) {
+        data_index = (c * datum_height + h_off + h) * datum_width + w_off + w;
+        if (do_mirror) {
+          top_index = (c * height + h) * width + (width - 1 - w);
+        } else {
+          top_index = (c * height + h) * width + w;
         }
-      }
-    }
-  } else {
-    for (int j = 0; j < size; ++j) {
-      if (has_unit8) {
+        if (has_unit8) {
           datum_element =
-            static_cast<Dtype>(static_cast<uint8_t>(data[j]));
-      } else {
-        datum_element = datum.float_data(j);
-      }
-      if (has_mean_file) {
-        transformed_data[j] =
-          (datum_element - mean[j]) * scale;
-      } else {
-        transformed_data[j] = datum_element * scale;
+            static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
+        } else {
+          datum_element = datum.float_data(data_index);
+        }
+        if (has_mean_file) {
+          transformed_data[top_index] =
+            (datum_element - mean[data_index]) * scale;
+        } else {
+          transformed_data[top_index] = datum_element * scale;
+        }
       }
     }
   }
@@ -114,7 +99,6 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,
 template<typename Dtype>
 void DataTransformer<Dtype>::Transform(const Datum& datum,
                                        Blob<Dtype>* transformed_blob) {
-  const string& data = datum.data();
   const int datum_channels = datum.channels();
   const int datum_height = datum.height();
   const int datum_width = datum.width();
@@ -122,78 +106,25 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,
   const int channels = transformed_blob->channels();
   const int height = transformed_blob->height();
   const int width = transformed_blob->width();
+  const int num = transformed_blob->num();
 
-  CHECK_EQ(datum_channels, channels);
-  CHECK_GE(datum_height, height);
-  CHECK_GE(datum_width, width);
-
-  CHECK_EQ(transformed_blob->num(), 1) <<
-    "transformed_blob should have num() = 1";
+  CHECK_EQ(channels, datum_channels);
+  CHECK_LE(height, datum_height);
+  CHECK_LE(width, datum_width);
+  CHECK_GE(num, 1);
 
   const int crop_size = param_.crop_size();
-  const Dtype scale = param_.scale();
-  const bool do_mirror = param_.mirror() && Rand() % 2;
-  const bool has_mean_file = param_.has_mean_file();
-  const bool has_unit8 = data.size() > 0;
 
-  int h_off = 0;
-  int w_off = 0;
   if (crop_size) {
     CHECK_EQ(crop_size, height);
     CHECK_EQ(crop_size, width);
-    // We only do random crop when we do training.
-    if (phase_ == Caffe::TRAIN) {
-      h_off = Rand() % (datum_height - crop_size);
-      w_off = Rand() % (datum_width - crop_size);
-    } else {
-      h_off = (datum_height - crop_size) / 2;
-      w_off = (datum_width - crop_size) / 2;
-    }
   } else {
     CHECK_EQ(datum_height, height);
     CHECK_EQ(datum_width, width);
   }
 
   Dtype* transformed_data = transformed_blob->mutable_cpu_data();
-
-  Dtype* mean = NULL;
-  if (has_mean_file) {
-    CHECK_EQ(datum_channels, data_mean_.channels());
-    CHECK_EQ(datum_height, data_mean_.height());
-    CHECK_EQ(datum_width, data_mean_.width());
-    mean = data_mean_.mutable_cpu_data();
-  }
-
-  Dtype datum_element;
-  int top_index, data_index;
-  for (int c = 0; c < channels; ++c) {
-    int top_index_c = c * height;
-    int data_index_c = c * datum_height + h_off;
-    for (int h = 0; h < height; ++h) {
-      int top_index_h = (top_index_c + h) * width;
-      int data_index_h = (data_index_c + h) * datum_width + w_off;
-      for (int w = 0; w < width; ++w) {
-        data_index = data_index_h + w;
-        if (do_mirror) {
-          top_index = top_index_h + (width - 1 - w);
-        } else {
-          top_index = top_index_h + w;
-        }
-        if (has_unit8) {
-          datum_element =
-            static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
-        } else {
-          datum_element = datum.float_data(data_index);
-        }
-        if (has_mean_file) {
-          transformed_data[top_index] =
-            (datum_element - mean[data_index]) * scale;
-        } else {
-          transformed_data[top_index] = datum_element * scale;
-        }
-      }
-    }
-  }
+  Transform(datum, transformed_data);
 }
 
 template<typename Dtype>
@@ -226,13 +157,12 @@ void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
   const int channels = transformed_blob->channels();
   const int height = transformed_blob->height();
   const int width = transformed_blob->width();
+  const int num = transformed_blob->num();
 
-  CHECK_EQ(img_channels, channels);
-  CHECK_GE(img_height, height);
-  CHECK_GE(img_width, width);
-
-  CHECK_EQ(transformed_blob->num(), 1) <<
-    "transformed_blob should have num() = 1";
+  CHECK_EQ(channels, img_channels);
+  CHECK_LE(height, img_height);
+  CHECK_LE(width, img_width);
+  CHECK_GE(num, 1);
 
   const int crop_size = param_.crop_size();
   const Dtype scale = param_.scale();
@@ -293,7 +223,6 @@ void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
   }
 }
 
-
 template<typename Dtype>
 void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,
                                        Blob<Dtype>* transformed_blob) {
index d7d4752..5ce52f0 100644 (file)
@@ -60,7 +60,8 @@ void BasePrefetchingDataLayer<Dtype>::Forward_cpu(
     const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
   // First, join the thread
   if (this->timer_forward_.has_run_at_least_once()) {
-    DLOG(INFO) << "Proccessing: " << this->timer_forward_.MilliSeconds() << "ms.";
+    DLOG(INFO) << "Proccessing: " <<
+                this->timer_forward_.MilliSeconds() << "ms.";
   }
   JoinPrefetchThread();
   DLOG(INFO) << "Thread joined";