split off Reshape for neuron layers
authorJonathan L Long <jonlong@cs.berkeley.edu>
Thu, 11 Sep 2014 04:48:51 +0000 (21:48 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Thu, 18 Sep 2014 19:41:46 +0000 (12:41 -0700)
include/caffe/neuron_layers.hpp
src/caffe/layers/cudnn_relu_layer.cpp
src/caffe/layers/cudnn_sigmoid_layer.cpp
src/caffe/layers/cudnn_tanh_layer.cpp
src/caffe/layers/dropout_layer.cpp
src/caffe/layers/neuron_layer.cpp

index 36acf96..0968a20 100644 (file)
@@ -26,7 +26,7 @@ class NeuronLayer : public Layer<Dtype> {
  public:
   explicit NeuronLayer(const LayerParameter& param)
      : Layer<Dtype>(param) {}
-  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
 
   virtual inline LayerParameter_LayerType type() const {
@@ -170,6 +170,8 @@ class DropoutLayer : public NeuronLayer<Dtype> {
       : NeuronLayer<Dtype>(param) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
 
   virtual inline LayerParameter_LayerType type() const {
     return LayerParameter_LayerType_DROPOUT;
@@ -367,6 +369,8 @@ class CuDNNReLULayer : public ReLULayer<Dtype> {
       : ReLULayer<Dtype>(param) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
   virtual ~CuDNNReLULayer();
 
  protected:
@@ -449,6 +453,8 @@ class CuDNNSigmoidLayer : public SigmoidLayer<Dtype> {
       : SigmoidLayer<Dtype>(param) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
   virtual ~CuDNNSigmoidLayer();
 
  protected:
@@ -533,6 +539,8 @@ class CuDNNTanHLayer : public TanHLayer<Dtype> {
       : TanHLayer<Dtype>(param) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
   virtual ~CuDNNTanHLayer();
 
  protected:
index f8bf77f..083868f 100644 (file)
@@ -13,12 +13,20 @@ void CuDNNReLULayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
   ReLULayer<Dtype>::LayerSetUp(bottom, top);
   // initialize cuDNN
   CUDNN_CHECK(cudnnCreate(&handle_));
+  cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
+  cudnn::createTensor4dDesc<Dtype>(&top_desc_);
+}
+
+template <typename Dtype>
+void CuDNNReLULayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  ReLULayer<Dtype>::Reshape(bottom, top);
   const int N = bottom[0]->num();
   const int K = bottom[0]->channels();
   const int H = bottom[0]->height();
   const int W = bottom[0]->width();
-  cudnn::createTensor4dDesc<Dtype>(&bottom_desc_, N, K, H, W);
-  cudnn::createTensor4dDesc<Dtype>(&top_desc_, N, K, H, W);
+  cudnn::setTensor4dDesc<Dtype>(&bottom_desc_, N, K, H, W);
+  cudnn::setTensor4dDesc<Dtype>(&top_desc_, N, K, H, W);
 }
 
 template <typename Dtype>
index 488c754..3fe800d 100644 (file)
@@ -13,12 +13,20 @@ void CuDNNSigmoidLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
   SigmoidLayer<Dtype>::LayerSetUp(bottom, top);
   // initialize cuDNN
   CUDNN_CHECK(cudnnCreate(&handle_));
+  cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
+  cudnn::createTensor4dDesc<Dtype>(&top_desc_);
+}
+
+template <typename Dtype>
+void CuDNNSigmoidLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  SigmoidLayer<Dtype>::Reshape(bottom, top);
   const int N = bottom[0]->num();
   const int K = bottom[0]->channels();
   const int H = bottom[0]->height();
   const int W = bottom[0]->width();
-  cudnn::createTensor4dDesc<Dtype>(&bottom_desc_, N, K, H, W);
-  cudnn::createTensor4dDesc<Dtype>(&top_desc_, N, K, H, W);
+  cudnn::setTensor4dDesc<Dtype>(&bottom_desc_, N, K, H, W);
+  cudnn::setTensor4dDesc<Dtype>(&top_desc_, N, K, H, W);
 }
 
 template <typename Dtype>
index 32b6611..7a5c06f 100644 (file)
@@ -13,12 +13,20 @@ void CuDNNTanHLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
   TanHLayer<Dtype>::LayerSetUp(bottom, top);
   // initialize cuDNN
   CUDNN_CHECK(cudnnCreate(&handle_));
+  cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
+  cudnn::createTensor4dDesc<Dtype>(&top_desc_);
+}
+
+template <typename Dtype>
+void CuDNNTanHLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  TanHLayer<Dtype>::Reshape(bottom, top);
   const int N = bottom[0]->num();
   const int K = bottom[0]->channels();
   const int H = bottom[0]->height();
   const int W = bottom[0]->width();
-  cudnn::createTensor4dDesc<Dtype>(&bottom_desc_, N, K, H, W);
-  cudnn::createTensor4dDesc<Dtype>(&top_desc_, N, K, H, W);
+  cudnn::setTensor4dDesc<Dtype>(&bottom_desc_, N, K, H, W);
+  cudnn::setTensor4dDesc<Dtype>(&top_desc_, N, K, H, W);
 }
 
 template <typename Dtype>
index 52537d1..47feb1d 100644 (file)
@@ -14,9 +14,6 @@ template <typename Dtype>
 void DropoutLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
   NeuronLayer<Dtype>::LayerSetUp(bottom, top);
-  // Set up the cache for random number generation
-  rand_vec_.Reshape(bottom[0]->num(), bottom[0]->channels(),
-      bottom[0]->height(), bottom[0]->width());
   threshold_ = this->layer_param_.dropout_param().dropout_ratio();
   DCHECK(threshold_ > 0.);
   DCHECK(threshold_ < 1.);
@@ -25,6 +22,15 @@ void DropoutLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
 }
 
 template <typename Dtype>
+void DropoutLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  NeuronLayer<Dtype>::Reshape(bottom, top);
+  // Set up the cache for random number generation
+  rand_vec_.Reshape(bottom[0]->num(), bottom[0]->channels(),
+      bottom[0]->height(), bottom[0]->width());
+}
+
+template <typename Dtype>
 void DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
   const Dtype* bottom_data = bottom[0]->cpu_data();
index eff7948..c28e36e 100644 (file)
@@ -6,13 +6,9 @@
 namespace caffe {
 
 template <typename Dtype>
-void NeuronLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+void NeuronLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
-  // NeuronLayer allows in-place computations. If the computation is not
-  // in-place, we will need to initialize the top blob.
-  if ((*top)[0] != bottom[0]) {
-    (*top)[0]->ReshapeLike(*bottom[0]);
-  }
+  (*top)[0]->ReshapeLike(*bottom[0]);
 }
 
 INSTANTIATE_CLASS(NeuronLayer);