}
}
+TYPED_TEST(NetTest, TestSharedWeightsResume) {
+ typedef typename TypeParam::Dtype Dtype;
+
+ // Create a net with weight sharing; Update it once.
+ Caffe::set_random_seed(this->seed_);
+ this->InitDiffDataSharedWeightsNet();
+ vector<Blob<Dtype>*> bottom;
+ EXPECT_EQ(this->net_->layer_names()[1], "innerproduct1");
+ EXPECT_EQ(this->net_->layer_names()[2], "innerproduct2");
+ Blob<Dtype>* ip1_weights = this->net_->layers()[1]->blobs()[0].get();
+ Blob<Dtype>* ip2_weights = this->net_->layers()[2]->blobs()[0].get();
+ // Check that data blobs of shared weights share the same location in memory.
+ EXPECT_EQ(ip1_weights->cpu_data(), ip2_weights->cpu_data());
+ // Check that diff blobs of shared weights are at different locations in
+ // memory. (The diffs should be accumulated at update time.)
+ EXPECT_NE(ip1_weights->cpu_diff(), ip2_weights->cpu_diff());
+ this->net_->ForwardBackward(bottom);
+ this->net_->Update();
+ Blob<Dtype> shared_params;
+ const bool kReshape = true;
+ const bool kCopyDiff = false;
+ shared_params.CopyFrom(*ip1_weights, kCopyDiff, kReshape);
+ const int count = ip1_weights->count();
+
+ // Write the net to a NetParameter, as in Solver::Snapshot.
+ NetParameter net_param;
+ this->net_->ToProto(&net_param);
+
+ // Reinitialize the net and copy parameters from net_param, as in
+ // Solver::Restore.
+ Caffe::set_random_seed(this->seed_);
+ this->InitDiffDataSharedWeightsNet();
+ this->net_->CopyTrainedLayersFrom(net_param);
+ ip1_weights = this->net_->layers()[1]->blobs()[0].get();
+ ip2_weights = this->net_->layers()[2]->blobs()[0].get();
+ ASSERT_FALSE(NULL == ip1_weights);
+ ASSERT_FALSE(NULL == ip2_weights);
+ EXPECT_NE(ip1_weights, ip2_weights);
+ // Check that data blobs of shared weights share the same location in memory.
+ EXPECT_EQ(ip1_weights->cpu_data(), ip2_weights->cpu_data());
+ for (int i = 0; i < count; ++i) {
+ EXPECT_FLOAT_EQ(shared_params.cpu_data()[i], ip1_weights->cpu_data()[i]);
+ }
+ // Check that diff blobs of shared weights are at different locations in
+ // memory. (The diffs should be accumulated at update time.)
+ EXPECT_NE(ip1_weights->cpu_diff(), ip2_weights->cpu_diff());
+}
+
TYPED_TEST(NetTest, TestParamPropagateDown) {
typedef typename TypeParam::Dtype Dtype;
vector<Blob<Dtype>*> bottom;