weight sharing
authorJeff Donahue <jeff.donahue@gmail.com>
Mon, 9 Jun 2014 02:53:45 +0000 (19:53 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Thu, 26 Jun 2014 19:41:29 +0000 (12:41 -0700)
include/caffe/net.hpp
src/caffe/net.cpp
src/caffe/proto/caffe.proto
src/caffe/test/test_net.cpp

index d05ca09..aa540ed 100644 (file)
@@ -6,6 +6,7 @@
 #include <map>
 #include <set>
 #include <string>
+#include <utility>
 #include <vector>
 
 #include "caffe/blob.hpp"
 #include "caffe/proto/caffe.pb.h"
 
 using std::map;
-using std::vector;
+using std::pair;
 using std::set;
 using std::string;
+using std::vector;
 
 namespace caffe {
 
@@ -103,6 +105,7 @@ class Net {
   const shared_ptr<Blob<Dtype> > blob_by_name(const string& blob_name);
   bool has_layer(const string& layer_name);
   const shared_ptr<Layer<Dtype> > layer_by_name(const string& layer_name);
+  const map<string, int>& param_names_index() { return param_names_index_; }
 
  protected:
   // Helpers for Init.
@@ -114,6 +117,8 @@ class Net {
   int AppendBottom(const NetParameter& param, const int layer_id,
                    const int bottom_id, set<string>* available_blobs,
                    map<string, int>* blob_name_to_idx);
+  void AppendParam(const NetParameter& param, const int layer_id,
+                   const int param_id);
   // Function to get misc parameters, e.g. the learning rate multiplier and
   // weight decay.
   void GetLearningRateAndWeightDecay();
@@ -138,6 +143,9 @@ class Net {
   // top_vecs stores the vectors containing the output for each layer
   vector<vector<Blob<Dtype>*> > top_vecs_;
   vector<vector<int> > top_id_vecs_;
+  vector<int> param_owners_;
+  vector<pair<int, int> > param_net_indices_;
+  map<string, int> param_names_index_;
   // blob indices for the input and the output of the net
   vector<int> net_input_blob_indices_;
   vector<int> net_output_blob_indices_;
index a653761..fc532b7 100644 (file)
 #include "caffe/net.hpp"
 #include "caffe/util/io.hpp"
 #include "caffe/util/insert_splits.hpp"
+#include "caffe/util/math_functions.hpp"
 #include "caffe/util/upgrade_proto.hpp"
 
-using std::pair;
+using std::make_pair;
 using std::map;
+using std::pair;
 using std::set;
 
 namespace caffe {
@@ -86,8 +88,9 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
     }
     DLOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype);
     const int blobs_lr_size = layers_[layer_id]->layer_param().blobs_lr_size();
-    CHECK(blobs_lr_size == layers_[layer_id]->blobs().size() ||
-          blobs_lr_size == 0) << "Incorrect blobs lr size: should be either 0 "
+    const int num_param_blobs = layers_[layer_id]->blobs().size();
+    CHECK(blobs_lr_size == num_param_blobs || blobs_lr_size == 0)
+        << "Incorrect blobs lr size: should be either 0 "
         << "or the same as the number of the layer's parameter blobs.";
     if (blobs_lr_size) {
       // Check if this layer needs backward operation itself
@@ -100,6 +103,17 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
       // learning rate to be 1. Thus we will need to perform backward.
       need_backward = true;
     }
+    const int blob_name_size = layer_param.blob_name_size();
+    CHECK(blob_name_size == num_param_blobs || blob_name_size == 0)
+        << "Incorrect blob_name size: should be either 0 or the same as "
+           "the number of the layer's parameter blobs: " << num_param_blobs;
+    const int blob_share_mode_size = layer_param.blob_share_mode_size();
+    CHECK(blob_share_mode_size == num_param_blobs || blob_share_mode_size == 0)
+        << "Incorrect blob_share_mode size: should be either 0 or the same as "
+           "the number of the layer's parameter blobs: " << num_param_blobs;
+    for (int param_id = 0; param_id < num_param_blobs; ++param_id) {
+      AppendParam(param, layer_id, param_id);
+    }
     // Finally, set the backward flag
     layer_need_backward_.push_back(need_backward);
     if (need_backward) {
@@ -218,13 +232,68 @@ int Net<Dtype>::AppendBottom(const NetParameter& param,
 }
 
 template <typename Dtype>
+void Net<Dtype>::AppendParam(const NetParameter& param, const int layer_id,
+                             const int param_id) {
+  const LayerParameter& layer_param = layers_[layer_id]->layer_param();
+  const int blob_name_size = layer_param.blob_name_size();
+  string param_name;
+  if (blob_name_size) {
+    param_name = layer_param.blob_name(param_id);
+  }
+  const int net_param_id = params_.size();
+  params_.push_back(layers_[layer_id]->blobs()[param_id]);
+  param_net_indices_.push_back(make_pair(layer_id, param_id));
+  if (!blob_name_size || !param_name.size() || (param_name.size() &&
+      param_names_index_.find(param_name) == param_names_index_.end())) {
+    // This layer "owns" this parameter blob -- it is either anonymous
+    // (i.e., not given a param_name) or explicitly given a name that we
+    // haven't already seen.
+    param_owners_.push_back(-1);
+    if (blob_name_size) {
+      param_names_index_[param_name] = net_param_id;
+    }
+  } else {
+    // Named param blob with name we've seen before: share params
+    const int owner_net_param_id = param_names_index_[param_name];
+    param_owners_.push_back(owner_net_param_id);
+    const pair<int, int>& owner_index =
+        param_net_indices_[owner_net_param_id];
+    const int owner_layer_id = owner_index.first;
+    const int owner_param_id = owner_index.second;
+    LOG(INFO) << "Sharing parameters '" << param_name << "' owned by "
+              << "layer '" << layer_names_[owner_layer_id] << "', param "
+              << "index " << owner_param_id;
+    Blob<Dtype>* this_blob = layers_[layer_id]->blobs()[param_id].get();
+    Blob<Dtype>* owner_blob =
+        layers_[owner_layer_id]->blobs()[owner_param_id].get();
+    const int blob_share_mode_size = layer_param.blob_share_mode_size();
+    if (blob_share_mode_size > param_id &&
+        (layer_param.blob_share_mode(param_id) ==
+         LayerParameter_DimCheckMode_PERMISSIVE)) {
+      // Permissive dimension checking -- only check counts are the same.
+      CHECK_EQ(this_blob->count(), owner_blob->count())
+          << "Shared parameter blobs must have the same count.";
+    } else {
+      // Strict dimension checking -- all dims must be the same.
+      CHECK_EQ(this_blob->num(), owner_blob->num())
+          << "Shared parameter blobs must have the same num.";
+      CHECK_EQ(this_blob->channels(), owner_blob->channels())
+          << "Shared parameter blobs must have the same channels.";
+      CHECK_EQ(this_blob->height(), owner_blob->height())
+          << "Shared parameter blobs must have the same height.";
+      CHECK_EQ(this_blob->width(), owner_blob->width())
+          << "Shared parameter blobs must have the same width.";
+    }
+    layers_[layer_id]->blobs()[param_id]->ShareData(
+        *layers_[owner_layer_id]->blobs()[owner_param_id]);
+  }
+}
+
+template <typename Dtype>
 void Net<Dtype>::GetLearningRateAndWeightDecay() {
   LOG(INFO) << "Collecting Learning Rate and Weight Decay.";
   for (int i = 0; i < layers_.size(); ++i) {
     vector<shared_ptr<Blob<Dtype> > >& layer_blobs = layers_[i]->blobs();
-    for (int j = 0; j < layer_blobs.size(); ++j) {
-      params_.push_back(layer_blobs[j]);
-    }
     // push the learning rate mutlipliers
     if (layers_[i]->layer_param().blobs_lr_size()) {
       CHECK_EQ(layers_[i]->layer_param().blobs_lr_size(), layer_blobs.size());
@@ -403,8 +472,36 @@ void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) {
 
 template <typename Dtype>
 void Net<Dtype>::Update() {
+  // First, accumulate the diffs of any shared parameters into their owner's
+  // diff. (Assumes that the learning rate, weight decay, etc. have already been
+  // accounted for in the current diff.)
   for (int i = 0; i < params_.size(); ++i) {
-    params_[i]->Update();
+    if (param_owners_[i] < 0) {
+      continue;
+    }
+    const int count = params_[i]->count();
+    const Dtype* this_diff;
+    Dtype* owner_diff;
+    switch (Caffe::mode()) {
+    case Caffe::CPU:
+      this_diff = params_[i]->cpu_diff();
+      owner_diff = params_[param_owners_[i]]->mutable_cpu_diff();
+      caffe_add(count, this_diff, owner_diff, owner_diff);
+      break;
+    case Caffe::GPU:
+      this_diff = params_[i]->gpu_diff();
+      owner_diff = params_[param_owners_[i]]->mutable_gpu_diff();
+      caffe_gpu_add(count, this_diff, owner_diff, owner_diff);
+      break;
+    default:
+      LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
+    }
+  }
+  // Now, update the owned parameters.
+  for (int i = 0; i < params_.size(); ++i) {
+    if (param_owners_[i] < 0) {
+      params_[i]->Update();
+    }
   }
 }
 
index 954f103..76b643d 100644 (file)
@@ -170,6 +170,18 @@ message LayerParameter {
 
   // The blobs containing the numeric parameters of the layer
   repeated BlobProto blobs = 6;
+  // The names of the parameter blobs -- useful for sharing parameters among
+  // layers (but never required).
+  repeated string blob_name = 1001;
+  // Whether to require shared weights to have the same shape, or just the same
+  // count -- defaults to STRICT if unspecified.
+  repeated DimCheckMode blob_share_mode = 1002;
+  enum DimCheckMode {
+    // STRICT (default) requires that num, channels, height, width each match.
+    STRICT = 0;
+    // PERMISSIVE requires only the count (num*channels*height*width) to match.
+    PERMISSIVE = 1;
+  }
   // The ratio that is multiplied on the global learning rate. If you want to
   // set the learning ratio for one blob, you need to set it for all blobs.
   repeated float blobs_lr = 7;
index 8fb02fc..1eb7b07 100644 (file)
@@ -1,6 +1,7 @@
 // Copyright 2014 BVLC and contributors.
 
 #include <string>
+#include <utility>
 #include <vector>
 
 #include "google/protobuf/text_format.h"
@@ -8,6 +9,7 @@
 #include "gtest/gtest.h"
 #include "caffe/common.hpp"
 #include "caffe/net.hpp"
+#include "caffe/util/math_functions.hpp"
 #include "caffe/test/test_gradient_check_util.hpp"
 
 #include "caffe/test/test_caffe_main.hpp"
@@ -17,6 +19,8 @@ namespace caffe {
 template <typename Dtype>
 class NetTest : public ::testing::Test {
  protected:
+  NetTest() : seed_(1701) {}
+
   virtual void InitNetFromProtoString(const string& proto) {
     NetParameter param;
     CHECK(google::protobuf::TextFormat::ParseFromString(proto, &param));
@@ -210,6 +214,245 @@ class NetTest : public ::testing::Test {
     InitNetFromProtoString(proto);
   }
 
+  virtual void InitUnsharedWeightsNet() {
+    const string& proto =
+        "name: 'UnsharedWeightsNetwork' "
+        "layers: { "
+        "  name: 'data' "
+        "  type: DUMMY_DATA "
+        "  dummy_data_param { "
+        "    num: 5 "
+        "    channels: 2 "
+        "    height: 3 "
+        "    width: 4 "
+        "    data_filler { "
+        "      type: 'gaussian' "
+        "      std: 0.01 "
+        "    } "
+        "  } "
+        "  top: 'data' "
+        "} "
+        "layers: { "
+        "  name: 'innerproduct1' "
+        "  type: INNER_PRODUCT "
+        "  inner_product_param { "
+        "    num_output: 10 "
+        "    bias_term: false "
+        "    weight_filler { "
+        "      type: 'gaussian' "
+        "      std: 10 "
+        "    } "
+        "  } "
+        "  blob_name: 'unsharedweights1' "
+        "  bottom: 'data' "
+        "  top: 'innerproduct1' "
+        "} "
+        "layers: { "
+        "  name: 'innerproduct2' "
+        "  type: INNER_PRODUCT "
+        "  inner_product_param { "
+        "    num_output: 10 "
+        "    bias_term: false "
+        "    weight_filler { "
+        "      type: 'gaussian' "
+        "      std: 10 "
+        "    } "
+        "  } "
+        "  blob_name: 'unsharedweights2' "
+        "  bottom: 'data' "
+        "  top: 'innerproduct2' "
+        "} "
+        "layers: { "
+        "  name: 'loss' "
+        "  type: EUCLIDEAN_LOSS "
+        "  bottom: 'innerproduct1' "
+        "  bottom: 'innerproduct2' "
+        "} ";
+    InitNetFromProtoString(proto);
+  }
+
+  virtual void InitSharedWeightsNet() {
+    const string& proto =
+        "name: 'SharedWeightsNetwork' "
+        "layers: { "
+        "  name: 'data' "
+        "  type: DUMMY_DATA "
+        "  dummy_data_param { "
+        "    num: 5 "
+        "    channels: 2 "
+        "    height: 3 "
+        "    width: 4 "
+        "    data_filler { "
+        "      type: 'gaussian' "
+        "      std: 0.01 "
+        "    } "
+        "  } "
+        "  top: 'data' "
+        "} "
+        "layers: { "
+        "  name: 'innerproduct1' "
+        "  type: INNER_PRODUCT "
+        "  inner_product_param { "
+        "    num_output: 10 "
+        "    bias_term: false "
+        "    weight_filler { "
+        "      type: 'gaussian' "
+        "      std: 10 "
+        "    } "
+        "  } "
+        "  blob_name: 'sharedweights' "
+        "  bottom: 'data' "
+        "  top: 'innerproduct1' "
+        "} "
+        "layers: { "
+        "  name: 'innerproduct2' "
+        "  type: INNER_PRODUCT "
+        "  inner_product_param { "
+        "    num_output: 10 "
+        "    bias_term: false "
+        "    weight_filler { "
+        "      type: 'gaussian' "
+        "      std: 10 "
+        "    } "
+        "  } "
+        "  blob_name: 'sharedweights' "
+        "  bottom: 'data' "
+        "  top: 'innerproduct2' "
+        "} "
+        "layers: { "
+        "  name: 'loss' "
+        "  type: EUCLIDEAN_LOSS "
+        "  bottom: 'innerproduct1' "
+        "  bottom: 'innerproduct2' "
+        "} ";
+    InitNetFromProtoString(proto);
+  }
+
+  virtual void InitDiffDataUnsharedWeightsNet() {
+    const string& proto =
+        "name: 'DiffDataUnsharedWeightsNetwork' "
+        "layers: { "
+        "  name: 'data' "
+        "  type: DUMMY_DATA "
+        "  dummy_data_param { "
+        "    num: 10 "
+        "    channels: 10 "
+        "    height: 1 "
+        "    width: 1 "
+        "    num: 10 "
+        "    channels: 10 "
+        "    height: 1 "
+        "    width: 1 "
+        "    data_filler { "
+        "      type: 'gaussian' "
+        "      std: 10 "
+        "    } "
+        "  } "
+        "  top: 'data1' "
+        "  top: 'data2' "
+        "} "
+        "layers: { "
+        "  name: 'innerproduct1' "
+        "  type: INNER_PRODUCT "
+        "  inner_product_param { "
+        "    num_output: 10 "
+        "    bias_term: false "
+        "    weight_filler { "
+        "      type: 'constant' "
+        "      value: 0.5 "
+        "    } "
+        "  } "
+        "  blob_name: 'unsharedweights1' "
+        "  bottom: 'data1' "
+        "  top: 'innerproduct1' "
+        "} "
+        "layers: { "
+        "  name: 'innerproduct2' "
+        "  type: INNER_PRODUCT "
+        "  inner_product_param { "
+        "    num_output: 10 "
+        "    bias_term: false "
+        "    weight_filler { "
+        "      type: 'constant' "
+        "      value: 0.5 "
+        "    } "
+        "  } "
+        "  blob_name: 'unsharedweights2' "
+        "  bottom: 'innerproduct1' "
+        "  top: 'innerproduct2' "
+        "} "
+        "layers: { "
+        "  name: 'loss' "
+        "  type: EUCLIDEAN_LOSS "
+        "  bottom: 'data2' "
+        "  bottom: 'innerproduct2' "
+        "} ";
+    InitNetFromProtoString(proto);
+  }
+
+  virtual void InitDiffDataSharedWeightsNet() {
+    const string& proto =
+        "name: 'DiffDataSharedWeightsNetwork' "
+        "layers: { "
+        "  name: 'data' "
+        "  type: DUMMY_DATA "
+        "  dummy_data_param { "
+        "    num: 10 "
+        "    channels: 10 "
+        "    height: 1 "
+        "    width: 1 "
+        "    num: 10 "
+        "    channels: 10 "
+        "    height: 1 "
+        "    width: 1 "
+        "    data_filler { "
+        "      type: 'gaussian' "
+        "      std: 10 "
+        "    } "
+        "  } "
+        "  top: 'data1' "
+        "  top: 'data2' "
+        "} "
+        "layers: { "
+        "  name: 'innerproduct1' "
+        "  type: INNER_PRODUCT "
+        "  inner_product_param { "
+        "    num_output: 10 "
+        "    bias_term: false "
+        "    weight_filler { "
+        "      type: 'constant' "
+        "      value: 0.5 "
+        "    } "
+        "  } "
+        "  blob_name: 'sharedweights' "
+        "  bottom: 'data1' "
+        "  top: 'innerproduct1' "
+        "} "
+        "layers: { "
+        "  name: 'innerproduct2' "
+        "  type: INNER_PRODUCT "
+        "  inner_product_param { "
+        "    num_output: 10 "
+        "    bias_term: false "
+        "    weight_filler { "
+        "      type: 'constant' "
+        "      value: 0.5 "
+        "    } "
+        "  } "
+        "  blob_name: 'sharedweights' "
+        "  bottom: 'innerproduct1' "
+        "  top: 'innerproduct2' "
+        "} "
+        "layers: { "
+        "  name: 'loss' "
+        "  type: EUCLIDEAN_LOSS "
+        "  bottom: 'data2' "
+        "  bottom: 'innerproduct2' "
+        "} ";
+    InitNetFromProtoString(proto);
+  }
+
+  int seed_;
   shared_ptr<Net<Dtype> > net_;
 };
 
@@ -309,4 +552,230 @@ TYPED_TEST(NetTest, TestBottomNeedBackwardTricky) {
   EXPECT_EQ(true, bottom_need_backward[3][1]);
 }
 
+TYPED_TEST(NetTest, TestUnsharedWeightsDataNet) {
+  this->InitUnsharedWeightsNet();
+  vector<Blob<TypeParam>*> bottom;
+  TypeParam loss;
+  this->net_->Forward(bottom, &loss);
+  EXPECT_GT(loss, 0);
+}
+
+TYPED_TEST(NetTest, TestSharedWeightsDataNet) {
+  this->InitSharedWeightsNet();
+  vector<Blob<TypeParam>*> bottom;
+  TypeParam loss;
+  this->net_->Forward(bottom, &loss);
+  EXPECT_FLOAT_EQ(loss, 0);
+}
+
+TYPED_TEST(NetTest, TestUnsharedWeightsDiffNet) {
+  this->InitUnsharedWeightsNet();
+  vector<Blob<TypeParam>*> bottom;
+  Net<TypeParam>* net = this->net_.get();
+  net->Forward(bottom);
+  net->Backward();
+  Layer<TypeParam>* ip1_layer = net->layer_by_name("innerproduct1").get();
+  Layer<TypeParam>* ip2_layer = net->layer_by_name("innerproduct2").get();
+  const int count = ip1_layer->blobs()[0]->count();
+  const TypeParam* grad1 = ip1_layer->blobs()[0]->cpu_diff();
+  const TypeParam* grad2 = ip2_layer->blobs()[0]->cpu_diff();
+  for (int i = 0; i < count; ++i) {
+    EXPECT_GT(fabs(grad1[i]), 0);
+    EXPECT_FLOAT_EQ(-1 * grad1[i], grad2[i]);
+  }
+}
+
+TYPED_TEST(NetTest, TestSharedWeightsDiffNet) {
+  this->InitSharedWeightsNet();
+  vector<Blob<TypeParam>*> bottom;
+  Net<TypeParam>* net = this->net_.get();
+  TypeParam loss;
+  net->Forward(bottom, &loss);
+  net->Backward();
+  EXPECT_FLOAT_EQ(loss, 0);
+  Layer<TypeParam>* ip1_layer = net->layer_by_name("innerproduct1").get();
+  Layer<TypeParam>* ip2_layer = net->layer_by_name("innerproduct2").get();
+  const int count = ip1_layer->blobs()[0]->count();
+  const TypeParam* grad1 = ip1_layer->blobs()[0]->cpu_diff();
+  const TypeParam* grad2 = ip2_layer->blobs()[0]->cpu_diff();
+  for (int i = 0; i < count; ++i) {
+    EXPECT_FLOAT_EQ(0, grad1[i]);
+    EXPECT_FLOAT_EQ(0, grad2[i]);
+  }
+}
+
+TYPED_TEST(NetTest, TestSharedWeightsUpdateCPU) {
+  Caffe::set_random_seed(this->seed_);
+  Caffe::set_mode(Caffe::CPU);
+  this->InitDiffDataSharedWeightsNet();
+  vector<Blob<TypeParam>*> bottom;
+  EXPECT_EQ(this->net_->layer_names()[1], "innerproduct1");
+  EXPECT_EQ(this->net_->layer_names()[2], "innerproduct2");
+  Blob<TypeParam>* ip1_weights = this->net_->layers()[1]->blobs()[0].get();
+  Blob<TypeParam>* ip2_weights = this->net_->layers()[2]->blobs()[0].get();
+  // Check that data blobs of shared weights share the same location in memory.
+  EXPECT_EQ(ip1_weights->cpu_data(), ip2_weights->cpu_data());
+  // Check that diff blobs of shared weights are at different locations in
+  // locations.  (The diffs should be accumulated at update time.)
+  EXPECT_NE(ip1_weights->cpu_diff(), ip2_weights->cpu_diff());
+  this->net_->Forward(bottom);
+  this->net_->Backward();
+  // Compute the expected update as the data minus the two diffs.
+  Blob<TypeParam> shared_params;
+  const bool reshape = true;
+  const bool copy_diff = false;
+  shared_params.CopyFrom(*ip1_weights, copy_diff, reshape);
+  shared_params.CopyFrom(*ip1_weights, !copy_diff, reshape);
+  const int count = ip1_weights->count();
+  // Make sure the diffs are non-trivial.
+  for (int i = 0; i < count; ++i) {
+    EXPECT_NE(0, ip1_weights->cpu_diff()[i]);
+    EXPECT_NE(0, ip2_weights->cpu_diff()[i]);
+    EXPECT_NE(ip1_weights->cpu_diff()[i], ip2_weights->cpu_diff()[i]);
+  }
+  caffe_axpy(count, TypeParam(1), ip2_weights->cpu_diff(),
+             shared_params.mutable_cpu_diff());
+  caffe_axpy(count, TypeParam(-1), shared_params.cpu_diff(),
+             shared_params.mutable_cpu_data());
+  const TypeParam* expected_updated_params = shared_params.cpu_data();
+  this->net_->Update();
+  const TypeParam* actual_updated_params = ip1_weights->cpu_data();
+  for (int i = 0; i < count; ++i) {
+    EXPECT_EQ(expected_updated_params[i], actual_updated_params[i]);
+  }
+  // Check that data blobs of shared weights STILL point to the same memory
+  // location (because ... who knows).
+  EXPECT_EQ(ip1_weights->cpu_data(), ip2_weights->cpu_data());
+
+  Caffe::set_random_seed(this->seed_);
+  this->InitDiffDataUnsharedWeightsNet();
+  EXPECT_EQ(this->net_->layer_names()[1], "innerproduct1");
+  EXPECT_EQ(this->net_->layer_names()[2], "innerproduct2");
+  ip1_weights = this->net_->layers()[1]->blobs()[0].get();
+  ip2_weights = this->net_->layers()[2]->blobs()[0].get();
+  // Check that data and diff blobs of unshared weights are at different
+  // locations in memory.
+  EXPECT_NE(ip1_weights->cpu_data(), ip2_weights->cpu_data());
+  EXPECT_NE(ip1_weights->cpu_diff(), ip2_weights->cpu_diff());
+  this->net_->Forward(bottom);
+  this->net_->Backward();
+  // Compute the expected update.
+  Blob<TypeParam> unshared_params1;
+  unshared_params1.CopyFrom(*ip1_weights, copy_diff, reshape);
+  unshared_params1.CopyFrom(*ip1_weights, !copy_diff, reshape);
+  Blob<TypeParam> unshared_params2;
+  unshared_params2.CopyFrom(*ip2_weights, copy_diff, reshape);
+  unshared_params2.CopyFrom(*ip2_weights, !copy_diff, reshape);
+  // Make sure the diffs are non-trivial and sum to the diff in the shared net.
+  for (int i = 0; i < count; ++i) {
+    EXPECT_NE(0, ip1_weights->cpu_diff()[i]);
+    EXPECT_NE(0, ip2_weights->cpu_diff()[i]);
+    EXPECT_NE(ip1_weights->cpu_diff()[i], ip2_weights->cpu_diff()[i]);
+    EXPECT_EQ(ip1_weights->cpu_diff()[i] + ip2_weights->cpu_diff()[i],
+              shared_params.cpu_diff()[i]);
+  }
+  caffe_axpy(count, TypeParam(-1), ip1_weights->cpu_diff(),
+             unshared_params1.mutable_cpu_data());
+  caffe_axpy(count, TypeParam(-1), ip2_weights->cpu_diff(),
+             unshared_params2.mutable_cpu_data());
+  const TypeParam* expected_updated_params1 = unshared_params1.cpu_data();
+  const TypeParam* expected_updated_params2 = unshared_params2.cpu_data();
+  this->net_->Update();
+  const TypeParam* actual_updated_params1 = ip1_weights->cpu_data();
+  const TypeParam* actual_updated_params2 = ip2_weights->cpu_data();
+  for (int i = 0; i < count; ++i) {
+    EXPECT_EQ(expected_updated_params1[i], actual_updated_params1[i]);
+    EXPECT_EQ(expected_updated_params2[i], actual_updated_params2[i]);
+    EXPECT_NE(actual_updated_params1[i], actual_updated_params2[i]);
+    EXPECT_NE(expected_updated_params, expected_updated_params1);
+  }
+}
+
+TYPED_TEST(NetTest, TestSharedWeightsUpdateGPU) {
+  Caffe::set_random_seed(this->seed_);
+  Caffe::set_mode(Caffe::GPU);
+  this->InitDiffDataSharedWeightsNet();
+  vector<Blob<TypeParam>*> bottom;
+  EXPECT_EQ(this->net_->layer_names()[1], "innerproduct1");
+  EXPECT_EQ(this->net_->layer_names()[2], "innerproduct2");
+  Blob<TypeParam>* ip1_weights = this->net_->layers()[1]->blobs()[0].get();
+  Blob<TypeParam>* ip2_weights = this->net_->layers()[2]->blobs()[0].get();
+  // Check that data blobs of shared weights share the same location in memory.
+  EXPECT_EQ(ip1_weights->cpu_data(), ip2_weights->cpu_data());
+  // Check that diff blobs of shared weights are at different locations in
+  // locations.  (The diffs should be accumulated at update time.)
+  EXPECT_NE(ip1_weights->cpu_diff(), ip2_weights->cpu_diff());
+  this->net_->Forward(bottom);
+  this->net_->Backward();
+  // Compute the expected update as the data minus the two diffs.
+  Blob<TypeParam> shared_params;
+  const bool reshape = true;
+  const bool copy_diff = false;
+  shared_params.CopyFrom(*ip1_weights, copy_diff, reshape);
+  shared_params.CopyFrom(*ip1_weights, !copy_diff, reshape);
+  const int count = ip1_weights->count();
+  // Make sure the diffs are non-trivial.
+  for (int i = 0; i < count; ++i) {
+    EXPECT_NE(0, ip1_weights->cpu_diff()[i]);
+    EXPECT_NE(0, ip2_weights->cpu_diff()[i]);
+    EXPECT_NE(ip1_weights->cpu_diff()[i], ip2_weights->cpu_diff()[i]);
+  }
+  caffe_axpy(count, TypeParam(1), ip2_weights->cpu_diff(),
+             shared_params.mutable_cpu_diff());
+  caffe_axpy(count, TypeParam(-1), shared_params.cpu_diff(),
+             shared_params.mutable_cpu_data());
+  const TypeParam* expected_updated_params = shared_params.cpu_data();
+  this->net_->Update();
+  const TypeParam* actual_updated_params = ip1_weights->cpu_data();
+  for (int i = 0; i < count; ++i) {
+    EXPECT_EQ(expected_updated_params[i], actual_updated_params[i]);
+  }
+  // Check that data blobs of shared weights STILL point to the same memory
+  // location (because ... who knows).
+  EXPECT_EQ(ip1_weights->cpu_data(), ip2_weights->cpu_data());
+
+  Caffe::set_random_seed(this->seed_);
+  this->InitDiffDataUnsharedWeightsNet();
+  EXPECT_EQ(this->net_->layer_names()[1], "innerproduct1");
+  EXPECT_EQ(this->net_->layer_names()[2], "innerproduct2");
+  ip1_weights = this->net_->layers()[1]->blobs()[0].get();
+  ip2_weights = this->net_->layers()[2]->blobs()[0].get();
+  // Check that data and diff blobs of unshared weights are at different
+  // locations in memory.
+  EXPECT_NE(ip1_weights->cpu_data(), ip2_weights->cpu_data());
+  EXPECT_NE(ip1_weights->cpu_diff(), ip2_weights->cpu_diff());
+  this->net_->Forward(bottom);
+  this->net_->Backward();
+  // Compute the expected update.
+  Blob<TypeParam> unshared_params1;
+  unshared_params1.CopyFrom(*ip1_weights, copy_diff, reshape);
+  unshared_params1.CopyFrom(*ip1_weights, !copy_diff, reshape);
+  Blob<TypeParam> unshared_params2;
+  unshared_params2.CopyFrom(*ip2_weights, copy_diff, reshape);
+  unshared_params2.CopyFrom(*ip2_weights, !copy_diff, reshape);
+  // Make sure the diffs are non-trivial and sum to the diff in the shared net.
+  for (int i = 0; i < count; ++i) {
+    EXPECT_NE(0, ip1_weights->cpu_diff()[i]);
+    EXPECT_NE(0, ip2_weights->cpu_diff()[i]);
+    EXPECT_NE(ip1_weights->cpu_diff()[i], ip2_weights->cpu_diff()[i]);
+    EXPECT_EQ(ip1_weights->cpu_diff()[i] + ip2_weights->cpu_diff()[i],
+              shared_params.cpu_diff()[i]);
+  }
+  caffe_axpy(count, TypeParam(-1), ip1_weights->cpu_diff(),
+             unshared_params1.mutable_cpu_data());
+  caffe_axpy(count, TypeParam(-1), ip2_weights->cpu_diff(),
+             unshared_params2.mutable_cpu_data());
+  const TypeParam* expected_updated_params1 = unshared_params1.cpu_data();
+  const TypeParam* expected_updated_params2 = unshared_params2.cpu_data();
+  this->net_->Update();
+  const TypeParam* actual_updated_params1 = ip1_weights->cpu_data();
+  const TypeParam* actual_updated_params2 = ip2_weights->cpu_data();
+  for (int i = 0; i < count; ++i) {
+    EXPECT_EQ(expected_updated_params1[i], actual_updated_params1[i]);
+    EXPECT_EQ(expected_updated_params2[i], actual_updated_params2[i]);
+    EXPECT_NE(actual_updated_params1[i], actual_updated_params2[i]);
+    EXPECT_NE(expected_updated_params, expected_updated_params1);
+  }
+}
+
 }  // namespace caffe