Added "propagate_down" param to LayerParameter
authormanuele <manuele.tamburrano@gmail.com>
Fri, 15 May 2015 09:17:00 +0000 (11:17 +0200)
committermanuele <manuele.tamburrano@gmail.com>
Fri, 15 May 2015 09:17:00 +0000 (11:17 +0200)
include/caffe/net.hpp
src/caffe/net.cpp
src/caffe/proto/caffe.proto
src/caffe/test/test_net.cpp

index 075afeb..5665df1 100644 (file)
@@ -137,6 +137,9 @@ class Net {
   inline const vector<Dtype>& blob_loss_weights() const {
     return blob_loss_weights_;
   }
+  inline const vector<bool>& layer_need_backward() const {
+    return layer_need_backward_;
+  }
   /// @brief returns the parameters
   inline const vector<shared_ptr<Blob<Dtype> > >& params() const {
     return params_;
index fd00b12..482b7c5 100644 (file)
@@ -79,10 +79,17 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
     }
     // Setup layer.
     const LayerParameter& layer_param = param.layer(layer_id);
+    if (layer_param.propagate_down_size() > 0) {
+      CHECK_EQ(layer_param.propagate_down_size(),
+          layer_param.bottom_size())
+          << "propagate_down param must be specified "
+          << "either 0 or bottom_size times ";
+    }
     layers_.push_back(LayerRegistry<Dtype>::CreateLayer(layer_param));
     layer_names_.push_back(layer_param.name());
     LOG(INFO) << "Creating Layer " << layer_param.name();
     bool need_backward = false;
+
     // Figure out this layer's input and output
     for (int bottom_id = 0; bottom_id < layer_param.bottom_size();
          ++bottom_id) {
@@ -151,15 +158,33 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
   // Go through the net backwards to determine which blobs contribute to the
   // loss.  We can skip backward computation for blobs that don't contribute
   // to the loss.
+  // Also checks if all bottom blobs don't need backward computation (possible
+  // because the skip_propagate_down param) and so we can skip bacward
+  // computation for the entire layer
   set<string> blobs_under_loss;
+  set<string> blobs_skip_backp;
   for (int layer_id = layers_.size() - 1; layer_id >= 0; --layer_id) {
     bool layer_contributes_loss = false;
+    bool layer_skip_propagate_down = true;
     for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) {
       const string& blob_name = blob_names_[top_id_vecs_[layer_id][top_id]];
       if (layers_[layer_id]->loss(top_id) ||
           (blobs_under_loss.find(blob_name) != blobs_under_loss.end())) {
         layer_contributes_loss = true;
+      }
+      if (blobs_skip_backp.find(blob_name) == blobs_skip_backp.end()) {
+        layer_skip_propagate_down = false;
+      }
+      if (layer_contributes_loss && !layer_skip_propagate_down)
         break;
+    }
+    // If this layer can skip backward computation, also all his bottom blobs
+    // don't need backpropagation
+    if (layer_need_backward_[layer_id] && layer_skip_propagate_down) {
+      layer_need_backward_[layer_id] = false;
+      for (int bottom_id = 0; bottom_id < bottom_vecs_[layer_id].size();
+               ++bottom_id) {
+        bottom_need_backward_[layer_id][bottom_id] = false;
       }
     }
     if (!layer_contributes_loss) { layer_need_backward_[layer_id] = false; }
@@ -178,6 +203,11 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
       } else {
         bottom_need_backward_[layer_id][bottom_id] = false;
       }
+      if (!bottom_need_backward_[layer_id][bottom_id]) {
+        const string& blob_name =
+                   blob_names_[bottom_id_vecs_[layer_id][bottom_id]];
+        blobs_skip_backp.insert(blob_name);
+      }
     }
   }
   // Handle force_backward if needed.
@@ -367,9 +397,9 @@ void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,
 
 // Helper for Net::Init: add a new bottom blob to the net.
 template <typename Dtype>
-int Net<Dtype>::AppendBottom(const NetParameter& param,
-    const int layer_id, const int bottom_id,
-    set<string>* available_blobs, map<string, int>* blob_name_to_idx) {
+int Net<Dtype>::AppendBottom(const NetParameter& param, const int layer_id,
+    const int bottom_id, set<string>* available_blobs,
+    map<string, int>* blob_name_to_idx) {
   const LayerParameter& layer_param = param.layer(layer_id);
   const string& blob_name = layer_param.bottom(bottom_id);
   if (available_blobs->find(blob_name) == available_blobs->end()) {
@@ -381,7 +411,12 @@ int Net<Dtype>::AppendBottom(const NetParameter& param,
   bottom_vecs_[layer_id].push_back(blobs_[blob_id].get());
   bottom_id_vecs_[layer_id].push_back(blob_id);
   available_blobs->erase(blob_name);
-  const bool need_backward = blob_need_backward_[blob_id];
+  bool propagate_down = true;
+  // Check if the backpropagation on bottom_id should be skipped
+  if (layer_param.propagate_down_size() > 0)
+    propagate_down = layer_param.propagate_down(bottom_id);
+  const bool need_backward = blob_need_backward_[blob_id] &&
+                          propagate_down;
   bottom_need_backward_[layer_id].push_back(need_backward);
   return blob_id;
 }
index e523efa..bc17575 100644 (file)
@@ -280,6 +280,10 @@ message LayerParameter {
 
   // The blobs containing the numeric parameters of the layer.
   repeated BlobProto blobs = 7;
+  
+  // Specifies on which bottoms the backpropagation should be skipped.
+  // The size must be either 0 or equal to the number of bottoms.
+  repeated bool propagate_down = 11;
 
   // Rules controlling whether and when a layer is included in the network,
   // based on the current NetState.  You may specify a non-zero number of rules
index 08106e7..782a96b 100644 (file)
@@ -613,6 +613,103 @@ class NetTest : public MultiDeviceTest<TypeParam> {
     InitNetFromProtoString(proto);
   }
 
+  virtual void InitSkipPropNet(bool test_skip_true) {
+    string proto =
+      "name: 'SkipPropTestNetwork' "
+      "layer { "
+      "  name: 'data' "
+      "  type: 'DummyData' "
+      "  dummy_data_param { "
+      "    shape { "
+      "      dim: 5 "
+      "      dim: 2 "
+      "      dim: 3 "
+      "      dim: 4 "
+      "    } "
+      "    data_filler { "
+      "      type: 'gaussian' "
+      "      std: 0.01 "
+      "    } "
+      "    shape { "
+      "      dim: 5 "
+      "    } "
+      "    data_filler { "
+      "      type: 'constant' "
+      "      value: 0 "
+      "    } "
+      "  } "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layer { "
+      "  name: 'silence' "
+      "  bottom: 'label' "
+      "  type: 'Silence' "
+      "} "
+      "layer { "
+      "  name: 'innerproduct' "
+      "  type: 'InnerProduct' "
+      "  inner_product_param { "
+      "    num_output: 1 "
+      "    weight_filler { "
+      "      type: 'gaussian' "
+      "      std: 0.01 "
+      "    } "
+      "    bias_filler { "
+      "      type: 'constant' "
+      "      value: 0 "
+      "    } "
+      "  } "
+      "  param { "
+      "    lr_mult: 1 "
+      "    decay_mult: 1 "
+      "  } "
+      "  param { "
+      "    lr_mult: 2 "
+      "    decay_mult: 0 "
+      "  } "
+      "  bottom: 'data' "
+      "  top: 'innerproduct' "
+      "} "
+      "layer { "
+      "  name: 'ip_fake_labels' "
+      "  type: 'InnerProduct' "
+      "  inner_product_param { "
+      "    num_output: 1 "
+      "    weight_filler { "
+      "      type: 'gaussian' "
+      "      std: 0.01 "
+      "    } "
+      "    bias_filler { "
+      "      type: 'constant' "
+      "      value: 0 "
+      "    } "
+      "  } "
+      "  bottom: 'data' "
+      "  top: 'fake_labels' "
+      "} "
+      "layer { "
+      "  name: 'argmax' "
+      "  bottom: 'fake_labels' "
+      "  top: 'label_argmax' "
+      "  type: 'ArgMax' "
+      "} "
+      "layer { "
+      "  name: 'loss' "
+      "  bottom: 'innerproduct' "
+      "  bottom: 'label_argmax' ";
+    if (test_skip_true)
+      proto += "  propagate_down: [true, false] ";
+    else
+      proto += "  propagate_down: [true, true] ";
+    proto +=
+      "  top: 'cross_entropy_loss' "
+      "  type: 'SigmoidCrossEntropyLoss' "
+      "  loss_weight: 0.1 "
+      "} ";
+    InitNetFromProtoString(proto);
+  }
+
   int seed_;
   shared_ptr<Net<Dtype> > net_;
 };
@@ -2224,4 +2321,52 @@ TYPED_TEST(NetTest, TestReshape) {
   }
 }
 
+TYPED_TEST(NetTest, TestSkipPropagateDown) {
+  // check bottom_need_backward if propagate_down is true
+  this->InitSkipPropNet(false);
+  vector<bool> vec_layer_need_backward = this->net_->layer_need_backward();
+  for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
+    string layer_name = this->net_->layer_names()[layer_id];
+    if (layer_name == "loss") {
+      // access to bottom_need_backward coresponding to label's blob
+      bool need_back = this->net_->bottom_need_backward()[layer_id][1];
+      // if propagate_down is true, the loss layer will try to
+      // backpropagate on labels
+      EXPECT_TRUE(need_back) << "bottom_need_backward should be True";
+    }
+    // layer_need_backward should be True except for data and silence layers
+    if (layer_name.find("data") != std::string::npos ||
+          layer_name == "silence") {
+      EXPECT_FALSE(vec_layer_need_backward[layer_id])
+          << "layer_need_backward for " << layer_name << " should be False";
+    } else {
+      EXPECT_TRUE(vec_layer_need_backward[layer_id])
+          << "layer_need_backward for " << layer_name << " should be True";
+    }
+  }
+  // check bottom_need_backward if propagat_down is false
+  this->InitSkipPropNet(true);
+  vec_layer_need_backward.clear();
+  vec_layer_need_backward = this->net_->layer_need_backward();
+  for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
+    string layer_name = this->net_->layer_names()[layer_id];
+    if (layer_name == "loss") {
+      // access to bottom_need_backward coresponding to label's blob
+      bool need_back = this->net_->bottom_need_backward()[layer_id][1];
+      // if propagate_down is false, the loss layer will not try to
+      // backpropagate on labels
+      EXPECT_FALSE(need_back) << "bottom_need_backward should be False";
+    }
+    // layer_need_backward should be False except for innerproduct and
+    // loss layers
+    if (layer_name == "innerproduct" || layer_name == "loss") {
+      EXPECT_TRUE(vec_layer_need_backward[layer_id])
+          << "layer_need_backward for " << layer_name << " should be True";
+    } else {
+      EXPECT_FALSE(vec_layer_need_backward[layer_id])
+          << "layer_need_backward for " << layer_name << " should be False";
+    }
+  }
+}
+
 }  // namespace caffe