change Backward interface: propagate_down is a vector -- use to fix
authorJeff Donahue <jeff.donahue@gmail.com>
Tue, 10 Jun 2014 19:38:59 +0000 (12:38 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Thu, 26 Jun 2014 18:34:19 +0000 (11:34 -0700)
long-standing issue with how this is handled in loss layers (esp.
EuclideanLossLayer)

58 files changed:
include/caffe/data_layers.hpp
include/caffe/layer.hpp
include/caffe/loss_layers.hpp
include/caffe/net.hpp
include/caffe/neuron_layers.hpp
include/caffe/vision_layers.hpp
src/caffe/layers/bnll_layer.cpp
src/caffe/layers/bnll_layer.cu
src/caffe/layers/concat_layer.cpp
src/caffe/layers/concat_layer.cu
src/caffe/layers/conv_layer.cpp
src/caffe/layers/conv_layer.cu
src/caffe/layers/dropout_layer.cpp
src/caffe/layers/dropout_layer.cu
src/caffe/layers/eltwise_layer.cpp
src/caffe/layers/eltwise_layer.cu
src/caffe/layers/euclidean_loss_layer.cpp
src/caffe/layers/euclidean_loss_layer.cu [new file with mode: 0644]
src/caffe/layers/flatten_layer.cpp
src/caffe/layers/flatten_layer.cu
src/caffe/layers/hdf5_data_layer.cpp
src/caffe/layers/hdf5_data_layer.cu
src/caffe/layers/hdf5_output_layer.cpp
src/caffe/layers/hdf5_output_layer.cu
src/caffe/layers/hinge_loss_layer.cpp
src/caffe/layers/im2col_layer.cpp
src/caffe/layers/im2col_layer.cu
src/caffe/layers/infogain_loss_layer.cpp
src/caffe/layers/inner_product_layer.cpp
src/caffe/layers/inner_product_layer.cu
src/caffe/layers/lrn_layer.cpp
src/caffe/layers/lrn_layer.cu
src/caffe/layers/multinomial_logistic_loss_layer.cpp
src/caffe/layers/pooling_layer.cpp
src/caffe/layers/pooling_layer.cu
src/caffe/layers/power_layer.cpp
src/caffe/layers/power_layer.cu
src/caffe/layers/relu_layer.cpp
src/caffe/layers/relu_layer.cu
src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu
src/caffe/layers/sigmoid_layer.cpp
src/caffe/layers/sigmoid_layer.cu
src/caffe/layers/softmax_layer.cpp
src/caffe/layers/softmax_layer.cu
src/caffe/layers/softmax_loss_layer.cpp
src/caffe/layers/softmax_loss_layer.cu
src/caffe/layers/split_layer.cpp
src/caffe/layers/split_layer.cu
src/caffe/layers/tanh_layer.cpp
src/caffe/layers/tanh_layer.cu
src/caffe/net.cpp
src/caffe/test/test_euclidean_loss_layer.cpp
src/caffe/test/test_gradient_check_util.hpp
src/caffe/test/test_lrn_layer.cpp
src/caffe/test/test_maxpool_dropout_layers.cpp
src/caffe/test/test_net.cpp
tools/net_speed_benchmark.cpp

index fc93daa..2c6be55 100644 (file)
@@ -47,9 +47,9 @@ class HDF5OutputLayer : public Layer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void SaveBlobs();
 
   std::string file_name_;
@@ -80,9 +80,9 @@ class HDF5DataLayer : public Layer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
   virtual void LoadHDF5FileData(const char* filename);
 
   std::vector<std::string> hdf_filenames_;
@@ -125,9 +125,9 @@ class DataLayer : public Layer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
 
   virtual void CreatePrefetchThread();
   virtual void JoinPrefetchThread();
@@ -175,9 +175,9 @@ class DummyDataLayer : public Layer<Dtype> {
   virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
 
   vector<shared_ptr<Filler<Dtype> > > fillers_;
   vector<bool> refill_;
@@ -211,9 +211,9 @@ class ImageDataLayer : public Layer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
 
   virtual void ShuffleImages();
 
@@ -263,9 +263,9 @@ class MemoryDataLayer : public Layer<Dtype> {
   virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
 
   Dtype* data_;
   Dtype* labels_;
@@ -306,9 +306,9 @@ class WindowDataLayer : public Layer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { return; }
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {}
 
   virtual void CreatePrefetchThread();
   virtual void JoinPrefetchThread();
index af3d544..12e7610 100644 (file)
@@ -46,7 +46,7 @@ class Layer {
   inline Dtype Forward(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   inline void Backward(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down,
+      const vector<bool>& propagate_down,
       vector<Blob<Dtype>*>* bottom);
 
   // Returns the vector of blobs.
@@ -104,10 +104,10 @@ class Layer {
   // Backward functions: compute the gradients for any parameters and
   // for the bottom blobs if propagate_down is true.
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down,
+      const vector<bool>& propagate_down,
       vector<Blob<Dtype>*>* bottom) = 0;
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down,
+      const vector<bool>& propagate_down,
       vector<Blob<Dtype>*>* bottom) {
     // LOG(WARNING) << "Using CPU code as backup.";
     Backward_cpu(top, propagate_down, bottom);
@@ -172,7 +172,7 @@ inline Dtype Layer<Dtype>::Forward(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 inline void Layer<Dtype>::Backward(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   switch (Caffe::mode()) {
   case Caffe::CPU:
index db7c63e..b8adc99 100644 (file)
@@ -62,7 +62,6 @@ class SoftmaxWithLossLayer : public Layer<Dtype> {
   virtual inline LayerParameter_LayerType type() const {
     return LayerParameter_LayerType_SOFTMAX_LOSS;
   }
-  virtual inline int ExactNumBottomBlobs() const { return 2; }
   virtual inline int MaxTopBlobs() const { return 2; }
 
  protected:
@@ -71,9 +70,9 @@ class SoftmaxWithLossLayer : public Layer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-     const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   shared_ptr<SoftmaxLayer<Dtype> > softmax_layer_;
   // prob stores the output probability of the layer.
@@ -105,9 +104,9 @@ class SigmoidCrossEntropyLossLayer : public LossLayer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   shared_ptr<SigmoidLayer<Dtype> > sigmoid_layer_;
   // sigmoid_output stores the output of the sigmoid layer.
@@ -138,8 +137,12 @@ class EuclideanLossLayer : public LossLayer<Dtype> {
  protected:
   virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
+  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   Blob<Dtype> diff_;
 };
@@ -162,7 +165,7 @@ class InfogainLossLayer : public LossLayer<Dtype> {
   virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   Blob<Dtype> infogain_;
 };
@@ -183,7 +186,7 @@ class HingeLossLayer : public LossLayer<Dtype> {
   virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 };
 
 /* MultinomialLogisticLossLayer
@@ -204,7 +207,7 @@ class MultinomialLogisticLossLayer : public LossLayer<Dtype> {
   virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 };
 
 /* AccuracyLayer
@@ -230,15 +233,11 @@ class AccuracyLayer : public Layer<Dtype> {
   virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
     NOT_IMPLEMENTED;
   }
 };
 
-/* Also see
-- SoftmaxWithLossLayer in vision_layers.hpp
-*/
-
 }  // namespace caffe
 
 #endif  // CAFFE_LOSS_LAYERS_HPP_
index 5b5dd22..d05ca09 100644 (file)
@@ -81,6 +81,9 @@ class Net {
   // this unless you do per-layer checks such as gradients.
   inline vector<vector<Blob<Dtype>*> >& bottom_vecs() { return bottom_vecs_; }
   inline vector<vector<Blob<Dtype>*> >& top_vecs() { return top_vecs_; }
+  inline vector<vector<bool> >& bottom_need_backward() {
+    return bottom_need_backward_;
+  }
   // returns the parameters
   inline vector<shared_ptr<Blob<Dtype> > >& params() { return params_; }
   // returns the parameter learning rate multipliers
@@ -131,6 +134,7 @@ class Net {
   // pointers.
   vector<vector<Blob<Dtype>*> > bottom_vecs_;
   vector<vector<int> > bottom_id_vecs_;
+  vector<vector<bool> > bottom_need_backward_;
   // top_vecs stores the vectors containing the output for each layer
   vector<vector<Blob<Dtype>*> > top_vecs_;
   vector<vector<int> > top_id_vecs_;
index ed664df..e52e395 100644 (file)
@@ -64,9 +64,9 @@ class BNLLLayer : public NeuronLayer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 };
 
 /* DropoutLayer
@@ -97,9 +97,9 @@ class DropoutLayer : public NeuronLayer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   shared_ptr<Blob<unsigned int> > rand_vec_;
   Dtype threshold_;
@@ -131,9 +131,9 @@ class PowerLayer : public NeuronLayer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   Dtype power_;
   Dtype scale_;
@@ -167,9 +167,9 @@ class ReLULayer : public NeuronLayer<Dtype> {
       vector<Blob<Dtype>*>* top);
 
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 };
 
 /* SigmoidLayer
@@ -199,9 +199,9 @@ class SigmoidLayer : public NeuronLayer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 };
 
 /* TanHLayer
@@ -227,9 +227,9 @@ class TanHLayer : public NeuronLayer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 };
 
 /* ThresholdLayer
@@ -260,7 +260,7 @@ class ThresholdLayer : public NeuronLayer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
     NOT_IMPLEMENTED;
   }
 
index 3fd7e2f..b68dcbf 100644 (file)
@@ -44,7 +44,7 @@ class ArgMaxLayer : public Layer<Dtype> {
   virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
     NOT_IMPLEMENTED;
   }
   bool out_max_val_;
@@ -74,9 +74,9 @@ class ConcatLayer : public Layer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   Blob<Dtype> col_bob_;
   int count_;
@@ -109,9 +109,9 @@ class ConvolutionLayer : public Layer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   int kernel_size_;
   int stride_;
@@ -153,9 +153,9 @@ class EltwiseLayer : public Layer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   EltwiseParameter_EltwiseOp op_;
   vector<Dtype> coeffs_;
@@ -183,9 +183,9 @@ class FlattenLayer : public Layer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   int count_;
 };
@@ -212,9 +212,9 @@ class Im2colLayer : public Layer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   int kernel_size_;
   int stride_;
@@ -246,9 +246,9 @@ class InnerProductLayer : public Layer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   int M_;
   int K_;
@@ -284,9 +284,9 @@ class LRNLayer : public Layer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   virtual Dtype CrossChannelForward_cpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
@@ -295,11 +295,11 @@ class LRNLayer : public Layer<Dtype> {
   virtual Dtype WithinChannelForward(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void CrossChannelBackward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void CrossChannelBackward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void WithinChannelBackward(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   int size_;
   int pre_pad_;
@@ -356,9 +356,9 @@ class PoolingLayer : public Layer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   int max_top_blobs_;
   int kernel_size_;
@@ -395,9 +395,9 @@ class SoftmaxLayer : public Layer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-     const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+     const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   // sum_multiplier is just used to carry out sum using blas
   Blob<Dtype> sum_multiplier_;
@@ -427,9 +427,9 @@ class SplitLayer : public Layer<Dtype> {
   virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top);
   virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
 
   int count_;
 };
index d08adc4..95e6bd8 100644 (file)
@@ -28,9 +28,9 @@ Dtype BNLLLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void BNLLLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  if (propagate_down) {
+  if (propagate_down[0]) {
     const Dtype* bottom_data = (*bottom)[0]->cpu_data();
     const Dtype* top_diff = top[0]->cpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
index 75bea00..7849d11 100644 (file)
@@ -45,9 +45,9 @@ __global__ void BNLLBackward(const int n, const Dtype* in_diff,
 
 template <typename Dtype>
 void BNLLLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  if (propagate_down) {
+  if (propagate_down[0]) {
     const Dtype* bottom_data = (*bottom)[0]->gpu_data();
     const Dtype* top_diff = top[0]->gpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
index 4541ee7..4c894dd 100644 (file)
@@ -70,26 +70,30 @@ Dtype ConcatLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void ConcatLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->cpu_diff();
   if (concat_dim_ == 0) {
     int offset_num = 0;
     for (int i = 0; i < bottom->size(); ++i) {
       Blob<Dtype>* blob = (*bottom)[i];
-      Dtype* bottom_diff = blob->mutable_cpu_diff();
-      caffe_copy(blob->count(),
-        top_diff+top[0]->offset(offset_num), bottom_diff);
+      if (propagate_down[i]) {
+        Dtype* bottom_diff = blob->mutable_cpu_diff();
+        caffe_copy(blob->count(), top_diff + top[0]->offset(offset_num),
+                   bottom_diff);
+      }
       offset_num += blob->num();
     }
   } else if (concat_dim_ == 1) {
     int offset_channel = 0;
     for (int i = 0; i < bottom->size(); ++i) {
       Blob<Dtype>* blob = (*bottom)[i];
-      Dtype* bottom_diff = blob->mutable_cpu_diff();
-      int num_elem = blob->channels()*blob->height()*blob->width();
-      for (int n = 0; n < num_; ++n) {
-        caffe_copy(num_elem, top_diff+top[0]->offset(n, offset_channel),
-          bottom_diff+blob->offset(n));
+      if (propagate_down[i]) {
+        Dtype* bottom_diff = blob->mutable_cpu_diff();
+        int num_elem = blob->channels()*blob->height()*blob->width();
+        for (int n = 0; n < num_; ++n) {
+          caffe_copy(num_elem, top_diff + top[0]->offset(n, offset_channel),
+                     bottom_diff + blob->offset(n));
+        }
       }
       offset_channel += blob->channels();
     }
index 2820bf0..ca0cf0c 100644 (file)
@@ -41,26 +41,30 @@ Dtype ConcatLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void ConcatLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->gpu_diff();
   if (concat_dim_ == 0) {
     int offset_num = 0;
     for (int i = 0; i < bottom->size(); ++i) {
       Blob<Dtype>* blob = (*bottom)[i];
-      Dtype* bottom_diff = blob->mutable_gpu_diff();
-      caffe_gpu_copy(blob->count(),
-        top_diff + top[0]->offset(offset_num), bottom_diff);
+      if (propagate_down[i]) {
+        Dtype* bottom_diff = blob->mutable_gpu_diff();
+        caffe_gpu_copy(blob->count(), top_diff + top[0]->offset(offset_num),
+                       bottom_diff);
+      }
       offset_num += blob->num();
     }
   } else if (concat_dim_ == 1) {
     int offset_channel = 0;
     for (int i = 0; i < bottom->size(); ++i) {
       Blob<Dtype>* blob = (*bottom)[i];
-      Dtype* bottom_diff = blob->mutable_gpu_diff();
-      int num_elem = blob->channels()*blob->height()*blob->width();
-      for (int n = 0; n < num_; ++n) {
-        caffe_gpu_copy(num_elem, top_diff + top[0]->offset(n, offset_channel),
-          bottom_diff + blob->offset(n));
+      if (propagate_down[i]) {
+        Dtype* bottom_diff = blob->mutable_gpu_diff();
+        int num_elem = blob->channels()*blob->height()*blob->width();
+        for (int n = 0; n < num_; ++n) {
+          caffe_gpu_copy(num_elem, top_diff + top[0]->offset(n, offset_channel),
+                         bottom_diff + blob->offset(n));
+        }
       }
       offset_channel += blob->channels();
     }
index 880b83a..9ec8da4 100644 (file)
@@ -109,7 +109,7 @@ Dtype ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->cpu_diff();
   const Dtype* weight = this->blobs_[0]->cpu_data();
   Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff();
@@ -148,7 +148,7 @@ void ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
         weight_diff + weight_offset * g);
     }
     // gradient w.r.t. bottom data, if necessary
-    if (propagate_down) {
+    if (propagate_down[0]) {
       for (int g = 0; g < group_; ++g) {
         caffe_cpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_,
           (Dtype)1., weight + weight_offset * g,
index 51f5d15..85f95fd 100644 (file)
@@ -43,7 +43,7 @@ Dtype ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->gpu_diff();
   const Dtype* weight = this->blobs_[0]->gpu_data();
   Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff();
@@ -84,7 +84,7 @@ void ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
         weight_diff + weight_offset * g);
     }
     // gradient w.r.t. bottom data, if necessary
-    if (propagate_down) {
+    if (propagate_down[0]) {
       for (int g = 0; g < group_; ++g) {
         caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_,
           (Dtype)1., weight + weight_offset * g,
index f1e5411..e9a1a52 100644 (file)
@@ -47,10 +47,10 @@ Dtype DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void DropoutLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   CHECK(Caffe::phase() == Caffe::TRAIN);
-  if (propagate_down) {
+  if (propagate_down[0]) {
     const Dtype* top_diff = top[0]->cpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
     const unsigned int* mask = rand_vec_->cpu_data();
index 3c25d6a..2c72264 100644 (file)
@@ -56,10 +56,10 @@ __global__ void DropoutBackward(const int n, const Dtype* in_diff,
 
 template <typename Dtype>
 void DropoutLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   CHECK(Caffe::phase() == Caffe::TRAIN);
-  if (propagate_down) {
+  if (propagate_down[0]) {
     const Dtype* top_diff = top[0]->gpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
     const unsigned int* mask =
index 5e5d760..2c265f6 100644 (file)
@@ -67,12 +67,12 @@ Dtype EltwiseLayer<Dtype>::Forward_cpu(
 
 template <typename Dtype>
 void EltwiseLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
-  if (propagate_down) {
-    const int count = top[0]->count();
-    const Dtype* top_data = top[0]->cpu_data();
-    const Dtype* top_diff = top[0]->cpu_diff();
-    for (int i = 0; i < bottom->size(); ++i) {
+    const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
+  const int count = top[0]->count();
+  const Dtype* top_data = top[0]->cpu_data();
+  const Dtype* top_diff = top[0]->cpu_diff();
+  for (int i = 0; i < bottom->size(); ++i) {
+    if (propagate_down[i]) {
       const Dtype* bottom_data = (*bottom)[i]->cpu_data();
       Dtype* bottom_diff = (*bottom)[i]->mutable_cpu_diff();
       switch (op_) {
index 75827ba..3860944 100644 (file)
@@ -36,12 +36,12 @@ Dtype EltwiseLayer<Dtype>::Forward_gpu(
 
 template <typename Dtype>
 void EltwiseLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
-  if (propagate_down) {
-    const int count = top[0]->count();
-    const Dtype* top_data = top[0]->gpu_data();
-    const Dtype* top_diff = top[0]->gpu_diff();
-    for (int i = 0; i < bottom->size(); ++i) {
+    const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
+  const int count = top[0]->count();
+  const Dtype* top_data = top[0]->gpu_data();
+  const Dtype* top_diff = top[0]->gpu_diff();
+  for (int i = 0; i < bottom->size(); ++i) {
+    if (propagate_down[i]) {
       const Dtype* bottom_data = (*bottom)[i]->gpu_data();
       Dtype* bottom_diff = (*bottom)[i]->mutable_gpu_diff();
       switch (op_) {
index 7662949..2478a51 100644 (file)
@@ -1,8 +1,5 @@
 // Copyright 2014 BVLC and contributors.
 
-#include <algorithm>
-#include <cmath>
-#include <cfloat>
 #include <vector>
 
 #include "caffe/layer.hpp"
@@ -10,8 +7,6 @@
 #include "caffe/util/math_functions.hpp"
 #include "caffe/util/io.hpp"
 
-using std::max;
-
 namespace caffe {
 
 template <typename Dtype>
@@ -43,13 +38,18 @@ Dtype EuclideanLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void EuclideanLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
-  caffe_cpu_axpby(
-      (*bottom)[0]->count(),              // count
-      Dtype(1) / (*bottom)[0]->num(),     // alpha
-      diff_.cpu_data(),                   // a
-      Dtype(0),                           // beta
-      (*bottom)[0]->mutable_cpu_diff());  // b
+    const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
+  for (int i = 0; i < 2; ++i) {
+    if (propagate_down[i]) {
+      const Dtype sign = (i == 0) ? 1 : -1;
+      caffe_cpu_axpby(
+          (*bottom)[i]->count(),              // count
+          sign / (*bottom)[i]->num(),         // alpha
+          diff_.cpu_data(),                   // a
+          Dtype(0),                           // beta
+          (*bottom)[i]->mutable_cpu_diff());  // b
+    }
+  }
 }
 
 INSTANTIATE_CLASS(EuclideanLossLayer);
diff --git a/src/caffe/layers/euclidean_loss_layer.cu b/src/caffe/layers/euclidean_loss_layer.cu
new file mode 100644 (file)
index 0000000..b070ea9
--- /dev/null
@@ -0,0 +1,45 @@
+// Copyright 2014 BVLC and contributors.
+
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/util/math_functions.hpp"
+#include "caffe/util/io.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+Dtype EuclideanLossLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  int count = bottom[0]->count();
+  caffe_gpu_sub(
+      count,
+      bottom[0]->gpu_data(),
+      bottom[1]->gpu_data(),
+      diff_.mutable_gpu_data());
+  Dtype dot;
+  caffe_gpu_dot(count, diff_.gpu_data(), diff_.gpu_data(), &dot);
+  Dtype loss = dot / bottom[0]->num() / Dtype(2);
+  return loss;
+}
+
+template <typename Dtype>
+void EuclideanLossLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
+  for (int i = 0; i < 2; ++i) {
+    if (propagate_down[i]) {
+      const Dtype sign = (i == 0) ? 1 : -1;
+      caffe_gpu_axpby(
+          (*bottom)[i]->count(),              // count
+          sign / (*bottom)[i]->num(),         // alpha
+          diff_.gpu_data(),                   // a
+          Dtype(0),                           // beta
+          (*bottom)[i]->mutable_gpu_diff());  // b
+    }
+  }
+}
+
+INSTANTIATE_CLASS(EuclideanLossLayer);
+
+}  // namespace caffe
index 95f4859..9494da9 100644 (file)
@@ -29,7 +29,7 @@ Dtype FlattenLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void FlattenLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   (*bottom)[0]->ShareDiff(*top[0]);
 }
 
index 157eeb1..68add38 100644 (file)
@@ -17,7 +17,7 @@ Dtype FlattenLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void FlattenLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   (*bottom)[0]->ShareDiff(*top[0]);
 }
 
index 50863b8..2ba7fa7 100644 (file)
@@ -114,11 +114,6 @@ Dtype HDF5DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   return Dtype(0.);
 }
 
-// The backward operations are dummy - they do not carry any computation.
-template <typename Dtype>
-void HDF5DataLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) { }
-
 INSTANTIATE_CLASS(HDF5DataLayer);
 
 }  // namespace caffe
index 9c5bb5a..b2b09ef 100644 (file)
@@ -54,11 +54,6 @@ Dtype HDF5DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   return Dtype(0.);
 }
 
-template <typename Dtype>
-void HDF5DataLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
-}
-
 INSTANTIATE_CLASS(HDF5DataLayer);
 
 }  // namespace caffe
index 0961b9b..3a513b9 100644 (file)
@@ -67,7 +67,7 @@ Dtype HDF5OutputLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void HDF5OutputLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   return;
 }
 
index b994825..59505ee 100644 (file)
@@ -40,7 +40,7 @@ Dtype HDF5OutputLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void HDF5OutputLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   return;
 }
 
index fd00d8c..8097761 100644 (file)
@@ -45,27 +45,33 @@ Dtype HingeLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void HingeLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
-  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
-  const Dtype* label = (*bottom)[1]->cpu_data();
-  int num = (*bottom)[0]->num();
-  int count = (*bottom)[0]->count();
-  int dim = count / num;
-
-  for (int i = 0; i < num; ++i) {
-    bottom_diff[i * dim + static_cast<int>(label[i])] *= -1;
+    const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
+  if (propagate_down[1]) {
+    LOG(FATAL) << this->type_name()
+               << " Layer cannot backpropagate to label inputs.";
   }
+  if (propagate_down[0]) {
+    Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+    const Dtype* label = (*bottom)[1]->cpu_data();
+    int num = (*bottom)[0]->num();
+    int count = (*bottom)[0]->count();
+    int dim = count / num;
 
-  switch (this->layer_param_.hinge_loss_param().norm()) {
-  case HingeLossParameter_Norm_L1:
-    caffe_cpu_sign(count, bottom_diff, bottom_diff);
-    caffe_scal(count, Dtype(1. / num), bottom_diff);
-    break;
-  case HingeLossParameter_Norm_L2:
-    caffe_scal(count, Dtype(2. / num), bottom_diff);
-    break;
-  default:
-    LOG(FATAL) << "Unknown Norm";
+    for (int i = 0; i < num; ++i) {
+      bottom_diff[i * dim + static_cast<int>(label[i])] *= -1;
+    }
+
+    switch (this->layer_param_.hinge_loss_param().norm()) {
+    case HingeLossParameter_Norm_L1:
+      caffe_cpu_sign(count, bottom_diff, bottom_diff);
+      caffe_scal(count, Dtype(1. / num), bottom_diff);
+      break;
+    case HingeLossParameter_Norm_L2:
+      caffe_scal(count, Dtype(2. / num), bottom_diff);
+      break;
+    default:
+      LOG(FATAL) << "Unknown Norm";
+    }
   }
 }
 
index f0c26c9..e047dfb 100644 (file)
@@ -38,7 +38,7 @@ Dtype Im2colLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void Im2colLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->cpu_diff();
   Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
   for (int n = 0; n < top[0]->num(); ++n) {
index 26bc1b9..9cfb74e 100644 (file)
@@ -23,7 +23,7 @@ Dtype Im2colLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void Im2colLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->gpu_diff();
   Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
   for (int n = 0; n < top[0]->num(); ++n) {
index 3e7fc4f..a72874e 100644 (file)
@@ -56,20 +56,26 @@ Dtype InfogainLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void InfogainLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  const Dtype* bottom_data = (*bottom)[0]->cpu_data();
-  const Dtype* bottom_label = (*bottom)[1]->cpu_data();
-  const Dtype* infogain_mat = infogain_.cpu_data();
-  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
-  int num = (*bottom)[0]->num();
-  int dim = (*bottom)[0]->count() / (*bottom)[0]->num();
-  CHECK_EQ(infogain_.height(), dim);
-  for (int i = 0; i < num; ++i) {
-    int label = static_cast<int>(bottom_label[i]);
-    for (int j = 0; j < dim; ++j) {
-      Dtype prob = max(bottom_data[i * dim + j], Dtype(kLOG_THRESHOLD));
-      bottom_diff[i * dim + j] = - infogain_mat[label * dim + j] / prob / num;
+  if (propagate_down[1]) {
+    LOG(FATAL) << this->type_name()
+               << " Layer cannot backpropagate to label inputs.";
+  }
+  if (propagate_down[0]) {
+    const Dtype* bottom_data = (*bottom)[0]->cpu_data();
+    const Dtype* bottom_label = (*bottom)[1]->cpu_data();
+    const Dtype* infogain_mat = infogain_.cpu_data();
+    Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+    int num = (*bottom)[0]->num();
+    int dim = (*bottom)[0]->count() / (*bottom)[0]->num();
+    CHECK_EQ(infogain_.height(), dim);
+    for (int i = 0; i < num; ++i) {
+      int label = static_cast<int>(bottom_label[i]);
+      for (int j = 0; j < dim; ++j) {
+        Dtype prob = max(bottom_data[i * dim + j], Dtype(kLOG_THRESHOLD));
+        bottom_diff[i * dim + j] = - infogain_mat[label * dim + j] / prob / num;
+      }
     }
   }
 }
index 971254c..ddf55e4 100644 (file)
@@ -74,7 +74,7 @@ Dtype InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->cpu_diff();
   const Dtype* bottom_data = (*bottom)[0]->cpu_data();
@@ -87,7 +87,7 @@ void InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
         reinterpret_cast<const Dtype*>(bias_multiplier_->cpu_data()), (Dtype)0.,
         this->blobs_[1]->mutable_cpu_diff());
   }
-  if (propagate_down) {
+  if (propagate_down[0]) {
     // Gradient with respect to bottom data
     caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1.,
         top_diff, this->blobs_[0]->cpu_data(), (Dtype)0.,
index f139c23..5b95a57 100644 (file)
@@ -31,7 +31,7 @@ Dtype InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->gpu_diff();
   const Dtype* bottom_data = (*bottom)[0]->gpu_data();
@@ -44,7 +44,7 @@ void InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
         reinterpret_cast<const Dtype*>(bias_multiplier_->gpu_data()),
         (Dtype)0., this->blobs_[1]->mutable_gpu_diff());
   }
-  if (propagate_down) {
+  if (propagate_down[0]) {
     // Gradient with respect to bottom data
     caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1.,
         top_diff, this->blobs_[0]->gpu_data(), (Dtype)0.,
index 071e719..a86c1d4 100644 (file)
@@ -173,7 +173,7 @@ Dtype LRNLayer<Dtype>::WithinChannelForward(
 
 template <typename Dtype>
 void LRNLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+    const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   switch (this->layer_param_.lrn_param().norm_region()) {
   case LRNParameter_NormRegion_ACROSS_CHANNELS:
     CrossChannelBackward_cpu(top, propagate_down, bottom);
@@ -188,7 +188,7 @@ void LRNLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
 
 template <typename Dtype>
 void LRNLayer<Dtype>::CrossChannelBackward_cpu(
-    const vector<Blob<Dtype>*>& top, const bool propagate_down,
+    const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->cpu_diff();
   const Dtype* top_data = top[0]->cpu_data();
@@ -243,14 +243,16 @@ void LRNLayer<Dtype>::CrossChannelBackward_cpu(
 
 template <typename Dtype>
 void LRNLayer<Dtype>::WithinChannelBackward(
-    const vector<Blob<Dtype>*>& top, const bool propagate_down,
+    const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  if (propagate_down) {
-    product_layer_->Backward(top, true, &product_bottom_vec_);
-    power_layer_->Backward(power_top_vec_, true, &pool_top_vec_);
-    pool_layer_->Backward(pool_top_vec_, true, &square_top_vec_);
-    square_layer_->Backward(square_top_vec_, true, &square_bottom_vec_);
-    split_layer_->Backward(split_top_vec_, true, bottom);
+  if (propagate_down[0]) {
+    vector<bool> product_propagate_down(2, true);
+    product_layer_->Backward(top, product_propagate_down, &product_bottom_vec_);
+    power_layer_->Backward(power_top_vec_, propagate_down, &pool_top_vec_);
+    pool_layer_->Backward(pool_top_vec_, propagate_down, &square_top_vec_);
+    square_layer_->Backward(square_top_vec_, propagate_down,
+                            &square_bottom_vec_);
+    split_layer_->Backward(split_top_vec_, propagate_down, bottom);
   }
 }
 
index b2097eb..237a42b 100644 (file)
@@ -104,7 +104,7 @@ Dtype LRNLayer<Dtype>::CrossChannelForward_gpu(
 
 template <typename Dtype>
 void LRNLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+    const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   switch (this->layer_param_.lrn_param().norm_region()) {
   case LRNParameter_NormRegion_ACROSS_CHANNELS:
     CrossChannelBackward_gpu(top, propagate_down, bottom);
@@ -180,7 +180,7 @@ __global__ void LRNComputeDiff(const int nthreads, const Dtype* bottom_data,
 
 template <typename Dtype>
 void LRNLayer<Dtype>::CrossChannelBackward_gpu(
-    const vector<Blob<Dtype>*>& top, const bool propagate_down,
+    const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   int n_threads = num_ * height_ * width_;
   // NOLINT_NEXT_LINE(whitespace/operators)
index 5a40879..013d403 100644 (file)
@@ -43,18 +43,24 @@ Dtype MultinomialLogisticLossLayer<Dtype>::Forward_cpu(
 
 template <typename Dtype>
 void MultinomialLogisticLossLayer<Dtype>::Backward_cpu(
-    const vector<Blob<Dtype>*>& top, const bool propagate_down,
+    const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  const Dtype* bottom_data = (*bottom)[0]->cpu_data();
-  const Dtype* bottom_label = (*bottom)[1]->cpu_data();
-  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
-  int num = (*bottom)[0]->num();
-  int dim = (*bottom)[0]->count() / (*bottom)[0]->num();
-  memset(bottom_diff, 0, sizeof(Dtype) * (*bottom)[0]->count());
-  for (int i = 0; i < num; ++i) {
-    int label = static_cast<int>(bottom_label[i]);
-    Dtype prob = max(bottom_data[i * dim + label], Dtype(kLOG_THRESHOLD));
-    bottom_diff[i * dim + label] = -1. / prob / num;
+  if (propagate_down[1]) {
+    LOG(FATAL) << this->type_name()
+               << " Layer cannot backpropagate to label inputs.";
+  }
+  if (propagate_down[0]) {
+    const Dtype* bottom_data = (*bottom)[0]->cpu_data();
+    const Dtype* bottom_label = (*bottom)[1]->cpu_data();
+    Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+    int num = (*bottom)[0]->num();
+    int dim = (*bottom)[0]->count() / (*bottom)[0]->num();
+    memset(bottom_diff, 0, sizeof(Dtype) * (*bottom)[0]->count());
+    for (int i = 0; i < num; ++i) {
+      int label = static_cast<int>(bottom_label[i]);
+      Dtype prob = max(bottom_data[i * dim + label], Dtype(kLOG_THRESHOLD));
+      bottom_diff[i * dim + label] = -1. / prob / num;
+    }
   }
 }
 
index ba84ede..bc00207 100644 (file)
@@ -184,8 +184,8 @@ Dtype PoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void PoolingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
-  if (!propagate_down) {
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
+  if (!propagate_down[0]) {
     return;
   }
   const Dtype* top_diff = top[0]->cpu_diff();
index abba625..f07fe3c 100644 (file)
@@ -322,8 +322,8 @@ __global__ void StoPoolBackward(const int nthreads,
 
 template <typename Dtype>
 void PoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
-  if (!propagate_down) {
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
+  if (!propagate_down[0]) {
     return;
   }
   const Dtype* top_diff = top[0]->gpu_diff();
index 85c8442..5ff3392 100644 (file)
@@ -49,9 +49,9 @@ Dtype PowerLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void PowerLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  if (propagate_down) {
+  if (propagate_down[0]) {
     Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
     const int count = (*bottom)[0]->count();
     const Dtype* top_diff = top[0]->cpu_diff();
index 9a25de7..6d69963 100644 (file)
@@ -38,9 +38,9 @@ Dtype PowerLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void PowerLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  if (propagate_down) {
+  if (propagate_down[0]) {
     Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
     const int count = (*bottom)[0]->count();
     const Dtype* top_diff = top[0]->gpu_diff();
index 7a33e55..d7a8509 100644 (file)
@@ -24,9 +24,9 @@ Dtype ReLULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void ReLULayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  if (propagate_down) {
+  if (propagate_down[0]) {
     const Dtype* bottom_data = (*bottom)[0]->cpu_data();
     const Dtype* top_diff = top[0]->cpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
index 51e5ef2..a5b24af 100644 (file)
@@ -45,9 +45,9 @@ __global__ void ReLUBackward(const int n, const Dtype* in_diff,
 
 template <typename Dtype>
 void ReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  if (propagate_down) {
+  if (propagate_down[0]) {
     const Dtype* bottom_data = (*bottom)[0]->gpu_data();
     const Dtype* top_diff = top[0]->gpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
index 955581d..8cb830f 100644 (file)
@@ -16,7 +16,7 @@ template <typename Dtype>
 void SigmoidCrossEntropyLossLayer<Dtype>::FurtherSetUp(
     const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
   CHECK_EQ(bottom[0]->count(), bottom[1]->count()) <<
-      "SigmoidCrossEntropyLoss Layer inputs must have same count.";
+      "SIGMOID_CROSS_ENTROPY_LOSS layer inputs must have the same count.";
   sigmoid_bottom_vec_.clear();
   sigmoid_bottom_vec_.push_back(bottom[0]);
   sigmoid_top_vec_.clear();
@@ -49,17 +49,23 @@ Dtype SigmoidCrossEntropyLossLayer<Dtype>::Forward_cpu(
 
 template <typename Dtype>
 void SigmoidCrossEntropyLossLayer<Dtype>::Backward_cpu(
-    const vector<Blob<Dtype>*>& top, const bool propagate_down,
+    const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  // First, compute the diff
-  const int count = (*bottom)[0]->count();
-  const int num = (*bottom)[0]->num();
-  const Dtype* sigmoid_output_data = sigmoid_output_->cpu_data();
-  const Dtype* target = (*bottom)[1]->cpu_data();
-  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
-  caffe_sub(count, sigmoid_output_data, target, bottom_diff);
-  // Scale down gradient
-  caffe_scal(count, Dtype(1) / num, bottom_diff);
+  if (propagate_down[1]) {
+    LOG(FATAL) << this->type_name()
+               << " Layer cannot backpropagate to label inputs.";
+  }
+  if (propagate_down[0]) {
+    // First, compute the diff
+    const int count = (*bottom)[0]->count();
+    const int num = (*bottom)[0]->num();
+    const Dtype* sigmoid_output_data = sigmoid_output_->cpu_data();
+    const Dtype* target = (*bottom)[1]->cpu_data();
+    Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+    caffe_sub(count, sigmoid_output_data, target, bottom_diff);
+    // Scale down gradient
+    caffe_scal(count, Dtype(1) / num, bottom_diff);
+  }
 }
 
 INSTANTIATE_CLASS(SigmoidCrossEntropyLossLayer);
index 0caed2b..8f72758 100644 (file)
@@ -37,18 +37,24 @@ Dtype SigmoidCrossEntropyLossLayer<Dtype>::Forward_gpu(
 
 template <typename Dtype>
 void SigmoidCrossEntropyLossLayer<Dtype>::Backward_gpu(
-    const vector<Blob<Dtype>*>& top, const bool propagate_down,
+    const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  // First, compute the diff
-  const int count = (*bottom)[0]->count();
-  const int num = (*bottom)[0]->num();
-  const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data();
-  const Dtype* target = (*bottom)[1]->gpu_data();
-  Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
-  caffe_gpu_copy(count, sigmoid_output_data, bottom_diff);
-  caffe_gpu_axpy(count, Dtype(-1), target, bottom_diff);
-  // Scale down gradient
-  caffe_gpu_scal(count, Dtype(1) / num, bottom_diff);
+  if (propagate_down[1]) {
+    LOG(FATAL) << this->type_name()
+               << " Layer cannot backpropagate to label inputs.";
+  }
+  if (propagate_down[0]) {
+    // First, compute the diff
+    const int count = (*bottom)[0]->count();
+    const int num = (*bottom)[0]->num();
+    const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data();
+    const Dtype* target = (*bottom)[1]->gpu_data();
+    Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
+    caffe_gpu_copy(count, sigmoid_output_data, bottom_diff);
+    caffe_gpu_axpy(count, Dtype(-1), target, bottom_diff);
+    // Scale down gradient
+    caffe_gpu_scal(count, Dtype(1) / num, bottom_diff);
+  }
 }
 
 INSTANTIATE_CLASS(SigmoidCrossEntropyLossLayer);
index 88a7920..50139d8 100644 (file)
@@ -28,9 +28,9 @@ Dtype SigmoidLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void SigmoidLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  if (propagate_down) {
+  if (propagate_down[0]) {
     const Dtype* top_data = top[0]->cpu_data();
     const Dtype* top_diff = top[0]->cpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
index aa8568a..d55968e 100644 (file)
@@ -47,9 +47,9 @@ __global__ void SigmoidBackward(const int n, const Dtype* in_diff,
 
 template <typename Dtype>
 void SigmoidLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  if (propagate_down) {
+  if (propagate_down[0]) {
     const Dtype* top_data = top[0]->gpu_data();
     const Dtype* top_diff = top[0]->gpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
index dbe16da..57847d0 100644 (file)
@@ -60,7 +60,7 @@ Dtype SoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void SoftmaxLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->cpu_diff();
   const Dtype* top_data = top[0]->cpu_data();
index a264a81..f53883c 100644 (file)
@@ -79,7 +79,7 @@ Dtype SoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
 // TODO(Yangqing): implement the GPU version of softmax.
 template <typename Dtype>
 void SoftmaxLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+    const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->gpu_diff();
   const Dtype* top_data = top[0]->gpu_data();
   Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
index bdb3272..1a3601a 100644 (file)
@@ -57,20 +57,25 @@ Dtype SoftmaxWithLossLayer<Dtype>::Forward_cpu(
 
 template <typename Dtype>
 void SoftmaxWithLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  // Compute the diff
-  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
-  const Dtype* prob_data = prob_.cpu_data();
-  memcpy(bottom_diff, prob_data, sizeof(Dtype) * prob_.count());
-  const Dtype* label = (*bottom)[1]->cpu_data();
-  int num = prob_.num();
-  int dim = prob_.count() / num;
-  for (int i = 0; i < num; ++i) {
-    bottom_diff[i * dim + static_cast<int>(label[i])] -= 1;
+  if (propagate_down[1]) {
+    LOG(FATAL) << this->type_name()
+               << " Layer cannot backpropagate to label inputs.";
+  }
+  if (propagate_down[0]) {
+    Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+    const Dtype* prob_data = prob_.cpu_data();
+    memcpy(bottom_diff, prob_data, sizeof(Dtype) * prob_.count());
+    const Dtype* label = (*bottom)[1]->cpu_data();
+    int num = prob_.num();
+    int dim = prob_.count() / num;
+    for (int i = 0; i < num; ++i) {
+      bottom_diff[i * dim + static_cast<int>(label[i])] -= 1;
+    }
+    // Scale down gradient
+    caffe_scal(prob_.count(), Dtype(1) / num, bottom_diff);
   }
-  // Scale down gradient
-  caffe_scal(prob_.count(), Dtype(1) / num, bottom_diff);
 }
 
 
index 24a3c38..e46be6b 100644 (file)
@@ -21,7 +21,7 @@ Dtype SoftmaxWithLossLayer<Dtype>::Forward_gpu(
 
 template <typename Dtype>
 void SoftmaxWithLossLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
+    const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   // TODO(Yangqing): implement the GPU version of softmax.
   Backward_cpu(top, propagate_down, bottom);
 }
index 2f99ca1..28abd95 100644 (file)
@@ -37,8 +37,8 @@ Dtype SplitLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void SplitLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
-  if (propagate_down) {
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
+  if (propagate_down[0]) {
     (*bottom)[0]->ShareDiff(*top[0]);
     // Add remaining top blob diffs.
     Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
index e2269b8..4c921d3 100644 (file)
@@ -19,8 +19,8 @@ Dtype SplitLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void SplitLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-      const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
-  if (propagate_down) {
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
+  if (propagate_down[0]) {
     (*bottom)[0]->ShareDiff(*top[0]);
     // Add remaining top blob diffs.
     Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
index 77b4410..6b5166d 100644 (file)
@@ -26,9 +26,9 @@ Dtype TanHLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 void TanHLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  if (propagate_down) {
+  if (propagate_down[0]) {
     const Dtype* top_data = top[0]->cpu_data();
     const Dtype* top_diff = top[0]->cpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
index aa822d8..9307cab 100644 (file)
@@ -42,9 +42,9 @@ __global__ void TanHBackward(const int n, const Dtype* in_diff,
 
 template <typename Dtype>
 void TanHLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
-    const bool propagate_down,
+    const vector<bool>& propagate_down,
     vector<Blob<Dtype>*>* bottom) {
-  if (propagate_down) {
+  if (propagate_down[0]) {
     const Dtype* top_data = top[0]->gpu_data();
     const Dtype* top_diff = top[0]->gpu_diff();
     Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
index d76d284..8687d00 100644 (file)
@@ -56,6 +56,7 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
   top_vecs_.resize(param.layers_size());
   bottom_id_vecs_.resize(param.layers_size());
   top_id_vecs_.resize(param.layers_size());
+  bottom_need_backward_.resize(param.layers_size());
   for (int layer_id = 0; layer_id < param.layers_size(); ++layer_id) {
     const LayerParameter& layer_param = param.layers(layer_id);
     layers_.push_back(shared_ptr<Layer<Dtype> >(GetLayer<Dtype>(layer_param)));
@@ -196,6 +197,8 @@ int Net<Dtype>::AppendBottom(const NetParameter& param,
   bottom_vecs_[layer_id].push_back(blobs_[blob_id].get());
   bottom_id_vecs_[layer_id].push_back(blob_id);
   available_blobs->erase(blob_name);
+  const bool need_backward = param.force_backward() || blob_need_backward_[blob_id];
+  bottom_need_backward_[layer_id].push_back(need_backward);
   return blob_id;
 }
 
@@ -288,7 +291,8 @@ template <typename Dtype>
 void Net<Dtype>::Backward() {
   for (int i = layers_.size() - 1; i >= 0; --i) {
     if (layer_need_backward_[i]) {
-      layers_[i]->Backward(top_vecs_[i], true, &bottom_vecs_[i]);
+      layers_[i]->Backward(
+          top_vecs_[i], bottom_need_backward_[i], &bottom_vecs_[i]);
     }
   }
 }
index d5e4107..02945b6 100644 (file)
@@ -47,13 +47,23 @@ typedef ::testing::Types<float, double> Dtypes;
 TYPED_TEST_CASE(EuclideanLossLayerTest, Dtypes);
 
 TYPED_TEST(EuclideanLossLayerTest, TestGradientCPU) {
-  LayerParameter layer_param;
   Caffe::set_mode(Caffe::CPU);
+  LayerParameter layer_param;
+  EuclideanLossLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
+  GradientChecker<TypeParam> checker(1e-2, 1e-2, 1701);
+  checker.CheckGradientSingle(&layer, &(this->blob_bottom_vec_),
+      &(this->blob_top_vec_), -1, -1, -1);
+}
+
+TYPED_TEST(EuclideanLossLayerTest, TestGradientGPU) {
+  Caffe::set_mode(Caffe::GPU);
+  LayerParameter layer_param;
   EuclideanLossLayer<TypeParam> layer(layer_param);
   layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
   GradientChecker<TypeParam> checker(1e-2, 1e-2, 1701);
   checker.CheckGradientSingle(&layer, &(this->blob_bottom_vec_),
-      &(this->blob_top_vec_), 0, -1, -1);
+      &(this->blob_top_vec_), -1, -1, -1);
 }
 
 }  // namespace caffe
index bcf0397..ff104b9 100644 (file)
@@ -86,6 +86,7 @@ void GradientChecker<Dtype>::CheckGradientSingle(Layer<Dtype>* layer,
   }
   // First, figure out what blobs we need to check against.
   vector<Blob<Dtype>*> blobs_to_check;
+  vector<bool> propagate_down(bottom->size(), check_bottom < 0);
   for (int i = 0; i < layer->blobs().size(); ++i) {
     blobs_to_check.push_back(layer->blobs()[i].get());
   }
@@ -94,8 +95,9 @@ void GradientChecker<Dtype>::CheckGradientSingle(Layer<Dtype>* layer,
       blobs_to_check.push_back((*bottom)[i]);
     }
   } else {
-    CHECK(check_bottom < bottom->size());
+    CHECK_LT(check_bottom, bottom->size());
     blobs_to_check.push_back((*bottom)[check_bottom]);
+    propagate_down[check_bottom] = true;
   }
   // Compute the gradient analytically using Backward
   Caffe::set_random_seed(seed_);
@@ -103,7 +105,7 @@ void GradientChecker<Dtype>::CheckGradientSingle(Layer<Dtype>* layer,
   Dtype computed_objective = layer->Forward(*bottom, top);
   // Get additional loss from the objective
   computed_objective += GetObjAndGradient(top, top_id, top_data_id);
-  layer->Backward(*top, true, bottom);
+  layer->Backward(*top, propagate_down, bottom);
   // Store computed gradients for all checked blobs
   vector<shared_ptr<Blob<Dtype> > >
       computed_gradient_blobs(blobs_to_check.size());
index 7f0a211..6996a23 100644 (file)
@@ -164,7 +164,9 @@ TYPED_TEST(LRNLayerTest, TestCPUGradientAcrossChannels) {
   for (int i = 0; i < this->blob_top_->count(); ++i) {
     this->blob_top_->mutable_cpu_diff()[i] = 1.;
   }
-  layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_));
+  vector<bool> propagate_down(this->blob_bottom_vec_.size(), true);
+  layer.Backward(this->blob_top_vec_, propagate_down,
+                 &(this->blob_bottom_vec_));
   // for (int i = 0; i < this->blob_bottom_->count(); ++i) {
   //   std::cout << "CPU diff " << this->blob_bottom_->cpu_diff()[i]
   //       << std::endl;
@@ -183,7 +185,9 @@ TYPED_TEST(LRNLayerTest, TestGPUGradientAcrossChannels) {
   for (int i = 0; i < this->blob_top_->count(); ++i) {
     this->blob_top_->mutable_cpu_diff()[i] = 1.;
   }
-  layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_));
+  vector<bool> propagate_down(this->blob_bottom_vec_.size(), true);
+  layer.Backward(this->blob_top_vec_, propagate_down,
+                 &(this->blob_bottom_vec_));
   // for (int i = 0; i < this->blob_bottom_->count(); ++i) {
   //   std::cout << "GPU diff " << this->blob_bottom_->cpu_diff()[i]
   //       << std::endl;
@@ -254,7 +258,6 @@ TYPED_TEST(LRNLayerTest, TestCPUGradientWithinChannel) {
   for (int i = 0; i < this->blob_top_->count(); ++i) {
     this->blob_top_->mutable_cpu_diff()[i] = 1.;
   }
-  layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_));
   checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
       &(this->blob_top_vec_));
 }
@@ -272,7 +275,6 @@ TYPED_TEST(LRNLayerTest, TestGPUGradientWithinChannel) {
   for (int i = 0; i < this->blob_top_->count(); ++i) {
     this->blob_top_->mutable_cpu_diff()[i] = 1.;
   }
-  layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_));
   checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
       &(this->blob_top_vec_));
 }
index 3862e12..ac6f86f 100644 (file)
@@ -130,7 +130,9 @@ TYPED_TEST(MaxPoolingDropoutTest, CPUBackward) {
   for (int i = 0; i < this->blob_top_->count(); ++i) {
     this->blob_top_->mutable_cpu_diff()[i] = 1.;
   }
-  layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_));
+  vector<bool> propagate_down(this->blob_bottom_vec_.size(), true);
+  layer.Backward(this->blob_top_vec_, propagate_down,
+                 &(this->blob_bottom_vec_));
   const TypeParam* bottom_diff = this->blob_bottom_->cpu_diff();
   TypeParam sum = 0.;
   for (int i = 0; i < this->blob_bottom_->count(); ++i) {
@@ -141,8 +143,10 @@ TYPED_TEST(MaxPoolingDropoutTest, CPUBackward) {
   DropoutLayer<TypeParam> dropout_layer(layer_param);
   dropout_layer.SetUp(this->blob_top_vec_, &(this->blob_top_vec_));
   dropout_layer.Forward(this->blob_top_vec_, &(this->blob_top_vec_));
-  dropout_layer.Backward(this->blob_top_vec_, true, &(this->blob_top_vec_));
-  layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_));
+  dropout_layer.Backward(this->blob_top_vec_, propagate_down,
+                         &(this->blob_top_vec_));
+  layer.Backward(this->blob_top_vec_, propagate_down,
+                 &(this->blob_bottom_vec_));
   TypeParam sum_with_dropout = 0.;
   bottom_diff = this->blob_bottom_->cpu_diff();
   for (int i = 0; i < this->blob_bottom_->count(); ++i) {
@@ -164,7 +168,9 @@ TYPED_TEST(MaxPoolingDropoutTest, GPUBackward) {
   for (int i = 0; i < this->blob_top_->count(); ++i) {
     this->blob_top_->mutable_cpu_diff()[i] = 1.;
   }
-  layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_));
+  vector<bool> propagate_down(this->blob_bottom_vec_.size(), true);
+  layer.Backward(this->blob_top_vec_, propagate_down,
+                 &(this->blob_bottom_vec_));
   const TypeParam* bottom_diff = this->blob_bottom_->cpu_diff();
   TypeParam sum = 0.;
   for (int i = 0; i < this->blob_bottom_->count(); ++i) {
@@ -175,8 +181,10 @@ TYPED_TEST(MaxPoolingDropoutTest, GPUBackward) {
   DropoutLayer<TypeParam> dropout_layer(layer_param);
   dropout_layer.SetUp(this->blob_top_vec_, &(this->blob_top_vec_));
   dropout_layer.Forward(this->blob_top_vec_, &(this->blob_top_vec_));
-  dropout_layer.Backward(this->blob_top_vec_, true, &(this->blob_top_vec_));
-  layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_));
+  dropout_layer.Backward(this->blob_top_vec_, propagate_down,
+                         &(this->blob_top_vec_));
+  layer.Backward(this->blob_top_vec_, propagate_down,
+                 &(this->blob_bottom_vec_));
   TypeParam sum_with_dropout = 0.;
   bottom_diff = this->blob_bottom_->cpu_diff();
   for (int i = 0; i < this->blob_bottom_->count(); ++i) {
index 7f2f67b..eb368cd 100644 (file)
@@ -1,7 +1,9 @@
 // Copyright 2014 BVLC and contributors.
 
-#include <google/protobuf/text_format.h>
 #include <string>
+#include <vector>
+
+#include "google/protobuf/text_format.h"
 
 #include "gtest/gtest.h"
 #include "caffe/common.hpp"
@@ -15,9 +17,15 @@ namespace caffe {
 template <typename Dtype>
 class NetTest : public ::testing::Test {
  protected:
-  virtual void SetUp() {
-    const string& proto =
-        "name: 'TestNetwork' "
+  virtual void InitNetFromProtoString(const string& proto) {
+    NetParameter param;
+    CHECK(google::protobuf::TextFormat::ParseFromString(proto, &param));
+    net_.reset(new Net<Dtype>(param));
+  }
+
+  virtual void InitTinyNet(const bool force_backward = false) {
+    string proto =
+        "name: 'TinyTestNetwork' "
         "layers: { "
         "  name: 'data' "
         "  type: DUMMY_DATA "
@@ -66,9 +74,84 @@ class NetTest : public ::testing::Test {
         "  bottom: 'label' "
         "  top: 'top_loss' "
         "} ";
-    NetParameter param;
-    CHECK(google::protobuf::TextFormat::ParseFromString(proto, &param));
-    net_.reset(new Net<Dtype>(param));
+    if (force_backward) {
+      proto += "force_backward: true ";
+    }
+    InitNetFromProtoString(proto);
+  }
+
+  virtual void InitTrickyNet() {
+    const string& proto =
+        "name: 'TrickyTestNetwork' "
+        "layers: { "
+        "  name: 'data' "
+        "  type: DUMMY_DATA "
+        "  dummy_data_param { "
+        "    num: 5 "
+        "    channels: 2 "
+        "    height: 3 "
+        "    width: 4 "
+        "    num: 5 "
+        "    channels: 1 "
+        "    height: 1 "
+        "    width: 1 "
+        "    data_filler { "
+        "      type: 'gaussian' "
+        "      std: 0.01 "
+        "    } "
+        "  } "
+        "  top: 'data' "
+        "  top: 'label' "
+        "} "
+        "layers: { "
+        "  name: 'innerproduct' "
+        "  type: INNER_PRODUCT "
+        "  inner_product_param { "
+        "    num_output: 1000 "
+        "    weight_filler { "
+        "      type: 'gaussian' "
+        "      std: 0.01 "
+        "    } "
+        "    bias_filler { "
+        "      type: 'constant' "
+        "      value: 0 "
+        "    } "
+        "  } "
+        "  blobs_lr: 1. "
+        "  blobs_lr: 2. "
+        "  weight_decay: 1. "
+        "  weight_decay: 0. "
+        "  bottom: 'data' "
+        "  top: 'transformed_data' "
+        "} "
+        "layers: { "
+        "  name: 'innerproduct' "
+        "  type: INNER_PRODUCT "
+        "  inner_product_param { "
+        "    num_output: 1 "
+        "    weight_filler { "
+        "      type: 'gaussian' "
+        "      std: 0.01 "
+        "    } "
+        "    bias_filler { "
+        "      type: 'constant' "
+        "      value: 0 "
+        "    } "
+        "  } "
+        "  blobs_lr: 1. "
+        "  blobs_lr: 2. "
+        "  weight_decay: 1. "
+        "  weight_decay: 0. "
+        "  bottom: 'label' "
+        "  top: 'transformed_label' "
+        "} "
+        "layers: { "
+        "  name: 'loss' "
+        "  type: SOFTMAX_LOSS "
+        "  bottom: 'transformed_data' "
+        "  bottom: 'transformed_label' "
+        "} ";
+    InitNetFromProtoString(proto);
   }
 
   shared_ptr<Net<Dtype> > net_;
@@ -78,6 +161,7 @@ typedef ::testing::Types<float, double> Dtypes;
 TYPED_TEST_CASE(NetTest, Dtypes);
 
 TYPED_TEST(NetTest, TestHasBlob) {
+  this->InitTinyNet();
   EXPECT_TRUE(this->net_->has_blob("data"));
   EXPECT_TRUE(this->net_->has_blob("label"));
   EXPECT_TRUE(this->net_->has_blob("innerproduct"));
@@ -86,6 +170,7 @@ TYPED_TEST(NetTest, TestHasBlob) {
 }
 
 TYPED_TEST(NetTest, TestGetBlob) {
+  this->InitTinyNet();
   EXPECT_EQ(this->net_->blob_by_name("data"), this->net_->blobs()[0]);
   EXPECT_EQ(this->net_->blob_by_name("label"), this->net_->blobs()[1]);
   EXPECT_EQ(this->net_->blob_by_name("innerproduct"), this->net_->blobs()[2]);
@@ -94,6 +179,7 @@ TYPED_TEST(NetTest, TestGetBlob) {
 }
 
 TYPED_TEST(NetTest, TestHasLayer) {
+  this->InitTinyNet();
   EXPECT_TRUE(this->net_->has_layer("data"));
   EXPECT_TRUE(this->net_->has_layer("innerproduct"));
   EXPECT_TRUE(this->net_->has_layer("loss"));
@@ -101,10 +187,53 @@ TYPED_TEST(NetTest, TestHasLayer) {
 }
 
 TYPED_TEST(NetTest, TestGetLayerByName) {
+  this->InitTinyNet();
   EXPECT_EQ(this->net_->layer_by_name("data"), this->net_->layers()[0]);
   EXPECT_EQ(this->net_->layer_by_name("innerproduct"), this->net_->layers()[1]);
   EXPECT_EQ(this->net_->layer_by_name("loss"), this->net_->layers()[2]);
   EXPECT_FALSE(this->net_->layer_by_name("label"));
 }
 
+TYPED_TEST(NetTest, TestBottomNeedBackward) {
+  this->InitTinyNet();
+  const vector<vector<bool> >& bottom_need_backward =
+      this->net_->bottom_need_backward();
+  EXPECT_EQ(3, bottom_need_backward.size());
+  EXPECT_EQ(0, bottom_need_backward[0].size());
+  EXPECT_EQ(1, bottom_need_backward[1].size());
+  EXPECT_EQ(false, bottom_need_backward[1][0]);
+  EXPECT_EQ(2, bottom_need_backward[2].size());
+  EXPECT_EQ(true, bottom_need_backward[2][0]);
+  EXPECT_EQ(false, bottom_need_backward[2][1]);
+}
+
+TYPED_TEST(NetTest, TestBottomNeedBackwardForce) {
+  const bool force_backward = true;
+  this->InitTinyNet(force_backward);
+  const vector<vector<bool> >& bottom_need_backward =
+      this->net_->bottom_need_backward();
+  EXPECT_EQ(3, bottom_need_backward.size());
+  EXPECT_EQ(0, bottom_need_backward[0].size());
+  EXPECT_EQ(1, bottom_need_backward[1].size());
+  EXPECT_EQ(true, bottom_need_backward[1][0]);
+  EXPECT_EQ(2, bottom_need_backward[2].size());
+  EXPECT_EQ(true, bottom_need_backward[2][0]);
+  EXPECT_EQ(true, bottom_need_backward[2][1]);
+}
+
+TYPED_TEST(NetTest, TestBottomNeedBackwardTricky) {
+  this->InitTrickyNet();
+  const vector<vector<bool> >& bottom_need_backward =
+      this->net_->bottom_need_backward();
+  EXPECT_EQ(4, bottom_need_backward.size());
+  EXPECT_EQ(0, bottom_need_backward[0].size());
+  EXPECT_EQ(1, bottom_need_backward[1].size());
+  EXPECT_EQ(false, bottom_need_backward[1][0]);
+  EXPECT_EQ(1, bottom_need_backward[2].size());
+  EXPECT_EQ(false, bottom_need_backward[2][0]);
+  EXPECT_EQ(2, bottom_need_backward[3].size());
+  EXPECT_EQ(true, bottom_need_backward[3][0]);
+  EXPECT_EQ(true, bottom_need_backward[3][1]);
+}
+
 }  // namespace caffe
index 36a0077..4b65cfd 100644 (file)
@@ -64,6 +64,8 @@ int main(int argc, char** argv) {
   const vector<shared_ptr<Layer<float> > >& layers = caffe_net.layers();
   vector<vector<Blob<float>*> >& bottom_vecs = caffe_net.bottom_vecs();
   vector<vector<Blob<float>*> >& top_vecs = caffe_net.top_vecs();
+  const vector<vector<bool> >& bottom_need_backward =
+      caffe_net.bottom_need_backward();
   LOG(ERROR) << "*** Benchmark begins ***";
   Timer total_timer;
   total_timer.Start();
@@ -87,7 +89,8 @@ int main(int argc, char** argv) {
     const string& layername = layers[i]->layer_param().name();
     timer.Start();
     for (int j = 0; j < total_iter; ++j) {
-      layers[i]->Backward(top_vecs[i], true, &bottom_vecs[i]);
+      layers[i]->Backward(top_vecs[i], bottom_need_backward[i],
+                          &bottom_vecs[i]);
     }
     LOG(ERROR) << layername << "\tbackward: "
         << timer.MilliSeconds() << " milli seconds.";