From: manuele Date: Mon, 10 Nov 2014 18:24:02 +0000 (+0100) Subject: MemoryDataLayer now accepts dynamic batch_size X-Git-Tag: submit/tizen/20180823.020014~572^2~30^2~4 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=cedefd77625db978623fcf00ecc32d899eed954b;p=platform%2Fupstream%2Fcaffeonacl.git MemoryDataLayer now accepts dynamic batch_size --- diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp index b5449b3..117b1b5 100644 --- a/include/caffe/data_layers.hpp +++ b/include/caffe/data_layers.hpp @@ -284,6 +284,7 @@ class MemoryDataLayer : public BaseDataLayer { Blob added_data_; Blob added_label_; bool has_new_data_; + bool needs_reshape_; }; /** diff --git a/src/caffe/layers/memory_data_layer.cpp b/src/caffe/layers/memory_data_layer.cpp index 67990f3..f36eb6f 100644 --- a/src/caffe/layers/memory_data_layer.cpp +++ b/src/caffe/layers/memory_data_layer.cpp @@ -32,7 +32,15 @@ void MemoryDataLayer::AddDatumVector(const vector& datum_vector) { CHECK(!has_new_data_) << "Can't add Datum when earlier ones haven't been consumed" << " by the upper layers"; + size_t num = datum_vector.size(); + if (batch_size_ != num) { + needs_reshape_ = true; + batch_size_ = num; + added_data_.Reshape(batch_size_, channels_, height_, width_); + added_label_.Reshape(batch_size_, 1, 1, 1); + } + CHECK_GT(num, 0) << "There is no datum to add"; CHECK_LE(num, batch_size_) << "The number of added datum must be no greater than the batch size"; @@ -53,10 +61,22 @@ void MemoryDataLayer::AddDatumVector(const vector& datum_vector) { template void MemoryDataLayer::AddMatVector(const vector& mat_vector, const vector& labels) { + CHECK(!has_new_data_) << "Can't add Mat when earlier ones haven't been consumed" << " by the upper layers"; + + CHECK_EQ(mat_vector.size(), labels.size()) << + "vector of labels and vector of mats need to be of the same size"; + size_t num = mat_vector.size(); + if (batch_size_ != num) { + needs_reshape_ = true; + batch_size_ = num; + added_data_.Reshape(batch_size_, channels_, height_, width_); + added_label_.Reshape(batch_size_, 1, 1, 1); + } + CHECK_GT(num, 0) << "There is no mat to add"; CHECK_LE(num, batch_size_) << "The number of added mat must be no greater than the batch size"; @@ -89,10 +109,15 @@ template void MemoryDataLayer::Forward_cpu(const vector*>& bottom, const vector*>& top) { CHECK(data_) << "MemoryDataLayer needs to be initalized by calling Reset"; + if (needs_reshape_) { + top[0]->Reshape(batch_size_, channels_, height_, width_); + top[1]->Reshape(batch_size_, 1, 1, 1); + } top[0]->set_cpu_data(data_ + pos_ * size_); top[1]->set_cpu_data(labels_ + pos_); pos_ = (pos_ + batch_size_) % n_; has_new_data_ = false; + needs_reshape_ = false; } INSTANTIATE_CLASS(MemoryDataLayer);