From 74443591d62feb2755d98ed87fb4aecc22d78297 Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Tue, 10 Jun 2014 12:38:59 -0700 Subject: [PATCH] change Backward interface: propagate_down is a vector -- use to fix long-standing issue with how this is handled in loss layers (esp. EuclideanLossLayer) --- include/caffe/data_layers.hpp | 28 ++-- include/caffe/layer.hpp | 8 +- include/caffe/loss_layers.hpp | 27 ++-- include/caffe/net.hpp | 4 + include/caffe/neuron_layers.hpp | 26 ++-- include/caffe/vision_layers.hpp | 48 +++---- src/caffe/layers/bnll_layer.cpp | 4 +- src/caffe/layers/bnll_layer.cu | 4 +- src/caffe/layers/concat_layer.cpp | 22 ++-- src/caffe/layers/concat_layer.cu | 22 ++-- src/caffe/layers/conv_layer.cpp | 4 +- src/caffe/layers/conv_layer.cu | 4 +- src/caffe/layers/dropout_layer.cpp | 4 +- src/caffe/layers/dropout_layer.cu | 4 +- src/caffe/layers/eltwise_layer.cpp | 12 +- src/caffe/layers/eltwise_layer.cu | 12 +- src/caffe/layers/euclidean_loss_layer.cpp | 24 ++-- src/caffe/layers/euclidean_loss_layer.cu | 45 +++++++ src/caffe/layers/flatten_layer.cpp | 2 +- src/caffe/layers/flatten_layer.cu | 2 +- src/caffe/layers/hdf5_data_layer.cpp | 5 - src/caffe/layers/hdf5_data_layer.cu | 5 - src/caffe/layers/hdf5_output_layer.cpp | 2 +- src/caffe/layers/hdf5_output_layer.cu | 2 +- src/caffe/layers/hinge_loss_layer.cpp | 44 ++++--- src/caffe/layers/im2col_layer.cpp | 2 +- src/caffe/layers/im2col_layer.cu | 2 +- src/caffe/layers/infogain_loss_layer.cpp | 32 +++-- src/caffe/layers/inner_product_layer.cpp | 4 +- src/caffe/layers/inner_product_layer.cu | 4 +- src/caffe/layers/lrn_layer.cpp | 20 +-- src/caffe/layers/lrn_layer.cu | 4 +- .../layers/multinomial_logistic_loss_layer.cpp | 28 ++-- src/caffe/layers/pooling_layer.cpp | 4 +- src/caffe/layers/pooling_layer.cu | 4 +- src/caffe/layers/power_layer.cpp | 4 +- src/caffe/layers/power_layer.cu | 4 +- src/caffe/layers/relu_layer.cpp | 4 +- src/caffe/layers/relu_layer.cu | 4 +- .../layers/sigmoid_cross_entropy_loss_layer.cpp | 28 ++-- .../layers/sigmoid_cross_entropy_loss_layer.cu | 28 ++-- src/caffe/layers/sigmoid_layer.cpp | 4 +- src/caffe/layers/sigmoid_layer.cu | 4 +- src/caffe/layers/softmax_layer.cpp | 2 +- src/caffe/layers/softmax_layer.cu | 2 +- src/caffe/layers/softmax_loss_layer.cpp | 29 +++-- src/caffe/layers/softmax_loss_layer.cu | 2 +- src/caffe/layers/split_layer.cpp | 4 +- src/caffe/layers/split_layer.cu | 4 +- src/caffe/layers/tanh_layer.cpp | 4 +- src/caffe/layers/tanh_layer.cu | 4 +- src/caffe/net.cpp | 6 +- src/caffe/test/test_euclidean_loss_layer.cpp | 14 +- src/caffe/test/test_gradient_check_util.hpp | 6 +- src/caffe/test/test_lrn_layer.cpp | 10 +- src/caffe/test/test_maxpool_dropout_layers.cpp | 20 ++- src/caffe/test/test_net.cpp | 143 ++++++++++++++++++++- tools/net_speed_benchmark.cpp | 5 +- 58 files changed, 522 insertions(+), 281 deletions(-) create mode 100644 src/caffe/layers/euclidean_loss_layer.cu diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp index fc93daa..2c6be55 100644 --- a/include/caffe/data_layers.hpp +++ b/include/caffe/data_layers.hpp @@ -47,9 +47,9 @@ class HDF5OutputLayer : public Layer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void SaveBlobs(); std::string file_name_; @@ -80,9 +80,9 @@ class HDF5DataLayer : public Layer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom) {} virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom) {} virtual void LoadHDF5FileData(const char* filename); std::vector hdf_filenames_; @@ -125,9 +125,9 @@ class DataLayer : public Layer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { return; } + const vector& propagate_down, vector*>* bottom) {} virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { return; } + const vector& propagate_down, vector*>* bottom) {} virtual void CreatePrefetchThread(); virtual void JoinPrefetchThread(); @@ -175,9 +175,9 @@ class DummyDataLayer : public Layer { virtual Dtype Forward_cpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { return; } + const vector& propagate_down, vector*>* bottom) {} virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { return; } + const vector& propagate_down, vector*>* bottom) {} vector > > fillers_; vector refill_; @@ -211,9 +211,9 @@ class ImageDataLayer : public Layer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { return; } + const vector& propagate_down, vector*>* bottom) {} virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { return; } + const vector& propagate_down, vector*>* bottom) {} virtual void ShuffleImages(); @@ -263,9 +263,9 @@ class MemoryDataLayer : public Layer { virtual Dtype Forward_cpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { return; } + const vector& propagate_down, vector*>* bottom) {} virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { return; } + const vector& propagate_down, vector*>* bottom) {} Dtype* data_; Dtype* labels_; @@ -306,9 +306,9 @@ class WindowDataLayer : public Layer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { return; } + const vector& propagate_down, vector*>* bottom) {} virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { return; } + const vector& propagate_down, vector*>* bottom) {} virtual void CreatePrefetchThread(); virtual void JoinPrefetchThread(); diff --git a/include/caffe/layer.hpp b/include/caffe/layer.hpp index af3d544..12e7610 100644 --- a/include/caffe/layer.hpp +++ b/include/caffe/layer.hpp @@ -46,7 +46,7 @@ class Layer { inline Dtype Forward(const vector*>& bottom, vector*>* top); inline void Backward(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom); // Returns the vector of blobs. @@ -104,10 +104,10 @@ class Layer { // Backward functions: compute the gradients for any parameters and // for the bottom blobs if propagate_down is true. virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) = 0; virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { // LOG(WARNING) << "Using CPU code as backup."; Backward_cpu(top, propagate_down, bottom); @@ -172,7 +172,7 @@ inline Dtype Layer::Forward(const vector*>& bottom, template inline void Layer::Backward(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { switch (Caffe::mode()) { case Caffe::CPU: diff --git a/include/caffe/loss_layers.hpp b/include/caffe/loss_layers.hpp index db7c63e..b8adc99 100644 --- a/include/caffe/loss_layers.hpp +++ b/include/caffe/loss_layers.hpp @@ -62,7 +62,6 @@ class SoftmaxWithLossLayer : public Layer { virtual inline LayerParameter_LayerType type() const { return LayerParameter_LayerType_SOFTMAX_LOSS; } - virtual inline int ExactNumBottomBlobs() const { return 2; } virtual inline int MaxTopBlobs() const { return 2; } protected: @@ -71,9 +70,9 @@ class SoftmaxWithLossLayer : public Layer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); shared_ptr > softmax_layer_; // prob stores the output probability of the layer. @@ -105,9 +104,9 @@ class SigmoidCrossEntropyLossLayer : public LossLayer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); shared_ptr > sigmoid_layer_; // sigmoid_output stores the output of the sigmoid layer. @@ -138,8 +137,12 @@ class EuclideanLossLayer : public LossLayer { protected: virtual Dtype Forward_cpu(const vector*>& bottom, vector*>* top); + virtual Dtype Forward_gpu(const vector*>& bottom, + vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, vector*>* bottom); Blob diff_; }; @@ -162,7 +165,7 @@ class InfogainLossLayer : public LossLayer { virtual Dtype Forward_cpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); Blob infogain_; }; @@ -183,7 +186,7 @@ class HingeLossLayer : public LossLayer { virtual Dtype Forward_cpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); }; /* MultinomialLogisticLossLayer @@ -204,7 +207,7 @@ class MultinomialLogisticLossLayer : public LossLayer { virtual Dtype Forward_cpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); }; /* AccuracyLayer @@ -230,15 +233,11 @@ class AccuracyLayer : public Layer { virtual Dtype Forward_cpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { + const vector& propagate_down, vector*>* bottom) { NOT_IMPLEMENTED; } }; -/* Also see -- SoftmaxWithLossLayer in vision_layers.hpp -*/ - } // namespace caffe #endif // CAFFE_LOSS_LAYERS_HPP_ diff --git a/include/caffe/net.hpp b/include/caffe/net.hpp index 5b5dd22..d05ca09 100644 --- a/include/caffe/net.hpp +++ b/include/caffe/net.hpp @@ -81,6 +81,9 @@ class Net { // this unless you do per-layer checks such as gradients. inline vector*> >& bottom_vecs() { return bottom_vecs_; } inline vector*> >& top_vecs() { return top_vecs_; } + inline vector >& bottom_need_backward() { + return bottom_need_backward_; + } // returns the parameters inline vector > >& params() { return params_; } // returns the parameter learning rate multipliers @@ -131,6 +134,7 @@ class Net { // pointers. vector*> > bottom_vecs_; vector > bottom_id_vecs_; + vector > bottom_need_backward_; // top_vecs stores the vectors containing the output for each layer vector*> > top_vecs_; vector > top_id_vecs_; diff --git a/include/caffe/neuron_layers.hpp b/include/caffe/neuron_layers.hpp index ed664df..e52e395 100644 --- a/include/caffe/neuron_layers.hpp +++ b/include/caffe/neuron_layers.hpp @@ -64,9 +64,9 @@ class BNLLLayer : public NeuronLayer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); }; /* DropoutLayer @@ -97,9 +97,9 @@ class DropoutLayer : public NeuronLayer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); shared_ptr > rand_vec_; Dtype threshold_; @@ -131,9 +131,9 @@ class PowerLayer : public NeuronLayer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); Dtype power_; Dtype scale_; @@ -167,9 +167,9 @@ class ReLULayer : public NeuronLayer { vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); }; /* SigmoidLayer @@ -199,9 +199,9 @@ class SigmoidLayer : public NeuronLayer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); }; /* TanHLayer @@ -227,9 +227,9 @@ class TanHLayer : public NeuronLayer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); }; /* ThresholdLayer @@ -260,7 +260,7 @@ class ThresholdLayer : public NeuronLayer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { + const vector& propagate_down, vector*>* bottom) { NOT_IMPLEMENTED; } diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 3fd7e2f..b68dcbf 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -44,7 +44,7 @@ class ArgMaxLayer : public Layer { virtual Dtype Forward_cpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { + const vector& propagate_down, vector*>* bottom) { NOT_IMPLEMENTED; } bool out_max_val_; @@ -74,9 +74,9 @@ class ConcatLayer : public Layer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); Blob col_bob_; int count_; @@ -109,9 +109,9 @@ class ConvolutionLayer : public Layer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); int kernel_size_; int stride_; @@ -153,9 +153,9 @@ class EltwiseLayer : public Layer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); EltwiseParameter_EltwiseOp op_; vector coeffs_; @@ -183,9 +183,9 @@ class FlattenLayer : public Layer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); int count_; }; @@ -212,9 +212,9 @@ class Im2colLayer : public Layer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); int kernel_size_; int stride_; @@ -246,9 +246,9 @@ class InnerProductLayer : public Layer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); int M_; int K_; @@ -284,9 +284,9 @@ class LRNLayer : public Layer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual Dtype CrossChannelForward_cpu(const vector*>& bottom, vector*>* top); @@ -295,11 +295,11 @@ class LRNLayer : public Layer { virtual Dtype WithinChannelForward(const vector*>& bottom, vector*>* top); virtual void CrossChannelBackward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void CrossChannelBackward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void WithinChannelBackward(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); int size_; int pre_pad_; @@ -356,9 +356,9 @@ class PoolingLayer : public Layer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); int max_top_blobs_; int kernel_size_; @@ -395,9 +395,9 @@ class SoftmaxLayer : public Layer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); // sum_multiplier is just used to carry out sum using blas Blob sum_multiplier_; @@ -427,9 +427,9 @@ class SplitLayer : public Layer { virtual Dtype Forward_gpu(const vector*>& bottom, vector*>* top); virtual void Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); virtual void Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom); + const vector& propagate_down, vector*>* bottom); int count_; }; diff --git a/src/caffe/layers/bnll_layer.cpp b/src/caffe/layers/bnll_layer.cpp index d08adc4..95e6bd8 100644 --- a/src/caffe/layers/bnll_layer.cpp +++ b/src/caffe/layers/bnll_layer.cpp @@ -28,9 +28,9 @@ Dtype BNLLLayer::Forward_cpu(const vector*>& bottom, template void BNLLLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { - if (propagate_down) { + if (propagate_down[0]) { const Dtype* bottom_data = (*bottom)[0]->cpu_data(); const Dtype* top_diff = top[0]->cpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); diff --git a/src/caffe/layers/bnll_layer.cu b/src/caffe/layers/bnll_layer.cu index 75bea00..7849d11 100644 --- a/src/caffe/layers/bnll_layer.cu +++ b/src/caffe/layers/bnll_layer.cu @@ -45,9 +45,9 @@ __global__ void BNLLBackward(const int n, const Dtype* in_diff, template void BNLLLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { - if (propagate_down) { + if (propagate_down[0]) { const Dtype* bottom_data = (*bottom)[0]->gpu_data(); const Dtype* top_diff = top[0]->gpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); diff --git a/src/caffe/layers/concat_layer.cpp b/src/caffe/layers/concat_layer.cpp index 4541ee7..4c894dd 100644 --- a/src/caffe/layers/concat_layer.cpp +++ b/src/caffe/layers/concat_layer.cpp @@ -70,26 +70,30 @@ Dtype ConcatLayer::Forward_cpu(const vector*>& bottom, template void ConcatLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { + const vector& propagate_down, vector*>* bottom) { const Dtype* top_diff = top[0]->cpu_diff(); if (concat_dim_ == 0) { int offset_num = 0; for (int i = 0; i < bottom->size(); ++i) { Blob* blob = (*bottom)[i]; - Dtype* bottom_diff = blob->mutable_cpu_diff(); - caffe_copy(blob->count(), - top_diff+top[0]->offset(offset_num), bottom_diff); + if (propagate_down[i]) { + Dtype* bottom_diff = blob->mutable_cpu_diff(); + caffe_copy(blob->count(), top_diff + top[0]->offset(offset_num), + bottom_diff); + } offset_num += blob->num(); } } else if (concat_dim_ == 1) { int offset_channel = 0; for (int i = 0; i < bottom->size(); ++i) { Blob* blob = (*bottom)[i]; - Dtype* bottom_diff = blob->mutable_cpu_diff(); - int num_elem = blob->channels()*blob->height()*blob->width(); - for (int n = 0; n < num_; ++n) { - caffe_copy(num_elem, top_diff+top[0]->offset(n, offset_channel), - bottom_diff+blob->offset(n)); + if (propagate_down[i]) { + Dtype* bottom_diff = blob->mutable_cpu_diff(); + int num_elem = blob->channels()*blob->height()*blob->width(); + for (int n = 0; n < num_; ++n) { + caffe_copy(num_elem, top_diff + top[0]->offset(n, offset_channel), + bottom_diff + blob->offset(n)); + } } offset_channel += blob->channels(); } diff --git a/src/caffe/layers/concat_layer.cu b/src/caffe/layers/concat_layer.cu index 2820bf0..ca0cf0c 100644 --- a/src/caffe/layers/concat_layer.cu +++ b/src/caffe/layers/concat_layer.cu @@ -41,26 +41,30 @@ Dtype ConcatLayer::Forward_gpu(const vector*>& bottom, template void ConcatLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { + const vector& propagate_down, vector*>* bottom) { const Dtype* top_diff = top[0]->gpu_diff(); if (concat_dim_ == 0) { int offset_num = 0; for (int i = 0; i < bottom->size(); ++i) { Blob* blob = (*bottom)[i]; - Dtype* bottom_diff = blob->mutable_gpu_diff(); - caffe_gpu_copy(blob->count(), - top_diff + top[0]->offset(offset_num), bottom_diff); + if (propagate_down[i]) { + Dtype* bottom_diff = blob->mutable_gpu_diff(); + caffe_gpu_copy(blob->count(), top_diff + top[0]->offset(offset_num), + bottom_diff); + } offset_num += blob->num(); } } else if (concat_dim_ == 1) { int offset_channel = 0; for (int i = 0; i < bottom->size(); ++i) { Blob* blob = (*bottom)[i]; - Dtype* bottom_diff = blob->mutable_gpu_diff(); - int num_elem = blob->channels()*blob->height()*blob->width(); - for (int n = 0; n < num_; ++n) { - caffe_gpu_copy(num_elem, top_diff + top[0]->offset(n, offset_channel), - bottom_diff + blob->offset(n)); + if (propagate_down[i]) { + Dtype* bottom_diff = blob->mutable_gpu_diff(); + int num_elem = blob->channels()*blob->height()*blob->width(); + for (int n = 0; n < num_; ++n) { + caffe_gpu_copy(num_elem, top_diff + top[0]->offset(n, offset_channel), + bottom_diff + blob->offset(n)); + } } offset_channel += blob->channels(); } diff --git a/src/caffe/layers/conv_layer.cpp b/src/caffe/layers/conv_layer.cpp index 880b83a..9ec8da4 100644 --- a/src/caffe/layers/conv_layer.cpp +++ b/src/caffe/layers/conv_layer.cpp @@ -109,7 +109,7 @@ Dtype ConvolutionLayer::Forward_cpu(const vector*>& bottom, template void ConvolutionLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { + const vector& propagate_down, vector*>* bottom) { const Dtype* top_diff = top[0]->cpu_diff(); const Dtype* weight = this->blobs_[0]->cpu_data(); Dtype* weight_diff = this->blobs_[0]->mutable_cpu_diff(); @@ -148,7 +148,7 @@ void ConvolutionLayer::Backward_cpu(const vector*>& top, weight_diff + weight_offset * g); } // gradient w.r.t. bottom data, if necessary - if (propagate_down) { + if (propagate_down[0]) { for (int g = 0; g < group_; ++g) { caffe_cpu_gemm(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1., weight + weight_offset * g, diff --git a/src/caffe/layers/conv_layer.cu b/src/caffe/layers/conv_layer.cu index 51f5d15..85f95fd 100644 --- a/src/caffe/layers/conv_layer.cu +++ b/src/caffe/layers/conv_layer.cu @@ -43,7 +43,7 @@ Dtype ConvolutionLayer::Forward_gpu(const vector*>& bottom, template void ConvolutionLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { + const vector& propagate_down, vector*>* bottom) { const Dtype* top_diff = top[0]->gpu_diff(); const Dtype* weight = this->blobs_[0]->gpu_data(); Dtype* weight_diff = this->blobs_[0]->mutable_gpu_diff(); @@ -84,7 +84,7 @@ void ConvolutionLayer::Backward_gpu(const vector*>& top, weight_diff + weight_offset * g); } // gradient w.r.t. bottom data, if necessary - if (propagate_down) { + if (propagate_down[0]) { for (int g = 0; g < group_; ++g) { caffe_gpu_gemm(CblasTrans, CblasNoTrans, K_, N_, M_, (Dtype)1., weight + weight_offset * g, diff --git a/src/caffe/layers/dropout_layer.cpp b/src/caffe/layers/dropout_layer.cpp index f1e5411..e9a1a52 100644 --- a/src/caffe/layers/dropout_layer.cpp +++ b/src/caffe/layers/dropout_layer.cpp @@ -47,10 +47,10 @@ Dtype DropoutLayer::Forward_cpu(const vector*>& bottom, template void DropoutLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { CHECK(Caffe::phase() == Caffe::TRAIN); - if (propagate_down) { + if (propagate_down[0]) { const Dtype* top_diff = top[0]->cpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); const unsigned int* mask = rand_vec_->cpu_data(); diff --git a/src/caffe/layers/dropout_layer.cu b/src/caffe/layers/dropout_layer.cu index 3c25d6a..2c72264 100644 --- a/src/caffe/layers/dropout_layer.cu +++ b/src/caffe/layers/dropout_layer.cu @@ -56,10 +56,10 @@ __global__ void DropoutBackward(const int n, const Dtype* in_diff, template void DropoutLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { CHECK(Caffe::phase() == Caffe::TRAIN); - if (propagate_down) { + if (propagate_down[0]) { const Dtype* top_diff = top[0]->gpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); const unsigned int* mask = diff --git a/src/caffe/layers/eltwise_layer.cpp b/src/caffe/layers/eltwise_layer.cpp index 5e5d760..2c265f6 100644 --- a/src/caffe/layers/eltwise_layer.cpp +++ b/src/caffe/layers/eltwise_layer.cpp @@ -67,12 +67,12 @@ Dtype EltwiseLayer::Forward_cpu( template void EltwiseLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { - if (propagate_down) { - const int count = top[0]->count(); - const Dtype* top_data = top[0]->cpu_data(); - const Dtype* top_diff = top[0]->cpu_diff(); - for (int i = 0; i < bottom->size(); ++i) { + const vector& propagate_down, vector*>* bottom) { + const int count = top[0]->count(); + const Dtype* top_data = top[0]->cpu_data(); + const Dtype* top_diff = top[0]->cpu_diff(); + for (int i = 0; i < bottom->size(); ++i) { + if (propagate_down[i]) { const Dtype* bottom_data = (*bottom)[i]->cpu_data(); Dtype* bottom_diff = (*bottom)[i]->mutable_cpu_diff(); switch (op_) { diff --git a/src/caffe/layers/eltwise_layer.cu b/src/caffe/layers/eltwise_layer.cu index 75827ba..3860944 100644 --- a/src/caffe/layers/eltwise_layer.cu +++ b/src/caffe/layers/eltwise_layer.cu @@ -36,12 +36,12 @@ Dtype EltwiseLayer::Forward_gpu( template void EltwiseLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { - if (propagate_down) { - const int count = top[0]->count(); - const Dtype* top_data = top[0]->gpu_data(); - const Dtype* top_diff = top[0]->gpu_diff(); - for (int i = 0; i < bottom->size(); ++i) { + const vector& propagate_down, vector*>* bottom) { + const int count = top[0]->count(); + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* top_diff = top[0]->gpu_diff(); + for (int i = 0; i < bottom->size(); ++i) { + if (propagate_down[i]) { const Dtype* bottom_data = (*bottom)[i]->gpu_data(); Dtype* bottom_diff = (*bottom)[i]->mutable_gpu_diff(); switch (op_) { diff --git a/src/caffe/layers/euclidean_loss_layer.cpp b/src/caffe/layers/euclidean_loss_layer.cpp index 7662949..2478a51 100644 --- a/src/caffe/layers/euclidean_loss_layer.cpp +++ b/src/caffe/layers/euclidean_loss_layer.cpp @@ -1,8 +1,5 @@ // Copyright 2014 BVLC and contributors. -#include -#include -#include #include #include "caffe/layer.hpp" @@ -10,8 +7,6 @@ #include "caffe/util/math_functions.hpp" #include "caffe/util/io.hpp" -using std::max; - namespace caffe { template @@ -43,13 +38,18 @@ Dtype EuclideanLossLayer::Forward_cpu(const vector*>& bottom, template void EuclideanLossLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { - caffe_cpu_axpby( - (*bottom)[0]->count(), // count - Dtype(1) / (*bottom)[0]->num(), // alpha - diff_.cpu_data(), // a - Dtype(0), // beta - (*bottom)[0]->mutable_cpu_diff()); // b + const vector& propagate_down, vector*>* bottom) { + for (int i = 0; i < 2; ++i) { + if (propagate_down[i]) { + const Dtype sign = (i == 0) ? 1 : -1; + caffe_cpu_axpby( + (*bottom)[i]->count(), // count + sign / (*bottom)[i]->num(), // alpha + diff_.cpu_data(), // a + Dtype(0), // beta + (*bottom)[i]->mutable_cpu_diff()); // b + } + } } INSTANTIATE_CLASS(EuclideanLossLayer); diff --git a/src/caffe/layers/euclidean_loss_layer.cu b/src/caffe/layers/euclidean_loss_layer.cu new file mode 100644 index 0000000..b070ea9 --- /dev/null +++ b/src/caffe/layers/euclidean_loss_layer.cu @@ -0,0 +1,45 @@ +// Copyright 2014 BVLC and contributors. + +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/util/io.hpp" + +namespace caffe { + +template +Dtype EuclideanLossLayer::Forward_gpu(const vector*>& bottom, + vector*>* top) { + int count = bottom[0]->count(); + caffe_gpu_sub( + count, + bottom[0]->gpu_data(), + bottom[1]->gpu_data(), + diff_.mutable_gpu_data()); + Dtype dot; + caffe_gpu_dot(count, diff_.gpu_data(), diff_.gpu_data(), &dot); + Dtype loss = dot / bottom[0]->num() / Dtype(2); + return loss; +} + +template +void EuclideanLossLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, vector*>* bottom) { + for (int i = 0; i < 2; ++i) { + if (propagate_down[i]) { + const Dtype sign = (i == 0) ? 1 : -1; + caffe_gpu_axpby( + (*bottom)[i]->count(), // count + sign / (*bottom)[i]->num(), // alpha + diff_.gpu_data(), // a + Dtype(0), // beta + (*bottom)[i]->mutable_gpu_diff()); // b + } + } +} + +INSTANTIATE_CLASS(EuclideanLossLayer); + +} // namespace caffe diff --git a/src/caffe/layers/flatten_layer.cpp b/src/caffe/layers/flatten_layer.cpp index 95f4859..9494da9 100644 --- a/src/caffe/layers/flatten_layer.cpp +++ b/src/caffe/layers/flatten_layer.cpp @@ -29,7 +29,7 @@ Dtype FlattenLayer::Forward_cpu(const vector*>& bottom, template void FlattenLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { + const vector& propagate_down, vector*>* bottom) { (*bottom)[0]->ShareDiff(*top[0]); } diff --git a/src/caffe/layers/flatten_layer.cu b/src/caffe/layers/flatten_layer.cu index 157eeb1..68add38 100644 --- a/src/caffe/layers/flatten_layer.cu +++ b/src/caffe/layers/flatten_layer.cu @@ -17,7 +17,7 @@ Dtype FlattenLayer::Forward_gpu(const vector*>& bottom, template void FlattenLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { + const vector& propagate_down, vector*>* bottom) { (*bottom)[0]->ShareDiff(*top[0]); } diff --git a/src/caffe/layers/hdf5_data_layer.cpp b/src/caffe/layers/hdf5_data_layer.cpp index 50863b8..2ba7fa7 100644 --- a/src/caffe/layers/hdf5_data_layer.cpp +++ b/src/caffe/layers/hdf5_data_layer.cpp @@ -114,11 +114,6 @@ Dtype HDF5DataLayer::Forward_cpu(const vector*>& bottom, return Dtype(0.); } -// The backward operations are dummy - they do not carry any computation. -template -void HDF5DataLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { } - INSTANTIATE_CLASS(HDF5DataLayer); } // namespace caffe diff --git a/src/caffe/layers/hdf5_data_layer.cu b/src/caffe/layers/hdf5_data_layer.cu index 9c5bb5a..b2b09ef 100644 --- a/src/caffe/layers/hdf5_data_layer.cu +++ b/src/caffe/layers/hdf5_data_layer.cu @@ -54,11 +54,6 @@ Dtype HDF5DataLayer::Forward_gpu(const vector*>& bottom, return Dtype(0.); } -template -void HDF5DataLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { -} - INSTANTIATE_CLASS(HDF5DataLayer); } // namespace caffe diff --git a/src/caffe/layers/hdf5_output_layer.cpp b/src/caffe/layers/hdf5_output_layer.cpp index 0961b9b..3a513b9 100644 --- a/src/caffe/layers/hdf5_output_layer.cpp +++ b/src/caffe/layers/hdf5_output_layer.cpp @@ -67,7 +67,7 @@ Dtype HDF5OutputLayer::Forward_cpu(const vector*>& bottom, template void HDF5OutputLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { + const vector& propagate_down, vector*>* bottom) { return; } diff --git a/src/caffe/layers/hdf5_output_layer.cu b/src/caffe/layers/hdf5_output_layer.cu index b994825..59505ee 100644 --- a/src/caffe/layers/hdf5_output_layer.cu +++ b/src/caffe/layers/hdf5_output_layer.cu @@ -40,7 +40,7 @@ Dtype HDF5OutputLayer::Forward_gpu(const vector*>& bottom, template void HDF5OutputLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { + const vector& propagate_down, vector*>* bottom) { return; } diff --git a/src/caffe/layers/hinge_loss_layer.cpp b/src/caffe/layers/hinge_loss_layer.cpp index fd00d8c..8097761 100644 --- a/src/caffe/layers/hinge_loss_layer.cpp +++ b/src/caffe/layers/hinge_loss_layer.cpp @@ -45,27 +45,33 @@ Dtype HingeLossLayer::Forward_cpu(const vector*>& bottom, template void HingeLossLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { - Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); - const Dtype* label = (*bottom)[1]->cpu_data(); - int num = (*bottom)[0]->num(); - int count = (*bottom)[0]->count(); - int dim = count / num; - - for (int i = 0; i < num; ++i) { - bottom_diff[i * dim + static_cast(label[i])] *= -1; + const vector& propagate_down, vector*>* bottom) { + if (propagate_down[1]) { + LOG(FATAL) << this->type_name() + << " Layer cannot backpropagate to label inputs."; } + if (propagate_down[0]) { + Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); + const Dtype* label = (*bottom)[1]->cpu_data(); + int num = (*bottom)[0]->num(); + int count = (*bottom)[0]->count(); + int dim = count / num; - switch (this->layer_param_.hinge_loss_param().norm()) { - case HingeLossParameter_Norm_L1: - caffe_cpu_sign(count, bottom_diff, bottom_diff); - caffe_scal(count, Dtype(1. / num), bottom_diff); - break; - case HingeLossParameter_Norm_L2: - caffe_scal(count, Dtype(2. / num), bottom_diff); - break; - default: - LOG(FATAL) << "Unknown Norm"; + for (int i = 0; i < num; ++i) { + bottom_diff[i * dim + static_cast(label[i])] *= -1; + } + + switch (this->layer_param_.hinge_loss_param().norm()) { + case HingeLossParameter_Norm_L1: + caffe_cpu_sign(count, bottom_diff, bottom_diff); + caffe_scal(count, Dtype(1. / num), bottom_diff); + break; + case HingeLossParameter_Norm_L2: + caffe_scal(count, Dtype(2. / num), bottom_diff); + break; + default: + LOG(FATAL) << "Unknown Norm"; + } } } diff --git a/src/caffe/layers/im2col_layer.cpp b/src/caffe/layers/im2col_layer.cpp index f0c26c9..e047dfb 100644 --- a/src/caffe/layers/im2col_layer.cpp +++ b/src/caffe/layers/im2col_layer.cpp @@ -38,7 +38,7 @@ Dtype Im2colLayer::Forward_cpu(const vector*>& bottom, template void Im2colLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { + const vector& propagate_down, vector*>* bottom) { const Dtype* top_diff = top[0]->cpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); for (int n = 0; n < top[0]->num(); ++n) { diff --git a/src/caffe/layers/im2col_layer.cu b/src/caffe/layers/im2col_layer.cu index 26bc1b9..9cfb74e 100644 --- a/src/caffe/layers/im2col_layer.cu +++ b/src/caffe/layers/im2col_layer.cu @@ -23,7 +23,7 @@ Dtype Im2colLayer::Forward_gpu(const vector*>& bottom, template void Im2colLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { + const vector& propagate_down, vector*>* bottom) { const Dtype* top_diff = top[0]->gpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); for (int n = 0; n < top[0]->num(); ++n) { diff --git a/src/caffe/layers/infogain_loss_layer.cpp b/src/caffe/layers/infogain_loss_layer.cpp index 3e7fc4f..a72874e 100644 --- a/src/caffe/layers/infogain_loss_layer.cpp +++ b/src/caffe/layers/infogain_loss_layer.cpp @@ -56,20 +56,26 @@ Dtype InfogainLossLayer::Forward_cpu(const vector*>& bottom, template void InfogainLossLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { - const Dtype* bottom_data = (*bottom)[0]->cpu_data(); - const Dtype* bottom_label = (*bottom)[1]->cpu_data(); - const Dtype* infogain_mat = infogain_.cpu_data(); - Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); - int num = (*bottom)[0]->num(); - int dim = (*bottom)[0]->count() / (*bottom)[0]->num(); - CHECK_EQ(infogain_.height(), dim); - for (int i = 0; i < num; ++i) { - int label = static_cast(bottom_label[i]); - for (int j = 0; j < dim; ++j) { - Dtype prob = max(bottom_data[i * dim + j], Dtype(kLOG_THRESHOLD)); - bottom_diff[i * dim + j] = - infogain_mat[label * dim + j] / prob / num; + if (propagate_down[1]) { + LOG(FATAL) << this->type_name() + << " Layer cannot backpropagate to label inputs."; + } + if (propagate_down[0]) { + const Dtype* bottom_data = (*bottom)[0]->cpu_data(); + const Dtype* bottom_label = (*bottom)[1]->cpu_data(); + const Dtype* infogain_mat = infogain_.cpu_data(); + Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); + int num = (*bottom)[0]->num(); + int dim = (*bottom)[0]->count() / (*bottom)[0]->num(); + CHECK_EQ(infogain_.height(), dim); + for (int i = 0; i < num; ++i) { + int label = static_cast(bottom_label[i]); + for (int j = 0; j < dim; ++j) { + Dtype prob = max(bottom_data[i * dim + j], Dtype(kLOG_THRESHOLD)); + bottom_diff[i * dim + j] = - infogain_mat[label * dim + j] / prob / num; + } } } } diff --git a/src/caffe/layers/inner_product_layer.cpp b/src/caffe/layers/inner_product_layer.cpp index 971254c..ddf55e4 100644 --- a/src/caffe/layers/inner_product_layer.cpp +++ b/src/caffe/layers/inner_product_layer.cpp @@ -74,7 +74,7 @@ Dtype InnerProductLayer::Forward_cpu(const vector*>& bottom, template void InnerProductLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { const Dtype* top_diff = top[0]->cpu_diff(); const Dtype* bottom_data = (*bottom)[0]->cpu_data(); @@ -87,7 +87,7 @@ void InnerProductLayer::Backward_cpu(const vector*>& top, reinterpret_cast(bias_multiplier_->cpu_data()), (Dtype)0., this->blobs_[1]->mutable_cpu_diff()); } - if (propagate_down) { + if (propagate_down[0]) { // Gradient with respect to bottom data caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1., top_diff, this->blobs_[0]->cpu_data(), (Dtype)0., diff --git a/src/caffe/layers/inner_product_layer.cu b/src/caffe/layers/inner_product_layer.cu index f139c23..5b95a57 100644 --- a/src/caffe/layers/inner_product_layer.cu +++ b/src/caffe/layers/inner_product_layer.cu @@ -31,7 +31,7 @@ Dtype InnerProductLayer::Forward_gpu(const vector*>& bottom, template void InnerProductLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { const Dtype* top_diff = top[0]->gpu_diff(); const Dtype* bottom_data = (*bottom)[0]->gpu_data(); @@ -44,7 +44,7 @@ void InnerProductLayer::Backward_gpu(const vector*>& top, reinterpret_cast(bias_multiplier_->gpu_data()), (Dtype)0., this->blobs_[1]->mutable_gpu_diff()); } - if (propagate_down) { + if (propagate_down[0]) { // Gradient with respect to bottom data caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1., top_diff, this->blobs_[0]->gpu_data(), (Dtype)0., diff --git a/src/caffe/layers/lrn_layer.cpp b/src/caffe/layers/lrn_layer.cpp index 071e719..a86c1d4 100644 --- a/src/caffe/layers/lrn_layer.cpp +++ b/src/caffe/layers/lrn_layer.cpp @@ -173,7 +173,7 @@ Dtype LRNLayer::WithinChannelForward( template void LRNLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { + const vector& propagate_down, vector*>* bottom) { switch (this->layer_param_.lrn_param().norm_region()) { case LRNParameter_NormRegion_ACROSS_CHANNELS: CrossChannelBackward_cpu(top, propagate_down, bottom); @@ -188,7 +188,7 @@ void LRNLayer::Backward_cpu(const vector*>& top, template void LRNLayer::CrossChannelBackward_cpu( - const vector*>& top, const bool propagate_down, + const vector*>& top, const vector& propagate_down, vector*>* bottom) { const Dtype* top_diff = top[0]->cpu_diff(); const Dtype* top_data = top[0]->cpu_data(); @@ -243,14 +243,16 @@ void LRNLayer::CrossChannelBackward_cpu( template void LRNLayer::WithinChannelBackward( - const vector*>& top, const bool propagate_down, + const vector*>& top, const vector& propagate_down, vector*>* bottom) { - if (propagate_down) { - product_layer_->Backward(top, true, &product_bottom_vec_); - power_layer_->Backward(power_top_vec_, true, &pool_top_vec_); - pool_layer_->Backward(pool_top_vec_, true, &square_top_vec_); - square_layer_->Backward(square_top_vec_, true, &square_bottom_vec_); - split_layer_->Backward(split_top_vec_, true, bottom); + if (propagate_down[0]) { + vector product_propagate_down(2, true); + product_layer_->Backward(top, product_propagate_down, &product_bottom_vec_); + power_layer_->Backward(power_top_vec_, propagate_down, &pool_top_vec_); + pool_layer_->Backward(pool_top_vec_, propagate_down, &square_top_vec_); + square_layer_->Backward(square_top_vec_, propagate_down, + &square_bottom_vec_); + split_layer_->Backward(split_top_vec_, propagate_down, bottom); } } diff --git a/src/caffe/layers/lrn_layer.cu b/src/caffe/layers/lrn_layer.cu index b2097eb..237a42b 100644 --- a/src/caffe/layers/lrn_layer.cu +++ b/src/caffe/layers/lrn_layer.cu @@ -104,7 +104,7 @@ Dtype LRNLayer::CrossChannelForward_gpu( template void LRNLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { + const vector& propagate_down, vector*>* bottom) { switch (this->layer_param_.lrn_param().norm_region()) { case LRNParameter_NormRegion_ACROSS_CHANNELS: CrossChannelBackward_gpu(top, propagate_down, bottom); @@ -180,7 +180,7 @@ __global__ void LRNComputeDiff(const int nthreads, const Dtype* bottom_data, template void LRNLayer::CrossChannelBackward_gpu( - const vector*>& top, const bool propagate_down, + const vector*>& top, const vector& propagate_down, vector*>* bottom) { int n_threads = num_ * height_ * width_; // NOLINT_NEXT_LINE(whitespace/operators) diff --git a/src/caffe/layers/multinomial_logistic_loss_layer.cpp b/src/caffe/layers/multinomial_logistic_loss_layer.cpp index 5a40879..013d403 100644 --- a/src/caffe/layers/multinomial_logistic_loss_layer.cpp +++ b/src/caffe/layers/multinomial_logistic_loss_layer.cpp @@ -43,18 +43,24 @@ Dtype MultinomialLogisticLossLayer::Forward_cpu( template void MultinomialLogisticLossLayer::Backward_cpu( - const vector*>& top, const bool propagate_down, + const vector*>& top, const vector& propagate_down, vector*>* bottom) { - const Dtype* bottom_data = (*bottom)[0]->cpu_data(); - const Dtype* bottom_label = (*bottom)[1]->cpu_data(); - Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); - int num = (*bottom)[0]->num(); - int dim = (*bottom)[0]->count() / (*bottom)[0]->num(); - memset(bottom_diff, 0, sizeof(Dtype) * (*bottom)[0]->count()); - for (int i = 0; i < num; ++i) { - int label = static_cast(bottom_label[i]); - Dtype prob = max(bottom_data[i * dim + label], Dtype(kLOG_THRESHOLD)); - bottom_diff[i * dim + label] = -1. / prob / num; + if (propagate_down[1]) { + LOG(FATAL) << this->type_name() + << " Layer cannot backpropagate to label inputs."; + } + if (propagate_down[0]) { + const Dtype* bottom_data = (*bottom)[0]->cpu_data(); + const Dtype* bottom_label = (*bottom)[1]->cpu_data(); + Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); + int num = (*bottom)[0]->num(); + int dim = (*bottom)[0]->count() / (*bottom)[0]->num(); + memset(bottom_diff, 0, sizeof(Dtype) * (*bottom)[0]->count()); + for (int i = 0; i < num; ++i) { + int label = static_cast(bottom_label[i]); + Dtype prob = max(bottom_data[i * dim + label], Dtype(kLOG_THRESHOLD)); + bottom_diff[i * dim + label] = -1. / prob / num; + } } } diff --git a/src/caffe/layers/pooling_layer.cpp b/src/caffe/layers/pooling_layer.cpp index ba84ede..bc00207 100644 --- a/src/caffe/layers/pooling_layer.cpp +++ b/src/caffe/layers/pooling_layer.cpp @@ -184,8 +184,8 @@ Dtype PoolingLayer::Forward_cpu(const vector*>& bottom, template void PoolingLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { - if (!propagate_down) { + const vector& propagate_down, vector*>* bottom) { + if (!propagate_down[0]) { return; } const Dtype* top_diff = top[0]->cpu_diff(); diff --git a/src/caffe/layers/pooling_layer.cu b/src/caffe/layers/pooling_layer.cu index abba625..f07fe3c 100644 --- a/src/caffe/layers/pooling_layer.cu +++ b/src/caffe/layers/pooling_layer.cu @@ -322,8 +322,8 @@ __global__ void StoPoolBackward(const int nthreads, template void PoolingLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { - if (!propagate_down) { + const vector& propagate_down, vector*>* bottom) { + if (!propagate_down[0]) { return; } const Dtype* top_diff = top[0]->gpu_diff(); diff --git a/src/caffe/layers/power_layer.cpp b/src/caffe/layers/power_layer.cpp index 85c8442..5ff3392 100644 --- a/src/caffe/layers/power_layer.cpp +++ b/src/caffe/layers/power_layer.cpp @@ -49,9 +49,9 @@ Dtype PowerLayer::Forward_cpu(const vector*>& bottom, template void PowerLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { - if (propagate_down) { + if (propagate_down[0]) { Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); const int count = (*bottom)[0]->count(); const Dtype* top_diff = top[0]->cpu_diff(); diff --git a/src/caffe/layers/power_layer.cu b/src/caffe/layers/power_layer.cu index 9a25de7..6d69963 100644 --- a/src/caffe/layers/power_layer.cu +++ b/src/caffe/layers/power_layer.cu @@ -38,9 +38,9 @@ Dtype PowerLayer::Forward_gpu(const vector*>& bottom, template void PowerLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { - if (propagate_down) { + if (propagate_down[0]) { Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); const int count = (*bottom)[0]->count(); const Dtype* top_diff = top[0]->gpu_diff(); diff --git a/src/caffe/layers/relu_layer.cpp b/src/caffe/layers/relu_layer.cpp index 7a33e55..d7a8509 100644 --- a/src/caffe/layers/relu_layer.cpp +++ b/src/caffe/layers/relu_layer.cpp @@ -24,9 +24,9 @@ Dtype ReLULayer::Forward_cpu(const vector*>& bottom, template void ReLULayer::Backward_cpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { - if (propagate_down) { + if (propagate_down[0]) { const Dtype* bottom_data = (*bottom)[0]->cpu_data(); const Dtype* top_diff = top[0]->cpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); diff --git a/src/caffe/layers/relu_layer.cu b/src/caffe/layers/relu_layer.cu index 51e5ef2..a5b24af 100644 --- a/src/caffe/layers/relu_layer.cu +++ b/src/caffe/layers/relu_layer.cu @@ -45,9 +45,9 @@ __global__ void ReLUBackward(const int n, const Dtype* in_diff, template void ReLULayer::Backward_gpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { - if (propagate_down) { + if (propagate_down[0]) { const Dtype* bottom_data = (*bottom)[0]->gpu_data(); const Dtype* top_diff = top[0]->gpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); diff --git a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp index 955581d..8cb830f 100644 --- a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp +++ b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp @@ -16,7 +16,7 @@ template void SigmoidCrossEntropyLossLayer::FurtherSetUp( const vector*>& bottom, vector*>* top) { CHECK_EQ(bottom[0]->count(), bottom[1]->count()) << - "SigmoidCrossEntropyLoss Layer inputs must have same count."; + "SIGMOID_CROSS_ENTROPY_LOSS layer inputs must have the same count."; sigmoid_bottom_vec_.clear(); sigmoid_bottom_vec_.push_back(bottom[0]); sigmoid_top_vec_.clear(); @@ -49,17 +49,23 @@ Dtype SigmoidCrossEntropyLossLayer::Forward_cpu( template void SigmoidCrossEntropyLossLayer::Backward_cpu( - const vector*>& top, const bool propagate_down, + const vector*>& top, const vector& propagate_down, vector*>* bottom) { - // First, compute the diff - const int count = (*bottom)[0]->count(); - const int num = (*bottom)[0]->num(); - const Dtype* sigmoid_output_data = sigmoid_output_->cpu_data(); - const Dtype* target = (*bottom)[1]->cpu_data(); - Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); - caffe_sub(count, sigmoid_output_data, target, bottom_diff); - // Scale down gradient - caffe_scal(count, Dtype(1) / num, bottom_diff); + if (propagate_down[1]) { + LOG(FATAL) << this->type_name() + << " Layer cannot backpropagate to label inputs."; + } + if (propagate_down[0]) { + // First, compute the diff + const int count = (*bottom)[0]->count(); + const int num = (*bottom)[0]->num(); + const Dtype* sigmoid_output_data = sigmoid_output_->cpu_data(); + const Dtype* target = (*bottom)[1]->cpu_data(); + Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); + caffe_sub(count, sigmoid_output_data, target, bottom_diff); + // Scale down gradient + caffe_scal(count, Dtype(1) / num, bottom_diff); + } } INSTANTIATE_CLASS(SigmoidCrossEntropyLossLayer); diff --git a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu index 0caed2b..8f72758 100644 --- a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu +++ b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu @@ -37,18 +37,24 @@ Dtype SigmoidCrossEntropyLossLayer::Forward_gpu( template void SigmoidCrossEntropyLossLayer::Backward_gpu( - const vector*>& top, const bool propagate_down, + const vector*>& top, const vector& propagate_down, vector*>* bottom) { - // First, compute the diff - const int count = (*bottom)[0]->count(); - const int num = (*bottom)[0]->num(); - const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data(); - const Dtype* target = (*bottom)[1]->gpu_data(); - Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); - caffe_gpu_copy(count, sigmoid_output_data, bottom_diff); - caffe_gpu_axpy(count, Dtype(-1), target, bottom_diff); - // Scale down gradient - caffe_gpu_scal(count, Dtype(1) / num, bottom_diff); + if (propagate_down[1]) { + LOG(FATAL) << this->type_name() + << " Layer cannot backpropagate to label inputs."; + } + if (propagate_down[0]) { + // First, compute the diff + const int count = (*bottom)[0]->count(); + const int num = (*bottom)[0]->num(); + const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data(); + const Dtype* target = (*bottom)[1]->gpu_data(); + Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); + caffe_gpu_copy(count, sigmoid_output_data, bottom_diff); + caffe_gpu_axpy(count, Dtype(-1), target, bottom_diff); + // Scale down gradient + caffe_gpu_scal(count, Dtype(1) / num, bottom_diff); + } } INSTANTIATE_CLASS(SigmoidCrossEntropyLossLayer); diff --git a/src/caffe/layers/sigmoid_layer.cpp b/src/caffe/layers/sigmoid_layer.cpp index 88a7920..50139d8 100644 --- a/src/caffe/layers/sigmoid_layer.cpp +++ b/src/caffe/layers/sigmoid_layer.cpp @@ -28,9 +28,9 @@ Dtype SigmoidLayer::Forward_cpu(const vector*>& bottom, template void SigmoidLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { - if (propagate_down) { + if (propagate_down[0]) { const Dtype* top_data = top[0]->cpu_data(); const Dtype* top_diff = top[0]->cpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); diff --git a/src/caffe/layers/sigmoid_layer.cu b/src/caffe/layers/sigmoid_layer.cu index aa8568a..d55968e 100644 --- a/src/caffe/layers/sigmoid_layer.cu +++ b/src/caffe/layers/sigmoid_layer.cu @@ -47,9 +47,9 @@ __global__ void SigmoidBackward(const int n, const Dtype* in_diff, template void SigmoidLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { - if (propagate_down) { + if (propagate_down[0]) { const Dtype* top_data = top[0]->gpu_data(); const Dtype* top_diff = top[0]->gpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); diff --git a/src/caffe/layers/softmax_layer.cpp b/src/caffe/layers/softmax_layer.cpp index dbe16da..57847d0 100644 --- a/src/caffe/layers/softmax_layer.cpp +++ b/src/caffe/layers/softmax_layer.cpp @@ -60,7 +60,7 @@ Dtype SoftmaxLayer::Forward_cpu(const vector*>& bottom, template void SoftmaxLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { const Dtype* top_diff = top[0]->cpu_diff(); const Dtype* top_data = top[0]->cpu_data(); diff --git a/src/caffe/layers/softmax_layer.cu b/src/caffe/layers/softmax_layer.cu index a264a81..f53883c 100644 --- a/src/caffe/layers/softmax_layer.cu +++ b/src/caffe/layers/softmax_layer.cu @@ -79,7 +79,7 @@ Dtype SoftmaxLayer::Forward_gpu(const vector*>& bottom, // TODO(Yangqing): implement the GPU version of softmax. template void SoftmaxLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { + const vector& propagate_down, vector*>* bottom) { const Dtype* top_diff = top[0]->gpu_diff(); const Dtype* top_data = top[0]->gpu_data(); Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); diff --git a/src/caffe/layers/softmax_loss_layer.cpp b/src/caffe/layers/softmax_loss_layer.cpp index bdb3272..1a3601a 100644 --- a/src/caffe/layers/softmax_loss_layer.cpp +++ b/src/caffe/layers/softmax_loss_layer.cpp @@ -57,20 +57,25 @@ Dtype SoftmaxWithLossLayer::Forward_cpu( template void SoftmaxWithLossLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { - // Compute the diff - Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); - const Dtype* prob_data = prob_.cpu_data(); - memcpy(bottom_diff, prob_data, sizeof(Dtype) * prob_.count()); - const Dtype* label = (*bottom)[1]->cpu_data(); - int num = prob_.num(); - int dim = prob_.count() / num; - for (int i = 0; i < num; ++i) { - bottom_diff[i * dim + static_cast(label[i])] -= 1; + if (propagate_down[1]) { + LOG(FATAL) << this->type_name() + << " Layer cannot backpropagate to label inputs."; + } + if (propagate_down[0]) { + Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); + const Dtype* prob_data = prob_.cpu_data(); + memcpy(bottom_diff, prob_data, sizeof(Dtype) * prob_.count()); + const Dtype* label = (*bottom)[1]->cpu_data(); + int num = prob_.num(); + int dim = prob_.count() / num; + for (int i = 0; i < num; ++i) { + bottom_diff[i * dim + static_cast(label[i])] -= 1; + } + // Scale down gradient + caffe_scal(prob_.count(), Dtype(1) / num, bottom_diff); } - // Scale down gradient - caffe_scal(prob_.count(), Dtype(1) / num, bottom_diff); } diff --git a/src/caffe/layers/softmax_loss_layer.cu b/src/caffe/layers/softmax_loss_layer.cu index 24a3c38..e46be6b 100644 --- a/src/caffe/layers/softmax_loss_layer.cu +++ b/src/caffe/layers/softmax_loss_layer.cu @@ -21,7 +21,7 @@ Dtype SoftmaxWithLossLayer::Forward_gpu( template void SoftmaxWithLossLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { + const vector& propagate_down, vector*>* bottom) { // TODO(Yangqing): implement the GPU version of softmax. Backward_cpu(top, propagate_down, bottom); } diff --git a/src/caffe/layers/split_layer.cpp b/src/caffe/layers/split_layer.cpp index 2f99ca1..28abd95 100644 --- a/src/caffe/layers/split_layer.cpp +++ b/src/caffe/layers/split_layer.cpp @@ -37,8 +37,8 @@ Dtype SplitLayer::Forward_cpu(const vector*>& bottom, template void SplitLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { - if (propagate_down) { + const vector& propagate_down, vector*>* bottom) { + if (propagate_down[0]) { (*bottom)[0]->ShareDiff(*top[0]); // Add remaining top blob diffs. Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); diff --git a/src/caffe/layers/split_layer.cu b/src/caffe/layers/split_layer.cu index e2269b8..4c921d3 100644 --- a/src/caffe/layers/split_layer.cu +++ b/src/caffe/layers/split_layer.cu @@ -19,8 +19,8 @@ Dtype SplitLayer::Forward_gpu(const vector*>& bottom, template void SplitLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, vector*>* bottom) { - if (propagate_down) { + const vector& propagate_down, vector*>* bottom) { + if (propagate_down[0]) { (*bottom)[0]->ShareDiff(*top[0]); // Add remaining top blob diffs. Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); diff --git a/src/caffe/layers/tanh_layer.cpp b/src/caffe/layers/tanh_layer.cpp index 77b4410..6b5166d 100644 --- a/src/caffe/layers/tanh_layer.cpp +++ b/src/caffe/layers/tanh_layer.cpp @@ -26,9 +26,9 @@ Dtype TanHLayer::Forward_cpu(const vector*>& bottom, template void TanHLayer::Backward_cpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { - if (propagate_down) { + if (propagate_down[0]) { const Dtype* top_data = top[0]->cpu_data(); const Dtype* top_diff = top[0]->cpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); diff --git a/src/caffe/layers/tanh_layer.cu b/src/caffe/layers/tanh_layer.cu index aa822d8..9307cab 100644 --- a/src/caffe/layers/tanh_layer.cu +++ b/src/caffe/layers/tanh_layer.cu @@ -42,9 +42,9 @@ __global__ void TanHBackward(const int n, const Dtype* in_diff, template void TanHLayer::Backward_gpu(const vector*>& top, - const bool propagate_down, + const vector& propagate_down, vector*>* bottom) { - if (propagate_down) { + if (propagate_down[0]) { const Dtype* top_data = top[0]->gpu_data(); const Dtype* top_diff = top[0]->gpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index d76d284..8687d00 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -56,6 +56,7 @@ void Net::Init(const NetParameter& in_param) { top_vecs_.resize(param.layers_size()); bottom_id_vecs_.resize(param.layers_size()); top_id_vecs_.resize(param.layers_size()); + bottom_need_backward_.resize(param.layers_size()); for (int layer_id = 0; layer_id < param.layers_size(); ++layer_id) { const LayerParameter& layer_param = param.layers(layer_id); layers_.push_back(shared_ptr >(GetLayer(layer_param))); @@ -196,6 +197,8 @@ int Net::AppendBottom(const NetParameter& param, bottom_vecs_[layer_id].push_back(blobs_[blob_id].get()); bottom_id_vecs_[layer_id].push_back(blob_id); available_blobs->erase(blob_name); + const bool need_backward = param.force_backward() || blob_need_backward_[blob_id]; + bottom_need_backward_[layer_id].push_back(need_backward); return blob_id; } @@ -288,7 +291,8 @@ template void Net::Backward() { for (int i = layers_.size() - 1; i >= 0; --i) { if (layer_need_backward_[i]) { - layers_[i]->Backward(top_vecs_[i], true, &bottom_vecs_[i]); + layers_[i]->Backward( + top_vecs_[i], bottom_need_backward_[i], &bottom_vecs_[i]); } } } diff --git a/src/caffe/test/test_euclidean_loss_layer.cpp b/src/caffe/test/test_euclidean_loss_layer.cpp index d5e4107..02945b6 100644 --- a/src/caffe/test/test_euclidean_loss_layer.cpp +++ b/src/caffe/test/test_euclidean_loss_layer.cpp @@ -47,13 +47,23 @@ typedef ::testing::Types Dtypes; TYPED_TEST_CASE(EuclideanLossLayerTest, Dtypes); TYPED_TEST(EuclideanLossLayerTest, TestGradientCPU) { - LayerParameter layer_param; Caffe::set_mode(Caffe::CPU); + LayerParameter layer_param; + EuclideanLossLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_); + GradientChecker checker(1e-2, 1e-2, 1701); + checker.CheckGradientSingle(&layer, &(this->blob_bottom_vec_), + &(this->blob_top_vec_), -1, -1, -1); +} + +TYPED_TEST(EuclideanLossLayerTest, TestGradientGPU) { + Caffe::set_mode(Caffe::GPU); + LayerParameter layer_param; EuclideanLossLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_); GradientChecker checker(1e-2, 1e-2, 1701); checker.CheckGradientSingle(&layer, &(this->blob_bottom_vec_), - &(this->blob_top_vec_), 0, -1, -1); + &(this->blob_top_vec_), -1, -1, -1); } } // namespace caffe diff --git a/src/caffe/test/test_gradient_check_util.hpp b/src/caffe/test/test_gradient_check_util.hpp index bcf0397..ff104b9 100644 --- a/src/caffe/test/test_gradient_check_util.hpp +++ b/src/caffe/test/test_gradient_check_util.hpp @@ -86,6 +86,7 @@ void GradientChecker::CheckGradientSingle(Layer* layer, } // First, figure out what blobs we need to check against. vector*> blobs_to_check; + vector propagate_down(bottom->size(), check_bottom < 0); for (int i = 0; i < layer->blobs().size(); ++i) { blobs_to_check.push_back(layer->blobs()[i].get()); } @@ -94,8 +95,9 @@ void GradientChecker::CheckGradientSingle(Layer* layer, blobs_to_check.push_back((*bottom)[i]); } } else { - CHECK(check_bottom < bottom->size()); + CHECK_LT(check_bottom, bottom->size()); blobs_to_check.push_back((*bottom)[check_bottom]); + propagate_down[check_bottom] = true; } // Compute the gradient analytically using Backward Caffe::set_random_seed(seed_); @@ -103,7 +105,7 @@ void GradientChecker::CheckGradientSingle(Layer* layer, Dtype computed_objective = layer->Forward(*bottom, top); // Get additional loss from the objective computed_objective += GetObjAndGradient(top, top_id, top_data_id); - layer->Backward(*top, true, bottom); + layer->Backward(*top, propagate_down, bottom); // Store computed gradients for all checked blobs vector > > computed_gradient_blobs(blobs_to_check.size()); diff --git a/src/caffe/test/test_lrn_layer.cpp b/src/caffe/test/test_lrn_layer.cpp index 7f0a211..6996a23 100644 --- a/src/caffe/test/test_lrn_layer.cpp +++ b/src/caffe/test/test_lrn_layer.cpp @@ -164,7 +164,9 @@ TYPED_TEST(LRNLayerTest, TestCPUGradientAcrossChannels) { for (int i = 0; i < this->blob_top_->count(); ++i) { this->blob_top_->mutable_cpu_diff()[i] = 1.; } - layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_)); + vector propagate_down(this->blob_bottom_vec_.size(), true); + layer.Backward(this->blob_top_vec_, propagate_down, + &(this->blob_bottom_vec_)); // for (int i = 0; i < this->blob_bottom_->count(); ++i) { // std::cout << "CPU diff " << this->blob_bottom_->cpu_diff()[i] // << std::endl; @@ -183,7 +185,9 @@ TYPED_TEST(LRNLayerTest, TestGPUGradientAcrossChannels) { for (int i = 0; i < this->blob_top_->count(); ++i) { this->blob_top_->mutable_cpu_diff()[i] = 1.; } - layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_)); + vector propagate_down(this->blob_bottom_vec_.size(), true); + layer.Backward(this->blob_top_vec_, propagate_down, + &(this->blob_bottom_vec_)); // for (int i = 0; i < this->blob_bottom_->count(); ++i) { // std::cout << "GPU diff " << this->blob_bottom_->cpu_diff()[i] // << std::endl; @@ -254,7 +258,6 @@ TYPED_TEST(LRNLayerTest, TestCPUGradientWithinChannel) { for (int i = 0; i < this->blob_top_->count(); ++i) { this->blob_top_->mutable_cpu_diff()[i] = 1.; } - layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_)); checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_), &(this->blob_top_vec_)); } @@ -272,7 +275,6 @@ TYPED_TEST(LRNLayerTest, TestGPUGradientWithinChannel) { for (int i = 0; i < this->blob_top_->count(); ++i) { this->blob_top_->mutable_cpu_diff()[i] = 1.; } - layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_)); checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_), &(this->blob_top_vec_)); } diff --git a/src/caffe/test/test_maxpool_dropout_layers.cpp b/src/caffe/test/test_maxpool_dropout_layers.cpp index 3862e12..ac6f86f 100644 --- a/src/caffe/test/test_maxpool_dropout_layers.cpp +++ b/src/caffe/test/test_maxpool_dropout_layers.cpp @@ -130,7 +130,9 @@ TYPED_TEST(MaxPoolingDropoutTest, CPUBackward) { for (int i = 0; i < this->blob_top_->count(); ++i) { this->blob_top_->mutable_cpu_diff()[i] = 1.; } - layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_)); + vector propagate_down(this->blob_bottom_vec_.size(), true); + layer.Backward(this->blob_top_vec_, propagate_down, + &(this->blob_bottom_vec_)); const TypeParam* bottom_diff = this->blob_bottom_->cpu_diff(); TypeParam sum = 0.; for (int i = 0; i < this->blob_bottom_->count(); ++i) { @@ -141,8 +143,10 @@ TYPED_TEST(MaxPoolingDropoutTest, CPUBackward) { DropoutLayer dropout_layer(layer_param); dropout_layer.SetUp(this->blob_top_vec_, &(this->blob_top_vec_)); dropout_layer.Forward(this->blob_top_vec_, &(this->blob_top_vec_)); - dropout_layer.Backward(this->blob_top_vec_, true, &(this->blob_top_vec_)); - layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_)); + dropout_layer.Backward(this->blob_top_vec_, propagate_down, + &(this->blob_top_vec_)); + layer.Backward(this->blob_top_vec_, propagate_down, + &(this->blob_bottom_vec_)); TypeParam sum_with_dropout = 0.; bottom_diff = this->blob_bottom_->cpu_diff(); for (int i = 0; i < this->blob_bottom_->count(); ++i) { @@ -164,7 +168,9 @@ TYPED_TEST(MaxPoolingDropoutTest, GPUBackward) { for (int i = 0; i < this->blob_top_->count(); ++i) { this->blob_top_->mutable_cpu_diff()[i] = 1.; } - layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_)); + vector propagate_down(this->blob_bottom_vec_.size(), true); + layer.Backward(this->blob_top_vec_, propagate_down, + &(this->blob_bottom_vec_)); const TypeParam* bottom_diff = this->blob_bottom_->cpu_diff(); TypeParam sum = 0.; for (int i = 0; i < this->blob_bottom_->count(); ++i) { @@ -175,8 +181,10 @@ TYPED_TEST(MaxPoolingDropoutTest, GPUBackward) { DropoutLayer dropout_layer(layer_param); dropout_layer.SetUp(this->blob_top_vec_, &(this->blob_top_vec_)); dropout_layer.Forward(this->blob_top_vec_, &(this->blob_top_vec_)); - dropout_layer.Backward(this->blob_top_vec_, true, &(this->blob_top_vec_)); - layer.Backward(this->blob_top_vec_, true, &(this->blob_bottom_vec_)); + dropout_layer.Backward(this->blob_top_vec_, propagate_down, + &(this->blob_top_vec_)); + layer.Backward(this->blob_top_vec_, propagate_down, + &(this->blob_bottom_vec_)); TypeParam sum_with_dropout = 0.; bottom_diff = this->blob_bottom_->cpu_diff(); for (int i = 0; i < this->blob_bottom_->count(); ++i) { diff --git a/src/caffe/test/test_net.cpp b/src/caffe/test/test_net.cpp index 7f2f67b..eb368cd 100644 --- a/src/caffe/test/test_net.cpp +++ b/src/caffe/test/test_net.cpp @@ -1,7 +1,9 @@ // Copyright 2014 BVLC and contributors. -#include #include +#include + +#include "google/protobuf/text_format.h" #include "gtest/gtest.h" #include "caffe/common.hpp" @@ -15,9 +17,15 @@ namespace caffe { template class NetTest : public ::testing::Test { protected: - virtual void SetUp() { - const string& proto = - "name: 'TestNetwork' " + virtual void InitNetFromProtoString(const string& proto) { + NetParameter param; + CHECK(google::protobuf::TextFormat::ParseFromString(proto, ¶m)); + net_.reset(new Net(param)); + } + + virtual void InitTinyNet(const bool force_backward = false) { + string proto = + "name: 'TinyTestNetwork' " "layers: { " " name: 'data' " " type: DUMMY_DATA " @@ -66,9 +74,84 @@ class NetTest : public ::testing::Test { " bottom: 'label' " " top: 'top_loss' " "} "; - NetParameter param; - CHECK(google::protobuf::TextFormat::ParseFromString(proto, ¶m)); - net_.reset(new Net(param)); + if (force_backward) { + proto += "force_backward: true "; + } + InitNetFromProtoString(proto); + } + + virtual void InitTrickyNet() { + const string& proto = + "name: 'TrickyTestNetwork' " + "layers: { " + " name: 'data' " + " type: DUMMY_DATA " + " dummy_data_param { " + " num: 5 " + " channels: 2 " + " height: 3 " + " width: 4 " + " num: 5 " + " channels: 1 " + " height: 1 " + " width: 1 " + " data_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " } " + " top: 'data' " + " top: 'label' " + "} " + "layers: { " + " name: 'innerproduct' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 1000 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'data' " + " top: 'transformed_data' " + "} " + "layers: { " + " name: 'innerproduct' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 0.01 " + " } " + " bias_filler { " + " type: 'constant' " + " value: 0 " + " } " + " } " + " blobs_lr: 1. " + " blobs_lr: 2. " + " weight_decay: 1. " + " weight_decay: 0. " + " bottom: 'label' " + " top: 'transformed_label' " + "} " + "layers: { " + " name: 'loss' " + " type: SOFTMAX_LOSS " + " bottom: 'transformed_data' " + " bottom: 'transformed_label' " + "} "; + InitNetFromProtoString(proto); } shared_ptr > net_; @@ -78,6 +161,7 @@ typedef ::testing::Types Dtypes; TYPED_TEST_CASE(NetTest, Dtypes); TYPED_TEST(NetTest, TestHasBlob) { + this->InitTinyNet(); EXPECT_TRUE(this->net_->has_blob("data")); EXPECT_TRUE(this->net_->has_blob("label")); EXPECT_TRUE(this->net_->has_blob("innerproduct")); @@ -86,6 +170,7 @@ TYPED_TEST(NetTest, TestHasBlob) { } TYPED_TEST(NetTest, TestGetBlob) { + this->InitTinyNet(); EXPECT_EQ(this->net_->blob_by_name("data"), this->net_->blobs()[0]); EXPECT_EQ(this->net_->blob_by_name("label"), this->net_->blobs()[1]); EXPECT_EQ(this->net_->blob_by_name("innerproduct"), this->net_->blobs()[2]); @@ -94,6 +179,7 @@ TYPED_TEST(NetTest, TestGetBlob) { } TYPED_TEST(NetTest, TestHasLayer) { + this->InitTinyNet(); EXPECT_TRUE(this->net_->has_layer("data")); EXPECT_TRUE(this->net_->has_layer("innerproduct")); EXPECT_TRUE(this->net_->has_layer("loss")); @@ -101,10 +187,53 @@ TYPED_TEST(NetTest, TestHasLayer) { } TYPED_TEST(NetTest, TestGetLayerByName) { + this->InitTinyNet(); EXPECT_EQ(this->net_->layer_by_name("data"), this->net_->layers()[0]); EXPECT_EQ(this->net_->layer_by_name("innerproduct"), this->net_->layers()[1]); EXPECT_EQ(this->net_->layer_by_name("loss"), this->net_->layers()[2]); EXPECT_FALSE(this->net_->layer_by_name("label")); } +TYPED_TEST(NetTest, TestBottomNeedBackward) { + this->InitTinyNet(); + const vector >& bottom_need_backward = + this->net_->bottom_need_backward(); + EXPECT_EQ(3, bottom_need_backward.size()); + EXPECT_EQ(0, bottom_need_backward[0].size()); + EXPECT_EQ(1, bottom_need_backward[1].size()); + EXPECT_EQ(false, bottom_need_backward[1][0]); + EXPECT_EQ(2, bottom_need_backward[2].size()); + EXPECT_EQ(true, bottom_need_backward[2][0]); + EXPECT_EQ(false, bottom_need_backward[2][1]); +} + +TYPED_TEST(NetTest, TestBottomNeedBackwardForce) { + const bool force_backward = true; + this->InitTinyNet(force_backward); + const vector >& bottom_need_backward = + this->net_->bottom_need_backward(); + EXPECT_EQ(3, bottom_need_backward.size()); + EXPECT_EQ(0, bottom_need_backward[0].size()); + EXPECT_EQ(1, bottom_need_backward[1].size()); + EXPECT_EQ(true, bottom_need_backward[1][0]); + EXPECT_EQ(2, bottom_need_backward[2].size()); + EXPECT_EQ(true, bottom_need_backward[2][0]); + EXPECT_EQ(true, bottom_need_backward[2][1]); +} + +TYPED_TEST(NetTest, TestBottomNeedBackwardTricky) { + this->InitTrickyNet(); + const vector >& bottom_need_backward = + this->net_->bottom_need_backward(); + EXPECT_EQ(4, bottom_need_backward.size()); + EXPECT_EQ(0, bottom_need_backward[0].size()); + EXPECT_EQ(1, bottom_need_backward[1].size()); + EXPECT_EQ(false, bottom_need_backward[1][0]); + EXPECT_EQ(1, bottom_need_backward[2].size()); + EXPECT_EQ(false, bottom_need_backward[2][0]); + EXPECT_EQ(2, bottom_need_backward[3].size()); + EXPECT_EQ(true, bottom_need_backward[3][0]); + EXPECT_EQ(true, bottom_need_backward[3][1]); +} + } // namespace caffe diff --git a/tools/net_speed_benchmark.cpp b/tools/net_speed_benchmark.cpp index 36a0077..4b65cfd 100644 --- a/tools/net_speed_benchmark.cpp +++ b/tools/net_speed_benchmark.cpp @@ -64,6 +64,8 @@ int main(int argc, char** argv) { const vector > >& layers = caffe_net.layers(); vector*> >& bottom_vecs = caffe_net.bottom_vecs(); vector*> >& top_vecs = caffe_net.top_vecs(); + const vector >& bottom_need_backward = + caffe_net.bottom_need_backward(); LOG(ERROR) << "*** Benchmark begins ***"; Timer total_timer; total_timer.Start(); @@ -87,7 +89,8 @@ int main(int argc, char** argv) { const string& layername = layers[i]->layer_param().name(); timer.Start(); for (int j = 0; j < total_iter; ++j) { - layers[i]->Backward(top_vecs[i], true, &bottom_vecs[i]); + layers[i]->Backward(top_vecs[i], bottom_need_backward[i], + &bottom_vecs[i]); } LOG(ERROR) << layername << "\tbackward: " << timer.MilliSeconds() << " milli seconds."; -- 2.7.4