Refactoring serialization of ONNX initializers to be name-based (Resubmission) (...
authorSpandan Tiwari <sptiwari@microsoft.com>
Fri, 29 Mar 2019 22:17:14 +0000 (15:17 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 29 Mar 2019 22:23:29 +0000 (15:23 -0700)
Summary:
houseroad - this is the resubmission of https://github.com/pytorch/pytorch/pull/17420, as suggested.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17830

Reviewed By: zrphercule

Differential Revision: D14398714

Pulled By: houseroad

fbshipit-source-id: bda475f1ae8a5273ebdb0f6883fc66036c29d326

test/onnx/expect/TestOperators.test_batchnorm.expect
test/onnx/expect/TestOperators.test_batchnorm_1d.expect
test/onnx/expect/TestOperators.test_batchnorm_noaffine.expect
test/onnx/expect/TestOperators.test_batchnorm_training.expect
test/onnx/expect/TestOperators.test_linear.expect
test/onnx/test_pytorch_onnx_caffe2.py
torch/csrc/jit/export.cpp
torch/csrc/jit/export.h
torch/csrc/jit/python_ir.cpp
torch/onnx/utils.py

index 916e87f..e9edb45 100644 (file)
@@ -25,13 +25,12 @@ graph {
   initializer {
     dims: 2
     data_type: 1
-    name: "weight"
-    raw_data: "\340e\355<\246\305\315>"
+    name: "bias"
+    raw_data: "\000\000\000\000\000\000\000\000"
   }
   initializer {
-    dims: 2
-    data_type: 1
-    name: "bias"
+    data_type: 7
+    name: "num_batches_tracked"
     raw_data: "\000\000\000\000\000\000\000\000"
   }
   initializer {
@@ -47,9 +46,10 @@ graph {
     raw_data: "\000\000\200?\000\000\200?"
   }
   initializer {
-    data_type: 7
-    name: "num_batches_tracked"
-    raw_data: "\000\000\000\000\000\000\000\000"
+    dims: 2
+    data_type: 1
+    name: "weight"
+    raw_data: "\340e\355<\246\305\315>"
   }
   input {
     name: "input"
index 3291043..f3dac32 100644 (file)
@@ -45,13 +45,12 @@ graph {
   initializer {
     dims: 2
     data_type: 1
-    name: "weight"
-    raw_data: "\340e\355<\246\305\315>"
+    name: "bias"
+    raw_data: "\000\000\000\000\000\000\000\000"
   }
   initializer {
-    dims: 2
-    data_type: 1
-    name: "bias"
+    data_type: 7
+    name: "num_batches_tracked"
     raw_data: "\000\000\000\000\000\000\000\000"
   }
   initializer {
@@ -67,9 +66,10 @@ graph {
     raw_data: "\000\000\200?\000\000\200?"
   }
   initializer {
-    data_type: 7
-    name: "num_batches_tracked"
-    raw_data: "\000\000\000\000\000\000\000\000"
+    dims: 2
+    data_type: 1
+    name: "weight"
+    raw_data: "\340e\355<\246\305\315>"
   }
   input {
     name: "input"
index 6528fca..6e7b9e7 100644 (file)
@@ -49,6 +49,11 @@ graph {
   }
   name: "torch-jit-export"
   initializer {
+    data_type: 7
+    name: "num_batches_tracked"
+    raw_data: "\000\000\000\000\000\000\000\000"
+  }
+  initializer {
     dims: 128
     data_type: 1
     name: "running_mean"
@@ -60,11 +65,6 @@ graph {
     name: "running_var"
     raw_data: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?"
   }
-  initializer {
-    data_type: 7
-    name: "num_batches_tracked"
-    raw_data: "\000\000\000\000\000\000\000\000"
-  }
   input {
     name: "input"
     type {
index f6a3e2f..e4f0e6e 100644 (file)
@@ -29,16 +29,15 @@ graph {
   initializer {
     dims: 2
     data_type: 1
-    name: "weight"
-    raw_data: "\340e\355<\246\305\315>"
-  }
-  initializer {
-    dims: 2
-    data_type: 1
     name: "bias"
     raw_data: "\000\000\000\000\000\000\000\000"
   }
   initializer {
+    data_type: 7
+    name: "num_batches_tracked"
+    raw_data: "\001\000\000\000\000\000\000\000"
+  }
+  initializer {
     dims: 2
     data_type: 1
     name: "running_mean"
@@ -51,9 +50,10 @@ graph {
     raw_data: "fff?fff?"
   }
   initializer {
-    data_type: 7
-    name: "num_batches_tracked"
-    raw_data: "\001\000\000\000\000\000\000\000"
+    dims: 2
+    data_type: 1
+    name: "weight"
+    raw_data: "\340e\355<\246\305\315>"
   }
   input {
     name: "input"
index ecb3e6d..71a81f5 100644 (file)
@@ -27,16 +27,16 @@ graph {
   name: "torch-jit-export"
   initializer {
     dims: 5
-    dims: 4
     data_type: 1
-    name: "weight"
-    raw_data: "\212\332\356>@\265u>p\303E\275 \320\306\274\354\201\221>\004\354\261\276\2746*>8\247)\276\340\035\224>\024\2446\276\200\211\312<\224\344,>D\356\257>\320\202\226\275\364\213\351>z\226\330\276\310\250\266\275\352F\377\276\000\250)=\244K\021>"
+    name: "bias"
+    raw_data: "\324BO\276@\245T>\350\377\245\275\374u\336\276&\212\304>"
   }
   initializer {
     dims: 5
+    dims: 4
     data_type: 1
-    name: "bias"
-    raw_data: "\324BO\276@\245T>\350\377\245\275\374u\336\276&\212\304>"
+    name: "weight"
+    raw_data: "\212\332\356>@\265u>p\303E\275 \320\306\274\354\201\221>\004\354\261\276\2746*>8\247)\276\340\035\224>\024\2446\276\200\211\312<\224\344,>D\356\257>\320\202\226\275\364\213\351>z\226\330\276\310\250\266\275\352F\377\276\000\250)=\244K\021>"
   }
   input {
     name: "input"
index ecc3e4f..45d243d 100644 (file)
@@ -11,6 +11,7 @@ import itertools
 
 import torch.onnx
 import torch.onnx.operators
+from torch.onnx import ExportTypes
 from torch import nn
 from torch.autograd import Variable, function
 import torch.utils.model_zoo as model_zoo
@@ -207,6 +208,57 @@ class TestCaffe2Backend(unittest.TestCase):
         input = torch.randn(3, 4, requires_grad=True)
         self.run_model_test(model, train=False, batch_size=0, input=input)
 
+    def test_onnx_export_with_parameter_renaming(self):
+        class SimpleFcNet(nn.Module):
+            def __init__(self):
+                super(SimpleFcNet, self).__init__()
+                self.fc1 = nn.Linear(5, 10)
+
+            def forward(self, input):
+                return self.fc1(input)
+
+        model = SimpleFcNet()
+        input = torch.randn(7, 5)
+        output = model(input)
+
+        f = io.BytesIO()
+        # Note that the export call explicitly sets the names of not just the input,
+        # but also the parameters. This test checks that the model can be loaded and
+        # executed in Caffe2 backend correctly.
+        torch.onnx._export(model, input, f, verbose=True, export_type=ExportTypes.ZIP_ARCHIVE,
+                           input_names=['input1', 'parameter1', 'parameter2'])
+
+        f.seek(0)
+        model_c2 = c2.prepare_zip_archive(f)
+        result = model_c2.run(input.numpy())
+        np.testing.assert_almost_equal(output.data.cpu().numpy(), result[0], decimal=3)
+
+    def test_onnx_export_param_name_duplication(self):
+        class SimpleFcNet(nn.Module):
+            def __init__(self):
+                super(SimpleFcNet, self).__init__()
+                self.fc1 = nn.Linear(5, 10)
+
+            def forward(self, input):
+                return self.fc1(input)
+
+        model = SimpleFcNet()
+        input = torch.randn(7, 5)
+        output = model(input)
+
+        f = io.BytesIO()
+        # The export call explicitly sets the names of the input, and the first parameter.
+        # But note that the target first parameter name is the same as the second parameter name.
+        # This test checks that given this edge condition, the model can be loaded and executed
+        # in Caffe2 backend correctly.
+        torch.onnx._export(model, input, f, verbose=True, export_type=ExportTypes.ZIP_ARCHIVE,
+                           input_names=['input1', 'fc1.bias'], _retain_param_name=False)
+
+        f.seek(0)
+        model_c2 = c2.prepare_zip_archive(f)
+        result = model_c2.run(input.numpy())
+        np.testing.assert_almost_equal(output.data.cpu().numpy(), result[0], decimal=3)
+
     def test_lstm_cell(self):
         model = nn.LSTMCell(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE)
         input = torch.randn(BATCH_SIZE, RNN_INPUT_SIZE)
index fe96902..4112aa9 100644 (file)
@@ -116,15 +116,23 @@ class EncoderBase {
   }
 
  protected:
+  // Using std::map instead of std::unordered_map for initializers
+  // in EncodeGraph cosntructor so that the order in which initializers
+  // get written to the ONNX graph is always the deterministic and 
+  // predictable. While this is not a ONNX requirement, it is needed 
+  // for testing purposes in tests that use _export_to_pretty_string()
+  // for validating ONNX graphs.
   void EncodeGraph(
       onnx::GraphProto* graph_proto,
       const std::shared_ptr<Graph>& graph,
-      const std::vector<at::Tensor>& initializers = {});
+      const std::map<std::string, at::Tensor>& initializers = 
+        std::map<std::string, at::Tensor>());
 
   void EncodeBlock(
       onnx::GraphProto* graph_proto,
       const Block* block,
-      const std::vector<at::Tensor>& initializers = {});
+      const std::map<std::string, at::Tensor>& initializers = 
+        std::map<std::string, at::Tensor>());
 
   virtual void EncodeTensor(
       onnx::TensorProto* tensor_proto,
@@ -210,14 +218,14 @@ void EncoderBase::EncodeValueInfo(
 void EncoderBase::EncodeGraph(
     onnx::GraphProto* graph_proto,
     const std::shared_ptr<Graph>& graph,
-    const std::vector<at::Tensor>& initializers) {
+    const std::map<std::string, at::Tensor>& initializers) {
   EncodeBlock(graph_proto, graph->block(), initializers);
 }
 
 void EncoderBase::EncodeBlock(
     onnx::GraphProto* graph_proto,
     const Block* block,
-    const std::vector<at::Tensor>& initializers) {
+    const std::map<std::string, at::Tensor>& initializers) {
   AT_ASSERT(graph_proto != nullptr);
   std::string block_name = "torch-jit-export";
   if (num_blocks_) {
@@ -304,16 +312,11 @@ void EncoderBase::EncodeBlock(
       EncodeBlock(false_g, node->blocks()[1]);
     }
   }
-  auto num_initializers = initializers.size();
-  AT_ASSERT(block->inputs().size() >= num_initializers);
-  size_t inputs_count = block->inputs().size() - num_initializers;
-  for (auto& tensor : initializers) {
-    // TODO: stop using positions to determine which initializers
-    // match to which inputs
-    std::string name = graph_proto->input(inputs_count++).name();
+  AT_ASSERT(block->inputs().size() >= initializers.size());
+  for (auto& name_tensor_pair : initializers) {
     auto p = graph_proto->add_initializer();
-    p->set_name(name);
-    EncodeTensor(p, tensor, name);
+    p->set_name(name_tensor_pair.first);
+    EncodeTensor(p, name_tensor_pair.second, name_tensor_pair.first);
   }
 }
 
@@ -387,7 +390,7 @@ class GraphEncoder : public EncoderBase {
       const std::shared_ptr<Graph>& graph,
       int64_t onnx_opset_version,
       onnx_torch::OperatorExportTypes operator_export_type,
-      const std::vector<at::Tensor>& initializers,
+      const std::map<std::string, at::Tensor>& initializers,
       bool defer_weight_export,
       bool strip_doc);
 
@@ -409,7 +412,7 @@ GraphEncoder::GraphEncoder(
     const std::shared_ptr<Graph>& graph,
     int64_t onnx_opset_version,
     onnx_torch::OperatorExportTypes operator_export_type,
-    const std::vector<at::Tensor>& initializers,
+    const std::map<std::string, at::Tensor>& initializers,
     bool defer_weight_export,
     bool strip_doc)
     : EncoderBase(operator_export_type, strip_doc),
@@ -959,7 +962,7 @@ std::string prettyPrint(const onnx::ModelProto& model) {
 
 std::string pretty_print_onnx(
     const std::shared_ptr<Graph>& graph,
-    const std::vector<at::Tensor>& initializers,
+    const std::map<std::string, at::Tensor>& initializers,
     int64_t onnx_opset_version,
     bool defer_weight_export,
     ::torch::onnx::OperatorExportTypes operator_export_type,
@@ -984,7 +987,7 @@ std::string pretty_print_onnx(
 // libtorch will be able to import the IR and play it back.
 std::tuple<std::string, RawDataExportMap> export_onnx(
     const std::shared_ptr<Graph>& graph,
-    const std::vector<at::Tensor>& initializers,
+    const std::map<std::string, at::Tensor>& initializers,
     int64_t onnx_opset_version,
     bool defer_weight_export,
     ::torch::onnx::OperatorExportTypes operator_export_type) {
index ae49245..1904723 100644 (file)
@@ -21,7 +21,7 @@ using RawDataExportMap = std::unordered_map<std::string, at::Tensor>;
 
 TORCH_API std::tuple<std::string, RawDataExportMap> export_onnx(
     const std::shared_ptr<Graph>& graph,
-    const std::vector<at::Tensor>& initializers,
+    const std::map<std::string, at::Tensor>& initializers,
     int64_t onnx_opset_version,
     bool defer_weight_export = false,
     ::torch::onnx::OperatorExportTypes operator_export_type =
@@ -30,7 +30,7 @@ TORCH_API std::tuple<std::string, RawDataExportMap> export_onnx(
 // For testing purposes
 TORCH_API std::string pretty_print_onnx(
     const std::shared_ptr<Graph>& graph,
-    const std::vector<at::Tensor>& initializers,
+    const std::map<std::string, at::Tensor>& initializers,
     int64_t onnx_opset_version,
     bool defer_weight_export,
     ::torch::onnx::OperatorExportTypes operator_export_type =
index 95c06fe..a6a900b 100644 (file)
@@ -221,7 +221,7 @@ void initPythonIRBindings(PyObject* module_) {
       .def(
           "_export_onnx",
           [](const std::shared_ptr<Graph> g,
-             const std::vector<at::Tensor>& initializers,
+             const std::map<std::string, at::Tensor>& initializers,
              int64_t onnx_opset_version,
              bool defer_weight_export,
              ::torch::onnx::OperatorExportTypes operator_export_type) {
@@ -255,7 +255,7 @@ void initPythonIRBindings(PyObject* module_) {
       .def(
           "_pretty_print_onnx",
           [](const std::shared_ptr<Graph> g,
-             const std::vector<at::Tensor>& initializers,
+             const std::map<std::string, at::Tensor>& initializers,
              int64_t onnx_opset_version,
              bool defer_weight_export,
              ::torch::onnx::OperatorExportTypes operator_export_type,
index 179e054..b6249b2 100644 (file)
@@ -255,10 +255,15 @@ def _model_to_graph(model, args, f, verbose=False, training=False,
             output.inferTypeFrom(tensor)
 
     _set_input_and_output_names(graph, input_names, output_names)
+
+    input_and_param_names = [val.uniqueName() for val in graph.inputs()]
+    param_names = input_and_param_names[len(input_and_param_names) - len(params):]
+    params_dict = dict(zip(param_names, params))
+
     if verbose:
         print(graph)
 
-    return graph, params, torch_out
+    return graph, params_dict, torch_out
 
 
 def export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=False,
@@ -306,18 +311,18 @@ def _export(model, args, f, export_params=True, verbose=False, training=False,
     if opset_version is None:
         opset_version = _default_onnx_opset_version
     _set_opset_version(opset_version)
-    graph, params, torch_out = _model_to_graph(model, args, f, verbose,
-                                               training, input_names,
-                                               output_names, operator_export_type,
-                                               example_outputs, propagate,
-                                               _retain_param_name)
+    graph, params_dict, torch_out = _model_to_graph(model, args, f, verbose,
+                                                    training, input_names,
+                                                    output_names, operator_export_type,
+                                                    example_outputs, propagate,
+                                                    _retain_param_name)
 
     # TODO: Don't allocate a in-memory string for the protobuf
     defer_weight_export = export_type is not ExportTypes.PROTOBUF_FILE
     if export_params:
-        proto, export_map = graph._export_onnx(params, opset_version, defer_weight_export, operator_export_type)
+        proto, export_map = graph._export_onnx(params_dict, opset_version, defer_weight_export, operator_export_type)
     else:
-        proto, export_map = graph._export_onnx([], opset_version, False, operator_export_type)
+        proto, export_map = graph._export_onnx({}, opset_version, False, operator_export_type)
 
     if export_type == ExportTypes.PROTOBUF_FILE:
         assert(len(export_map) == 0)