From e4c9d75008f5a8d5cd44c6b4e24d1e90bb81c598 Mon Sep 17 00:00:00 2001 From: Spandan Tiwari Date: Thu, 7 Mar 2019 10:06:17 -0800 Subject: [PATCH] - refactoring serialization of ONNX initializers to be name-based (#17420) Summary: Currently, serialization of model parameters in ONNX export depends on the order in which they are stored in a container (`list` on Python side and `std::vector` on C++ side). This has worked fine till now, but if we need to do any pass on that graph that mutates the parameter list, then strictly order-based serialization may not work. This PR is the first in a set to bring in more passes (such as constant folding) related to ONNX export. This PR lays the groundwork by moving the serialization in ONNX export from order-based to name based approach, which is more amenable to some of the passes. houseroad - As discussed this change uses a map for export, and removes the code from `export.cpp` that relies on the order to compute initializer names. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17420 Differential Revision: D14361993 Pulled By: houseroad fbshipit-source-id: da93e945d55755c126de06641f35df87d1648cc4 --- torch/csrc/jit/export.cpp | 37 ++++++++++++++++++++----------------- torch/csrc/jit/export.h | 4 ++-- torch/csrc/jit/python_ir.cpp | 5 +++-- torch/onnx/utils.py | 18 +++++++++++------- 4 files changed, 36 insertions(+), 28 deletions(-) diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index 4974929..39417ba 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -115,15 +115,23 @@ class EncoderBase { } protected: + // Using std::map instead of std::unordered_map for initializers + // in EncodeGraph cosntructor so that the order in which initializers + // get written to the ONNX graph is always the deterministic and + // predictable. While this is not a ONNX requirement, it is needed + // for testing purposes in tests that use _export_to_pretty_string() + // for validating ONNX graphs. void EncodeGraph( onnx::GraphProto* graph_proto, const std::shared_ptr& graph, - const std::vector& initializers = {}); + const std::map& initializers = + std::map()); void EncodeBlock( onnx::GraphProto* graph_proto, const Block* block, - const std::vector& initializers = {}); + const std::map& initializers = + std::map()); virtual void EncodeTensor( onnx::TensorProto* tensor_proto, @@ -209,14 +217,14 @@ void EncoderBase::EncodeValueInfo( void EncoderBase::EncodeGraph( onnx::GraphProto* graph_proto, const std::shared_ptr& graph, - const std::vector& initializers) { + const std::map& initializers) { EncodeBlock(graph_proto, graph->block(), initializers); } void EncoderBase::EncodeBlock( onnx::GraphProto* graph_proto, const Block* block, - const std::vector& initializers) { + const std::map& initializers) { AT_ASSERT(graph_proto != nullptr); std::string block_name = "torch-jit-export"; if (num_blocks_) { @@ -303,16 +311,11 @@ void EncoderBase::EncodeBlock( EncodeBlock(false_g, node->blocks()[1]); } } - auto num_initializers = initializers.size(); - AT_ASSERT(block->inputs().size() >= num_initializers); - size_t inputs_count = block->inputs().size() - num_initializers; - for (auto& tensor : initializers) { - // TODO: stop using positions to determine which initializers - // match to which inputs - std::string name = graph_proto->input(inputs_count++).name(); + AT_ASSERT(block->inputs().size() >= initializers.size()); + for (auto& name_tensor_pair : initializers) { auto p = graph_proto->add_initializer(); - p->set_name(name); - EncodeTensor(p, tensor, name); + p->set_name(name_tensor_pair.first); + EncodeTensor(p, name_tensor_pair.second, name_tensor_pair.first); } } @@ -386,7 +389,7 @@ class GraphEncoder : public EncoderBase { const std::shared_ptr& graph, int64_t onnx_opset_version, onnx_torch::OperatorExportTypes operator_export_type, - const std::vector& initializers, + const std::map& initializers, bool defer_weight_export, bool strip_doc); @@ -408,7 +411,7 @@ GraphEncoder::GraphEncoder( const std::shared_ptr& graph, int64_t onnx_opset_version, onnx_torch::OperatorExportTypes operator_export_type, - const std::vector& initializers, + const std::map& initializers, bool defer_weight_export, bool strip_doc) : EncoderBase(operator_export_type, strip_doc), @@ -858,7 +861,7 @@ std::string prettyPrint(const onnx::ModelProto& model) { std::string pretty_print_onnx( const std::shared_ptr& graph, - const std::vector& initializers, + const std::map& initializers, int64_t onnx_opset_version, bool defer_weight_export, ::torch::onnx::OperatorExportTypes operator_export_type, @@ -883,7 +886,7 @@ std::string pretty_print_onnx( // libtorch will be able to import the IR and play it back. std::tuple export_onnx( const std::shared_ptr& graph, - const std::vector& initializers, + const std::map& initializers, int64_t onnx_opset_version, bool defer_weight_export, ::torch::onnx::OperatorExportTypes operator_export_type) { diff --git a/torch/csrc/jit/export.h b/torch/csrc/jit/export.h index ae49245..1904723 100644 --- a/torch/csrc/jit/export.h +++ b/torch/csrc/jit/export.h @@ -21,7 +21,7 @@ using RawDataExportMap = std::unordered_map; TORCH_API std::tuple export_onnx( const std::shared_ptr& graph, - const std::vector& initializers, + const std::map& initializers, int64_t onnx_opset_version, bool defer_weight_export = false, ::torch::onnx::OperatorExportTypes operator_export_type = @@ -30,7 +30,7 @@ TORCH_API std::tuple export_onnx( // For testing purposes TORCH_API std::string pretty_print_onnx( const std::shared_ptr& graph, - const std::vector& initializers, + const std::map& initializers, int64_t onnx_opset_version, bool defer_weight_export, ::torch::onnx::OperatorExportTypes operator_export_type = diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 04e60a3..1b204aa 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -14,6 +14,7 @@ #include #include +#include namespace torch { namespace jit { @@ -221,7 +222,7 @@ void initPythonIRBindings(PyObject* module_) { .def( "_export_onnx", [](const std::shared_ptr g, - const std::vector& initializers, + const std::map& initializers, int64_t onnx_opset_version, bool defer_weight_export, ::torch::onnx::OperatorExportTypes operator_export_type) { @@ -255,7 +256,7 @@ void initPythonIRBindings(PyObject* module_) { .def( "_pretty_print_onnx", [](const std::shared_ptr g, - const std::vector& initializers, + const std::map& initializers, int64_t onnx_opset_version, bool defer_weight_export, ::torch::onnx::OperatorExportTypes operator_export_type, diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 7641aa0..0dea267 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -233,6 +233,10 @@ def _model_to_graph(model, args, f, verbose=False, training=False, graph, torch_out = _trace_and_get_graph_from_model(model, args, training) params = list(_unique_state_dict(model).values()) + input_and_param_names = [val.uniqueName() for val in graph.inputs()] + param_names = input_and_param_names[len(input_and_param_names) - len(params):] + params_dict = dict(zip(param_names, params)) + graph = _optimize_graph(graph, operator_export_type) # NB: ONNX requires complete information about output types, which might be @@ -246,7 +250,7 @@ def _model_to_graph(model, args, f, verbose=False, training=False, if verbose: print(graph) - return graph, params, torch_out + return graph, params_dict, torch_out def export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False, @@ -294,17 +298,17 @@ def _export(model, args, f, export_params=True, verbose=False, training=False, if opset_version is None: opset_version = _default_onnx_opset_version _set_opset_version(opset_version) - graph, params, torch_out = _model_to_graph(model, args, f, verbose, - training, input_names, - output_names, operator_export_type, - example_outputs, propagate) + graph, params_dict, torch_out = _model_to_graph(model, args, f, verbose, + training, input_names, + output_names, operator_export_type, + example_outputs, propagate) # TODO: Don't allocate a in-memory string for the protobuf defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE if export_params: - proto, export_map = graph._export_onnx(params, opset_version, defer_weight_export, operator_export_type) + proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type) else: - proto, export_map = graph._export_onnx([], opset_version, False, operator_export_type) + proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type) if export_type == ExportTypes.PROTOBUF_FILE: assert(len(export_map) == 0) -- 2.7.4