upgrade net parameter data transformation fields automagically
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Fri, 22 Aug 2014 04:22:44 +0000 (21:22 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Fri, 22 Aug 2014 05:20:21 +0000 (22:20 -0700)
Convert DataParameter and ImageDataParameter data transformation fields
into a TransformationParameter.

include/caffe/util/upgrade_proto.hpp
src/caffe/util/upgrade_proto.cpp
tools/upgrade_net_proto_text.cpp

index 3b10624..4548368 100644 (file)
@@ -29,6 +29,13 @@ bool UpgradeLayerParameter(const LayerParameter& v0_layer_connection,
 
 LayerParameter_LayerType UpgradeV0LayerType(const string& type);
 
+// Return true iff any layer contains deprecated data transformation parameters.
+bool NetNeedsDataUpgrade(const NetParameter& net_param);
+
+// Perform all necessary transformations to upgrade old transformation fields
+// into a TransformationParameter.
+void UpgradeNetDataTransformation(NetParameter* net_param);
+
 // Convert a NetParameter to NetParameterPrettyPrint used for dumping to
 // proto text files.
 void NetParameterToPrettyPrint(const NetParameter& param,
index c9c57a2..9045abb 100644 (file)
@@ -547,6 +547,75 @@ LayerParameter_LayerType UpgradeV0LayerType(const string& type) {
   }
 }
 
+bool NetNeedsDataUpgrade(const NetParameter& net_param) {
+  for (int i = 0; i < net_param.layers_size(); ++i) {
+    if (net_param.layers(i).type() == LayerParameter_LayerType_DATA) {
+      DataParameter layer_param = net_param.layers(i).data_param();
+      if (layer_param.has_scale()) { return true; }
+      if (layer_param.has_mean_file()) { return true; }
+      if (layer_param.has_crop_size()) { return true; }
+      if (layer_param.has_mirror()) { return true; }
+    }
+    if (net_param.layers(i).type() == LayerParameter_LayerType_IMAGE_DATA) {
+      ImageDataParameter layer_param = net_param.layers(i).image_data_param();
+      if (layer_param.has_scale()) { return true; }
+      if (layer_param.has_mean_file()) { return true; }
+      if (layer_param.has_crop_size()) { return true; }
+      if (layer_param.has_mirror()) { return true; }
+    }
+  }
+  return false;
+}
+
+void UpgradeNetDataTransformation(NetParameter* net_param) {
+  for (int i = 0; i < net_param->layers_size(); ++i) {
+    if (net_param->layers(i).type() == LayerParameter_LayerType_DATA) {
+      DataParameter* layer_param =
+          net_param->mutable_layers(i)->mutable_data_param();
+      TransformationParameter* transform_param =
+          layer_param->mutable_transform_param();
+      if (layer_param->has_scale()) {
+        transform_param->set_scale(layer_param->scale());
+        layer_param->clear_scale();
+      }
+      if (layer_param->has_mean_file()) {
+        transform_param->set_mean_file(layer_param->mean_file());
+        layer_param->clear_mean_file();
+      }
+      if (layer_param->has_crop_size()) {
+        transform_param->set_crop_size(layer_param->crop_size());
+        layer_param->clear_crop_size();
+      }
+      if (layer_param->has_mirror()) {
+        transform_param->set_mirror(layer_param->mirror());
+        layer_param->clear_mirror();
+      }
+    }
+    if (net_param->layers(i).type() == LayerParameter_LayerType_IMAGE_DATA) {
+      ImageDataParameter* layer_param =
+          net_param->mutable_layers(i)->mutable_image_data_param();
+      TransformationParameter* transform_param =
+          layer_param->mutable_transform_param();
+      if (layer_param->has_scale()) {
+        transform_param->set_scale(layer_param->scale());
+        layer_param->clear_scale();
+      }
+      if (layer_param->has_mean_file()) {
+        transform_param->set_mean_file(layer_param->mean_file());
+        layer_param->clear_mean_file();
+      }
+      if (layer_param->has_crop_size()) {
+        transform_param->set_crop_size(layer_param->crop_size());
+        layer_param->clear_crop_size();
+      }
+      if (layer_param->has_mirror()) {
+        transform_param->set_mirror(layer_param->mirror());
+        layer_param->clear_mirror();
+      }
+    }
+  }
+}
+
 void NetParameterToPrettyPrint(const NetParameter& param,
                                NetParameterPrettyPrint* pretty_param) {
   pretty_param->Clear();
@@ -586,6 +655,16 @@ void UpgradeNetAsNeeded(const string& param_file, NetParameter* param) {
         << "prototxt and ./build/tools/upgrade_net_proto_binary for model "
         << "weights upgrade this and any other net protos to the new format.";
   }
+  // NetParameter uses old style data transformation fields; try to upgrade it.
+  if (NetNeedsDataUpgrade(*param)) {
+    LOG(ERROR) << "Attempting to upgrade input file specified using deprecated "
+               << "transformation parameters: " << param_file;
+    UpgradeNetDataTransformation(param);
+    LOG(INFO) << "Successfully upgraded file specified using deprecated "
+              << "data transformation parameters.";
+    LOG(ERROR) << "Note that future Caffe releases will only support "
+               << "transform_param messages for transformation fields.";
+  }
 }
 
 void ReadNetParamsFromTextFileOrDie(const string& param_file,
index 1176585..2f290fc 100644 (file)
@@ -29,6 +29,7 @@ int main(int argc, char** argv) {
     return 2;
   }
   bool need_upgrade = NetNeedsUpgrade(net_param);
+  bool need_data_upgrade = NetNeedsDataUpgrade(net_param);
   bool success = true;
   if (need_upgrade) {
     NetParameter v0_net_param(net_param);
@@ -37,6 +38,10 @@ int main(int argc, char** argv) {
     LOG(ERROR) << "File already in V1 proto format: " << argv[1];
   }
 
+  if (need_data_upgrade) {
+    UpgradeNetDataTransformation(&net_param);
+  }
+
   // Convert to a NetParameterPrettyPrint to print fields in desired
   // order.
   NetParameterPrettyPrint net_param_pretty;