From 6512838fabcbbf2854597790c62599af1e7f74ed Mon Sep 17 00:00:00 2001 From: BowenBao Date: Wed, 15 Sep 2021 12:56:33 -0700 Subject: [PATCH] [ONNX] Enhance shape (two changes merged) (#64585) Summary: Enhanced shape inference by introducing typeReliableMap. [ONNX] exporter changes for torch hub models (https://github.com/pytorch/pytorch/issues/62856) Pull Request resolved: https://github.com/pytorch/pytorch/pull/64585 Reviewed By: ezyang Differential Revision: D30870418 Pulled By: msaroufim fbshipit-source-id: 87a294799cb87d649d1d13b6114a5cfbac9be15c Co-authored-by: jiafatom --- aten/src/ATen/core/interned_strings.h | 4 + docs/source/onnx.rst | 8 +- .../TestOperators.test_aten_embedding_1.expect | 36 ++ .../TestOperators.test_aten_embedding_2.expect | 155 +++++++ test/onnx/expect/TestOperators.test_c2_op.expect | 6 +- .../TestOperators.test_dynamic_axes_add.expect | 64 +++ .../TestOperators.test_dynamic_axes_matmul.expect | 73 ++++ ...tOperators.test_dynamic_axes_reduce_mean.expect | 60 +++ ...TestOperators.test_dynamic_axes_unchange.expect | 76 ++++ ...stOperators.test_lstm_none_sequence_lens.expect | 44 ++ test/onnx/test_operators.py | 129 +++++- test/onnx/test_pytorch_onnx_onnxruntime.py | 9 +- torch/csrc/jit/passes/onnx.cpp | 17 +- torch/csrc/jit/passes/onnx/constant_fold.cpp | 6 + torch/csrc/jit/passes/onnx/constant_map.cpp | 62 +++ torch/csrc/jit/passes/onnx/constant_map.h | 15 + .../csrc/jit/passes/onnx/shape_type_inference.cpp | 452 ++++++++++++++++++++- torch/csrc/jit/passes/onnx/shape_type_inference.h | 12 +- torch/onnx/__init__.py | 5 + torch/onnx/symbolic_opset11.py | 17 +- torch/onnx/symbolic_registry.py | 9 + torch/onnx/utils.py | 32 +- 22 files changed, 1245 insertions(+), 46 deletions(-) create mode 100644 test/onnx/expect/TestOperators.test_aten_embedding_1.expect create mode 100644 test/onnx/expect/TestOperators.test_aten_embedding_2.expect create mode 100644 test/onnx/expect/TestOperators.test_dynamic_axes_add.expect create mode 100644 test/onnx/expect/TestOperators.test_dynamic_axes_matmul.expect create mode 100644 test/onnx/expect/TestOperators.test_dynamic_axes_reduce_mean.expect create mode 100644 test/onnx/expect/TestOperators.test_dynamic_axes_unchange.expect create mode 100644 test/onnx/expect/TestOperators.test_lstm_none_sequence_lens.expect diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 0b12603..7ed3cf8 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -394,12 +394,14 @@ namespace c10 { _(onnx, Gather) \ _(onnx, Gemm) \ _(onnx, LSTM) \ + _(onnx, MatMul) \ _(onnx, Mul) \ _(onnx, Pow) \ _(onnx, RNN) \ _(onnx, Shape) \ _(onnx, Size) \ _(onnx, Slice) \ + _(onnx, Softmax) \ _(onnx, Squeeze) \ _(onnx, Sub) \ _(onnx, Transpose) \ @@ -435,7 +437,9 @@ namespace c10 { _(onnx, ReduceL2) \ _(onnx, Conv) \ _(onnx, BatchNormalization) \ + _(onnx, ReduceMean) \ _(onnx, ReduceProd) \ + _(onnx, Relu) \ _(onnx, Neg) \ _(onnx, NonZero) \ _(onnx, Range) \ diff --git a/docs/source/onnx.rst b/docs/source/onnx.rst index eb6c2c0..8beeb55 100644 --- a/docs/source/onnx.rst +++ b/docs/source/onnx.rst @@ -396,10 +396,16 @@ All autograd ``Function``s are emitted in the TorchScript graph as ``prim::Pytho In order to differentiate between different ``Function`` subclasses, the symbolic function should use the ``name`` kwarg which gets set to the name of the class. -:func:`register_custom_op_symbolic` does does not allow registration for ops in +:func:`register_custom_op_symbolic` does not allow registration for ops in the ``prim`` namespace, so for this use case, there's a back door: register the symbolic for ``"::prim_PythonOp"``. +Please also consider adding shape inference logic when you regiester a custom symbolic function +via setType API. This can help the exporter to obtain correct shape inference. +An example of setType is test_aten_embedding_2 in test_operators.py. +Although it is not required to add shape inference logic, +the exporter emits a warning message if it is not added. + The example below shows how you can access ``requires_grad`` via the ``Node`` object:: class MyClip(torch.autograd.Function): diff --git a/test/onnx/expect/TestOperators.test_aten_embedding_1.expect b/test/onnx/expect/TestOperators.test_aten_embedding_1.expect new file mode 100644 index 0000000..317fa3a --- /dev/null +++ b/test/onnx/expect/TestOperators.test_aten_embedding_1.expect @@ -0,0 +1,36 @@ +ir_version: 6 +producer_name: "pytorch" +producer_version: "CURRENT_VERSION" +graph { + node { + output: "3" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 32 + data_type: 1 + raw_data: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?" + } + type: TENSOR + } + } + name: "torch-jit-export" + output { + name: "3" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 32 + } + } + } + } + } +} +opset_import { + version: 12 +} diff --git a/test/onnx/expect/TestOperators.test_aten_embedding_2.expect b/test/onnx/expect/TestOperators.test_aten_embedding_2.expect new file mode 100644 index 0000000..a5ab9d3 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_aten_embedding_2.expect @@ -0,0 +1,155 @@ +ir_version: 6 +producer_name: "pytorch" +producer_version: "CURRENT_VERSION" +graph { + node { + input: "emb.weight" + input: "input_1" + output: "3" + name: "ATenOp_0" + op_type: "ATenOp" + attribute { + name: "custom_attributes_json" + s: "{\"padding_idx\":-1,\"scale_grad_by_freq\":false,\"sparse\":false}" + type: STRING + } + attribute { + name: "name" + s: "aten::embedding" + type: STRING + } + domain: "com.microsoft" + } + node { + input: "3" + input: "input_2" + output: "4" + name: "Add_1" + op_type: "Add" + } + node { + input: "4" + output: "5" + name: "Shape_2" + op_type: "Shape" + } + node { + output: "6" + name: "Constant_3" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 7 + raw_data: "\000\000\000\000\000\000\000\000" + } + type: TENSOR + } + } + node { + input: "5" + input: "6" + output: "7" + name: "Gather_4" + op_type: "Gather" + attribute { + name: "axis" + i: 0 + type: INT + } + } + node { + input: "7" + output: "8" + name: "Unsqueeze_5" + op_type: "Unsqueeze" + attribute { + name: "axes" + ints: 0 + type: INTS + } + } + node { + input: "8" + output: "9" + name: "Concat_6" + op_type: "Concat" + attribute { + name: "axis" + i: 0 + type: INT + } + } + node { + input: "9" + output: "10" + name: "ConstantOfShape_7" + op_type: "ConstantOfShape" + attribute { + name: "value" + t { + dims: 1 + data_type: 1 + raw_data: "\000\000\200?" + } + type: TENSOR + } + } + name: "torch-jit-export" + initializer { + dims: 4 + dims: 8 + data_type: 1 + name: "emb.weight" + raw_data: "\264\314\344\275\017A\376\276\313\374&>J\266a\277s\306\\=\212\032+?\211[t\275\344[\357\276Dk\\\276OKb?\234\'B\277A\334\274\2767N\257\276\320s\263\277\371+\244>:\314\202\277K\200L??\001\275\275\236u4\2774\032\315\277\214\004\224>Z\320\372>\267B\305\276\346G6\277N\265.\276\343\316\272\277t\364a>\201)|>p\223\251\277Qm2?\346\275)\277\354\235\233?" + } + input { + name: "input_1" + type { + tensor_type { + elem_type: 7 + shape { + dim { + dim_param: "input_1_dim_0" + } + } + } + } + } + input { + name: "input_2" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_param: "input_2_dim_0" + } + dim { + dim_param: "input_2_dim_1" + } + } + } + } + } + output { + name: "10" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_param: "ConstantOfShape10_dim_0" + } + } + } + } + } +} +opset_import { + version: 12 +} +opset_import { + domain: "com.microsoft" + version: 1 +} diff --git a/test/onnx/expect/TestOperators.test_c2_op.expect b/test/onnx/expect/TestOperators.test_c2_op.expect index fc191b5..e9aae2c 100644 --- a/test/onnx/expect/TestOperators.test_c2_op.expect +++ b/test/onnx/expect/TestOperators.test_c2_op.expect @@ -147,10 +147,10 @@ graph { elem_type: 1 shape { dim { - dim_value: 0 + dim_param: "GenerateProposals4_dim_0" } dim { - dim_value: 5 + dim_param: "GenerateProposals4_dim_1" } } } @@ -163,7 +163,7 @@ graph { elem_type: 1 shape { dim { - dim_value: 0 + dim_param: "GenerateProposals5_dim_0" } } } diff --git a/test/onnx/expect/TestOperators.test_dynamic_axes_add.expect b/test/onnx/expect/TestOperators.test_dynamic_axes_add.expect new file mode 100644 index 0000000..83e6e74 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_dynamic_axes_add.expect @@ -0,0 +1,64 @@ +ir_version: 6 +producer_name: "pytorch" +producer_version: "CURRENT_VERSION" +graph { + node { + input: "input_1" + input: "input_2" + output: "2" + name: "Add_0" + op_type: "Add" + } + name: "torch-jit-export" + input { + name: "input_1" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_param: "input_1_dim_1" + } + } + } + } + } + input { + name: "input_2" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_param: "input_2_dim_1" + } + } + } + } + } + output { + name: "2" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_param: "Add2_dim_1" + } + } + } + } + } +} +opset_import { + version: 12 +} diff --git a/test/onnx/expect/TestOperators.test_dynamic_axes_matmul.expect b/test/onnx/expect/TestOperators.test_dynamic_axes_matmul.expect new file mode 100644 index 0000000..038d3dd --- /dev/null +++ b/test/onnx/expect/TestOperators.test_dynamic_axes_matmul.expect @@ -0,0 +1,73 @@ +ir_version: 6 +producer_name: "pytorch" +producer_version: "CURRENT_VERSION" +graph { + node { + input: "input_1" + input: "input_2" + output: "2" + name: "MatMul_0" + op_type: "MatMul" + } + name: "torch-jit-export" + input { + name: "input_1" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_param: "input_1_dim_1" + } + dim { + dim_value: 4 + } + } + } + } + } + input { + name: "input_2" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_value: 4 + } + dim { + dim_param: "input_2_dim_2" + } + } + } + } + } + output { + name: "2" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_param: "input_1_dim_1" + } + dim { + dim_param: "input_2_dim_2" + } + } + } + } + } +} +opset_import { + version: 12 +} diff --git a/test/onnx/expect/TestOperators.test_dynamic_axes_reduce_mean.expect b/test/onnx/expect/TestOperators.test_dynamic_axes_reduce_mean.expect new file mode 100644 index 0000000..24de171 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_dynamic_axes_reduce_mean.expect @@ -0,0 +1,60 @@ +ir_version: 6 +producer_name: "pytorch" +producer_version: "CURRENT_VERSION" +graph { + node { + input: "input" + output: "1" + name: "ReduceMean_0" + op_type: "ReduceMean" + attribute { + name: "axes" + ints: 1 + type: INTS + } + attribute { + name: "keepdims" + i: 0 + type: INT + } + } + name: "torch-jit-export" + input { + name: "input" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_param: "input_dim_1" + } + dim { + dim_param: "input_dim_2" + } + } + } + } + } + output { + name: "1" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_param: "input_dim_2" + } + } + } + } + } +} +opset_import { + version: 12 +} diff --git a/test/onnx/expect/TestOperators.test_dynamic_axes_unchange.expect b/test/onnx/expect/TestOperators.test_dynamic_axes_unchange.expect new file mode 100644 index 0000000..e304a96 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_dynamic_axes_unchange.expect @@ -0,0 +1,76 @@ +ir_version: 6 +producer_name: "pytorch" +producer_version: "CURRENT_VERSION" +graph { + node { + input: "input" + output: "1" + name: "Transpose_0" + op_type: "Transpose" + attribute { + name: "perm" + ints: 1 + ints: 0 + type: INTS + } + } + node { + input: "1" + output: "2" + name: "Softmax_1" + op_type: "Softmax" + attribute { + name: "axis" + i: 1 + type: INT + } + } + node { + input: "2" + output: "3" + name: "Transpose_2" + op_type: "Transpose" + attribute { + name: "perm" + ints: 1 + ints: 0 + type: INTS + } + } + name: "torch-jit-export" + input { + name: "input" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_param: "input_dim_1" + } + } + } + } + } + output { + name: "3" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 2 + } + dim { + dim_param: "input_dim_1" + } + } + } + } + } +} +opset_import { + version: 12 +} diff --git a/test/onnx/expect/TestOperators.test_lstm_none_sequence_lens.expect b/test/onnx/expect/TestOperators.test_lstm_none_sequence_lens.expect new file mode 100644 index 0000000..b544aa6 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_lstm_none_sequence_lens.expect @@ -0,0 +1,44 @@ +ir_version: 6 +producer_name: "pytorch" +producer_version: "CURRENT_VERSION" +graph { + node { + output: "7" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value" + t { + dims: 1 + dims: 2 + dims: 3 + data_type: 1 + raw_data: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?" + } + type: TENSOR + } + } + name: "torch-jit-export" + output { + name: "7" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + } + } + } + } +} +opset_import { + version: 12 +} diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index b9e391b..14d4794 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -1,8 +1,11 @@ -from test_pytorch_common import TestCase, run_tests, flatten, skipIfNoLapack +from test_pytorch_common import TestCase, run_tests, flatten, skipIfNoLapack, \ + BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE import torch import torch.onnx +from torch.onnx.symbolic_helper import parse_args, _get_tensor_dim_size, _get_tensor_sizes +from torch.onnx import register_custom_op_symbolic, unregister_custom_op_symbolic from torch.autograd import Variable, Function from torch.nn import Module, functional import torch.nn as nn @@ -907,6 +910,130 @@ class TestOperators(TestCase): y = torch.empty(3, 2, 1, dtype=torch.long).random_(5) self.assertONNX(torch.nn.CrossEntropyLoss(), (x, y), opset_version=12) + def test_lstm_none_sequence_lens(self): + """Test symbolic shape inference for LSTM when the input sequence_lens = None.""" + input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE) + h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE) + c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE) + + class LSTMModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.rnn = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False) + + def forward(self, x, h0, c0): + a, b = self.rnn(x, (h0, c0)) + return torch.ones(b[0].shape) + + self.assertONNX(LSTMModel(), + (input, h0, c0), input_names=["x", "y"], + dynamic_axes={"x" : {0: 'batch'}}, opset_version=12) + + def test_dynamic_axes_add(self): + m1 = torch.randn(2, 3, requires_grad=True) + m2 = torch.randn(2, 1, requires_grad=True) + self.assertONNX(lambda x, y: torch.add(x, y), (m1, m2), input_names=["input_1", "input_2"], + dynamic_axes={"input_1": {1: "dim_1"}, "input_2": {1: "dim_2"}}, + opset_version=12) + + def test_dynamic_axes_matmul(self): + m1 = torch.randn(2, 2, 4, requires_grad=True) + m2 = torch.randn(2, 4, 3, requires_grad=True) + self.assertONNX(lambda x, y: torch.matmul(x, y), (m1, m2), input_names=["input_1", "input_2"], + dynamic_axes={"input_1": {1: "dim_0"}, "input_2": {2: "dim_1"}}, + opset_version=12) + + def test_dynamic_axes_reduce_mean(self): + m1 = torch.randn(2, 3, 4, requires_grad=True) + self.assertONNX(lambda x: torch.mean(x, dim=1), (m1), input_names=["input"], + dynamic_axes={"input": {1: "dim_1", 2: "dim_2"}}, + opset_version=12) + + def test_dynamic_axes_unchange(self): + """Test ProcessUnchangeNode in symbolic shape inference.""" + m1 = torch.randn(2, 3, requires_grad=True) + self.assertONNX(lambda x: torch.softmax(x, dim=0), (m1,), input_names=["input"], + dynamic_axes={"input": {1: "dim_1"}}, + opset_version=12) + + def test_aten_embedding_1(self): + _onnx_opset_version = 12 + + @parse_args('v', 'v', 'i', 'b', 'b') + def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): + custom_attributes_json = ( + '{' + f'"padding_idx":{str(padding_idx)},' + f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},' + f'"sparse":{str(sparse).lower()}' + '}' + ) + output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding', + custom_attributes_json_s=custom_attributes_json) + return output + + register_custom_op_symbolic('::embedding', embedding, _onnx_opset_version) + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.emb = torch.nn.Embedding(4, 8) + + def forward(self, x, y): + res = self.emb(x) + res = res + y + return torch.ones(res.shape[0]) + + model = Model() + x = torch.ones(32, dtype=torch.long) + y = torch.randn(1, 8) + self.assertONNX(model, (x, y), opset_version=_onnx_opset_version) + + unregister_custom_op_symbolic('::embedding', _onnx_opset_version) + + # This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding. + def test_aten_embedding_2(self): + _onnx_opset_version = 12 + + @parse_args('v', 'v', 'i', 'b', 'b') + def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): + custom_attributes_json = ( + '{' + f'"padding_idx":{str(padding_idx)},' + f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},' + f'"sparse":{str(sparse).lower()}' + '}' + ) + output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding', + custom_attributes_json_s=custom_attributes_json) + + # do shape inference and set it via setType + indices_shape = _get_tensor_sizes(indices) + if indices_shape is not None and hasattr(weight.type(), 'with_sizes'): + output_type = weight.type().with_sizes(indices_shape + [_get_tensor_dim_size(weight, 1)]) + output.setType(output_type) + return output + + register_custom_op_symbolic('::embedding', embedding, _onnx_opset_version) + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.emb = torch.nn.Embedding(4, 8) + + def forward(self, x, y): + res = self.emb(x) + res = res + y + return torch.ones(res.shape[0]) + + model = Model() + x = torch.ones(32, dtype=torch.long) + y = torch.randn(1, 8) + self.assertONNX(model, (x, y), opset_version=_onnx_opset_version, input_names=['input_1', 'input_2'], + dynamic_axes={"input_1": {0: "dim_0"}, 'input_2': {0: "dim_1", 1: "dim_2"}}) + + unregister_custom_op_symbolic('::embedding', _onnx_opset_version) + if __name__ == "__main__": no_onnx_dep_flag = "--no-onnx" _onnx_dep = no_onnx_dep_flag not in common.UNITTEST_ARGS diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 54a116b..60eac0d 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -47,7 +47,7 @@ def convert_to_onnx(model, input=None, opset_version=9, example_outputs=None, do_constant_folding=True, keep_initializers_as_inputs=True, dynamic_axes=None, input_names=None, output_names=None, fixed_batch_size=False, training=None, - onnx_shape_inference=False): + onnx_shape_inference=True): # export the model to ONNX f = io.BytesIO() input_copy = copy.deepcopy(input) @@ -8582,8 +8582,13 @@ class TestONNXRuntime(unittest.TestCase): random_state = torch.rand((1, 1, 10, 30, 30)) self.run_test(model, (random_data, empty_tensor), input_names=["data", "state"], - dynamic_axes={"state": [0, 1, 2, 3, 4]}, + dynamic_axes={"data": [0, 1, 2], "state": [0, 1, 2, 3, 4]}, test_with_inputs=[(random_data, random_state)]) + self.run_test(model, (random_data, empty_tensor), + input_names=["data", "state"], + dynamic_axes={"state": [0, 1, 2, 3, 4]}, + test_with_inputs=[(random_data, random_state)], + remained_onnx_input_idx=[1]) self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[]) @skipIfUnsupportedMinOpsetVersion(11) diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index 0697f89..03e894c 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -219,6 +219,14 @@ std::unordered_map BlockToONNX( return {}; } +bool ConstantFoldCondition(torch::jit::Value* output) { + auto fold_condition = output->node()->kind() != c10::onnx::Constant && + ConstantValueMap::HasValue(output->debugName()); + auto reliable_value = + ConstantValueMap::GetTypeReliable(output->debugName()).value_or(false); + return fold_condition && reliable_value; +} + void NodeToONNX( Node* old_node, Block* new_block, @@ -267,8 +275,7 @@ void NodeToONNX( // // If onnx shape inference is turned on, the new outputs will have // types inferred, and they will be merged with the old types. - if (outputs[i]->node()->kind() != c10::onnx::Constant && - ConstantValueMap::HasValue(outputs[i]->debugName())) { + if (ConstantFoldCondition(outputs[i])) { // Create a const node if the node output value is in // ConstantValueMap. auto value = @@ -286,8 +293,10 @@ void NodeToONNX( ONNXShapeTypeInference(const_node, empty_params_dict, opset_version); env[old] = const_node->output(); } else { - outputs[i]->setType( - MergeInferredType(old->type(), outputs[i]->type())); + // ConstantValueMap has been set in shape inference, + // set_constant_value_map = false here to avoid redundancy. + MergeInferredTypeAndSetMap( + outputs[i], old->type(), outputs[i]->type(), false); // Copy over source location and scope information to all nodes // created by the symbolic diff --git a/torch/csrc/jit/passes/onnx/constant_fold.cpp b/torch/csrc/jit/passes/onnx/constant_fold.cpp index 76c0674..cce5a43 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.cpp +++ b/torch/csrc/jit/passes/onnx/constant_fold.cpp @@ -489,6 +489,12 @@ c10::optional runTorchBackendForOnnx( } else if (node->kind() == onnx::Equal) { updated_val = at::eq(inputTensorValues[0], inputTensorValues[1]); return c10::optional(updated_val); + } else if (node->kind() == onnx::Greater) { + updated_val = at::greater(inputTensorValues[0], inputTensorValues[1]); + return c10::optional(updated_val); + } else if (node->kind() == onnx::Less) { + updated_val = at::less(inputTensorValues[0], inputTensorValues[1]); + return c10::optional(updated_val); } else if (node->kind() == onnx::Neg) { updated_val = at::neg(inputTensorValues[0]); return c10::optional(updated_val); diff --git a/torch/csrc/jit/passes/onnx/constant_map.cpp b/torch/csrc/jit/passes/onnx/constant_map.cpp index 8cbec27..69fe70a 100644 --- a/torch/csrc/jit/passes/onnx/constant_map.cpp +++ b/torch/csrc/jit/passes/onnx/constant_map.cpp @@ -24,6 +24,7 @@ void ConstantValueMap::SetRank( const std::string& tensorName, size_t rankValue) { ConstantValueMap::getInstance().rankMap.emplace(tensorName, rankValue); + ConstantValueMap::getInstance().useInferredTypeMap.emplace(tensorName, true); } bool ConstantValueMap::HasRank(const std::string& tensorName) { @@ -42,6 +43,7 @@ void ConstantValueMap::SetShape( const std::string& tensorName, const c10::SymbolicShape& shapeValue) { ConstantValueMap::getInstance().shapeMap.emplace(tensorName, shapeValue); + ConstantValueMap::getInstance().useInferredTypeMap.emplace(tensorName, true); } bool ConstantValueMap::HasShape(const std::string& tensorName) { @@ -146,10 +148,50 @@ std::vector ConstantValueMap::GetValueInto1DInt64Vector( return value_vector; } +void ConstantValueMap::SetTypeReliable( + const std::string& tensorName, + bool value) { + ConstantValueMap::getInstance().typeReliableMap.emplace(tensorName, value); +} + +bool ConstantValueMap::HasTypeReliable(const std::string& tensorName) { + return ConstantValueMap::getInstance().typeReliableMap.find(tensorName) != + ConstantValueMap::getInstance().typeReliableMap.end(); +} + +c10::optional ConstantValueMap::GetTypeReliable( + const std::string& tensorName) { + if (!HasTypeReliable(tensorName)) { + return c10::nullopt; + } + return ConstantValueMap::getInstance().typeReliableMap[tensorName]; +} + +void ConstantValueMap::SetUseInferredType( + const std::string& tensorName, + bool value) { + ConstantValueMap::getInstance().useInferredTypeMap.emplace(tensorName, value); +} + +bool ConstantValueMap::HasUseInferredType(const std::string& tensorName) { + return ConstantValueMap::getInstance().useInferredTypeMap.find(tensorName) != + ConstantValueMap::getInstance().useInferredTypeMap.end(); +} + +c10::optional ConstantValueMap::GetUseInferredType( + const std::string& tensorName) { + if (!HasUseInferredType(tensorName)) { + return c10::nullopt; + } + return ConstantValueMap::getInstance().useInferredTypeMap[tensorName]; +} + void ConstantValueMap::ClearMaps() { ConstantValueMap::getInstance().rankMap.clear(); ConstantValueMap::getInstance().shapeMap.clear(); ConstantValueMap::getInstance().tensorValueMap.clear(); + ConstantValueMap::getInstance().typeReliableMap.clear(); + ConstantValueMap::getInstance().useInferredTypeMap.clear(); } // For debug only. @@ -179,6 +221,26 @@ void ConstantValueMap::PrintMaps() { for (const auto& x : ConstantValueMap::getInstance().tensorValueMap) { std::cout << "node " << x.first << ": " << x.second << std::endl; } + std::cout << std::endl; + std::cout << "Print TypeReliable Maps:" << std::endl; + size_t count = 0; + for (const auto& x : ConstantValueMap::getInstance().typeReliableMap) { + std::cout << "(node " << x.first << ": " << x.second << "), "; + count++; + if (count % 10 == 0) { + std::cout << std::endl; + } + } + std::cout << std::endl; + std::cout << "Print UseInferredType Maps:" << std::endl; + count = 0; + for (const auto& x : ConstantValueMap::getInstance().useInferredTypeMap) { + std::cout << "(node " << x.first << ": " << x.second << "), "; + count++; + if (count % 10 == 0) { + std::cout << std::endl; + } + } } } // namespace jit diff --git a/torch/csrc/jit/passes/onnx/constant_map.h b/torch/csrc/jit/passes/onnx/constant_map.h index 97fa140..ab71557 100644 --- a/torch/csrc/jit/passes/onnx/constant_map.h +++ b/torch/csrc/jit/passes/onnx/constant_map.h @@ -35,6 +35,16 @@ class ConstantValueMap { static std::vector GetValueInto1DInt64Vector( const std::string& value_name); + static void SetTypeReliable(const std::string& tensorName, bool reliable); + static bool HasTypeReliable(const std::string& tensorName); + static c10::optional GetTypeReliable(const std::string& tensorName); + + static void SetUseInferredType( + const std::string& tensorName, + bool useInferredType); + static bool HasUseInferredType(const std::string& tensorName); + static c10::optional GetUseInferredType(const std::string& tensorName); + static void PrintMaps(); static void ClearMaps(); ~ConstantValueMap() = default; @@ -47,6 +57,11 @@ class ConstantValueMap { std::unordered_map rankMap; std::unordered_map shapeMap; std::unordered_map tensorValueMap; + // This map indicates whether the current type is reliably estimated or not. + std::unordered_map typeReliableMap; + // This map indicates whether the current type is estimated through inference + // or tracer. + std::unordered_map useInferredTypeMap; }; } // namespace jit diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index 8ade722..5760c48 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -16,6 +16,7 @@ #include #include #include +#include namespace torch { namespace jit { @@ -36,10 +37,13 @@ namespace jit { // 3. existing type: Scalar[], inferred type: Tensor // ONNX represents list of scalars by 1-d Tensor. Return inferred type since // it is more compatible with ONNX. -TypePtr MergeInferredType(TypePtr existing_type, TypePtr inferred_type) { +std::pair MergeInferredType( + TypePtr existing_type, + TypePtr inferred_type) { auto new_list_type = inferred_type->cast(); + auto use_inferred_type = false; if (new_list_type) { - return inferred_type; + return std::make_pair(inferred_type, true); } auto new_tensor_type = inferred_type->cast(); auto old_tensor_type = existing_type->cast(); @@ -47,32 +51,49 @@ TypePtr MergeInferredType(TypePtr existing_type, TypePtr inferred_type) { if (new_tensor_type && old_tensor_type) { if (!old_tensor_type->device()) { // device not available means this is an invalid tensor type (most likely - // an empty one) -> return inferred type directly. - return new_tensor_type; + // an empty one) return inferred type directly. + return std::make_pair(new_tensor_type, true); } auto type = old_tensor_type; if (new_tensor_type->dim()) { type = type->withSymbolicShapes(new_tensor_type->symbolic_sizes()); + use_inferred_type = true; } if (new_tensor_type->scalarType().has_value()) { type = type->withScalarType(new_tensor_type->scalarType()); + use_inferred_type = true; } - return type; + return std::make_pair(type, use_inferred_type); } if (old_tensor_type) { - return existing_type; + return std::make_pair(existing_type, false); } auto old_list_type = existing_type->cast(); if (new_tensor_type && old_list_type) { if (new_tensor_type->sizes().isComplete()) { - return inferred_type; + return std::make_pair(inferred_type, true); } - return existing_type; + return std::make_pair(existing_type, false); } - return inferred_type; + return std::make_pair(inferred_type, true); +} + +void MergeInferredTypeAndSetMap( + Value* dest_v, + TypePtr existing_type, + TypePtr inferred_type, + bool set_constant_value_map) { + TypePtr mergedType; + bool inferred; + std::tie(mergedType, inferred) = + MergeInferredType(existing_type, inferred_type); + dest_v->setType(mergedType); + if (set_constant_value_map) { + ConstantValueMap::SetUseInferredType(dest_v->debugName(), inferred); + } } namespace { @@ -123,6 +144,7 @@ TensorTypePtr TorchTensorTypeFromONNX( // Assign a new Symbol, no need to keep track // of it because there won't be duplicates. sym = c10::ShapeSymbol::newSymbol(); + symbol_map[sym.value()] = ""; } sizes.emplace_back(sym.value()); } @@ -170,13 +192,13 @@ void UpdateTorchValueByOnnxValueInfo( const auto torch_tensor_type = TorchTensorTypeFromONNX(p_type.tensor_type(), symbol_map); if (torch_tensor_type) { - v->setType(MergeInferredType(v->type(), torch_tensor_type)); + MergeInferredTypeAndSetMap(v, v->type(), torch_tensor_type); } } else if (p_type.has_sequence_type()) { const auto torch_list_type = TorchListTypeFromONNX(p_type.sequence_type(), symbol_map); if (torch_list_type) { - v->setType(MergeInferredType(v->type(), torch_list_type)); + MergeInferredTypeAndSetMap(v, v->type(), torch_list_type); } } } @@ -596,6 +618,16 @@ void UpdateShape(Value* value, const ::c10::SymbolicShape& shape) { } } +void UpdateShapeConstantValueMap( + const Value* value, + const ::c10::SymbolicShape& shape) { + ConstantValueMap::SetShape(value->debugName(), shape); + if (shape.rank().has_value()) { + auto rank = shape.rank().value(); + ConstantValueMap::SetRank(value->debugName(), rank); + } +} + c10::optional> GetValueFromListConstructNode( Node* lc_node) { auto rank = lc_node->inputs().size(); @@ -618,6 +650,198 @@ c10::optional> GetValueFromListConstructNode( : c10::nullopt; } +void ProcessBroadCastNode(Node* n) { + TORCH_INTERNAL_ASSERT(n->inputs().size() == 2); + if (ConstantValueMap::HasShape(n->input(0)->debugName()) && + ConstantValueMap::HasShape(n->input(1)->debugName())) { + auto input_shape_0 = ConstantValueMap::GetShape(n->input(0)->debugName()); + auto input_shape_value_0 = input_shape_0.value().sizes(); + auto input_shape_1 = ConstantValueMap::GetShape(n->input(1)->debugName()); + auto input_shape_value_1 = input_shape_1.value().sizes(); + size_t rank_0 = input_shape_value_0.value().size(); + size_t rank_1 = input_shape_value_1.value().size(); + size_t rank_max = std::max(rank_0, rank_1); + size_t rank_min = std::min(rank_0, rank_1); + std::vector<::c10::ShapeSymbol> final_shape; + final_shape.reserve(rank_max); + for (auto idx = 0; idx < rank_max; idx++) { + final_shape.emplace_back(::c10::ShapeSymbol::newSymbol()); + } + for (auto idx = 0; idx < rank_min; idx++) { + auto is_static_0 = + input_shape_value_0.value()[rank_0 - 1 - idx].is_static(); + auto is_static_1 = + input_shape_value_1.value()[rank_1 - 1 - idx].is_static(); + if (is_static_0 && is_static_1) { + auto static_0_sz = + input_shape_value_0.value()[rank_0 - 1 - idx].static_size(); + auto static_1_sz = + input_shape_value_1.value()[rank_1 - 1 - idx].static_size(); + final_shape[rank_max - 1 - idx] = ::c10::ShapeSymbol::fromStaticSize( + std::max(static_0_sz, static_1_sz)); + } + } + + if (rank_0 < rank_1) { + for (auto idx = rank_min; idx < rank_max; idx++) { + auto shape_idx = rank_max - 1 - idx; + final_shape[shape_idx] = input_shape_value_1.value()[shape_idx]; + } + } else { + for (auto idx = rank_min; idx < rank_max; idx++) { + auto shape_idx = rank_max - 1 - idx; + final_shape[shape_idx] = input_shape_value_0.value()[shape_idx]; + } + } + + UpdateShape(n->output(0), c10::SymbolicShape(final_shape)); + } +} + +void ProcessConcatNode(Node* n) { + int axis = n->i(attr::axis); + if (ConstantValueMap::HasRank(n->input(0)->debugName())) { + auto rank = ConstantValueMap::GetRank(n->input(0)->debugName()).value(); + size_t axis_adjust = 0; + if (axis >= 0) { + axis_adjust = static_cast(axis); + } else { + axis_adjust = static_cast(axis + static_cast(rank)); + } + std::vector<::c10::ShapeSymbol> final_shape; + final_shape.reserve(rank); + for (auto idx = 0; idx < rank; idx++) { + if (idx == axis_adjust) { + auto flag = true; + int64_t size_total = 0; + for (auto input_idx = 0; input_idx < n->inputs().size(); input_idx++) { + if (ConstantValueMap::HasShape(n->input(input_idx)->debugName())) { + auto input_shape = + ConstantValueMap::GetShape(n->input(input_idx)->debugName()); + auto input_shape_value = input_shape.value().sizes(); + auto shape_symbol = input_shape_value.value()[idx]; + if (shape_symbol.is_static()) { + size_total += shape_symbol.static_size(); + } else { + flag = false; + break; + } + } + } + if (flag) { + final_shape.emplace_back( + ::c10::ShapeSymbol::fromStaticSize(size_total)); + } else { + final_shape.emplace_back(::c10::ShapeSymbol::newSymbol()); + } + } else { + auto flag = false; + for (auto input_idx = 0; input_idx < n->inputs().size(); input_idx++) { + if (ConstantValueMap::HasShape(n->input(input_idx)->debugName())) { + auto input_shape = + ConstantValueMap::GetShape(n->input(input_idx)->debugName()); + auto input_shape_value = input_shape.value().sizes(); + auto shape_symbol = input_shape_value.value()[idx]; + if (shape_symbol.is_static()) { + final_shape.emplace_back(::c10::ShapeSymbol::fromStaticSize( + shape_symbol.static_size())); + flag = true; + break; + } + } + } + if (!flag) { + final_shape.emplace_back(::c10::ShapeSymbol::newSymbol()); + } + } + } + UpdateShape(n->output(0), c10::SymbolicShape(final_shape)); + } +} + +void ProcessMatMulNode(Node* n) { + if (ConstantValueMap::HasShape(n->input(0)->debugName()) && + ConstantValueMap::HasShape(n->input(1)->debugName())) { + auto input_shape_0 = + ConstantValueMap::GetShape(n->input(0)->debugName()).value(); + auto input_shape_value_0 = input_shape_0.sizes().value(); + auto input_shape_1 = + ConstantValueMap::GetShape(n->input(1)->debugName()).value(); + auto input_shape_value_1 = input_shape_1.sizes().value(); + size_t rank_0 = input_shape_value_0.size(); + size_t rank_1 = input_shape_value_1.size(); + auto is_rank_0_1 = false; + if (rank_0 == 1) { + input_shape_value_0.insert( + input_shape_value_0.begin(), ::c10::ShapeSymbol::fromStaticSize(1)); + rank_0 = 2; + is_rank_0_1 = true; + } + auto is_rank_1_1 = false; + if (rank_1 == 1) { + input_shape_value_1.emplace_back(::c10::ShapeSymbol::fromStaticSize(1)); + rank_1 = 2; + is_rank_1_1 = true; + } + size_t rank = std::max(rank_0, rank_1); + std::vector<::c10::ShapeSymbol> final_shape; + final_shape.reserve(rank); + if (rank_0 >= rank_1) { + for (auto idx = 0; idx < rank_0 - 2; idx++) { + final_shape.emplace_back(input_shape_value_0[idx]); + } + } else { + for (auto idx = 0; idx < rank_1 - 2; idx++) { + final_shape.emplace_back(input_shape_value_1[idx]); + } + } + final_shape.emplace_back(input_shape_value_0[rank_0 - 2]); + final_shape.emplace_back(input_shape_value_1[rank_1 - 1]); + if (is_rank_0_1) { + final_shape.erase(final_shape.begin()); + } + if (is_rank_1_1) { + final_shape.pop_back(); + } + UpdateShape(n->output(0), c10::SymbolicShape(final_shape)); + } +} + +void ProcessReduceNode(Node* n) { + if (ConstantValueMap::HasShape(n->input(0)->debugName())) { + auto input_shape_0 = ConstantValueMap::GetShape(n->input(0)->debugName()); + auto input_shape_value_0 = input_shape_0.value().sizes(); + size_t rank_0 = input_shape_value_0.value().size(); + std::vector<::c10::ShapeSymbol> final_shape; + if (!n->hasAttributeS("axes")) { + UpdateShape(n->output(0), c10::SymbolicShape(final_shape)); + return; + } + final_shape.reserve(rank_0); + std::vector axes_vector = n->is(attr::axes); + for (auto idx = 0; idx < axes_vector.size(); idx++) { + if (axes_vector[idx] < 0) { + axes_vector[idx] += rank_0; + } + } + int64_t keepdims = 0; + if (n->hasAttributeS("keepdims")) { + keepdims = n->i(attr::keepdims); + } + for (auto idx = 0; idx < rank_0; idx++) { + auto it = std::find(axes_vector.begin(), axes_vector.end(), idx); + if (it != axes_vector.end()) { + if (keepdims != 0) { + final_shape.emplace_back(::c10::ShapeSymbol::fromStaticSize(1)); + } + } else { + final_shape.emplace_back(input_shape_value_0.value()[idx]); + } + } + UpdateShape(n->output(0), c10::SymbolicShape(final_shape)); + } +} + void ProcessReshapeNode(Node* n, int opset_version) { if (ConstantValueMap::HasValue(n->input(1)->debugName())) { auto shape_temp = @@ -625,9 +849,13 @@ void ProcessReshapeNode(Node* n, int opset_version) { auto shape_vector_0 = ConstantValueMap::GetShapeInto1DInt64VectorWithOneUnknown( n->input(0)->debugName()); + std::vector shape_vector_0_value(0); if (shape_vector_0.has_value()) { + shape_vector_0_value = shape_vector_0.value(); + } + if (shape_vector_0.has_value() || shape_temp.size() > 0) { auto final_shape = ComputeShapeFromReshape( - n, shape_vector_0.value(), shape_temp, opset_version); + n, shape_vector_0_value, shape_temp, opset_version); UpdateShapeFromVector(n->output(), final_shape); return; } @@ -786,6 +1014,14 @@ void ProcessSliceNode(Node* n, int opset_version) { } } +void ProcessUnchangeNode(Node* n) { + if (ConstantValueMap::HasShape(n->input(0)->debugName())) { + auto shape_size_0 = + ConstantValueMap::GetShape(n->input(0)->debugName()).value(); + UpdateShape(n->output(), shape_size_0); + } +} + void ProcessTimeSeriesNode(Node* n) { auto input0_shape = ConstantValueMap::GetShape(n->input(0)->debugName()); auto input1_shape = ConstantValueMap::GetShape(n->input(1)->debugName()); @@ -870,6 +1106,20 @@ void ComputeConstant(Node* n, int opset_version) { } switch (n->kind()) { + case ::c10::onnx::Add: + case ::c10::onnx::Div: + case ::c10::onnx::Equal: + case ::c10::onnx::Greater: + case ::c10::onnx::GreaterOrEqual: + case ::c10::onnx::Less: + case ::c10::onnx::LessOrEqual: + case ::c10::onnx::Mod: + case ::c10::onnx::Mul: + case ::c10::onnx::Pow: + case ::c10::onnx::Sub: { + ProcessBroadCastNode(n); + break; + } case ::c10::onnx::Shape: { auto input_shape = ConstantValueMap::GetShapeInto1DInt64Vector(n->input()->debugName()); @@ -885,6 +1135,10 @@ void ComputeConstant(Node* n, int opset_version) { at::Tensor f_copy = at::empty({shape_value_size}, options); f_copy.copy_(f); ConstantValueMap::SetValue(n->output()->debugName(), f_copy); + std::vector<::c10::ShapeSymbol> final_shape_vector( + 1, c10::ShapeSymbol::fromStaticSize(shape_value_size)); + ::c10::SymbolicShape final_shape(final_shape_vector); + UpdateShape(n->output(), final_shape); } break; } @@ -948,6 +1202,10 @@ void ComputeConstant(Node* n, int opset_version) { } break; } + case ::c10::onnx::Concat: { + ProcessConcatNode(n); + break; + } case ::c10::onnx::ConstantOfShape: { if (ConstantValueMap::HasValue(n->input()->debugName())) { auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector( @@ -1025,6 +1283,15 @@ void ComputeConstant(Node* n, int opset_version) { } break; } + case ::c10::onnx::MatMul: { + ProcessMatMulNode(n); + break; + } + case ::c10::onnx::ReduceMean: + case ::c10::onnx::ReduceProd: { + ProcessReduceNode(n); + break; + } case ::c10::onnx::RNN: case ::c10::onnx::LSTM: case ::c10::onnx::GRU: { @@ -1061,6 +1328,12 @@ void ComputeConstant(Node* n, int opset_version) { ProcessSliceNode(n, opset_version); break; } + case ::c10::onnx::Cast: + case ::c10::onnx::Relu: + case ::c10::onnx::Softmax: { + ProcessUnchangeNode(n); + break; + } case ::c10::onnx::Tile: { if (ConstantValueMap::HasShape(n->input(0)->debugName())) { auto input0_shape_size = @@ -1108,7 +1381,20 @@ bool IsListConstructIntType(const Value* v) { bool AllGraphInputsStatic(const Graph* g) { for (auto n : g->inputs()) { - if (!n->isCompleteTensor()) { + if (TensorTypePtr input_type = n->type()->cast()) { + if (input_type->dim()) { + auto shape = input_type->symbolic_sizes(); + if (!ConstantValueMap::HasShape(n->debugName())) { + UpdateShapeConstantValueMap(n, shape); + } + } + } + } + for (auto n : g->inputs()) { + // Some inputs can be non-Tensor type, e.g., + // __torch__.torch.classes.quantized.LinearPackedParamsBase + // so we only need check Tensor type here. + if (n->type()->cast() && !n->isCompleteTensor()) { return false; } } @@ -1410,10 +1696,108 @@ void ONNXShapeTypeInference( } // namespace +// For some operators, there are some inputs not related to shape inference. +// For example, LSTM input 4 (sequence_lens) is optional, +// and the shape inference can be done through other required inputs. +// When we compute reliable, we don't need this input be reliable. +static std::unordered_map> + non_required_shape_inference_idx_map = {{"onnx::LSTM", {4}}}; + +std::pair AreInputsReliableOrStatic(Node* n) { + auto reliable = true; + auto complete = true; + auto input_size = n->inputs().size(); + std::unordered_set non_required_idx = {}; + if (non_required_shape_inference_idx_map.find(n->kind().toDisplayString()) != + non_required_shape_inference_idx_map.end()) { + non_required_idx = + non_required_shape_inference_idx_map[n->kind().toDisplayString()]; + } + for (auto idx = 0; idx < input_size; idx++) { + if (!non_required_idx.empty() && + non_required_idx.find(idx) != non_required_idx.end()) { + continue; + } + auto input = n->inputs()[idx]; + reliable &= + ConstantValueMap::GetTypeReliable(input->debugName()).value_or(false); + if (auto pt = input->type()->cast()) { + if (!pt->sizes().isComplete()) { + complete = false; + } + } + } + return std::make_pair(reliable, complete); +} + +// There is no need to put onnx type here, but we need this +// for some legacy tests when onnx_shape_inference=False. +static std::unordered_set nodeTypeReliableForTracer = { + "prim::ListConstruct", + "onnx::Cast", + "onnx::Constant", + "onnx::Relu", + "com.microsoft::Gelu"}; + +void UpdateReliable( + torch::jit::Value* output, + const std::pair& inferred_type_reliable) { + auto inferred = + ConstantValueMap::GetUseInferredType(output->debugName()).value_or(false); + auto isTypeReliableForTracer = + nodeTypeReliableForTracer.find( + output->node()->kind().toDisplayString()) != + nodeTypeReliableForTracer.end(); + if (!inferred && !isTypeReliableForTracer && + !output->node()->kind().is_onnx()) { + std::cerr + << "WARNING: The shape inference of " + << output->node()->kind().toDisplayString() + << " type is missing, so it may result in wrong shape inference for the exported graph. " + << "Please consider adding it in symbolic function." << std::endl; + } + auto reliable = false; + if (inferred) { + reliable = inferred_type_reliable.first; + } else { + if (inferred_type_reliable.second && isTypeReliableForTracer) { + reliable = true; + } + } + // Assume that the tracer can estimate rank correctly, + // then the output tensor of Shape should always be reliable. + if (output->node()->kind() == ::c10::onnx::Shape) { + reliable = true; + } + ConstantValueMap::SetTypeReliable(output->debugName(), reliable); + if (!reliable) { + if (auto output_tensor_type = output->type()->cast()) { + output->setType(output_tensor_type->withSymbolicShapes( + ::c10::SymbolicShape(output_tensor_type->dim()))); + } + } +} + +void UpdateReliable(Node* n) { + auto input_reliable = AreInputsReliableOrStatic(n); + for (auto output : n->outputs()) { + UpdateReliable(output, input_reliable); + } +} + +void SetGraphInputTypeReliable(const Graph* g) { + for (auto graph_input : g->inputs()) { + if (!ConstantValueMap::HasTypeReliable(graph_input->debugName())) { + ConstantValueMap::SetTypeReliable(graph_input->debugName(), true); + } + } +} + void ONNXShapeTypeInference( Node* n, const ParamMap& params_dict, int opset_version) { + SetGraphInputTypeReliable(n->owningGraph()); GRAPH_UPDATE( "Running ONNX shape inference for node: ", n->kind().toDisplayString()); if (IsValidONNXNode(n)) { @@ -1476,7 +1860,36 @@ void ONNXShapeTypeInference( SpecialPostProcess(n); if (IsValidONNXNode(n)) { ProcessConstantValueMap(n, opset_version); + if (n->kind() != prim::ListConstruct) { + for (auto input : n->inputs()) { + if (input->node()->kind() == prim::ListConstruct) { + UpdateReliable(input, AreInputsReliableOrStatic(input->node())); + } + } + } } + UpdateReliable(n); + + // For the node type that does nott have ComputeConstant logic, it may have + // reliable shape but its shape is not in ConstantValueMap. So we need this + // logic to update ConstantValueMap. + for (auto node_output : n->outputs()) { + if (ConstantValueMap::HasTypeReliable(node_output->debugName())) { + auto reliable = + ConstantValueMap::GetTypeReliable(node_output->debugName()) + .value_or(false); + if (reliable && !ConstantValueMap::HasShape(node_output->debugName())) { + // TODO: ListType case + if (auto output_tensor_type = node_output->type()->cast()) { + if (output_tensor_type->dim()) { + auto symbolic_sizes = output_tensor_type->symbolic_sizes(); + UpdateShapeConstantValueMap(node_output, symbolic_sizes); + } + } + } + } + } + GRAPH_DEBUG( "Torch graph after shape inference:", n->owningGraph()->toString()); } @@ -1549,8 +1962,8 @@ void ONNXUpdateTypeFromTensor( const at::Tensor& output, bool onnx_shape_inference) { if (onnx_shape_inference) { - graph_output->setType( - MergeInferredType(TensorType::create(output), graph_output->type())); + MergeInferredTypeAndSetMap( + graph_output, TensorType::create(output), graph_output->type()); } else { graph_output->inferTypeFrom(output); } @@ -1615,11 +2028,9 @@ size_t ONNXAssignOutputShape( ->getElementType() ->cast(); elem_type = elem_type->withScalarType(var.scalar_type()); - graph->outputs() - .at(outputs_index) - ->setType(MergeInferredType( - graph->outputs().at(outputs_index)->type(), - ListType::create(elem_type))); + auto graph_output = graph->outputs().at(outputs_index); + MergeInferredTypeAndSetMap( + graph_output, graph_output->type(), ListType::create(elem_type)); } else { graph->outputs() .at(outputs_index) @@ -1696,6 +2107,7 @@ void ONNXShapeTypeInference( const ParamMap& params_dict, int opset_version) { ConstantValueMap::ClearMaps(); + SetGraphInputTypeReliable(graph.get()); ONNXShapeTypeInference(graph->block(), params_dict, opset_version); } diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.h b/torch/csrc/jit/passes/onnx/shape_type_inference.h index 69fbff1..f4347ca 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.h +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.h @@ -7,8 +7,11 @@ namespace torch { namespace jit { -TORCH_API TypePtr -MergeInferredType(TypePtr existing_type, TypePtr inferred_type); +void MergeInferredTypeAndSetMap( + Value* dest_v, + TypePtr existing_type, + TypePtr inferred_type, + bool set_constant_value_map = true); // Update graph input types with dynamic axes info. // Axes that are marked as dynamic will be assigned as dynamic ShapeSymbol. @@ -49,5 +52,10 @@ TORCH_API void ONNXShapeTypeInference( const ParamMap& params_dict, int opset_version); +std::pair AreInputsReliableOrStatic(Node* n); +void UpdateReliable( + torch::jit::Value* output, + const std::pair& input_reliable); + } // namespace jit } // namespace torch diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index b726b2b..01c98b5 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -379,3 +379,8 @@ def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version): """ from torch.onnx import utils utils.register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version) + + +def unregister_custom_op_symbolic(symbolic_name, opset_version): + from torch.onnx import utils + utils.unregister_custom_op_symbolic(symbolic_name, opset_version) diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 53440f1..09b8ae5 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -432,18 +432,23 @@ def unbind(g, self, dim=0, _outputs=None): # Generate paddings in ONNX order based on pad in pytorch. # Args: -# dim: the dimension of the tensor. +# input: the input tensor. # pad: the paddings in pytorch. # The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end, # where m is in range [0, n]. -def _prepare_onnx_paddings(g, dim, pad): +def _prepare_onnx_paddings(g, input, pad): # The desired order of paddings is # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end. # n is the dimension of input. # Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning pad_len = torch.onnx.symbolic_opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0]))) # Set extension = [0] * (dim * 2 - len(pad)) - extension = g.op("Sub", g.op("Mul", g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int64)), + rank = sym_help._get_tensor_rank(input) + if rank is None: + rank = g.op("Size", g.op("Shape", input)) + else: + rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64)) + extension = g.op("Sub", g.op("Mul", rank, g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))), pad_len) # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ] # Currently ONNX only supports int64 type for Pad @@ -464,19 +469,19 @@ def constant_pad_nd(g, input, padding, value=None): mode = "constant" value = sym_help._maybe_get_scalar(value) value = sym_help._if_scalar_type_as(g, value, input) - pad = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding) + pad = _prepare_onnx_paddings(g, input, padding) return g.op("Pad", input, pad, value, mode_s=mode) def reflection_pad(g, input, padding): mode = "reflect" - paddings = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding) + paddings = _prepare_onnx_paddings(g, input, padding) return g.op("Pad", input, paddings, mode_s=mode) def replication_pad(g, input, padding): mode = "edge" - paddings = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding) + paddings = _prepare_onnx_paddings(g, input, padding) return g.op("Pad", input, paddings, mode_s=mode) diff --git a/torch/onnx/symbolic_registry.py b/torch/onnx/symbolic_registry.py index fd3f6af..ebd379f 100644 --- a/torch/onnx/symbolic_registry.py +++ b/torch/onnx/symbolic_registry.py @@ -93,6 +93,15 @@ def is_registered_op(opname, domain, version): global _registry return (domain, version) in _registry and opname in _registry[(domain, version)] +def unregister_op(opname, domain, version): + global _registry + if is_registered_op(opname, domain, version): + del _registry[(domain, version)][opname] + if not _registry[(domain, version)]: + del _registry[(domain, version)] + else: + warnings.warn("The opname " + opname + " is not registered.") + def get_op_supported_version(opname, domain, version): iter_version = version while iter_version <= _onnx_main_opset: diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 0b447a9..8ee0eca 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -198,8 +198,6 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa torch._C._jit_pass_onnx_scalar_type_analysis(graph, True, _export_onnx_opset_version) torch._C._jit_pass_lint(graph) - torch._C._jit_pass_onnx_fold_if(graph) - torch._C._jit_pass_onnx_peephole(graph, _export_onnx_opset_version, fixed_batch_size) torch._C._jit_pass_lint(graph) @@ -530,7 +528,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None, google_printer=False, opset_version=None, _retain_param_name=True, keep_initializers_as_inputs=None, custom_opsets=None, add_node_names=True, - do_constant_folding=True): + do_constant_folding=True, dynamic_axes=None): return _export_to_pretty_string(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, export_type, example_outputs, google_printer, @@ -538,7 +536,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t do_constant_folding=do_constant_folding, add_node_names=add_node_names, keep_initializers_as_inputs=keep_initializers_as_inputs, - custom_opsets=custom_opsets) + custom_opsets=custom_opsets, dynamic_axes=dynamic_axes) def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=None, @@ -547,7 +545,7 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, google_printer=False, opset_version=None, _retain_param_name=False, do_constant_folding=True, keep_initializers_as_inputs=None, fixed_batch_size=False, custom_opsets=None, add_node_names=True, - onnx_shape_inference=True): + onnx_shape_inference=True, dynamic_axes=None): from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version from torch.onnx.symbolic_helper import _set_operator_export_type if opset_version is None: @@ -569,7 +567,7 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, output_names, operator_export_type, example_outputs, _retain_param_name, val_do_constant_folding, fixed_batch_size=fixed_batch_size, - training=training) + training=training, dynamic_axes=dynamic_axes) return graph._pretty_print_onnx(params_dict, opset_version, False, operator_export_type, google_printer, @@ -1187,7 +1185,7 @@ def _node_getitem(self, k): return getattr(self, sel)(k) -def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version): +def get_ns_op_name_from_custom_op(symbolic_name): if not bool(re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name)): raise RuntimeError("Failed to register operator {}. \ The symbolic name must match the format Domain::Name, \ @@ -1199,6 +1197,15 @@ def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version): if ns in unaccepted_domain_names: raise RuntimeError("Failed to register operator {}. The domain {} is already a used domain." .format(symbolic_name, ns)) + return ns, op_name + + +# When the user registers symbolic for custom/contrib ops, +# it is highly recommended to add shape inference for that operator via setType API, +# otherwise the exported graph may have incorrect shape inference in some extreme cases. +# An example of setType is test_aten_embedding_2 in test_operators.py.. +def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version): + ns, op_name = get_ns_op_name_from_custom_op(symbolic_name) import torch.onnx.symbolic_registry as sym_registry from torch.onnx.symbolic_helper import _onnx_stable_opsets, _onnx_main_opset @@ -1206,6 +1213,17 @@ def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version): if version >= opset_version: sym_registry.register_op(op_name, symbolic_fn, ns, version) + +def unregister_custom_op_symbolic(symbolic_name, opset_version): + ns, op_name = get_ns_op_name_from_custom_op(symbolic_name) + import torch.onnx.symbolic_registry as sym_registry + from torch.onnx.symbolic_helper import _onnx_stable_opsets, _onnx_main_opset + + for version in _onnx_stable_opsets + [_onnx_main_opset]: + if version >= opset_version: + sym_registry.unregister_op(op_name, ns, version) + + # This helper function ensures dynamic axes argument is following the expected format def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names): if len(dynamic_axes) == 0: -- 2.7.4