From c14c8be5f123ab650922750976f380d461a6d5cd Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Thu, 21 Aug 2014 21:22:44 -0700 Subject: [PATCH] upgrade net parameter data transformation fields automagically Convert DataParameter and ImageDataParameter data transformation fields into a TransformationParameter. --- include/caffe/util/upgrade_proto.hpp | 7 ++++ src/caffe/util/upgrade_proto.cpp | 79 ++++++++++++++++++++++++++++++++++++ tools/upgrade_net_proto_text.cpp | 5 +++ 3 files changed, 91 insertions(+) diff --git a/include/caffe/util/upgrade_proto.hpp b/include/caffe/util/upgrade_proto.hpp index 3b10624..4548368 100644 --- a/include/caffe/util/upgrade_proto.hpp +++ b/include/caffe/util/upgrade_proto.hpp @@ -29,6 +29,13 @@ bool UpgradeLayerParameter(const LayerParameter& v0_layer_connection, LayerParameter_LayerType UpgradeV0LayerType(const string& type); +// Return true iff any layer contains deprecated data transformation parameters. +bool NetNeedsDataUpgrade(const NetParameter& net_param); + +// Perform all necessary transformations to upgrade old transformation fields +// into a TransformationParameter. +void UpgradeNetDataTransformation(NetParameter* net_param); + // Convert a NetParameter to NetParameterPrettyPrint used for dumping to // proto text files. void NetParameterToPrettyPrint(const NetParameter& param, diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp index c9c57a2..9045abb 100644 --- a/src/caffe/util/upgrade_proto.cpp +++ b/src/caffe/util/upgrade_proto.cpp @@ -547,6 +547,75 @@ LayerParameter_LayerType UpgradeV0LayerType(const string& type) { } } +bool NetNeedsDataUpgrade(const NetParameter& net_param) { + for (int i = 0; i < net_param.layers_size(); ++i) { + if (net_param.layers(i).type() == LayerParameter_LayerType_DATA) { + DataParameter layer_param = net_param.layers(i).data_param(); + if (layer_param.has_scale()) { return true; } + if (layer_param.has_mean_file()) { return true; } + if (layer_param.has_crop_size()) { return true; } + if (layer_param.has_mirror()) { return true; } + } + if (net_param.layers(i).type() == LayerParameter_LayerType_IMAGE_DATA) { + ImageDataParameter layer_param = net_param.layers(i).image_data_param(); + if (layer_param.has_scale()) { return true; } + if (layer_param.has_mean_file()) { return true; } + if (layer_param.has_crop_size()) { return true; } + if (layer_param.has_mirror()) { return true; } + } + } + return false; +} + +void UpgradeNetDataTransformation(NetParameter* net_param) { + for (int i = 0; i < net_param->layers_size(); ++i) { + if (net_param->layers(i).type() == LayerParameter_LayerType_DATA) { + DataParameter* layer_param = + net_param->mutable_layers(i)->mutable_data_param(); + TransformationParameter* transform_param = + layer_param->mutable_transform_param(); + if (layer_param->has_scale()) { + transform_param->set_scale(layer_param->scale()); + layer_param->clear_scale(); + } + if (layer_param->has_mean_file()) { + transform_param->set_mean_file(layer_param->mean_file()); + layer_param->clear_mean_file(); + } + if (layer_param->has_crop_size()) { + transform_param->set_crop_size(layer_param->crop_size()); + layer_param->clear_crop_size(); + } + if (layer_param->has_mirror()) { + transform_param->set_mirror(layer_param->mirror()); + layer_param->clear_mirror(); + } + } + if (net_param->layers(i).type() == LayerParameter_LayerType_IMAGE_DATA) { + ImageDataParameter* layer_param = + net_param->mutable_layers(i)->mutable_image_data_param(); + TransformationParameter* transform_param = + layer_param->mutable_transform_param(); + if (layer_param->has_scale()) { + transform_param->set_scale(layer_param->scale()); + layer_param->clear_scale(); + } + if (layer_param->has_mean_file()) { + transform_param->set_mean_file(layer_param->mean_file()); + layer_param->clear_mean_file(); + } + if (layer_param->has_crop_size()) { + transform_param->set_crop_size(layer_param->crop_size()); + layer_param->clear_crop_size(); + } + if (layer_param->has_mirror()) { + transform_param->set_mirror(layer_param->mirror()); + layer_param->clear_mirror(); + } + } + } +} + void NetParameterToPrettyPrint(const NetParameter& param, NetParameterPrettyPrint* pretty_param) { pretty_param->Clear(); @@ -586,6 +655,16 @@ void UpgradeNetAsNeeded(const string& param_file, NetParameter* param) { << "prototxt and ./build/tools/upgrade_net_proto_binary for model " << "weights upgrade this and any other net protos to the new format."; } + // NetParameter uses old style data transformation fields; try to upgrade it. + if (NetNeedsDataUpgrade(*param)) { + LOG(ERROR) << "Attempting to upgrade input file specified using deprecated " + << "transformation parameters: " << param_file; + UpgradeNetDataTransformation(param); + LOG(INFO) << "Successfully upgraded file specified using deprecated " + << "data transformation parameters."; + LOG(ERROR) << "Note that future Caffe releases will only support " + << "transform_param messages for transformation fields."; + } } void ReadNetParamsFromTextFileOrDie(const string& param_file, diff --git a/tools/upgrade_net_proto_text.cpp b/tools/upgrade_net_proto_text.cpp index 1176585..2f290fc 100644 --- a/tools/upgrade_net_proto_text.cpp +++ b/tools/upgrade_net_proto_text.cpp @@ -29,6 +29,7 @@ int main(int argc, char** argv) { return 2; } bool need_upgrade = NetNeedsUpgrade(net_param); + bool need_data_upgrade = NetNeedsDataUpgrade(net_param); bool success = true; if (need_upgrade) { NetParameter v0_net_param(net_param); @@ -37,6 +38,10 @@ int main(int argc, char** argv) { LOG(ERROR) << "File already in V1 proto format: " << argv[1]; } + if (need_data_upgrade) { + UpgradeNetDataTransformation(&net_param); + } + // Convert to a NetParameterPrettyPrint to print fields in desired // order. NetParameterPrettyPrint net_param_pretty; -- 2.7.4