inline const vector<Dtype>& blob_loss_weights() const {
return blob_loss_weights_;
}
+ inline const vector<bool>& layer_need_backward() const {
+ return layer_need_backward_;
+ }
/// @brief returns the parameters
inline const vector<shared_ptr<Blob<Dtype> > >& params() const {
return params_;
}
// Setup layer.
const LayerParameter& layer_param = param.layer(layer_id);
+ if (layer_param.propagate_down_size() > 0) {
+ CHECK_EQ(layer_param.propagate_down_size(),
+ layer_param.bottom_size())
+ << "propagate_down param must be specified "
+ << "either 0 or bottom_size times ";
+ }
layers_.push_back(LayerRegistry<Dtype>::CreateLayer(layer_param));
layer_names_.push_back(layer_param.name());
LOG(INFO) << "Creating Layer " << layer_param.name();
bool need_backward = false;
+
// Figure out this layer's input and output
for (int bottom_id = 0; bottom_id < layer_param.bottom_size();
++bottom_id) {
// Go through the net backwards to determine which blobs contribute to the
// loss. We can skip backward computation for blobs that don't contribute
// to the loss.
+ // Also checks if all bottom blobs don't need backward computation (possible
+ // because the skip_propagate_down param) and so we can skip bacward
+ // computation for the entire layer
set<string> blobs_under_loss;
+ set<string> blobs_skip_backp;
for (int layer_id = layers_.size() - 1; layer_id >= 0; --layer_id) {
bool layer_contributes_loss = false;
+ bool layer_skip_propagate_down = true;
for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) {
const string& blob_name = blob_names_[top_id_vecs_[layer_id][top_id]];
if (layers_[layer_id]->loss(top_id) ||
(blobs_under_loss.find(blob_name) != blobs_under_loss.end())) {
layer_contributes_loss = true;
+ }
+ if (blobs_skip_backp.find(blob_name) == blobs_skip_backp.end()) {
+ layer_skip_propagate_down = false;
+ }
+ if (layer_contributes_loss && !layer_skip_propagate_down)
break;
+ }
+ // If this layer can skip backward computation, also all his bottom blobs
+ // don't need backpropagation
+ if (layer_need_backward_[layer_id] && layer_skip_propagate_down) {
+ layer_need_backward_[layer_id] = false;
+ for (int bottom_id = 0; bottom_id < bottom_vecs_[layer_id].size();
+ ++bottom_id) {
+ bottom_need_backward_[layer_id][bottom_id] = false;
}
}
if (!layer_contributes_loss) { layer_need_backward_[layer_id] = false; }
} else {
bottom_need_backward_[layer_id][bottom_id] = false;
}
+ if (!bottom_need_backward_[layer_id][bottom_id]) {
+ const string& blob_name =
+ blob_names_[bottom_id_vecs_[layer_id][bottom_id]];
+ blobs_skip_backp.insert(blob_name);
+ }
}
}
// Handle force_backward if needed.
// Helper for Net::Init: add a new bottom blob to the net.
template <typename Dtype>
-int Net<Dtype>::AppendBottom(const NetParameter& param,
- const int layer_id, const int bottom_id,
- set<string>* available_blobs, map<string, int>* blob_name_to_idx) {
+int Net<Dtype>::AppendBottom(const NetParameter& param, const int layer_id,
+ const int bottom_id, set<string>* available_blobs,
+ map<string, int>* blob_name_to_idx) {
const LayerParameter& layer_param = param.layer(layer_id);
const string& blob_name = layer_param.bottom(bottom_id);
if (available_blobs->find(blob_name) == available_blobs->end()) {
bottom_vecs_[layer_id].push_back(blobs_[blob_id].get());
bottom_id_vecs_[layer_id].push_back(blob_id);
available_blobs->erase(blob_name);
- const bool need_backward = blob_need_backward_[blob_id];
+ bool propagate_down = true;
+ // Check if the backpropagation on bottom_id should be skipped
+ if (layer_param.propagate_down_size() > 0)
+ propagate_down = layer_param.propagate_down(bottom_id);
+ const bool need_backward = blob_need_backward_[blob_id] &&
+ propagate_down;
bottom_need_backward_[layer_id].push_back(need_backward);
return blob_id;
}
// The blobs containing the numeric parameters of the layer.
repeated BlobProto blobs = 7;
+
+ // Specifies on which bottoms the backpropagation should be skipped.
+ // The size must be either 0 or equal to the number of bottoms.
+ repeated bool propagate_down = 11;
// Rules controlling whether and when a layer is included in the network,
// based on the current NetState. You may specify a non-zero number of rules
InitNetFromProtoString(proto);
}
+ virtual void InitSkipPropNet(bool test_skip_true) {
+ string proto =
+ "name: 'SkipPropTestNetwork' "
+ "layer { "
+ " name: 'data' "
+ " type: 'DummyData' "
+ " dummy_data_param { "
+ " shape { "
+ " dim: 5 "
+ " dim: 2 "
+ " dim: 3 "
+ " dim: 4 "
+ " } "
+ " data_filler { "
+ " type: 'gaussian' "
+ " std: 0.01 "
+ " } "
+ " shape { "
+ " dim: 5 "
+ " } "
+ " data_filler { "
+ " type: 'constant' "
+ " value: 0 "
+ " } "
+ " } "
+ " top: 'data' "
+ " top: 'label' "
+ "} "
+ "layer { "
+ " name: 'silence' "
+ " bottom: 'label' "
+ " type: 'Silence' "
+ "} "
+ "layer { "
+ " name: 'innerproduct' "
+ " type: 'InnerProduct' "
+ " inner_product_param { "
+ " num_output: 1 "
+ " weight_filler { "
+ " type: 'gaussian' "
+ " std: 0.01 "
+ " } "
+ " bias_filler { "
+ " type: 'constant' "
+ " value: 0 "
+ " } "
+ " } "
+ " param { "
+ " lr_mult: 1 "
+ " decay_mult: 1 "
+ " } "
+ " param { "
+ " lr_mult: 2 "
+ " decay_mult: 0 "
+ " } "
+ " bottom: 'data' "
+ " top: 'innerproduct' "
+ "} "
+ "layer { "
+ " name: 'ip_fake_labels' "
+ " type: 'InnerProduct' "
+ " inner_product_param { "
+ " num_output: 1 "
+ " weight_filler { "
+ " type: 'gaussian' "
+ " std: 0.01 "
+ " } "
+ " bias_filler { "
+ " type: 'constant' "
+ " value: 0 "
+ " } "
+ " } "
+ " bottom: 'data' "
+ " top: 'fake_labels' "
+ "} "
+ "layer { "
+ " name: 'argmax' "
+ " bottom: 'fake_labels' "
+ " top: 'label_argmax' "
+ " type: 'ArgMax' "
+ "} "
+ "layer { "
+ " name: 'loss' "
+ " bottom: 'innerproduct' "
+ " bottom: 'label_argmax' ";
+ if (test_skip_true)
+ proto += " propagate_down: [true, false] ";
+ else
+ proto += " propagate_down: [true, true] ";
+ proto +=
+ " top: 'cross_entropy_loss' "
+ " type: 'SigmoidCrossEntropyLoss' "
+ " loss_weight: 0.1 "
+ "} ";
+ InitNetFromProtoString(proto);
+ }
+
int seed_;
shared_ptr<Net<Dtype> > net_;
};
}
}
+TYPED_TEST(NetTest, TestSkipPropagateDown) {
+ // check bottom_need_backward if propagate_down is true
+ this->InitSkipPropNet(false);
+ vector<bool> vec_layer_need_backward = this->net_->layer_need_backward();
+ for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
+ string layer_name = this->net_->layer_names()[layer_id];
+ if (layer_name == "loss") {
+ // access to bottom_need_backward coresponding to label's blob
+ bool need_back = this->net_->bottom_need_backward()[layer_id][1];
+ // if propagate_down is true, the loss layer will try to
+ // backpropagate on labels
+ EXPECT_TRUE(need_back) << "bottom_need_backward should be True";
+ }
+ // layer_need_backward should be True except for data and silence layers
+ if (layer_name.find("data") != std::string::npos ||
+ layer_name == "silence") {
+ EXPECT_FALSE(vec_layer_need_backward[layer_id])
+ << "layer_need_backward for " << layer_name << " should be False";
+ } else {
+ EXPECT_TRUE(vec_layer_need_backward[layer_id])
+ << "layer_need_backward for " << layer_name << " should be True";
+ }
+ }
+ // check bottom_need_backward if propagat_down is false
+ this->InitSkipPropNet(true);
+ vec_layer_need_backward.clear();
+ vec_layer_need_backward = this->net_->layer_need_backward();
+ for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
+ string layer_name = this->net_->layer_names()[layer_id];
+ if (layer_name == "loss") {
+ // access to bottom_need_backward coresponding to label's blob
+ bool need_back = this->net_->bottom_need_backward()[layer_id][1];
+ // if propagate_down is false, the loss layer will not try to
+ // backpropagate on labels
+ EXPECT_FALSE(need_back) << "bottom_need_backward should be False";
+ }
+ // layer_need_backward should be False except for innerproduct and
+ // loss layers
+ if (layer_name == "innerproduct" || layer_name == "loss") {
+ EXPECT_TRUE(vec_layer_need_backward[layer_id])
+ << "layer_need_backward for " << layer_name << " should be True";
+ } else {
+ EXPECT_FALSE(vec_layer_need_backward[layer_id])
+ << "layer_need_backward for " << layer_name << " should be False";
+ }
+ }
+}
+
} // namespace caffe