From 1240327c5c84d817d845bf6ec19a28f7c8819a8f Mon Sep 17 00:00:00 2001 From: Spandan Tiwari Date: Fri, 29 Mar 2019 15:17:14 -0700 Subject: [PATCH] Refactoring serialization of ONNX initializers to be name-based (Resubmission) (#17830) Summary: houseroad - this is the resubmission of https://github.com/pytorch/pytorch/pull/17420, as suggested. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17830 Reviewed By: zrphercule Differential Revision: D14398714 Pulled By: houseroad fbshipit-source-id: bda475f1ae8a5273ebdb0f6883fc66036c29d326 --- .../expect/TestOperators.test_batchnorm.expect | 16 +++---- .../expect/TestOperators.test_batchnorm_1d.expect | 16 +++---- .../TestOperators.test_batchnorm_noaffine.expect | 10 ++--- .../TestOperators.test_batchnorm_training.expect | 18 ++++---- test/onnx/expect/TestOperators.test_linear.expect | 10 ++--- test/onnx/test_pytorch_onnx_caffe2.py | 52 ++++++++++++++++++++++ torch/csrc/jit/export.cpp | 37 ++++++++------- torch/csrc/jit/export.h | 4 +- torch/csrc/jit/python_ir.cpp | 4 +- torch/onnx/utils.py | 21 +++++---- 10 files changed, 124 insertions(+), 64 deletions(-) diff --git a/test/onnx/expect/TestOperators.test_batchnorm.expect b/test/onnx/expect/TestOperators.test_batchnorm.expect index 916e87f..e9edb45 100644 --- a/test/onnx/expect/TestOperators.test_batchnorm.expect +++ b/test/onnx/expect/TestOperators.test_batchnorm.expect @@ -25,13 +25,12 @@ graph { initializer { dims: 2 data_type: 1 - name: "weight" - raw_data: "\340e\355<\246\305\315>" + name: "bias" + raw_data: "\000\000\000\000\000\000\000\000" } initializer { - dims: 2 - data_type: 1 - name: "bias" + data_type: 7 + name: "num_batches_tracked" raw_data: "\000\000\000\000\000\000\000\000" } initializer { @@ -47,9 +46,10 @@ graph { raw_data: "\000\000\200?\000\000\200?" } initializer { - data_type: 7 - name: "num_batches_tracked" - raw_data: "\000\000\000\000\000\000\000\000" + dims: 2 + data_type: 1 + name: "weight" + raw_data: "\340e\355<\246\305\315>" } input { name: "input" diff --git a/test/onnx/expect/TestOperators.test_batchnorm_1d.expect b/test/onnx/expect/TestOperators.test_batchnorm_1d.expect index 3291043..f3dac32 100644 --- a/test/onnx/expect/TestOperators.test_batchnorm_1d.expect +++ b/test/onnx/expect/TestOperators.test_batchnorm_1d.expect @@ -45,13 +45,12 @@ graph { initializer { dims: 2 data_type: 1 - name: "weight" - raw_data: "\340e\355<\246\305\315>" + name: "bias" + raw_data: "\000\000\000\000\000\000\000\000" } initializer { - dims: 2 - data_type: 1 - name: "bias" + data_type: 7 + name: "num_batches_tracked" raw_data: "\000\000\000\000\000\000\000\000" } initializer { @@ -67,9 +66,10 @@ graph { raw_data: "\000\000\200?\000\000\200?" } initializer { - data_type: 7 - name: "num_batches_tracked" - raw_data: "\000\000\000\000\000\000\000\000" + dims: 2 + data_type: 1 + name: "weight" + raw_data: "\340e\355<\246\305\315>" } input { name: "input" diff --git a/test/onnx/expect/TestOperators.test_batchnorm_noaffine.expect b/test/onnx/expect/TestOperators.test_batchnorm_noaffine.expect index 6528fca..6e7b9e7 100644 --- a/test/onnx/expect/TestOperators.test_batchnorm_noaffine.expect +++ b/test/onnx/expect/TestOperators.test_batchnorm_noaffine.expect @@ -49,6 +49,11 @@ graph { } name: "torch-jit-export" initializer { + data_type: 7 + name: "num_batches_tracked" + raw_data: "\000\000\000\000\000\000\000\000" + } + initializer { dims: 128 data_type: 1 name: "running_mean" @@ -60,11 +65,6 @@ graph { name: "running_var" raw_data: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?" } - initializer { - data_type: 7 - name: "num_batches_tracked" - raw_data: "\000\000\000\000\000\000\000\000" - } input { name: "input" type { diff --git a/test/onnx/expect/TestOperators.test_batchnorm_training.expect b/test/onnx/expect/TestOperators.test_batchnorm_training.expect index f6a3e2f..e4f0e6e 100644 --- a/test/onnx/expect/TestOperators.test_batchnorm_training.expect +++ b/test/onnx/expect/TestOperators.test_batchnorm_training.expect @@ -29,16 +29,15 @@ graph { initializer { dims: 2 data_type: 1 - name: "weight" - raw_data: "\340e\355<\246\305\315>" - } - initializer { - dims: 2 - data_type: 1 name: "bias" raw_data: "\000\000\000\000\000\000\000\000" } initializer { + data_type: 7 + name: "num_batches_tracked" + raw_data: "\001\000\000\000\000\000\000\000" + } + initializer { dims: 2 data_type: 1 name: "running_mean" @@ -51,9 +50,10 @@ graph { raw_data: "fff?fff?" } initializer { - data_type: 7 - name: "num_batches_tracked" - raw_data: "\001\000\000\000\000\000\000\000" + dims: 2 + data_type: 1 + name: "weight" + raw_data: "\340e\355<\246\305\315>" } input { name: "input" diff --git a/test/onnx/expect/TestOperators.test_linear.expect b/test/onnx/expect/TestOperators.test_linear.expect index ecb3e6d..71a81f5 100644 --- a/test/onnx/expect/TestOperators.test_linear.expect +++ b/test/onnx/expect/TestOperators.test_linear.expect @@ -27,16 +27,16 @@ graph { name: "torch-jit-export" initializer { dims: 5 - dims: 4 data_type: 1 - name: "weight" - raw_data: "\212\332\356>@\265u>p\303E\275 \320\306\274\354\201\221>\004\354\261\276\2746*>8\247)\276\340\035\224>\024\2446\276\200\211\312<\224\344,>D\356\257>\320\202\226\275\364\213\351>z\226\330\276\310\250\266\275\352F\377\276\000\250)=\244K\021>" + name: "bias" + raw_data: "\324BO\276@\245T>\350\377\245\275\374u\336\276&\212\304>" } initializer { dims: 5 + dims: 4 data_type: 1 - name: "bias" - raw_data: "\324BO\276@\245T>\350\377\245\275\374u\336\276&\212\304>" + name: "weight" + raw_data: "\212\332\356>@\265u>p\303E\275 \320\306\274\354\201\221>\004\354\261\276\2746*>8\247)\276\340\035\224>\024\2446\276\200\211\312<\224\344,>D\356\257>\320\202\226\275\364\213\351>z\226\330\276\310\250\266\275\352F\377\276\000\250)=\244K\021>" } input { name: "input" diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index ecc3e4f..45d243d 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -11,6 +11,7 @@ import itertools import torch.onnx import torch.onnx.operators +from torch.onnx import ExportTypes from torch import nn from torch.autograd import Variable, function import torch.utils.model_zoo as model_zoo @@ -207,6 +208,57 @@ class TestCaffe2Backend(unittest.TestCase): input = torch.randn(3, 4, requires_grad=True) self.run_model_test(model, train=False, batch_size=0, input=input) + def test_onnx_export_with_parameter_renaming(self): + class SimpleFcNet(nn.Module): + def __init__(self): + super(SimpleFcNet, self).__init__() + self.fc1 = nn.Linear(5, 10) + + def forward(self, input): + return self.fc1(input) + + model = SimpleFcNet() + input = torch.randn(7, 5) + output = model(input) + + f = io.BytesIO() + # Note that the export call explicitly sets the names of not just the input, + # but also the parameters. This test checks that the model can be loaded and + # executed in Caffe2 backend correctly. + torch.onnx._export(model, input, f, verbose=True, export_type=ExportTypes.ZIP_ARCHIVE, + input_names=['input1', 'parameter1', 'parameter2']) + + f.seek(0) + model_c2 = c2.prepare_zip_archive(f) + result = model_c2.run(input.numpy()) + np.testing.assert_almost_equal(output.data.cpu().numpy(), result[0], decimal=3) + + def test_onnx_export_param_name_duplication(self): + class SimpleFcNet(nn.Module): + def __init__(self): + super(SimpleFcNet, self).__init__() + self.fc1 = nn.Linear(5, 10) + + def forward(self, input): + return self.fc1(input) + + model = SimpleFcNet() + input = torch.randn(7, 5) + output = model(input) + + f = io.BytesIO() + # The export call explicitly sets the names of the input, and the first parameter. + # But note that the target first parameter name is the same as the second parameter name. + # This test checks that given this edge condition, the model can be loaded and executed + # in Caffe2 backend correctly. + torch.onnx._export(model, input, f, verbose=True, export_type=ExportTypes.ZIP_ARCHIVE, + input_names=['input1', 'fc1.bias'], _retain_param_name=False) + + f.seek(0) + model_c2 = c2.prepare_zip_archive(f) + result = model_c2.run(input.numpy()) + np.testing.assert_almost_equal(output.data.cpu().numpy(), result[0], decimal=3) + def test_lstm_cell(self): model = nn.LSTMCell(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE) input = torch.randn(BATCH_SIZE, RNN_INPUT_SIZE) diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index fe96902..4112aa9 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -116,15 +116,23 @@ class EncoderBase { } protected: + // Using std::map instead of std::unordered_map for initializers + // in EncodeGraph cosntructor so that the order in which initializers + // get written to the ONNX graph is always the deterministic and + // predictable. While this is not a ONNX requirement, it is needed + // for testing purposes in tests that use _export_to_pretty_string() + // for validating ONNX graphs. void EncodeGraph( onnx::GraphProto* graph_proto, const std::shared_ptr& graph, - const std::vector& initializers = {}); + const std::map& initializers = + std::map()); void EncodeBlock( onnx::GraphProto* graph_proto, const Block* block, - const std::vector& initializers = {}); + const std::map& initializers = + std::map()); virtual void EncodeTensor( onnx::TensorProto* tensor_proto, @@ -210,14 +218,14 @@ void EncoderBase::EncodeValueInfo( void EncoderBase::EncodeGraph( onnx::GraphProto* graph_proto, const std::shared_ptr& graph, - const std::vector& initializers) { + const std::map& initializers) { EncodeBlock(graph_proto, graph->block(), initializers); } void EncoderBase::EncodeBlock( onnx::GraphProto* graph_proto, const Block* block, - const std::vector& initializers) { + const std::map& initializers) { AT_ASSERT(graph_proto != nullptr); std::string block_name = "torch-jit-export"; if (num_blocks_) { @@ -304,16 +312,11 @@ void EncoderBase::EncodeBlock( EncodeBlock(false_g, node->blocks()[1]); } } - auto num_initializers = initializers.size(); - AT_ASSERT(block->inputs().size() >= num_initializers); - size_t inputs_count = block->inputs().size() - num_initializers; - for (auto& tensor : initializers) { - // TODO: stop using positions to determine which initializers - // match to which inputs - std::string name = graph_proto->input(inputs_count++).name(); + AT_ASSERT(block->inputs().size() >= initializers.size()); + for (auto& name_tensor_pair : initializers) { auto p = graph_proto->add_initializer(); - p->set_name(name); - EncodeTensor(p, tensor, name); + p->set_name(name_tensor_pair.first); + EncodeTensor(p, name_tensor_pair.second, name_tensor_pair.first); } } @@ -387,7 +390,7 @@ class GraphEncoder : public EncoderBase { const std::shared_ptr& graph, int64_t onnx_opset_version, onnx_torch::OperatorExportTypes operator_export_type, - const std::vector& initializers, + const std::map& initializers, bool defer_weight_export, bool strip_doc); @@ -409,7 +412,7 @@ GraphEncoder::GraphEncoder( const std::shared_ptr& graph, int64_t onnx_opset_version, onnx_torch::OperatorExportTypes operator_export_type, - const std::vector& initializers, + const std::map& initializers, bool defer_weight_export, bool strip_doc) : EncoderBase(operator_export_type, strip_doc), @@ -959,7 +962,7 @@ std::string prettyPrint(const onnx::ModelProto& model) { std::string pretty_print_onnx( const std::shared_ptr& graph, - const std::vector& initializers, + const std::map& initializers, int64_t onnx_opset_version, bool defer_weight_export, ::torch::onnx::OperatorExportTypes operator_export_type, @@ -984,7 +987,7 @@ std::string pretty_print_onnx( // libtorch will be able to import the IR and play it back. std::tuple export_onnx( const std::shared_ptr& graph, - const std::vector& initializers, + const std::map& initializers, int64_t onnx_opset_version, bool defer_weight_export, ::torch::onnx::OperatorExportTypes operator_export_type) { diff --git a/torch/csrc/jit/export.h b/torch/csrc/jit/export.h index ae49245..1904723 100644 --- a/torch/csrc/jit/export.h +++ b/torch/csrc/jit/export.h @@ -21,7 +21,7 @@ using RawDataExportMap = std::unordered_map; TORCH_API std::tuple export_onnx( const std::shared_ptr& graph, - const std::vector& initializers, + const std::map& initializers, int64_t onnx_opset_version, bool defer_weight_export = false, ::torch::onnx::OperatorExportTypes operator_export_type = @@ -30,7 +30,7 @@ TORCH_API std::tuple export_onnx( // For testing purposes TORCH_API std::string pretty_print_onnx( const std::shared_ptr& graph, - const std::vector& initializers, + const std::map& initializers, int64_t onnx_opset_version, bool defer_weight_export, ::torch::onnx::OperatorExportTypes operator_export_type = diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 95c06fe..a6a900b 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -221,7 +221,7 @@ void initPythonIRBindings(PyObject* module_) { .def( "_export_onnx", [](const std::shared_ptr g, - const std::vector& initializers, + const std::map& initializers, int64_t onnx_opset_version, bool defer_weight_export, ::torch::onnx::OperatorExportTypes operator_export_type) { @@ -255,7 +255,7 @@ void initPythonIRBindings(PyObject* module_) { .def( "_pretty_print_onnx", [](const std::shared_ptr g, - const std::vector& initializers, + const std::map& initializers, int64_t onnx_opset_version, bool defer_weight_export, ::torch::onnx::OperatorExportTypes operator_export_type, diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 179e054..b6249b2 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -255,10 +255,15 @@ def _model_to_graph(model, args, f, verbose=False, training=False, output.inferTypeFrom(tensor) _set_input_and_output_names(graph, input_names, output_names) + + input_and_param_names = [val.uniqueName() for val in graph.inputs()] + param_names = input_and_param_names[len(input_and_param_names) - len(params):] + params_dict = dict(zip(param_names, params)) + if verbose: print(graph) - return graph, params, torch_out + return graph, params_dict, torch_out def export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False, @@ -306,18 +311,18 @@ def _export(model, args, f, export_params=True, verbose=False, training=False, if opset_version is None: opset_version = _default_onnx_opset_version _set_opset_version(opset_version) - graph, params, torch_out = _model_to_graph(model, args, f, verbose, - training, input_names, - output_names, operator_export_type, - example_outputs, propagate, - _retain_param_name) + graph, params_dict, torch_out = _model_to_graph(model, args, f, verbose, + training, input_names, + output_names, operator_export_type, + example_outputs, propagate, + _retain_param_name) # TODO: Don't allocate a in-memory string for the protobuf defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE if export_params: - proto, export_map = graph._export_onnx(params, opset_version, defer_weight_export, operator_export_type) + proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type) else: - proto, export_map = graph._export_onnx([], opset_version, False, operator_export_type) + proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type) if export_type == ExportTypes.PROTOBUF_FILE: assert(len(export_map) == 0) -- 2.7.4