save/restore shared weights unit test
authorJeff Donahue <jeff.donahue@gmail.com>
Fri, 3 Oct 2014 00:11:38 +0000 (17:11 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Fri, 3 Oct 2014 00:11:38 +0000 (17:11 -0700)
src/caffe/test/test_net.cpp

index d48a60e..930486d 100644 (file)
@@ -1068,6 +1068,54 @@ TYPED_TEST(NetTest, TestSharedWeightsUpdate) {
   }
 }
 
+TYPED_TEST(NetTest, TestSharedWeightsResume) {
+  typedef typename TypeParam::Dtype Dtype;
+
+  // Create a net with weight sharing; Update it once.
+  Caffe::set_random_seed(this->seed_);
+  this->InitDiffDataSharedWeightsNet();
+  vector<Blob<Dtype>*> bottom;
+  EXPECT_EQ(this->net_->layer_names()[1], "innerproduct1");
+  EXPECT_EQ(this->net_->layer_names()[2], "innerproduct2");
+  Blob<Dtype>* ip1_weights = this->net_->layers()[1]->blobs()[0].get();
+  Blob<Dtype>* ip2_weights = this->net_->layers()[2]->blobs()[0].get();
+  // Check that data blobs of shared weights share the same location in memory.
+  EXPECT_EQ(ip1_weights->cpu_data(), ip2_weights->cpu_data());
+  // Check that diff blobs of shared weights are at different locations in
+  // memory.  (The diffs should be accumulated at update time.)
+  EXPECT_NE(ip1_weights->cpu_diff(), ip2_weights->cpu_diff());
+  this->net_->ForwardBackward(bottom);
+  this->net_->Update();
+  Blob<Dtype> shared_params;
+  const bool kReshape = true;
+  const bool kCopyDiff = false;
+  shared_params.CopyFrom(*ip1_weights, kCopyDiff, kReshape);
+  const int count = ip1_weights->count();
+
+  // Write the net to a NetParameter, as in Solver::Snapshot.
+  NetParameter net_param;
+  this->net_->ToProto(&net_param);
+
+  // Reinitialize the net and copy parameters from net_param, as in
+  // Solver::Restore.
+  Caffe::set_random_seed(this->seed_);
+  this->InitDiffDataSharedWeightsNet();
+  this->net_->CopyTrainedLayersFrom(net_param);
+  ip1_weights = this->net_->layers()[1]->blobs()[0].get();
+  ip2_weights = this->net_->layers()[2]->blobs()[0].get();
+  ASSERT_FALSE(NULL == ip1_weights);
+  ASSERT_FALSE(NULL == ip2_weights);
+  EXPECT_NE(ip1_weights, ip2_weights);
+  // Check that data blobs of shared weights share the same location in memory.
+  EXPECT_EQ(ip1_weights->cpu_data(), ip2_weights->cpu_data());
+  for (int i = 0; i < count; ++i) {
+    EXPECT_FLOAT_EQ(shared_params.cpu_data()[i], ip1_weights->cpu_data()[i]);
+  }
+  // Check that diff blobs of shared weights are at different locations in
+  // memory.  (The diffs should be accumulated at update time.)
+  EXPECT_NE(ip1_weights->cpu_diff(), ip2_weights->cpu_diff());
+}
+
 TYPED_TEST(NetTest, TestParamPropagateDown) {
   typedef typename TypeParam::Dtype Dtype;
   vector<Blob<Dtype>*> bottom;