add size accessors to MemoryDataLayer
authorJonathan L Long <jonlong@cs.berkeley.edu>
Fri, 4 Apr 2014 22:20:02 +0000 (15:20 -0700)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Fri, 2 May 2014 20:25:51 +0000 (13:25 -0700)
This will facilitate input size checking for pycaffe (and potentially
others).

include/caffe/vision_layers.hpp
src/caffe/layers/memory_data_layer.cpp

index 8fb8fdb..4765398 100644 (file)
@@ -634,6 +634,10 @@ class MemoryDataLayer : public Layer<Dtype> {
   // Reset should accept const pointers, but can't, because the memory
   //  will be given to Blob, which is mutable
   void Reset(Dtype* data, Dtype* label, int n);
+  int datum_channels() { return datum_channels_; }
+  int datum_height() { return datum_height_; }
+  int datum_width() { return datum_width_; }
+  int batch_size() { return batch_size_; }
 
  protected:
   virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
@@ -645,6 +649,9 @@ class MemoryDataLayer : public Layer<Dtype> {
 
   Dtype* data_;
   Dtype* labels_;
+  int datum_channels_;
+  int datum_height_;
+  int datum_width_;
   int datum_size_;
   int batch_size_;
   int n_;
index 7a1f3ff..60bce27 100644 (file)
@@ -13,13 +13,13 @@ void MemoryDataLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   CHECK_EQ(bottom.size(), 0) << "Memory Data Layer takes no blobs as input.";
   CHECK_EQ(top->size(), 2) << "Memory Data Layer takes two blobs as output.";
   batch_size_ = this->layer_param_.memory_data_param().batch_size();
-  int channels = this->layer_param_.memory_data_param().channels();
-  int height = this->layer_param_.memory_data_param().height();
-  int width = this->layer_param_.memory_data_param().width();
-  datum_size_ = channels * height * width;
+  datum_channels_ = this->layer_param_.memory_data_param().channels();
+  datum_height_ = this->layer_param_.memory_data_param().height();
+  datum_width_ = this->layer_param_.memory_data_param().width();
+  datum_size_ = datum_channels_ * datum_height_ * datum_width_;
   CHECK_GT(batch_size_ * datum_size_, 0) << "batch_size, channels, height,"
     " and width must be specified and positive in memory_data_param";
-  (*top)[0]->Reshape(batch_size_, channels, height, width);
+  (*top)[0]->Reshape(batch_size_, datum_channels_, datum_height_, datum_width_);
   (*top)[1]->Reshape(batch_size_, 1, 1, 1);
   data_ = NULL;
   labels_ = NULL;