drop Net inputs + Forward with bottoms
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Sat, 17 Oct 2015 04:11:32 +0000 (21:11 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Thu, 25 Feb 2016 21:32:44 +0000 (13:32 -0800)
Drop special cases for `input` fields, the `Net` input members,
and the `Net` interface for Forward with bottoms along with
Forward() / ForwardPrefilled() distinction.

include/caffe/net.hpp
src/caffe/net.cpp
src/caffe/solver.cpp
src/caffe/test/test_gradient_based_solver.cpp
src/caffe/test/test_net.cpp
src/caffe/test/test_split_layer.cpp
src/caffe/util/insert_splits.cpp
tools/caffe.cpp
tools/extract_features.cpp

index 543133e..43e0a84 100644 (file)
@@ -32,11 +32,10 @@ class Net {
   void Init(const NetParameter& param);
 
   /**
-   * @brief Run Forward with the input Blob%s already fed separately.
+   * @brief Run Forward and return the result.
    *
-   * You can get the input blobs using input_blobs().
    */
-  const vector<Blob<Dtype>*>& ForwardPrefilled(Dtype* loss = NULL);
+  const vector<Blob<Dtype>*>& Forward(Dtype* loss = NULL);
 
   /**
    * The From and To variants of Forward and Backward operate on the
@@ -49,14 +48,6 @@ class Net {
   Dtype ForwardFromTo(int start, int end);
   Dtype ForwardFrom(int start);
   Dtype ForwardTo(int end);
-  /// @brief Run forward using a set of bottom blobs, and return the result.
-  const vector<Blob<Dtype>*>& Forward(const vector<Blob<Dtype>* > & bottom,
-      Dtype* loss = NULL);
-  /**
-   * @brief Run forward using a serialized BlobProtoVector and return the
-   *        result as a serialized BlobProtoVector
-   */
-  string Forward(const string& input_blob_protos, Dtype* loss = NULL);
 
   /**
    * @brief Zeroes out the diffs of all net parameters.
@@ -82,9 +73,9 @@ class Net {
    */
   void Reshape();
 
-  Dtype ForwardBackward(const vector<Blob<Dtype>* > & bottom) {
+  Dtype ForwardBackward() {
     Dtype loss;
-    Forward(bottom, &loss);
+    Forward(&loss);
     Backward();
     return loss;
   }
@@ -194,18 +185,11 @@ class Net {
   inline const vector<string>& param_display_names() const {
     return param_display_names_;
   }
-  /// @brief Input and output blob numbers
-  inline int num_inputs() const { return net_input_blobs_.size(); }
+  /// @brief output blob number
   inline int num_outputs() const { return net_output_blobs_.size(); }
-  inline const vector<Blob<Dtype>*>& input_blobs() const {
-    return net_input_blobs_;
-  }
   inline const vector<Blob<Dtype>*>& output_blobs() const {
     return net_output_blobs_;
   }
-  inline const vector<int>& input_blob_indices() const {
-    return net_input_blob_indices_;
-  }
   inline const vector<int>& output_blob_indices() const {
     return net_output_blob_indices_;
   }
@@ -229,7 +213,7 @@ class Net {
 
  protected:
   // Helpers for Init.
-  /// @brief Append a new input or top blob to the net.
+  /// @brief Append a new top blob to the net.
   void AppendTop(const NetParameter& param, const int layer_id,
                  const int top_id, set<string>* available_blobs,
                  map<string, int>* blob_name_to_idx);
@@ -241,8 +225,6 @@ class Net {
   void AppendParam(const NetParameter& param, const int layer_id,
                    const int param_id);
 
-  /// @brief Helper for displaying debug info in Forward about input Blobs.
-  void InputDebugInfo(const int layer_id);
   /// @brief Helper for displaying debug info in Forward.
   void ForwardDebugInfo(const int layer_id);
   /// @brief Helper for displaying debug info in Backward.
@@ -281,10 +263,8 @@ class Net {
   vector<string> param_display_names_;
   vector<pair<int, int> > param_layer_indices_;
   map<string, int> param_names_index_;
-  /// blob indices for the input and the output of the net
-  vector<int> net_input_blob_indices_;
+  /// blob indices for the output of the net
   vector<int> net_output_blob_indices_;
-  vector<Blob<Dtype>*> net_input_blobs_;
   vector<Blob<Dtype>*> net_output_blobs_;
   /// The parameters in the network.
   vector<shared_ptr<Blob<Dtype> > > params_;
index 05bee79..b7320e9 100644 (file)
@@ -56,22 +56,7 @@ void Net<Dtype>::Init(const NetParameter& in_param) {
   name_ = param.name();
   map<string, int> blob_name_to_idx;
   set<string> available_blobs;
-  CHECK(param.input_dim_size() == 0 || param.input_shape_size() == 0)
-      << "Must specify either input_shape OR deprecated input_dim, not both.";
-  if (param.input_dim_size() > 0) {
-    // Deprecated 4D dimensions.
-    CHECK_EQ(param.input_size() * 4, param.input_dim_size())
-        << "Incorrect input blob dimension specifications.";
-  } else {
-    CHECK_EQ(param.input_size(), param.input_shape_size())
-        << "Exactly one input_shape must be specified per input.";
-  }
   memory_used_ = 0;
-  // set the input blobs
-  for (int input_id = 0; input_id < param.input_size(); ++input_id) {
-    const int layer_id = -1;  // inputs have fake layer ID -1
-    AppendTop(param, layer_id, input_id, &available_blobs, &blob_name_to_idx);
-  }
   // For each layer, set up its input and output
   bottom_vecs_.resize(param.layer_size());
   top_vecs_.resize(param.layer_size());
@@ -379,19 +364,17 @@ bool Net<Dtype>::StateMeetsRule(const NetState& state,
   return true;
 }
 
-// Helper for Net::Init: add a new input or top blob to the net.  (Inputs have
-// layer_id == -1, tops have layer_id >= 0.)
+// Helper for Net::Init: add a new top blob to the net.
 template <typename Dtype>
 void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,
                            const int top_id, set<string>* available_blobs,
                            map<string, int>* blob_name_to_idx) {
-  shared_ptr<LayerParameter> layer_param((layer_id >= 0) ?
-    (new LayerParameter(param.layer(layer_id))) : NULL);
-  const string& blob_name = layer_param ?
-      (layer_param->top_size() > top_id ?
-          layer_param->top(top_id) : "(automatic)") : param.input(top_id);
+  shared_ptr<LayerParameter> layer_param(
+      new LayerParameter(param.layer(layer_id)));
+  const string& blob_name = (layer_param->top_size() > top_id) ?
+      layer_param->top(top_id) : "(automatic)";
   // Check if we are doing in-place computation
-  if (blob_name_to_idx && layer_param && layer_param->bottom_size() > top_id &&
+  if (blob_name_to_idx && layer_param->bottom_size() > top_id &&
       blob_name == layer_param->bottom(top_id)) {
     // In-place computation
     LOG_IF(INFO, Caffe::root_solver())
@@ -407,11 +390,7 @@ void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,
   } else {
     // Normal output.
     if (Caffe::root_solver()) {
-      if (layer_param) {
-        LOG(INFO) << layer_param->name() << " -> " << blob_name;
-      } else {
-        LOG(INFO) << "Input " << top_id << " -> " << blob_name;
-      }
+      LOG(INFO) << layer_param->name() << " -> " << blob_name;
     }
     shared_ptr<Blob<Dtype> > blob_pointer(new Blob<Dtype>());
     const int blob_id = blobs_.size();
@@ -419,22 +398,8 @@ void Net<Dtype>::AppendTop(const NetParameter& param, const int layer_id,
     blob_names_.push_back(blob_name);
     blob_need_backward_.push_back(false);
     if (blob_name_to_idx) { (*blob_name_to_idx)[blob_name] = blob_id; }
-    if (layer_id == -1) {
-      // Set the (explicitly specified) dimensions of the input blob.
-      if (param.input_dim_size() > 0) {
-        blob_pointer->Reshape(param.input_dim(top_id * 4),
-                              param.input_dim(top_id * 4 + 1),
-                              param.input_dim(top_id * 4 + 2),
-                              param.input_dim(top_id * 4 + 3));
-      } else {
-        blob_pointer->Reshape(param.input_shape(top_id));
-      }
-      net_input_blob_indices_.push_back(blob_id);
-      net_input_blobs_.push_back(blob_pointer.get());
-    } else {
-      top_id_vecs_[layer_id].push_back(blob_id);
-      top_vecs_[layer_id].push_back(blob_pointer.get());
-    }
+    top_id_vecs_[layer_id].push_back(blob_id);
+    top_vecs_[layer_id].push_back(blob_pointer.get());
   }
   if (available_blobs) { available_blobs->insert(blob_name); }
 }
@@ -566,11 +531,6 @@ Dtype Net<Dtype>::ForwardFromTo(int start, int end) {
   CHECK_GE(start, 0);
   CHECK_LT(end, layers_.size());
   Dtype loss = 0;
-  if (debug_info_) {
-    for (int i = 0; i < net_input_blobs_.size(); ++i) {
-      InputDebugInfo(i);
-    }
-  }
   for (int i = start; i <= end; ++i) {
     // LOG(ERROR) << "Forwarding " << layer_names_[i];
     Dtype layer_loss = layers_[i]->Forward(bottom_vecs_[i], top_vecs_[i]);
@@ -591,7 +551,7 @@ Dtype Net<Dtype>::ForwardTo(int end) {
 }
 
 template <typename Dtype>
-const vector<Blob<Dtype>*>& Net<Dtype>::ForwardPrefilled(Dtype* loss) {
+const vector<Blob<Dtype>*>& Net<Dtype>::Forward(Dtype* loss) {
   if (loss != NULL) {
     *loss = ForwardFromTo(0, layers_.size() - 1);
   } else {
@@ -601,37 +561,6 @@ const vector<Blob<Dtype>*>& Net<Dtype>::ForwardPrefilled(Dtype* loss) {
 }
 
 template <typename Dtype>
-const vector<Blob<Dtype>*>& Net<Dtype>::Forward(
-    const vector<Blob<Dtype>*> & bottom, Dtype* loss) {
-  // Copy bottom to internal bottom
-  for (int i = 0; i < bottom.size(); ++i) {
-    net_input_blobs_[i]->CopyFrom(*bottom[i]);
-  }
-  return ForwardPrefilled(loss);
-}
-
-template <typename Dtype>
-string Net<Dtype>::Forward(const string& input_blob_protos, Dtype* loss) {
-  BlobProtoVector blob_proto_vec;
-  if (net_input_blobs_.size()) {
-    blob_proto_vec.ParseFromString(input_blob_protos);
-    CHECK_EQ(blob_proto_vec.blobs_size(), net_input_blobs_.size())
-        << "Incorrect input size.";
-    for (int i = 0; i < blob_proto_vec.blobs_size(); ++i) {
-      net_input_blobs_[i]->FromProto(blob_proto_vec.blobs(i));
-    }
-  }
-  ForwardPrefilled(loss);
-  blob_proto_vec.Clear();
-  for (int i = 0; i < net_output_blobs_.size(); ++i) {
-    net_output_blobs_[i]->ToProto(blob_proto_vec.add_blobs());
-  }
-  string output;
-  blob_proto_vec.SerializeToString(&output);
-  return output;
-}
-
-template <typename Dtype>
 void Net<Dtype>::BackwardFromTo(int start, int end) {
   CHECK_GE(end, 0);
   CHECK_LT(start, layers_.size());
@@ -645,16 +574,6 @@ void Net<Dtype>::BackwardFromTo(int start, int end) {
 }
 
 template <typename Dtype>
-void Net<Dtype>::InputDebugInfo(const int input_id) {
-  const Blob<Dtype>& blob = *net_input_blobs_[input_id];
-  const string& blob_name = blob_names_[net_input_blob_indices_[input_id]];
-  const Dtype data_abs_val_mean = blob.asum_data() / blob.count();
-  LOG_IF(INFO, Caffe::root_solver())
-      << "    [Forward] "
-      << "Input " << blob_name << " data: " << data_abs_val_mean;
-}
-
-template <typename Dtype>
 void Net<Dtype>::ForwardDebugInfo(const int layer_id) {
   for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) {
     const Blob<Dtype>& blob = *top_vecs_[layer_id][top_id];
@@ -912,9 +831,6 @@ void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) const {
   param->Clear();
   param->set_name(name_);
   // Add bottom and top
-  for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
-    param->add_input(blob_names_[net_input_blob_indices_[i]]);
-  }
   DLOG(INFO) << "Serializing " << layers_.size() << " layers";
   for (int i = 0; i < layers_.size(); ++i) {
     LayerParameter* layer_param = param->add_layer();
index a5ccf9c..ece3913 100644 (file)
@@ -192,7 +192,6 @@ void Solver<Dtype>::InitTestNets() {
 
 template <typename Dtype>
 void Solver<Dtype>::Step(int iters) {
-  vector<Blob<Dtype>*> bottom_vec;
   const int start_iter = iter_;
   const int stop_iter = iter_ + iters;
   int average_loss = this->param_.average_loss();
@@ -220,7 +219,7 @@ void Solver<Dtype>::Step(int iters) {
     // accumulate the loss and gradient
     Dtype loss = 0;
     for (int i = 0; i < param_.iter_size(); ++i) {
-      loss += net_->ForwardBackward(bottom_vec);
+      loss += net_->ForwardBackward();
     }
     loss /= param_.iter_size();
     // average the loss across iterations for smoothed reporting
@@ -311,7 +310,7 @@ void Solver<Dtype>::Solve(const char* resume_file) {
   if (param_.display() && iter_ % param_.display() == 0) {
     int average_loss = this->param_.average_loss();
     Dtype loss;
-    net_->ForwardPrefilled(&loss);
+    net_->Forward(&loss);
 
     UpdateSmoothedLoss(loss, start_iter, average_loss);
 
@@ -341,7 +340,6 @@ void Solver<Dtype>::Test(const int test_net_id) {
       ShareTrainedLayersWith(net_.get());
   vector<Dtype> test_score;
   vector<int> test_score_output_id;
-  vector<Blob<Dtype>*> bottom_vec;
   const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
   Dtype loss = 0;
   for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
@@ -362,7 +360,7 @@ void Solver<Dtype>::Test(const int test_net_id) {
 
     Dtype iter_loss;
     const vector<Blob<Dtype>*>& result =
-        test_net->Forward(bottom_vec, &iter_loss);
+        test_net->Forward(&iter_loss);
     if (param_.test_compute_loss()) {
       loss += iter_loss;
     }
index 84c6747..09ec3a7 100644 (file)
@@ -185,9 +185,8 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
     this->InitSolverFromProtoString(proto.str());
     if (from_snapshot != NULL) {
       this->solver_->Restore(from_snapshot);
-      vector<Blob<Dtype>*> empty_bottom_vec;
       for (int i = 0; i < this->solver_->iter(); ++i) {
-        this->solver_->net()->Forward(empty_bottom_vec);
+        this->solver_->net()->Forward();
       }
     }
     if (devices == 1) {
@@ -231,8 +230,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
     // Run a forward pass, and manually compute the update values from the
     // result.
     Net<Dtype>& net = *this->solver_->net();
-    vector<Blob<Dtype>*> empty_bottom_vec;
-    net.Forward(empty_bottom_vec);
+    net.Forward();
     ASSERT_TRUE(net.has_blob("data"));
     const Blob<Dtype>& data = *net.blob_by_name("data");
     ASSERT_TRUE(net.has_blob("targets"));
index ab4afba..1e0788e 100644 (file)
@@ -555,11 +555,14 @@ class NetTest : public MultiDeviceTest<TypeParam> {
   virtual void InitReshapableNet() {
     const string& proto =
         "name: 'ReshapableNetwork' "
-        "input: 'data' "
-        "input_dim: 1 "
-        "input_dim: 3 "
-        "input_dim: 100 "
-        "input_dim: 100 "
+        "layer { "
+        "  name: 'data' "
+        "  type: 'Input' "
+        "  top: 'data' "
+        "  input_param { "
+        "  shape: { dim: 1 dim: 3 dim: 100 dim: 100 } "
+        "  } "
+        "} "
         "layer { "
         "  name: 'conv1' "
         "  type: 'Convolution' "
@@ -821,7 +824,7 @@ TYPED_TEST(NetTest, TestLossWeight) {
   Caffe::set_random_seed(this->seed_);
   const bool kForceBackward = true;
   this->InitUnsharedWeightsNet(NULL, NULL, kForceBackward);
-  const Dtype loss = this->net_->ForwardBackward(bottom);
+  const Dtype loss = this->net_->ForwardBackward();
   const bool kCopyDiff = true;
   vector<shared_ptr<Blob<Dtype> > > blob_grads;
   this->CopyNetBlobs(kCopyDiff, &blob_grads);
@@ -836,7 +839,7 @@ TYPED_TEST(NetTest, TestLossWeight) {
   for (int i = 0; i < kNumLossWeights; ++i) {
     Caffe::set_random_seed(this->seed_);
     this->InitUnsharedWeightsNet(&kLossWeights[i], NULL, kForceBackward);
-    const Dtype weighted_loss = this->net_->ForwardBackward(bottom);
+    const Dtype weighted_loss = this->net_->ForwardBackward();
     const Dtype error_margin = kErrorMargin * fabs(kLossWeights[i]);
     EXPECT_NEAR(loss * kLossWeights[i], weighted_loss, error_margin)
         << "loss weight = " << kLossWeights[i];
@@ -865,14 +868,13 @@ TYPED_TEST(NetTest, TestLossWeight) {
 
 TYPED_TEST(NetTest, TestLossWeightMidNet) {
   typedef typename TypeParam::Dtype Dtype;
-  vector<Blob<Dtype>*> bottom;
   Caffe::set_random_seed(this->seed_);
   const bool kForceBackward = true;
   Dtype loss_weight = 0;
   Dtype midnet_loss_weight = 1;
   this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight,
                                kForceBackward);
-  const Dtype loss = this->net_->ForwardBackward(bottom);
+  const Dtype loss = this->net_->ForwardBackward();
   const bool kCopyDiff = true;
   const bool kReshape = true;
   Blob<Dtype> data_grad;
@@ -887,7 +889,7 @@ TYPED_TEST(NetTest, TestLossWeightMidNet) {
     Caffe::set_random_seed(this->seed_);
     this->InitUnsharedWeightsNet(&loss_weight, &kLossWeights[i],
                                  kForceBackward);
-    const Dtype weighted_loss = this->net_->ForwardBackward(bottom);
+    const Dtype weighted_loss = this->net_->ForwardBackward();
     const Dtype error_margin = kErrorMargin * fabs(kLossWeights[i]);
     EXPECT_NEAR(loss * kLossWeights[i], weighted_loss, error_margin)
         << "loss weight = " << kLossWeights[i];
@@ -903,7 +905,6 @@ TYPED_TEST(NetTest, TestLossWeightMidNet) {
 
 TYPED_TEST(NetTest, TestComboLossWeight) {
   typedef typename TypeParam::Dtype Dtype;
-  vector<Blob<Dtype>*> bottom;
   Dtype loss_weight;
   Dtype midnet_loss_weight;
   const bool kForceBackward = true;
@@ -916,7 +917,7 @@ TYPED_TEST(NetTest, TestComboLossWeight) {
   Caffe::set_random_seed(this->seed_);
   this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight,
                                kForceBackward);
-  const Dtype loss = this->net_->ForwardBackward(bottom);
+  const Dtype loss = this->net_->ForwardBackward();
   const bool kCopyDiff = true;
   vector<shared_ptr<Blob<Dtype> > > blob_grads;
   this->CopyNetBlobs(kCopyDiff, &blob_grads);
@@ -928,7 +929,7 @@ TYPED_TEST(NetTest, TestComboLossWeight) {
   Caffe::set_random_seed(this->seed_);
   this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight,
                                kForceBackward);
-  const Dtype loss_main_2 = this->net_->ForwardBackward(bottom);
+  const Dtype loss_main_2 = this->net_->ForwardBackward();
   vector<shared_ptr<Blob<Dtype> > > blob_grads_loss_2;
   this->CopyNetBlobs(kCopyDiff, &blob_grads_loss_2);
   vector<shared_ptr<Blob<Dtype> > > param_grads_loss_2;
@@ -939,7 +940,7 @@ TYPED_TEST(NetTest, TestComboLossWeight) {
   Caffe::set_random_seed(this->seed_);
   this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight,
                                kForceBackward);
-  const Dtype loss_main_3 = this->net_->ForwardBackward(bottom);
+  const Dtype loss_main_3 = this->net_->ForwardBackward();
   const vector<shared_ptr<Blob<Dtype> > >& blob_grads_loss_3 =
       this->net_->blobs();
   ASSERT_EQ(blob_grads.size(), blob_grads_loss_3.size());
@@ -974,7 +975,7 @@ TYPED_TEST(NetTest, TestComboLossWeight) {
   Caffe::set_random_seed(this->seed_);
   this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight,
                                kForceBackward);
-  const Dtype loss_midnet_2 = this->net_->ForwardBackward(bottom);
+  const Dtype loss_midnet_2 = this->net_->ForwardBackward();
   this->CopyNetBlobs(kCopyDiff, &blob_grads_loss_2);
   this->CopyNetParams(kCopyDiff, &param_grads_loss_2);
 
@@ -983,7 +984,7 @@ TYPED_TEST(NetTest, TestComboLossWeight) {
   Caffe::set_random_seed(this->seed_);
   this->InitUnsharedWeightsNet(&loss_weight, &midnet_loss_weight,
                                kForceBackward);
-  const Dtype loss_midnet_3 = this->net_->ForwardBackward(bottom);
+  const Dtype loss_midnet_3 = this->net_->ForwardBackward();
   const vector<shared_ptr<Blob<Dtype> > >& blob_grads_midnet_loss_3 =
       this->net_->blobs();
   ASSERT_EQ(blob_grads.size(), blob_grads_midnet_loss_3.size());
@@ -1032,40 +1033,35 @@ TYPED_TEST(NetTest, TestComboLossWeight) {
 }
 
 TYPED_TEST(NetTest, TestBackwardWithAccuracyLayer) {
-  typedef typename TypeParam::Dtype Dtype;
   const bool kForceBackward = false;
   const bool kAccuracyLayer = true;
   this->InitTinyNet(kForceBackward, kAccuracyLayer);
   EXPECT_TRUE(this->net_->has_blob("accuracy"));
-  vector<Blob<Dtype>*> bottom;
   // Test that we can do Backward even though we have an 'Accuracy' layer.
-  this->net_->ForwardBackward(bottom);
+  this->net_->ForwardBackward();
 }
 
 TYPED_TEST(NetTest, TestUnsharedWeightsDataNet) {
   typedef typename TypeParam::Dtype Dtype;
   this->InitUnsharedWeightsNet();
-  vector<Blob<Dtype>*> bottom;
   Dtype loss;
-  this->net_->Forward(bottom, &loss);
+  this->net_->Forward(&loss);
   EXPECT_GT(loss, 0);
 }
 
 TYPED_TEST(NetTest, TestSharedWeightsDataNet) {
   typedef typename TypeParam::Dtype Dtype;
   this->InitSharedWeightsNet();
-  vector<Blob<Dtype>*> bottom;
   Dtype loss;
-  this->net_->Forward(bottom, &loss);
+  this->net_->Forward(&loss);
   EXPECT_FLOAT_EQ(loss, 0);
 }
 
 TYPED_TEST(NetTest, TestUnsharedWeightsDiffNet) {
   typedef typename TypeParam::Dtype Dtype;
   this->InitUnsharedWeightsNet();
-  vector<Blob<Dtype>*> bottom;
   Net<Dtype>* net = this->net_.get();
-  net->Forward(bottom);
+  net->Forward();
   net->Backward();
   Layer<Dtype>* ip1_layer = net->layer_by_name("innerproduct1").get();
   Layer<Dtype>* ip2_layer = net->layer_by_name("innerproduct2").get();
@@ -1081,10 +1077,9 @@ TYPED_TEST(NetTest, TestUnsharedWeightsDiffNet) {
 TYPED_TEST(NetTest, TestSharedWeightsDiffNet) {
   typedef typename TypeParam::Dtype Dtype;
   this->InitSharedWeightsNet();
-  vector<Blob<Dtype>*> bottom;
   Net<Dtype>* net = this->net_.get();
   Dtype loss;
-  net->Forward(bottom, &loss);
+  net->Forward(&loss);
   net->Backward();
   EXPECT_FLOAT_EQ(loss, 0);
   Layer<Dtype>* ip1_layer = net->layer_by_name("innerproduct1").get();
@@ -1102,7 +1097,6 @@ TYPED_TEST(NetTest, TestSharedWeightsUpdate) {
   typedef typename TypeParam::Dtype Dtype;
   Caffe::set_random_seed(this->seed_);
   this->InitDiffDataSharedWeightsNet();
-  vector<Blob<Dtype>*> bottom;
   EXPECT_EQ(this->net_->layer_names()[1], "innerproduct1");
   EXPECT_EQ(this->net_->layer_names()[2], "innerproduct2");
   Blob<Dtype>* ip1_weights = this->net_->layers()[1]->blobs()[0].get();
@@ -1111,7 +1105,7 @@ TYPED_TEST(NetTest, TestSharedWeightsUpdate) {
   // locations.
   EXPECT_EQ(ip1_weights->cpu_data(), ip2_weights->cpu_data());
   EXPECT_EQ(ip1_weights->cpu_diff(), ip2_weights->cpu_diff());
-  this->net_->Forward(bottom);
+  this->net_->Forward();
   this->net_->Backward();
   // Compute the expected update as the data minus the two diffs.
   Blob<Dtype> shared_params;
@@ -1146,7 +1140,7 @@ TYPED_TEST(NetTest, TestSharedWeightsUpdate) {
   // locations in memory.
   EXPECT_NE(ip1_weights->cpu_data(), ip2_weights->cpu_data());
   EXPECT_NE(ip1_weights->cpu_diff(), ip2_weights->cpu_diff());
-  this->net_->Forward(bottom);
+  this->net_->Forward();
   this->net_->Backward();
   // Compute the expected update.
   Blob<Dtype> unshared_params1;
@@ -1186,7 +1180,6 @@ TYPED_TEST(NetTest, TestSharedWeightsResume) {
   // Create a net with weight sharing; Update it once.
   Caffe::set_random_seed(this->seed_);
   this->InitDiffDataSharedWeightsNet();
-  vector<Blob<Dtype>*> bottom;
   EXPECT_EQ(this->net_->layer_names()[1], "innerproduct1");
   EXPECT_EQ(this->net_->layer_names()[2], "innerproduct2");
   Blob<Dtype>* ip1_weights = this->net_->layers()[1]->blobs()[0].get();
@@ -1195,7 +1188,7 @@ TYPED_TEST(NetTest, TestSharedWeightsResume) {
   // locations.
   EXPECT_EQ(ip1_weights->cpu_data(), ip2_weights->cpu_data());
   EXPECT_EQ(ip1_weights->cpu_diff(), ip2_weights->cpu_diff());
-  this->net_->ForwardBackward(bottom);
+  this->net_->ForwardBackward();
   this->net_->Update();
   Blob<Dtype> shared_params;
   const bool kReshape = true;
@@ -1228,7 +1221,6 @@ TYPED_TEST(NetTest, TestSharedWeightsResume) {
 
 TYPED_TEST(NetTest, TestParamPropagateDown) {
   typedef typename TypeParam::Dtype Dtype;
-  vector<Blob<Dtype>*> bottom;
   const bool kBiasTerm = true, kForceBackward = false;
   const Dtype* kLossWeight1 = NULL;
   const Dtype* kLossWeight2 = NULL;
@@ -1238,7 +1230,7 @@ TYPED_TEST(NetTest, TestParamPropagateDown) {
   Dtype blobs_lr_w1 = 1, blobs_lr_w2 = 1, blobs_lr_b1 = 2, blobs_lr_b2 = 2;
   this->InitUnsharedWeightsNet(kLossWeight1, kLossWeight2, kForceBackward,
       kBiasTerm, blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2);
-  this->net_->Forward(bottom);
+  this->net_->Forward();
   this->net_->Backward();
   const vector<shared_ptr<Blob<Dtype> > >& params = this->net_->params();
   const int num_params = params.size();
@@ -1258,7 +1250,7 @@ TYPED_TEST(NetTest, TestParamPropagateDown) {
   blobs_lr_w1 *= 2, blobs_lr_w2 *= 2, blobs_lr_b1 *= 2, blobs_lr_b2 *= 2;
   this->InitUnsharedWeightsNet(kLossWeight1, kLossWeight2, kForceBackward,
       kBiasTerm, blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2);
-  this->net_->Forward(bottom);
+  this->net_->Forward();
   this->net_->Backward();
   const vector<shared_ptr<Blob<Dtype> > >& params2 = this->net_->params();
   ASSERT_EQ(num_params, params2.size());
@@ -1274,7 +1266,7 @@ TYPED_TEST(NetTest, TestParamPropagateDown) {
   blobs_lr_w1 = 1, blobs_lr_w2 = 0, blobs_lr_b1 = 0, blobs_lr_b2 = 1;
   this->InitUnsharedWeightsNet(kLossWeight1, kLossWeight2, kForceBackward,
       kBiasTerm, blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2);
-  this->net_->Forward(bottom);
+  this->net_->Forward();
   this->net_->Backward();
   const vector<shared_ptr<Blob<Dtype> > >& params3 = this->net_->params();
   ASSERT_EQ(num_params, params3.size());
@@ -1293,7 +1285,7 @@ TYPED_TEST(NetTest, TestParamPropagateDown) {
   blobs_lr_w1 = 0, blobs_lr_w2 = 1, blobs_lr_b1 = 1, blobs_lr_b2 = 0;
   this->InitUnsharedWeightsNet(kLossWeight1, kLossWeight2, kForceBackward,
       kBiasTerm, blobs_lr_w1, blobs_lr_w2, blobs_lr_b1, blobs_lr_b2);
-  this->net_->Forward(bottom);
+  this->net_->Forward();
   this->net_->Backward();
   const vector<shared_ptr<Blob<Dtype> > >& params4 = this->net_->params();
   ASSERT_EQ(num_params, params4.size());
@@ -1315,7 +1307,7 @@ TYPED_TEST(NetTest, TestFromTo) {
   // Run Forward and Backward, recording the data diff and loss.
   Blob<Dtype> data;
   data.ReshapeLike(*this->net_->blob_by_name("data"));
-  this->net_->ForwardPrefilled();
+  this->net_->Forward();
   this->net_->Backward();
   data.CopyFrom(*this->net_->blob_by_name("data"), true, true);
   const Dtype *loss_ptr = this->net_->output_blobs()[0]->cpu_data();
@@ -2277,12 +2269,12 @@ TYPED_TEST(NetTest, TestReshape) {
   filler.Fill(&blob2);
 
   this->InitReshapableNet();
-  Blob<Dtype>* input_blob = this->net_->input_blobs()[0];
+  shared_ptr<Blob<Dtype> > input_blob = this->net_->blob_by_name("data");
   Blob<Dtype>* output_blob = this->net_->output_blobs()[0];
   input_blob->Reshape(blob1.num(), blob1.channels(), blob1.height(),
       blob1.width());
   caffe_copy(blob1.count(), blob1.cpu_data(), input_blob->mutable_cpu_data());
-  this->net_->ForwardPrefilled();
+  this->net_->Forward();
   // call backward just to make sure it runs
   this->net_->Backward();
   Blob<Dtype> output1(output_blob->num(), output_blob->channels(),
@@ -2293,7 +2285,7 @@ TYPED_TEST(NetTest, TestReshape) {
   input_blob->Reshape(blob2.num(), blob2.channels(), blob2.height(),
       blob2.width());
   caffe_copy(blob2.count(), blob2.cpu_data(), input_blob->mutable_cpu_data());
-  this->net_->ForwardPrefilled();
+  this->net_->Forward();
   this->net_->Backward();
   Blob<Dtype> output2(output_blob->num(), output_blob->channels(),
       output_blob->height(), output_blob->width());
@@ -2303,7 +2295,7 @@ TYPED_TEST(NetTest, TestReshape) {
   input_blob->Reshape(blob1.num(), blob1.channels(), blob1.height(),
       blob1.width());
   caffe_copy(blob1.count(), blob1.cpu_data(), input_blob->mutable_cpu_data());
-  this->net_->ForwardPrefilled();
+  this->net_->Forward();
   this->net_->Backward();
   for (int i = 0; i < output1.count(); ++i) {
     EXPECT_FLOAT_EQ(*(output1.cpu_data() + i), *(output_blob->cpu_data() + i));
@@ -2312,7 +2304,7 @@ TYPED_TEST(NetTest, TestReshape) {
   input_blob->Reshape(blob2.num(), blob2.channels(), blob2.height(),
       blob2.width());
   caffe_copy(blob2.count(), blob2.cpu_data(), input_blob->mutable_cpu_data());
-  this->net_->ForwardPrefilled();
+  this->net_->Forward();
   this->net_->Backward();
   for (int i = 0; i < output2.count(); ++i) {
     EXPECT_FLOAT_EQ(*(output2.cpu_data() + i), *(output_blob->cpu_data() + i));
index ba2ccbb..0071421 100644 (file)
@@ -886,67 +886,6 @@ TEST_F(SplitLayerInsertionTest, TestInsertionTwoTop) {
   this->RunInsertionTest(input_proto, expected_output_proto);
 }
 
-TEST_F(SplitLayerInsertionTest, TestInputInsertion) {
-  const string& input_proto =
-      "name: 'TestNetwork' "
-      "input: 'data' "
-      "input_dim: 10 "
-      "input_dim: 3 "
-      "input_dim: 227 "
-      "input_dim: 227 "
-      "layer { "
-      "  name: 'innerprod1' "
-      "  type: 'InnerProduct' "
-      "  bottom: 'data' "
-      "  top: 'innerprod1' "
-      "} "
-      "layer { "
-      "  name: 'innerprod2' "
-      "  type: 'InnerProduct' "
-      "  bottom: 'data' "
-      "  top: 'innerprod2' "
-      "} "
-      "layer { "
-      "  name: 'loss' "
-      "  type: 'EuclideanLoss' "
-      "  bottom: 'innerprod1' "
-      "  bottom: 'innerprod2' "
-      "} ";
-  const string& expected_output_proto =
-      "name: 'TestNetwork' "
-      "input: 'data' "
-      "input_dim: 10 "
-      "input_dim: 3 "
-      "input_dim: 227 "
-      "input_dim: 227 "
-      "layer { "
-      "  name: 'data_input_0_split' "
-      "  type: 'Split' "
-      "  bottom: 'data' "
-      "  top: 'data_input_0_split_0' "
-      "  top: 'data_input_0_split_1' "
-      "} "
-      "layer { "
-      "  name: 'innerprod1' "
-      "  type: 'InnerProduct' "
-      "  bottom: 'data_input_0_split_0' "
-      "  top: 'innerprod1' "
-      "} "
-      "layer { "
-      "  name: 'innerprod2' "
-      "  type: 'InnerProduct' "
-      "  bottom: 'data_input_0_split_1' "
-      "  top: 'innerprod2' "
-      "} "
-      "layer { "
-      "  name: 'loss' "
-      "  type: 'EuclideanLoss' "
-      "  bottom: 'innerprod1' "
-      "  bottom: 'innerprod2' "
-      "} ";
-  this->RunInsertionTest(input_proto, expected_output_proto);
-}
-
 TEST_F(SplitLayerInsertionTest, TestWithInPlace) {
   const string& input_proto =
       "name: 'TestNetwork' "
index 475a2a9..7a899c6 100644 (file)
@@ -19,12 +19,6 @@ void InsertSplits(const NetParameter& param, NetParameter* param_split) {
   map<pair<int, int>, float> top_idx_to_loss_weight;
   map<pair<int, int>, int> top_idx_to_bottom_split_idx;
   map<int, string> layer_idx_to_layer_name;
-  layer_idx_to_layer_name[-1] = "input";
-  // Determine the number of times each blob is used as an input (bottom) blob.
-  for (int i = 0; i < param.input_size(); ++i) {
-    const string& blob_name = param.input(i);
-    blob_name_to_last_top_idx[blob_name] = make_pair(-1, i);
-  }
   for (int i = 0; i < param.layer_size(); ++i) {
     const LayerParameter& layer_param = param.layer(i);
     layer_idx_to_layer_name[i] = layer_param.name();
@@ -45,7 +39,7 @@ void InsertSplits(const NetParameter& param, NetParameter* param_split) {
       blob_name_to_last_top_idx[blob_name] = make_pair(i, j);
     }
     // A use of a top blob as a loss should be handled similarly to the use of
-    // a top blob as an input (bottom) blob to another layer.
+    // a top blob as a bottom blob to another layer.
     const int last_loss =
         std::min(layer_param.loss_weight_size(), layer_param.top_size());
     for (int j = 0; j < last_loss; ++j) {
@@ -57,19 +51,6 @@ void InsertSplits(const NetParameter& param, NetParameter* param_split) {
       }
     }
   }
-  // Create split layer for any input blobs used by other layer as bottom
-  // blobs more than once.
-  for (int i = 0; i < param.input_size(); ++i) {
-    const int split_count = top_idx_to_bottom_count[make_pair(-1, i)];
-    if (split_count > 1) {
-      const string& layer_name = layer_idx_to_layer_name[-1];
-      const string& blob_name = param.input(i);
-      LayerParameter* split_layer_param = param_split->add_layer();
-      const float kZeroLossWeight = 0;
-      ConfigureSplitLayer(layer_name, blob_name, i, split_count,
-          kZeroLossWeight, split_layer_param);
-    }
-  }
   for (int i = 0; i < param.layer_size(); ++i) {
     LayerParameter* layer_param = param_split->add_layer();
     layer_param->CopyFrom(param.layer(i));
index ebe95d6..95b2f82 100644 (file)
@@ -251,14 +251,13 @@ int test() {
   caffe_net.CopyTrainedLayersFrom(FLAGS_weights);
   LOG(INFO) << "Running for " << FLAGS_iterations << " iterations.";
 
-  vector<Blob<float>* > bottom_vec;
   vector<int> test_score_output_id;
   vector<float> test_score;
   float loss = 0;
   for (int i = 0; i < FLAGS_iterations; ++i) {
     float iter_loss;
     const vector<Blob<float>*>& result =
-        caffe_net.Forward(bottom_vec, &iter_loss);
+        caffe_net.Forward(&iter_loss);
     loss += iter_loss;
     int idx = 0;
     for (int j = 0; j < result.size(); ++j) {
@@ -322,7 +321,7 @@ int time() {
   // Note that for the speed benchmark, we will assume that the network does
   // not take any input blobs.
   float initial_loss;
-  caffe_net.Forward(vector<Blob<float>*>(), &initial_loss);
+  caffe_net.Forward(&initial_loss);
   LOG(INFO) << "Initial loss: " << initial_loss;
   LOG(INFO) << "Performing Backward";
   caffe_net.Backward();
index d6562f9..7044672 100644 (file)
@@ -133,10 +133,9 @@ int feature_extraction_pipeline(int argc, char** argv) {
   LOG(ERROR)<< "Extacting Features";
 
   Datum datum;
-  std::vector<Blob<float>*> input_vec;
   std::vector<int> image_indices(num_features, 0);
   for (int batch_index = 0; batch_index < num_mini_batches; ++batch_index) {
-    feature_extraction_net->Forward(input_vec);
+    feature_extraction_net->Forward();
     for (int i = 0; i < num_features; ++i) {
       const boost::shared_ptr<Blob<Dtype> > feature_blob =
         feature_extraction_net->blob_by_name(blob_names[i]);