deprecate input fields and upgrade automagically
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Fri, 16 Oct 2015 23:44:45 +0000 (16:44 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Thu, 25 Feb 2016 21:20:04 +0000 (13:20 -0800)
include/caffe/util/upgrade_proto.hpp
src/caffe/proto/caffe.proto
src/caffe/util/upgrade_proto.cpp

index c94bb3c..14e1936 100644 (file)
@@ -59,6 +59,12 @@ bool UpgradeV1LayerParameter(const V1LayerParameter& v1_layer_param,
 
 const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type);
 
+// Return true iff the Net contains input fields.
+bool NetNeedsInputUpgrade(const NetParameter& net_param);
+
+// Perform all necessary transformations to upgrade input fields into layers.
+void UpgradeNetInput(NetParameter* net_param);
+
 // Return true iff the solver contains any old solver_type specified as enums
 bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param);
 
index 702ce6b..3b27bbd 100644 (file)
@@ -63,12 +63,12 @@ message FillerParameter {
 
 message NetParameter {
   optional string name = 1; // consider giving the network a name
-  // The input blobs to the network.
+  // DEPRECATED. See InputParameter. The input blobs to the network.
   repeated string input = 3;
-  // The shape of the input blobs.
+  // DEPRECATED. See InputParameter. The shape of the input blobs.
   repeated BlobShape input_shape = 8;
 
-  // 4D input dimensions -- deprecated.  Use "shape" instead.
+  // 4D input dimensions -- deprecated.  Use "input_shape" instead.
   // If specified, for each input blob there should be four
   // values specifying the num, channels, height and width of the input blob.
   // Thus, there should be a total of (4 * #input) numbers.
index ff3f8ff..449975b 100644 (file)
@@ -60,6 +60,16 @@ bool UpgradeNetAsNeeded(const string& param_file, NetParameter* param) {
                 << "V1LayerParameter";
     }
   }
+  // NetParameter uses old style input fields; try to upgrade it.
+  if (NetNeedsInputUpgrade(*param)) {
+    LOG(INFO) << "Attempting to upgrade input file specified using deprecated "
+              << "input fields: " << param_file;
+    UpgradeNetInput(param);
+    LOG(INFO) << "Successfully upgraded file specified using deprecated "
+              << "input fields.";
+    LOG(WARNING) << "Note that future Caffe releases will only support "
+                 << "input layers and not input fields.";
+  }
   return success;
 }
 
@@ -937,6 +947,41 @@ const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type) {
   }
 }
 
+bool NetNeedsInputUpgrade(const NetParameter& net_param) {
+  return net_param.input_size() > 0;
+}
+
+void UpgradeNetInput(NetParameter* net_param) {
+  LayerParameter* layer_param = net_param->add_layer();
+  layer_param->set_name("input");
+  layer_param->set_type("Input");
+  InputParameter* input_param = layer_param->mutable_input_param();
+  bool has_shape = net_param->input_shape_size() > 0;
+  // Convert input fields into a layer.
+  for (int i = 0; i < net_param->input_size(); ++i) {
+    layer_param->add_top(net_param->input(i));
+    if (has_shape) {
+      input_param->add_shape()->CopyFrom(net_param->input_shape(i));
+    } else {
+      // Turn legacy input dimensions into shape.
+      BlobShape* shape = input_param->add_shape();
+      int first_dim = i*4;
+      int last_dim = first_dim + 4;
+      for (int j = first_dim; j < last_dim; j++) {
+        shape->add_dim(net_param->input_dim(j));
+      }
+    }
+  }
+  // Swap input layer to beginning of net to satisfy layer dependencies.
+  for (int i = net_param->layer_size() - 1; i > 0; --i) {
+    net_param->mutable_layer(i-1)->Swap(net_param->mutable_layer(i));
+  }
+  // Clear inputs.
+  net_param->clear_input();
+  net_param->clear_input_shape();
+  net_param->clear_input_dim();
+}
+
 // Return true iff the solver contains any old solver_type specified as enums
 bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param) {
   if (solver_param.has_solver_type()) {