test_net.cpp: add TestForcePropagateDown
authorJeff Donahue <jeff.donahue@gmail.com>
Mon, 4 Apr 2016 18:36:15 +0000 (11:36 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Mon, 4 Apr 2016 18:36:40 +0000 (11:36 -0700)
src/caffe/test/test_net.cpp

index 1e0788e..92fd317 100644 (file)
@@ -716,6 +716,61 @@ class NetTest : public MultiDeviceTest<TypeParam> {
     InitNetFromProtoString(proto);
   }
 
+  virtual void InitForcePropNet(bool test_force_true) {
+    string proto =
+      "name: 'ForcePropTestNetwork' "
+      "layer { "
+      "  name: 'data' "
+      "  type: 'DummyData' "
+      "  dummy_data_param { "
+      "    shape { "
+      "      dim: 5 "
+      "      dim: 2 "
+      "      dim: 3 "
+      "      dim: 4 "
+      "    } "
+      "    data_filler { "
+      "      type: 'gaussian' "
+      "      std: 0.01 "
+      "    } "
+      "    shape { "
+      "      dim: 5 "
+      "    } "
+      "    data_filler { "
+      "      type: 'constant' "
+      "      value: 0 "
+      "    } "
+      "  } "
+      "  top: 'data' "
+      "  top: 'label' "
+      "} "
+      "layer { "
+      "  name: 'innerproduct' "
+      "  type: 'InnerProduct' "
+      "  inner_product_param { "
+      "    num_output: 1 "
+      "    weight_filler { "
+      "      type: 'gaussian' "
+      "      std: 0.01 "
+      "    } "
+      "  } "
+      "  bottom: 'data' "
+      "  top: 'innerproduct' ";
+    if (test_force_true) {
+      proto += "  propagate_down: true ";
+    }
+    proto +=
+      "} "
+      "layer { "
+      "  name: 'loss' "
+      "  bottom: 'innerproduct' "
+      "  bottom: 'label' "
+      "  top: 'cross_entropy_loss' "
+      "  type: 'SigmoidCrossEntropyLoss' "
+      "} ";
+    InitNetFromProtoString(proto);
+  }
+
   int seed_;
   shared_ptr<Net<Dtype> > net_;
 };
@@ -2371,4 +2426,51 @@ TYPED_TEST(NetTest, TestSkipPropagateDown) {
   }
 }
 
+TYPED_TEST(NetTest, TestForcePropagateDown) {
+  this->InitForcePropNet(false);
+  vector<bool> layer_need_backward = this->net_->layer_need_backward();
+  for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
+    const string& layer_name = this->net_->layer_names()[layer_id];
+    const vector<bool> need_backward =
+        this->net_->bottom_need_backward()[layer_id];
+    if (layer_name == "data") {
+      ASSERT_EQ(need_backward.size(), 0);
+      EXPECT_FALSE(layer_need_backward[layer_id]);
+    } else if (layer_name == "innerproduct") {
+      ASSERT_EQ(need_backward.size(), 1);
+      EXPECT_FALSE(need_backward[0]);  // data
+      EXPECT_TRUE(layer_need_backward[layer_id]);
+    } else if (layer_name == "loss") {
+      ASSERT_EQ(need_backward.size(), 2);
+      EXPECT_TRUE(need_backward[0]);   // innerproduct
+      EXPECT_FALSE(need_backward[1]);  // label
+      EXPECT_TRUE(layer_need_backward[layer_id]);
+    } else {
+      LOG(FATAL) << "Unknown layer: " << layer_name;
+    }
+  }
+  this->InitForcePropNet(true);
+  layer_need_backward = this->net_->layer_need_backward();
+  for (int layer_id = 0; layer_id < this->net_->layers().size(); ++layer_id) {
+    const string& layer_name = this->net_->layer_names()[layer_id];
+    const vector<bool> need_backward =
+        this->net_->bottom_need_backward()[layer_id];
+    if (layer_name == "data") {
+      ASSERT_EQ(need_backward.size(), 0);
+      EXPECT_FALSE(layer_need_backward[layer_id]);
+    } else if (layer_name == "innerproduct") {
+      ASSERT_EQ(need_backward.size(), 1);
+      EXPECT_TRUE(need_backward[0]);  // data
+      EXPECT_TRUE(layer_need_backward[layer_id]);
+    } else if (layer_name == "loss") {
+      ASSERT_EQ(need_backward.size(), 2);
+      EXPECT_TRUE(need_backward[0]);   // innerproduct
+      EXPECT_FALSE(need_backward[1]);  // label
+      EXPECT_TRUE(layer_need_backward[layer_id]);
+    } else {
+      LOG(FATAL) << "Unknown layer: " << layer_name;
+    }
+  }
+}
+
 }  // namespace caffe