From 05ff293ca476feaeab584906bb3c30fc20575371 Mon Sep 17 00:00:00 2001 From: Yangqing Jia Date: Tue, 1 Oct 2013 16:50:44 -0700 Subject: [PATCH] datalayer random cropping, not tested --- src/caffe/layers/data_layer.cpp | 57 +++++++++++++++++++++++++++++------------ src/caffe/proto/caffe.proto | 4 ++- src/caffe/vision_layers.hpp | 3 +++ 3 files changed, 47 insertions(+), 17 deletions(-) diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index d42a810..fdf9807 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -34,13 +34,24 @@ void DataLayer::SetUp(const vector*>& bottom, Datum datum; datum.ParseFromString(iter_->value().ToString()); // image - (*top)[0]->Reshape( - this->layer_param_.batchsize(), datum.channels(), datum.height(), - datum.width()); + int cropsize = this->layer_param_.cropsize(); + if (cropsize > 0) { + (*top)[0]->Reshape( + this->layer_param_.batchsize(), datum.channels(), cropsize, cropsize); + } else { + (*top)[0]->Reshape( + this->layer_param_.batchsize(), datum.channels(), datum.height(), + datum.width()); + } // label (*top)[1]->Reshape(this->layer_param_.batchsize(), 1, 1, 1); // datum size - datum_size_ = datum.channels() * datum.height() * datum.width(); + datum_channels_ = datum.channels(); + datum_height_ = datum.height(); + datum_width_ = datum.width(); + datum_size_ = datum.channels() * datum.height() * datum.width(); + CHECK_GT(datum_height_, cropsize); + CHECK_GT(datum_width_, cropsize); } template @@ -51,24 +62,38 @@ void DataLayer::Forward_cpu(const vector*>& bottom, Dtype* top_label = (*top)[1]->mutable_cpu_data(); const Dtype scale = this->layer_param_.scale(); const Dtype subtraction = this->layer_param_.subtraction(); - // LOG(ERROR) << "Debug code on"; - // if (true) { - // iter_->SeekToFirst(); - // } + int cropsize = this->layer_param_.cropsize(); for (int i = 0; i < this->layer_param_.batchsize(); ++i) { // get a blob datum.ParseFromString(iter_->value().ToString()); const string& data = datum.data(); - // we will prefer to use data() first, and then try float_data() - if (data.size()) { - for (int j = 0; j < datum_size_; ++j) { - top_data[i * datum_size_ + j] = - (static_cast((uint8_t)data[j]) * scale) - subtraction; + if (cropsize) { + CHECK(data.size()) << "Image cropping only support uint8 data"; + int h_offset = rand() % (datum_height_ - cropsize); + int w_offset = rand() % (datum_width_ - cropsize); + for (int c = 0; c < datum_channels_; ++i) { + for (int h = 0; h < cropsize; ++h) { + for (int w = 0; w < cropsize; ++w) { + top_data[((i * datum_channels_ + c) * cropsize + h) * cropsize + w] = + static_cast((uint8_t)data[ + (c * datum_height_ + h + h_offset) * datum_width_ + + w + w_offset] + ) * scale - subtraction; + } + } } } else { - for (int j = 0; j < datum_size_; ++j) { - top_data[i * datum_size_ + j] = - (datum.float_data(j) * scale) - subtraction; + // we will prefer to use data() first, and then try float_data() + if (data.size()) { + for (int j = 0; j < datum_size_; ++j) { + top_data[i * datum_size_ + j] = + (static_cast((uint8_t)data[j]) * scale) - subtraction; + } + } else { + for (int j = 0; j < datum_size_; ++j) { + top_data[i * datum_size_ + j] = + (datum.float_data(j) * scale) - subtraction; + } } } top_label[i] = datum.label(); diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index c3632c1..eef6058 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -63,8 +63,10 @@ message LayerParameter { // For data pre-processing, we can do simple scaling and constant subtraction optional float scale = 17 [ default = 1 ]; optional float subtraction = 18 [ default = 0 ]; - // For datay layers, specify the batch size. + // For data layers, specify the batch size. optional uint32 batchsize = 19; + // For data layers, specify if we would like to randomly crop an image. + optional uint32 cropsize = 20 [default = 0]; // The blobs containing the numeric parameters of the layer repeated BlobProto blobs = 50; diff --git a/src/caffe/vision_layers.hpp b/src/caffe/vision_layers.hpp index 74c5978..23678bf 100644 --- a/src/caffe/vision_layers.hpp +++ b/src/caffe/vision_layers.hpp @@ -252,6 +252,9 @@ class DataLayer : public Layer { shared_ptr db_; shared_ptr iter_; + int datum_channels_; + int datum_height_; + int datum_width_; int datum_size_; }; -- 2.7.4