From a4460972333507aec0c42db7f7cbaebcf16edcdf Mon Sep 17 00:00:00 2001 From: TANGUY Arnaud Date: Thu, 21 Aug 2014 15:50:39 +0200 Subject: [PATCH] Refactor ImageDataLayer to use DataTransformer --- include/caffe/data_layers.hpp | 6 ++- src/caffe/data_transformer.cpp | 1 - src/caffe/layers/image_data_layer.cpp | 97 ++++------------------------------- src/caffe/proto/caffe.proto | 25 ++++----- src/caffe/test/test_upgrade_proto.cpp | 10 ++-- src/caffe/util/upgrade_proto.cpp | 16 +++--- 6 files changed, 38 insertions(+), 117 deletions(-) diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp index 06508ea..bc29429 100644 --- a/include/caffe/data_layers.hpp +++ b/include/caffe/data_layers.hpp @@ -178,7 +178,8 @@ template class ImageDataLayer : public Layer, public InternalThread { public: explicit ImageDataLayer(const LayerParameter& param) - : Layer(param) {} + : Layer(param), + data_transformer_(param.data_param().transform_param()) {} virtual ~ImageDataLayer(); virtual void LayerSetUp(const vector*>& bottom, vector*>* top); @@ -203,10 +204,11 @@ class ImageDataLayer : public Layer, public InternalThread { virtual void CreatePrefetchThread(); virtual void JoinPrefetchThread(); - virtual unsigned int PrefetchRand(); virtual void InternalThreadEntry(); + DataTransformer data_transformer_; shared_ptr prefetch_rng_; + vector > lines_; int lines_id_; int datum_channels_; diff --git a/src/caffe/data_transformer.cpp b/src/caffe/data_transformer.cpp index 609c06d..bdde028 100644 --- a/src/caffe/data_transformer.cpp +++ b/src/caffe/data_transformer.cpp @@ -11,7 +11,6 @@ void DataTransformer::Transform(const int batch_item_id, const Datum& datum, const Dtype* mean, Dtype* transformed_data) { - const string& data = datum.data(); const int channels = datum.channels(); const int height = datum.height(); diff --git a/src/caffe/layers/image_data_layer.cpp b/src/caffe/layers/image_data_layer.cpp index c72bf9c..85ee988 100644 --- a/src/caffe/layers/image_data_layer.cpp +++ b/src/caffe/layers/image_data_layer.cpp @@ -20,22 +20,11 @@ void ImageDataLayer::InternalThreadEntry() { Dtype* top_data = prefetch_data_.mutable_cpu_data(); Dtype* top_label = prefetch_label_.mutable_cpu_data(); ImageDataParameter image_data_param = this->layer_param_.image_data_param(); - const Dtype scale = image_data_param.scale(); const int batch_size = image_data_param.batch_size(); - const int crop_size = image_data_param.crop_size(); - const bool mirror = image_data_param.mirror(); const int new_height = image_data_param.new_height(); const int new_width = image_data_param.new_width(); - if (mirror && crop_size == 0) { - LOG(FATAL) << "Current implementation requires mirror and crop_size to be " - << "set at the same time."; - } // datum scales - const int channels = datum_channels_; - const int height = datum_height_; - const int width = datum_width_; - const int size = datum_size_; const int lines_size = lines_.size(); const Dtype* mean = data_mean_.cpu_data(); for (int item_id = 0; item_id < batch_size; ++item_id) { @@ -46,62 +35,9 @@ void ImageDataLayer::InternalThreadEntry() { new_height, new_width, &datum)) { continue; } - const string& data = datum.data(); - if (crop_size) { - CHECK(data.size()) << "Image cropping only support uint8 data"; - int h_off, w_off; - // We only do random crop when we do training. - if (phase_ == Caffe::TRAIN) { - h_off = PrefetchRand() % (height - crop_size); - w_off = PrefetchRand() % (width - crop_size); - } else { - h_off = (height - crop_size) / 2; - w_off = (width - crop_size) / 2; - } - if (mirror && PrefetchRand() % 2) { - // Copy mirrored version - for (int c = 0; c < channels; ++c) { - for (int h = 0; h < crop_size; ++h) { - for (int w = 0; w < crop_size; ++w) { - int top_index = ((item_id * channels + c) * crop_size + h) - * crop_size + (crop_size - 1 - w); - int data_index = (c * height + h + h_off) * width + w + w_off; - Dtype datum_element = - static_cast(static_cast(data[data_index])); - top_data[top_index] = (datum_element - mean[data_index]) * scale; - } - } - } - } else { - // Normal copy - for (int c = 0; c < channels; ++c) { - for (int h = 0; h < crop_size; ++h) { - for (int w = 0; w < crop_size; ++w) { - int top_index = ((item_id * channels + c) * crop_size + h) - * crop_size + w; - int data_index = (c * height + h + h_off) * width + w + w_off; - Dtype datum_element = - static_cast(static_cast(data[data_index])); - top_data[top_index] = (datum_element - mean[data_index]) * scale; - } - } - } - } - } else { - // Just copy the whole data - if (data.size()) { - for (int j = 0; j < size; ++j) { - Dtype datum_element = - static_cast(static_cast(data[j])); - top_data[item_id * size + j] = (datum_element - mean[j]) * scale; - } - } else { - for (int j = 0; j < size; ++j) { - top_data[item_id * size + j] = - (datum.float_data(j) - mean[j]) * scale; - } - } - } + + // Apply transformations (mirror, crop...) to the data + data_transformer_.Transform(item_id, datum, mean, top_data); top_label[item_id] = datum.label(); // go to the next iter @@ -163,9 +99,11 @@ void ImageDataLayer::LayerSetUp(const vector*>& bottom, CHECK(ReadImageToDatum(lines_[lines_id_].first, lines_[lines_id_].second, new_height, new_width, &datum)); // image - const int crop_size = this->layer_param_.image_data_param().crop_size(); + const int crop_size = this->layer_param_.image_data_param() + .transform_param().crop_size(); const int batch_size = this->layer_param_.image_data_param().batch_size(); - const string& mean_file = this->layer_param_.image_data_param().mean_file(); + const string& mean_file = this->layer_param_.image_data_param() + .transform_param().mean_file(); if (crop_size > 0) { (*top)[0]->Reshape(batch_size, datum.channels(), crop_size, crop_size); prefetch_data_.Reshape(batch_size, datum.channels(), crop_size, crop_size); @@ -189,7 +127,7 @@ void ImageDataLayer::LayerSetUp(const vector*>& bottom, CHECK_GT(datum_height_, crop_size); CHECK_GT(datum_width_, crop_size); // check if we want to have mean - if (this->layer_param_.image_data_param().has_mean_file()) { + if (this->layer_param_.image_data_param().transform_param().has_mean_file()) { BlobProto blob_proto; LOG(INFO) << "Loading mean file from" << mean_file; ReadProtoFromBinaryFile(mean_file.c_str(), &blob_proto); @@ -217,15 +155,9 @@ void ImageDataLayer::LayerSetUp(const vector*>& bottom, template void ImageDataLayer::CreatePrefetchThread() { phase_ = Caffe::phase(); - const bool prefetch_needs_rand = - this->layer_param_.image_data_param().shuffle() || - this->layer_param_.image_data_param().crop_size(); - if (prefetch_needs_rand) { - const unsigned int prefetch_rng_seed = caffe_rng_rand(); - prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed)); - } else { - prefetch_rng_.reset(); - } + + data_transformer_.InitRand(); + // Create the thread. CHECK(!StartInternalThread()) << "Pthread execution failed"; } @@ -244,13 +176,6 @@ void ImageDataLayer::JoinPrefetchThread() { } template -unsigned int ImageDataLayer::PrefetchRand() { - caffe::rng_t* prefetch_rng = - static_cast(prefetch_rng_->generator()); - return (*prefetch_rng)(); -} - -template void ImageDataLayer::Forward_cpu(const vector*>& bottom, vector*>* top) { // First, join the thread diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index b7c6bca..75f4b28 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -345,7 +345,7 @@ message ConvolutionParameter { } // Message that stores parameters used to apply transformation -// to the data layer's data +// to the data layer's data message TransformationParameter { // For data pre-processing, we can do simple scaling and subtracting the // data mean, if provided. Note that the mean subtraction is always carried @@ -444,27 +444,20 @@ message HingeLossParameter { message ImageDataParameter { // Specify the data source. optional string source = 1; - // For data pre-processing, we can do simple scaling and subtracting the - // data mean, if provided. Note that the mean subtraction is always carried - // out before scaling. - optional float scale = 2 [default = 1]; - optional string mean_file = 3; // Specify the batch size. - optional uint32 batch_size = 4; - // Specify if we would like to randomly crop an image. - optional uint32 crop_size = 5 [default = 0]; - // Specify if we want to randomly mirror data. - optional bool mirror = 6 [default = false]; + optional uint32 batch_size = 2; // The rand_skip variable is for the data layer to skip a few data points // to avoid all asynchronous sgd clients to start at the same point. The skip // point would be set as rand_skip * rand(0,1). Note that rand_skip should not // be larger than the number of keys in the leveldb. - optional uint32 rand_skip = 7 [default = 0]; + optional uint32 rand_skip = 3 [default = 0]; // Whether or not ImageLayer should shuffle the list of files at every epoch. - optional bool shuffle = 8 [default = false]; + optional bool shuffle = 4 [default = false]; // It will also resize images if new_height or new_width are not zero. - optional uint32 new_height = 9 [default = 0]; - optional uint32 new_width = 10 [default = 0]; + optional uint32 new_height = 5 [default = 0]; + optional uint32 new_width = 6 [default = 0]; + // Parameters for data pre-processing. + optional TransformationParameter transform_param = 7; } // Message that stores parameters InfogainLossLayer @@ -505,7 +498,7 @@ message MemoryDataParameter { message MVNParameter { // This parameter can be set to false to normalize mean only optional bool normalize_variance = 1 [default = true]; - + // This parameter can be set to true to perform DNN-like MVN optional bool across_channels = 2 [default = false]; } diff --git a/src/caffe/test/test_upgrade_proto.cpp b/src/caffe/test/test_upgrade_proto.cpp index 3e9ab21..9d0cd58 100644 --- a/src/caffe/test/test_upgrade_proto.cpp +++ b/src/caffe/test/test_upgrade_proto.cpp @@ -1542,11 +1542,13 @@ TEST_F(V0UpgradeTest, TestAllParams) { " type: IMAGE_DATA " " image_data_param { " " source: '/home/jiayq/Data/ILSVRC12/train-images' " - " mean_file: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " " batch_size: 256 " - " crop_size: 227 " - " mirror: true " - " scale: 0.25 " + " transform_param {" + " mean_file: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' " + " crop_size: 227 " + " mirror: true " + " scale: 0.25 " + " } " " rand_skip: 73 " " shuffle: true " " new_height: 40 " diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp index 48eb579..21fc038 100644 --- a/src/caffe/util/upgrade_proto.cpp +++ b/src/caffe/util/upgrade_proto.cpp @@ -310,8 +310,8 @@ bool UpgradeLayerParameter(const LayerParameter& v0_layer_connection, layer_param->mutable_data_param()->mutable_transform_param()-> set_scale(v0_layer_param.scale()); } else if (type == "images") { - layer_param->mutable_image_data_param()->set_scale( - v0_layer_param.scale()); + layer_param->mutable_image_data_param()->mutable_transform_param()-> + set_scale(v0_layer_param.scale()); } else { LOG(ERROR) << "Unknown parameter scale for layer type " << type; is_fully_compatible = false; @@ -322,8 +322,8 @@ bool UpgradeLayerParameter(const LayerParameter& v0_layer_connection, layer_param->mutable_data_param()->mutable_transform_param()-> set_mean_file(v0_layer_param.meanfile()); } else if (type == "images") { - layer_param->mutable_image_data_param()->set_mean_file( - v0_layer_param.meanfile()); + layer_param->mutable_image_data_param()->mutable_transform_param()-> + set_mean_file(v0_layer_param.meanfile()); } else if (type == "window_data") { layer_param->mutable_window_data_param()->set_mean_file( v0_layer_param.meanfile()); @@ -355,8 +355,8 @@ bool UpgradeLayerParameter(const LayerParameter& v0_layer_connection, layer_param->mutable_data_param()->mutable_transform_param()-> set_crop_size(v0_layer_param.cropsize()); } else if (type == "images") { - layer_param->mutable_image_data_param()->set_crop_size( - v0_layer_param.cropsize()); + layer_param->mutable_image_data_param()->mutable_transform_param()-> + set_crop_size(v0_layer_param.cropsize()); } else if (type == "window_data") { layer_param->mutable_window_data_param()->set_crop_size( v0_layer_param.cropsize()); @@ -370,8 +370,8 @@ bool UpgradeLayerParameter(const LayerParameter& v0_layer_connection, layer_param->mutable_data_param()->mutable_transform_param()-> set_mirror(v0_layer_param.mirror()); } else if (type == "images") { - layer_param->mutable_image_data_param()->set_mirror( - v0_layer_param.mirror()); + layer_param->mutable_image_data_param()->mutable_transform_param()-> + set_mirror(v0_layer_param.mirror()); } else if (type == "window_data") { layer_param->mutable_window_data_param()->set_mirror( v0_layer_param.mirror()); -- 2.7.4