split off Reshape for vision layers
authorJonathan L Long <jonlong@cs.berkeley.edu>
Thu, 11 Sep 2014 06:31:33 +0000 (23:31 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Thu, 18 Sep 2014 20:17:43 +0000 (13:17 -0700)
Note that we are dropping some checks from LRN layer. However, these
checks are fairly redundant; something is very wrong if these layers
are producing top blobs that are different sizes than their inputs, and
tests are the right place to catch that. The thing that really should be
checked (that isn't) is that that local_size needs to be odd; this will
be added in a future commit.

include/caffe/vision_layers.hpp
src/caffe/layers/conv_layer.cpp
src/caffe/layers/cudnn_conv_layer.cpp
src/caffe/layers/cudnn_pooling_layer.cpp
src/caffe/layers/im2col_layer.cpp
src/caffe/layers/lrn_layer.cpp
src/caffe/layers/pooling_layer.cpp

index 9c8656f..1e7f3fc 100644 (file)
@@ -67,6 +67,8 @@ class ConvolutionLayer : public Layer<Dtype> {
       : Layer<Dtype>(param) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
 
   virtual inline LayerParameter_LayerType type() const {
     return LayerParameter_LayerType_CONVOLUTION;
@@ -131,6 +133,8 @@ class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
       : ConvolutionLayer<Dtype>(param) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
   virtual ~CuDNNConvolutionLayer();
 
  protected:
@@ -163,6 +167,8 @@ class Im2colLayer : public Layer<Dtype> {
       : Layer<Dtype>(param) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
 
   virtual inline LayerParameter_LayerType type() const {
     return LayerParameter_LayerType_IM2COL;
@@ -203,6 +209,8 @@ class LRNLayer : public Layer<Dtype> {
       : Layer<Dtype>(param) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
 
   virtual inline LayerParameter_LayerType type() const {
     return LayerParameter_LayerType_LRN;
@@ -278,6 +286,8 @@ class PoolingLayer : public Layer<Dtype> {
       : Layer<Dtype>(param) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
 
   virtual inline LayerParameter_LayerType type() const {
     return LayerParameter_LayerType_POOLING;
@@ -323,6 +333,8 @@ class CuDNNPoolingLayer : public PoolingLayer<Dtype> {
       : PoolingLayer<Dtype>(param) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
   virtual ~CuDNNPoolingLayer();
 
  protected:
index 769dfa6..58918fd 100644 (file)
@@ -47,47 +47,18 @@ void ConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
     stride_h_ = conv_param.stride_h();
     stride_w_ = conv_param.stride_w();
   }
-  num_ = bottom[0]->num();
+  // Configure output channels and groups.
   channels_ = bottom[0]->channels();
-  height_ = bottom[0]->height();
-  width_ = bottom[0]->width();
-  // TODO: generalize to handle inputs of different shapes.
-  for (int bottom_id = 1; bottom_id < bottom.size(); ++bottom_id) {
-    CHECK_EQ(num_, bottom[bottom_id]->num()) << "Inputs must have same num.";
-    CHECK_EQ(channels_, bottom[bottom_id]->channels())
-        << "Inputs must have same channels.";
-    CHECK_EQ(height_, bottom[bottom_id]->height())
-        << "Inputs must have same height.";
-    CHECK_EQ(width_, bottom[bottom_id]->width())
-        << "Inputs must have same width.";
-  }
-  // Configure output channels, groups, and spatial dimensions.
   num_output_ = this->layer_param_.convolution_param().num_output();
   CHECK_GT(num_output_, 0);
   group_ = this->layer_param_.convolution_param().group();
   CHECK_EQ(channels_ % group_, 0);
   CHECK_EQ(num_output_ % group_, 0)
       << "Number of output should be multiples of group.";
-  height_out_ =
-      (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1;
-  width_out_ = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1;
-  for (int top_id = 0; top_id < top->size(); ++top_id) {
-    (*top)[top_id]->Reshape(num_, num_output_, height_out_, width_out_);
-  }
-  // Prepare the matrix multiplication computation.
-  // Each input will be convolved as a single GEMM.
-  M_ = num_output_ / group_;
-  K_ = channels_ * kernel_h_ * kernel_w_ / group_;
-  N_ = height_out_ * width_out_;
-  // The im2col result buffer holds one image at a time to avoid
-  // overly large memory usage.
-  col_buffer_.Reshape(
-      1, channels_ * kernel_h_ * kernel_w_, height_out_, width_out_);
   // Handle the parameters: weights and biases.
   // - blobs_[0] holds the filter weights
   // - blobs_[1] holds the biases (optional)
   bias_term_ = this->layer_param_.convolution_param().bias_term();
-  // Check if we need to set up the weights.
   if (this->blobs_.size() > 0) {
     LOG(INFO) << "Skipping parameter initialization";
   } else {
@@ -112,16 +83,54 @@ void ConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       bias_filler->Fill(this->blobs_[1].get());
     }
   }
+  // Propagate gradients to the parameters (as directed by backward pass).
+  this->param_propagate_down_.resize(this->blobs_.size(), true);
+}
+
+template <typename Dtype>
+void ConvolutionLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  num_ = bottom[0]->num();
+  height_ = bottom[0]->height();
+  width_ = bottom[0]->width();
+  CHECK_EQ(bottom[0]->channels(), channels_) << "Input size incompatible with"
+    " convolution kernel.";
+  // TODO: generalize to handle inputs of different shapes.
+  for (int bottom_id = 1; bottom_id < bottom.size(); ++bottom_id) {
+    CHECK_EQ(num_, bottom[bottom_id]->num()) << "Inputs must have same num.";
+    CHECK_EQ(channels_, bottom[bottom_id]->channels())
+        << "Inputs must have same channels.";
+    CHECK_EQ(height_, bottom[bottom_id]->height())
+        << "Inputs must have same height.";
+    CHECK_EQ(width_, bottom[bottom_id]->width())
+        << "Inputs must have same width.";
+  }
+  // Shape the tops.
+  height_out_ =
+      (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1;
+  width_out_ = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1;
+  for (int top_id = 0; top_id < top->size(); ++top_id) {
+    (*top)[top_id]->Reshape(num_, num_output_, height_out_, width_out_);
+  }
+  // Prepare the matrix multiplication computation.
+  // Each input will be convolved as a single GEMM.
+  M_ = num_output_ / group_;
+  K_ = channels_ * kernel_h_ * kernel_w_ / group_;
+  N_ = height_out_ * width_out_;
+  // The im2col result buffer will only hold one image at a time to avoid
+  // overly large memory usage.
+  col_buffer_.Reshape(
+      1, channels_ * kernel_h_ * kernel_w_, height_out_, width_out_);
+  for (int top_id = 0; top_id < top->size(); ++top_id) {
+    (*top)[top_id]->Reshape(num_, num_output_, height_out_, width_out_);
+  }
   // Set up the all ones "bias multiplier" for adding biases by BLAS
   if (bias_term_) {
     bias_multiplier_.Reshape(1, 1, 1, N_);
     caffe_set(N_, Dtype(1), bias_multiplier_.mutable_cpu_data());
   }
-  // Propagate gradients to the parameters (as directed by backward pass).
-  this->param_propagate_down_.resize(this->blobs_.size(), true);
 }
 
-
 template <typename Dtype>
 void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
index eaacddc..137bbab 100644 (file)
@@ -21,7 +21,7 @@ template <typename Dtype>
 void CuDNNConvolutionLayer<Dtype>::LayerSetUp(
     const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
   ConvolutionLayer<Dtype>::LayerSetUp(bottom, top);
-  // Initialize CUDA streams and cuNN.
+  // Initialize CUDA streams and cuDNN.
   stream_         = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
   handle_         = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
 
@@ -32,10 +32,6 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp(
   }
 
   // Set the indexing parameters.
-  bottom_offset_ = (this->channels_ / this->group_)
-      * this->height_ * this->width_;
-  top_offset_ = (this->num_output_ / this->group_)
-      * this->height_out_ * this->width_out_;
   weight_offset_ = (this->num_output_ / this->group_)
       * (this->channels_ / this->group_) * this->kernel_h_ * this->kernel_w_;
   bias_offset_ = (this->num_output_ / this->group_);
@@ -48,33 +44,54 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp(
   // Create tensor descriptor(s) for data and corresponding convolution(s).
   for (int i = 0; i < bottom.size(); i++) {
     cudnnTensor4dDescriptor_t bottom_desc;
-    cudnn::createTensor4dDesc<Dtype>(&bottom_desc,
+    cudnn::createTensor4dDesc<Dtype>(&bottom_desc);
+    bottom_descs_.push_back(bottom_desc);
+    cudnnTensor4dDescriptor_t top_desc;
+    cudnn::createTensor4dDesc<Dtype>(&top_desc);
+    top_descs_.push_back(top_desc);
+    cudnnConvolutionDescriptor_t conv_desc;
+    cudnn::createConvolutionDesc<Dtype>(&conv_desc);
+    conv_descs_.push_back(conv_desc);
+  }
+
+  // Tensor descriptor for bias.
+  if (this->bias_term_) {
+    cudnn::createTensor4dDesc<Dtype>(&bias_desc_);
+  }
+}
+
+template <typename Dtype>
+void CuDNNConvolutionLayer<Dtype>::Reshape(
+    const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+  ConvolutionLayer<Dtype>::Reshape(bottom, top);
+  bottom_offset_ = (this->channels_ / this->group_)
+      * this->height_ * this->width_;
+  top_offset_ = (this->num_output_ / this->group_)
+      * this->height_out_ * this->width_out_;
+
+  for (int i = 0; i < bottom.size(); i++) {
+    cudnn::setTensor4dDesc<Dtype>(&bottom_descs_[i],
         this->num_,
         this->channels_ / this->group_,
         this->height_, this->width_,
         this->channels_ * this->height_ * this->width_,
         this->height_ * this->width_,
         this->width_, 1);
-    bottom_descs_.push_back(bottom_desc);
-    cudnnTensor4dDescriptor_t top_desc;
-    cudnn::createTensor4dDesc<Dtype>(&top_desc,
+    cudnn::setTensor4dDesc<Dtype>(&top_descs_[i],
         this->num_,
         this->num_output_ / this->group_,
         this->height_out_, this->width_out_,
         this->num_output_ * this->height_out_ * this->width_out_,
         this->height_out_ * this->width_out_,
         this->width_out_, 1);
-    top_descs_.push_back(top_desc);
-    cudnnConvolutionDescriptor_t conv_desc;
-    cudnn::createConvolutionDesc<Dtype>(&conv_desc, bottom_desc,
+    cudnn::setConvolutionDesc<Dtype>(&conv_descs_[i], bottom_descs_[i],
         filter_desc_, this->pad_h_, this->pad_w_,
         this->stride_h_, this->stride_w_);
-    conv_descs_.push_back(conv_desc);
   }
 
   // Tensor descriptor for bias.
   if (this->bias_term_) {
-    cudnn::createTensor4dDesc<Dtype>(&bias_desc_,
+    cudnn::setTensor4dDesc<Dtype>(&bias_desc_,
         1, this->num_output_ / this->group_, 1, 1);
   }
 }
index 23c5201..5aea0dc 100644 (file)
@@ -15,16 +15,24 @@ void CuDNNPoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
   PoolingLayer<Dtype>::LayerSetUp(bottom, top);
 
   CUDNN_CHECK(cudnnCreate(&handle_));
-  cudnn::createTensor4dDesc<Dtype>(&bottom_desc_, bottom[0]->num(),
-      this->channels_, this->height_, this->width_);
-  cudnn::createTensor4dDesc<Dtype>(&top_desc_, bottom[0]->num(),
-      this->channels_, this->pooled_height_, this->pooled_width_);
+  cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
+  cudnn::createTensor4dDesc<Dtype>(&top_desc_);
   cudnn::createPoolingDesc<Dtype>(&pooling_desc_,
       this->layer_param_.pooling_param().pool(), &mode_,
       this->kernel_h_, this->kernel_w_, this->stride_h_, this->stride_w_);
 }
 
 template <typename Dtype>
+void CuDNNPoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  PoolingLayer<Dtype>::Reshape(bottom, top);
+  cudnn::setTensor4dDesc<Dtype>(&bottom_desc_, bottom[0]->num(),
+      this->channels_, this->height_, this->width_);
+  cudnn::setTensor4dDesc<Dtype>(&top_desc_, bottom[0]->num(),
+      this->channels_, this->pooled_height_, this->pooled_width_);
+}
+
+template <typename Dtype>
 CuDNNPoolingLayer<Dtype>::~CuDNNPoolingLayer() {
   cudnnDestroyTensor4dDescriptor(bottom_desc_);
   cudnnDestroyTensor4dDescriptor(top_desc_);
index 02f33f1..870d5a9 100644 (file)
@@ -45,6 +45,11 @@ void Im2colLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
     stride_h_ = conv_param.stride_h();
     stride_w_ = conv_param.stride_w();
   }
+}
+
+template <typename Dtype>
+void Im2colLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
   channels_ = bottom[0]->channels();
   height_ = bottom[0]->height();
   width_ = bottom[0]->width();
index e81a32b..55b7673 100644 (file)
@@ -9,88 +9,80 @@ namespace caffe {
 template <typename Dtype>
 void LRNLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
-  num_ = bottom[0]->num();
-  channels_ = bottom[0]->channels();
-  height_ = bottom[0]->height();
-  width_ = bottom[0]->width();
   size_ = this->layer_param_.lrn_param().local_size();
   pre_pad_ = (size_ - 1) / 2;
   alpha_ = this->layer_param_.lrn_param().alpha();
   beta_ = this->layer_param_.lrn_param().beta();
+  if (this->layer_param_.lrn_param().norm_region() ==
+      LRNParameter_NormRegion_WITHIN_CHANNEL) {
+    // Set up split_layer_ to use inputs in the numerator and denominator.
+    split_top_vec_.clear();
+    split_top_vec_.push_back(&product_input_);
+    split_top_vec_.push_back(&square_input_);
+    LayerParameter split_param;
+    split_layer_.reset(new SplitLayer<Dtype>(split_param));
+    split_layer_->SetUp(bottom, &split_top_vec_);
+    // Set up square_layer_ to square the inputs.
+    square_bottom_vec_.clear();
+    square_top_vec_.clear();
+    square_bottom_vec_.push_back(&square_input_);
+    square_top_vec_.push_back(&square_output_);
+    LayerParameter square_param;
+    square_param.mutable_power_param()->set_power(Dtype(2));
+    square_layer_.reset(new PowerLayer<Dtype>(square_param));
+    square_layer_->SetUp(square_bottom_vec_, &square_top_vec_);
+    // Set up pool_layer_ to sum over square neighborhoods of the input.
+    pool_top_vec_.clear();
+    pool_top_vec_.push_back(&pool_output_);
+    LayerParameter pool_param;
+    pool_param.mutable_pooling_param()->set_pool(
+        PoolingParameter_PoolMethod_AVE);
+    pool_param.mutable_pooling_param()->set_pad(pre_pad_);
+    pool_param.mutable_pooling_param()->set_kernel_size(size_);
+    pool_layer_.reset(new PoolingLayer<Dtype>(pool_param));
+    pool_layer_->SetUp(square_top_vec_, &pool_top_vec_);
+    // Set up power_layer_ to compute (1 + alpha_/N^2 s)^-beta_, where s is
+    // the sum of a squared neighborhood (the output of pool_layer_).
+    power_top_vec_.clear();
+    power_top_vec_.push_back(&power_output_);
+    LayerParameter power_param;
+    power_param.mutable_power_param()->set_power(-beta_);
+    power_param.mutable_power_param()->set_scale(alpha_);
+    power_param.mutable_power_param()->set_shift(Dtype(1));
+    power_layer_.reset(new PowerLayer<Dtype>(power_param));
+    power_layer_->SetUp(pool_top_vec_, &power_top_vec_);
+    // Set up a product_layer_ to compute outputs by multiplying inputs by the
+    // inverse demoninator computed by the power layer.
+    product_bottom_vec_.clear();
+    product_bottom_vec_.push_back(&product_input_);
+    product_bottom_vec_.push_back(&power_output_);
+    LayerParameter product_param;
+    EltwiseParameter* eltwise_param = product_param.mutable_eltwise_param();
+    eltwise_param->set_operation(EltwiseParameter_EltwiseOp_PROD);
+    product_layer_.reset(new EltwiseLayer<Dtype>(product_param));
+    product_layer_->SetUp(product_bottom_vec_, top);
+  }
+}
+
+template <typename Dtype>
+void LRNLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  num_ = bottom[0]->num();
+  channels_ = bottom[0]->channels();
+  height_ = bottom[0]->height();
+  width_ = bottom[0]->width();
   switch (this->layer_param_.lrn_param().norm_region()) {
   case LRNParameter_NormRegion_ACROSS_CHANNELS:
     (*top)[0]->Reshape(num_, channels_, height_, width_);
     scale_.Reshape(num_, channels_, height_, width_);
     break;
   case LRNParameter_NormRegion_WITHIN_CHANNEL:
-    {
-      // Set up split_layer_ to use inputs in the numerator and denominator.
-      split_top_vec_.clear();
-      split_top_vec_.push_back(&product_input_);
-      split_top_vec_.push_back(&square_input_);
-      LayerParameter split_param;
-      split_layer_.reset(new SplitLayer<Dtype>(split_param));
-      split_layer_->SetUp(bottom, &split_top_vec_);
-      // Set up square_layer_ to square the inputs.
-      square_input_.Reshape(num_, channels_, height_, width_);
-      square_bottom_vec_.clear();
-      square_top_vec_.clear();
-      square_bottom_vec_.push_back(&square_input_);
-      square_top_vec_.push_back(&square_output_);
-      LayerParameter square_param;
-      square_param.mutable_power_param()->set_power(Dtype(2));
-      square_layer_.reset(new PowerLayer<Dtype>(square_param));
-      square_layer_->SetUp(square_bottom_vec_, &square_top_vec_);
-      CHECK_EQ(square_output_.num(), num_);
-      CHECK_EQ(square_output_.channels(), channels_);
-      CHECK_EQ(square_output_.height(), height_);
-      CHECK_EQ(square_output_.width(), width_);
-      // Set up pool_layer_ to sum over square neighborhoods of the input.
-      pool_top_vec_.clear();
-      pool_top_vec_.push_back(&pool_output_);
-      LayerParameter pool_param;
-      pool_param.mutable_pooling_param()->set_pool(
-          PoolingParameter_PoolMethod_AVE);
-      pool_param.mutable_pooling_param()->set_pad(pre_pad_);
-      pool_param.mutable_pooling_param()->set_kernel_size(size_);
-      pool_layer_.reset(new PoolingLayer<Dtype>(pool_param));
-      pool_layer_->SetUp(square_top_vec_, &pool_top_vec_);
-      CHECK_EQ(pool_output_.num(), num_);
-      CHECK_EQ(pool_output_.channels(), channels_);
-      CHECK_EQ(pool_output_.height(), height_);
-      CHECK_EQ(pool_output_.width(), width_);
-      // Set up power_layer_ to compute (1 + alpha_/N^2 s)^-beta_, where s is
-      // the sum of a squared neighborhood (the output of pool_layer_).
-      power_top_vec_.clear();
-      power_top_vec_.push_back(&power_output_);
-      LayerParameter power_param;
-      power_param.mutable_power_param()->set_power(-beta_);
-      power_param.mutable_power_param()->set_scale(alpha_);
-      power_param.mutable_power_param()->set_shift(Dtype(1));
-      power_layer_.reset(new PowerLayer<Dtype>(power_param));
-      power_layer_->SetUp(pool_top_vec_, &power_top_vec_);
-      CHECK_EQ(power_output_.num(), num_);
-      CHECK_EQ(power_output_.channels(), channels_);
-      CHECK_EQ(power_output_.height(), height_);
-      CHECK_EQ(power_output_.width(), width_);
-      // Set up a product_layer_ to compute outputs by multiplying inputs by the
-      // inverse demoninator computed by the power layer.
-      product_bottom_vec_.clear();
-      product_bottom_vec_.push_back(&product_input_);
-      product_bottom_vec_.push_back(&power_output_);
-      LayerParameter product_param;
-      EltwiseParameter* eltwise_param = product_param.mutable_eltwise_param();
-      eltwise_param->set_operation(EltwiseParameter_EltwiseOp_PROD);
-      product_layer_.reset(new EltwiseLayer<Dtype>(product_param));
-      product_layer_->SetUp(product_bottom_vec_, top);
-      CHECK_EQ((*top)[0]->num(), num_);
-      CHECK_EQ((*top)[0]->channels(), channels_);
-      CHECK_EQ((*top)[0]->height(), height_);
-      CHECK_EQ((*top)[0]->width(), width_);
-    }
+    split_layer_->Reshape(bottom, &split_top_vec_);
+    square_layer_->Reshape(square_bottom_vec_, &square_top_vec_);
+    pool_layer_->Reshape(square_top_vec_, &pool_top_vec_);
+    power_layer_->Reshape(pool_top_vec_, &power_top_vec_);
+    product_layer_->Reshape(product_bottom_vec_, top);
     break;
-  default:
-    LOG(FATAL) << "Unknown normalization region.";
   }
 }
 
index 9e77fa2..8e8ffad 100644 (file)
@@ -60,6 +60,11 @@ void PoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
     CHECK_LT(pad_h_, kernel_h_);
     CHECK_LT(pad_w_, kernel_w_);
   }
+}
+
+template <typename Dtype>
+void PoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
   channels_ = bottom[0]->channels();
   height_ = bottom[0]->height();
   width_ = bottom[0]->width();