From d1ccbe325715646e21631829975412fc9b32a344 Mon Sep 17 00:00:00 2001 From: Sergio Date: Thu, 25 Sep 2014 19:58:52 -0700 Subject: [PATCH] Refactor common code Make lint happy Conflicts: src/caffe/data_transformer.cpp --- include/caffe/data_transformer.hpp | 3 +- src/caffe/data_transformer.cpp | 147 +++++++++-------------------------- src/caffe/layers/base_data_layer.cpp | 3 +- 3 files changed, 41 insertions(+), 112 deletions(-) diff --git a/include/caffe/data_transformer.hpp b/include/caffe/data_transformer.hpp index f2cbbd0..4a2afda 100644 --- a/include/caffe/data_transformer.hpp +++ b/include/caffe/data_transformer.hpp @@ -44,8 +44,6 @@ class DataTransformer { * within the blob's data. */ - void Transform(const Datum& datum, Dtype* transformed_data); - void Transform(const Datum& datum, Blob* transformed_blob); void Transform(const vector & datum_vector, @@ -66,6 +64,7 @@ class DataTransformer { */ virtual int Rand(int n); + void Transform(const Datum& datum, Dtype* transformed_data); // Tranformation parameters TransformationParameter param_; diff --git a/src/caffe/data_transformer.cpp b/src/caffe/data_transformer.cpp index af1db34..cdb6e10 100644 --- a/src/caffe/data_transformer.cpp +++ b/src/caffe/data_transformer.cpp @@ -31,7 +31,6 @@ void DataTransformer::Transform(const Datum& datum, const int datum_channels = datum.channels(); const int datum_height = datum.height(); const int datum_width = datum.width(); - const int size = datum.channels() * datum.height() * datum.width(); const int crop_size = param_.crop_size(); const Dtype scale = param_.scale(); @@ -51,10 +50,14 @@ void DataTransformer::Transform(const Datum& datum, mean = data_mean_.mutable_cpu_data(); } + int height = datum_height; + int width = datum_width; + int h_off = 0; int w_off = 0; - Dtype datum_element; if (crop_size) { + height = crop_size; + width = crop_size; // We only do random crop when we do training. if (phase_ == Caffe::TRAIN) { h_off = Rand(height - crop_size + 1); @@ -63,49 +66,31 @@ void DataTransformer::Transform(const Datum& datum, h_off = (datum_height - crop_size) / 2; w_off = (datum_width - crop_size) / 2; } + } - int top_index, data_index; - for (int c = 0; c < datum_channels; ++c) { - int top_index_c = c * crop_size; - int data_index_c = c * datum_height + h_off; - for (int h = 0; h < crop_size; ++h) { - int top_index_h = (top_index_c + h) * crop_size; - int data_index_h = (data_index_c + h) * datum_width + w_off; - for (int w = 0; w < crop_size; ++w) { - data_index = data_index_h + w; - if (do_mirror) { - top_index = top_index_h + (crop_size - 1 - w); - } else { - top_index = top_index_h + w; - } - if (has_unit8) { - datum_element = - static_cast(static_cast(data[data_index])); - } else { - datum_element = datum.float_data(data_index); - } - if (has_mean_file) { - transformed_data[top_index] = - (datum_element - mean[data_index]) * scale; - } else { - transformed_data[top_index] = datum_element * scale; - } + Dtype datum_element; + int top_index, data_index; + for (int c = 0; c < datum_channels; ++c) { + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + data_index = (c * datum_height + h_off + h) * datum_width + w_off + w; + if (do_mirror) { + top_index = (c * height + h) * width + (width - 1 - w); + } else { + top_index = (c * height + h) * width + w; } - } - } - } else { - for (int j = 0; j < size; ++j) { - if (has_unit8) { + if (has_unit8) { datum_element = - static_cast(static_cast(data[j])); - } else { - datum_element = datum.float_data(j); - } - if (has_mean_file) { - transformed_data[j] = - (datum_element - mean[j]) * scale; - } else { - transformed_data[j] = datum_element * scale; + static_cast(static_cast(data[data_index])); + } else { + datum_element = datum.float_data(data_index); + } + if (has_mean_file) { + transformed_data[top_index] = + (datum_element - mean[data_index]) * scale; + } else { + transformed_data[top_index] = datum_element * scale; + } } } } @@ -114,7 +99,6 @@ void DataTransformer::Transform(const Datum& datum, template void DataTransformer::Transform(const Datum& datum, Blob* transformed_blob) { - const string& data = datum.data(); const int datum_channels = datum.channels(); const int datum_height = datum.height(); const int datum_width = datum.width(); @@ -122,78 +106,25 @@ void DataTransformer::Transform(const Datum& datum, const int channels = transformed_blob->channels(); const int height = transformed_blob->height(); const int width = transformed_blob->width(); + const int num = transformed_blob->num(); - CHECK_EQ(datum_channels, channels); - CHECK_GE(datum_height, height); - CHECK_GE(datum_width, width); - - CHECK_EQ(transformed_blob->num(), 1) << - "transformed_blob should have num() = 1"; + CHECK_EQ(channels, datum_channels); + CHECK_LE(height, datum_height); + CHECK_LE(width, datum_width); + CHECK_GE(num, 1); const int crop_size = param_.crop_size(); - const Dtype scale = param_.scale(); - const bool do_mirror = param_.mirror() && Rand() % 2; - const bool has_mean_file = param_.has_mean_file(); - const bool has_unit8 = data.size() > 0; - int h_off = 0; - int w_off = 0; if (crop_size) { CHECK_EQ(crop_size, height); CHECK_EQ(crop_size, width); - // We only do random crop when we do training. - if (phase_ == Caffe::TRAIN) { - h_off = Rand() % (datum_height - crop_size); - w_off = Rand() % (datum_width - crop_size); - } else { - h_off = (datum_height - crop_size) / 2; - w_off = (datum_width - crop_size) / 2; - } } else { CHECK_EQ(datum_height, height); CHECK_EQ(datum_width, width); } Dtype* transformed_data = transformed_blob->mutable_cpu_data(); - - Dtype* mean = NULL; - if (has_mean_file) { - CHECK_EQ(datum_channels, data_mean_.channels()); - CHECK_EQ(datum_height, data_mean_.height()); - CHECK_EQ(datum_width, data_mean_.width()); - mean = data_mean_.mutable_cpu_data(); - } - - Dtype datum_element; - int top_index, data_index; - for (int c = 0; c < channels; ++c) { - int top_index_c = c * height; - int data_index_c = c * datum_height + h_off; - for (int h = 0; h < height; ++h) { - int top_index_h = (top_index_c + h) * width; - int data_index_h = (data_index_c + h) * datum_width + w_off; - for (int w = 0; w < width; ++w) { - data_index = data_index_h + w; - if (do_mirror) { - top_index = top_index_h + (width - 1 - w); - } else { - top_index = top_index_h + w; - } - if (has_unit8) { - datum_element = - static_cast(static_cast(data[data_index])); - } else { - datum_element = datum.float_data(data_index); - } - if (has_mean_file) { - transformed_data[top_index] = - (datum_element - mean[data_index]) * scale; - } else { - transformed_data[top_index] = datum_element * scale; - } - } - } - } + Transform(datum, transformed_data); } template @@ -226,13 +157,12 @@ void DataTransformer::Transform(const cv::Mat& cv_img, const int channels = transformed_blob->channels(); const int height = transformed_blob->height(); const int width = transformed_blob->width(); + const int num = transformed_blob->num(); - CHECK_EQ(img_channels, channels); - CHECK_GE(img_height, height); - CHECK_GE(img_width, width); - - CHECK_EQ(transformed_blob->num(), 1) << - "transformed_blob should have num() = 1"; + CHECK_EQ(channels, img_channels); + CHECK_LE(height, img_height); + CHECK_LE(width, img_width); + CHECK_GE(num, 1); const int crop_size = param_.crop_size(); const Dtype scale = param_.scale(); @@ -293,7 +223,6 @@ void DataTransformer::Transform(const cv::Mat& cv_img, } } - template void DataTransformer::Transform(Blob* input_blob, Blob* transformed_blob) { diff --git a/src/caffe/layers/base_data_layer.cpp b/src/caffe/layers/base_data_layer.cpp index d7d4752..5ce52f0 100644 --- a/src/caffe/layers/base_data_layer.cpp +++ b/src/caffe/layers/base_data_layer.cpp @@ -60,7 +60,8 @@ void BasePrefetchingDataLayer::Forward_cpu( const vector*>& bottom, const vector*>& top) { // First, join the thread if (this->timer_forward_.has_run_at_least_once()) { - DLOG(INFO) << "Proccessing: " << this->timer_forward_.MilliSeconds() << "ms."; + DLOG(INFO) << "Proccessing: " << + this->timer_forward_.MilliSeconds() << "ms."; } JoinPrefetchThread(); DLOG(INFO) << "Thread joined"; -- 2.7.4