[ONNX] Enhance shape (two changes merged) (#64585)
authorBowenBao <bowbao@microsoft.com>
Wed, 15 Sep 2021 19:56:33 +0000 (12:56 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 15 Sep 2021 20:02:19 +0000 (13:02 -0700)
Summary:
Enhanced shape inference by introducing typeReliableMap.
[ONNX] exporter changes for torch hub models (https://github.com/pytorch/pytorch/issues/62856)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64585

Reviewed By: ezyang

Differential Revision: D30870418

Pulled By: msaroufim

fbshipit-source-id: 87a294799cb87d649d1d13b6114a5cfbac9be15c

Co-authored-by: jiafatom <jiafa@microsoft.com>
22 files changed:
aten/src/ATen/core/interned_strings.h
docs/source/onnx.rst
test/onnx/expect/TestOperators.test_aten_embedding_1.expect [new file with mode: 0644]
test/onnx/expect/TestOperators.test_aten_embedding_2.expect [new file with mode: 0644]
test/onnx/expect/TestOperators.test_c2_op.expect
test/onnx/expect/TestOperators.test_dynamic_axes_add.expect [new file with mode: 0644]
test/onnx/expect/TestOperators.test_dynamic_axes_matmul.expect [new file with mode: 0644]
test/onnx/expect/TestOperators.test_dynamic_axes_reduce_mean.expect [new file with mode: 0644]
test/onnx/expect/TestOperators.test_dynamic_axes_unchange.expect [new file with mode: 0644]
test/onnx/expect/TestOperators.test_lstm_none_sequence_lens.expect [new file with mode: 0644]
test/onnx/test_operators.py
test/onnx/test_pytorch_onnx_onnxruntime.py
torch/csrc/jit/passes/onnx.cpp
torch/csrc/jit/passes/onnx/constant_fold.cpp
torch/csrc/jit/passes/onnx/constant_map.cpp
torch/csrc/jit/passes/onnx/constant_map.h
torch/csrc/jit/passes/onnx/shape_type_inference.cpp
torch/csrc/jit/passes/onnx/shape_type_inference.h
torch/onnx/__init__.py
torch/onnx/symbolic_opset11.py
torch/onnx/symbolic_registry.py
torch/onnx/utils.py

index 0b12603..7ed3cf8 100644 (file)
@@ -394,12 +394,14 @@ namespace c10 {
   _(onnx, Gather)                    \
   _(onnx, Gemm)                      \
   _(onnx, LSTM)                      \
+  _(onnx, MatMul)                    \
   _(onnx, Mul)                       \
   _(onnx, Pow)                       \
   _(onnx, RNN)                       \
   _(onnx, Shape)                     \
   _(onnx, Size)                      \
   _(onnx, Slice)                     \
+  _(onnx, Softmax)                   \
   _(onnx, Squeeze)                   \
   _(onnx, Sub)                       \
   _(onnx, Transpose)                 \
@@ -435,7 +437,9 @@ namespace c10 {
   _(onnx, ReduceL2)                  \
   _(onnx, Conv)                      \
   _(onnx, BatchNormalization)        \
+  _(onnx, ReduceMean)                \
   _(onnx, ReduceProd)                \
+  _(onnx, Relu)                      \
   _(onnx, Neg)                       \
   _(onnx, NonZero)                   \
   _(onnx, Range)                     \
index eb6c2c0..8beeb55 100644 (file)
@@ -396,10 +396,16 @@ All autograd ``Function``s are emitted in the TorchScript graph as ``prim::Pytho
 In order to differentiate between different ``Function`` subclasses, the
 symbolic function should use the ``name`` kwarg which gets set to the name of the class.
 
-:func:`register_custom_op_symbolic` does does not allow registration for ops in
+:func:`register_custom_op_symbolic` does not allow registration for ops in
 the ``prim`` namespace, so for this use case, there's a back door: register the
 symbolic for ``"::prim_PythonOp"``.
 
+Please also consider adding shape inference logic when you regiester a custom symbolic function
+via setType API. This can help the exporter to obtain correct shape inference.
+An example of setType is test_aten_embedding_2 in test_operators.py.
+Although it is not required to add shape inference logic,
+the exporter emits a warning message if it is not added.
+
 The example below shows how you can access ``requires_grad`` via the ``Node`` object::
 
     class MyClip(torch.autograd.Function):
diff --git a/test/onnx/expect/TestOperators.test_aten_embedding_1.expect b/test/onnx/expect/TestOperators.test_aten_embedding_1.expect
new file mode 100644 (file)
index 0000000..317fa3a
--- /dev/null
@@ -0,0 +1,36 @@
+ir_version: 6
+producer_name: "pytorch"
+producer_version: "CURRENT_VERSION"
+graph {
+  node {
+    output: "3"
+    name: "Constant_0"
+    op_type: "Constant"
+    attribute {
+      name: "value"
+      t {
+        dims: 32
+        data_type: 1
+        raw_data: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?"
+      }
+      type: TENSOR
+    }
+  }
+  name: "torch-jit-export"
+  output {
+    name: "3"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 32
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 12
+}
diff --git a/test/onnx/expect/TestOperators.test_aten_embedding_2.expect b/test/onnx/expect/TestOperators.test_aten_embedding_2.expect
new file mode 100644 (file)
index 0000000..a5ab9d3
--- /dev/null
@@ -0,0 +1,155 @@
+ir_version: 6
+producer_name: "pytorch"
+producer_version: "CURRENT_VERSION"
+graph {
+  node {
+    input: "emb.weight"
+    input: "input_1"
+    output: "3"
+    name: "ATenOp_0"
+    op_type: "ATenOp"
+    attribute {
+      name: "custom_attributes_json"
+      s: "{\"padding_idx\":-1,\"scale_grad_by_freq\":false,\"sparse\":false}"
+      type: STRING
+    }
+    attribute {
+      name: "name"
+      s: "aten::embedding"
+      type: STRING
+    }
+    domain: "com.microsoft"
+  }
+  node {
+    input: "3"
+    input: "input_2"
+    output: "4"
+    name: "Add_1"
+    op_type: "Add"
+  }
+  node {
+    input: "4"
+    output: "5"
+    name: "Shape_2"
+    op_type: "Shape"
+  }
+  node {
+    output: "6"
+    name: "Constant_3"
+    op_type: "Constant"
+    attribute {
+      name: "value"
+      t {
+        data_type: 7
+        raw_data: "\000\000\000\000\000\000\000\000"
+      }
+      type: TENSOR
+    }
+  }
+  node {
+    input: "5"
+    input: "6"
+    output: "7"
+    name: "Gather_4"
+    op_type: "Gather"
+    attribute {
+      name: "axis"
+      i: 0
+      type: INT
+    }
+  }
+  node {
+    input: "7"
+    output: "8"
+    name: "Unsqueeze_5"
+    op_type: "Unsqueeze"
+    attribute {
+      name: "axes"
+      ints: 0
+      type: INTS
+    }
+  }
+  node {
+    input: "8"
+    output: "9"
+    name: "Concat_6"
+    op_type: "Concat"
+    attribute {
+      name: "axis"
+      i: 0
+      type: INT
+    }
+  }
+  node {
+    input: "9"
+    output: "10"
+    name: "ConstantOfShape_7"
+    op_type: "ConstantOfShape"
+    attribute {
+      name: "value"
+      t {
+        dims: 1
+        data_type: 1
+        raw_data: "\000\000\200?"
+      }
+      type: TENSOR
+    }
+  }
+  name: "torch-jit-export"
+  initializer {
+    dims: 4
+    dims: 8
+    data_type: 1
+    name: "emb.weight"
+    raw_data: "\264\314\344\275\017A\376\276\313\374&>J\266a\277s\306\\=\212\032+?\211[t\275\344[\357\276Dk\\\276OKb?\234\'B\277A\334\274\2767N\257\276\320s\263\277\371+\244>:\314\202\277K\200L??\001\275\275\236u4\2774\032\315\277\214\004\224>Z\320\372>\267B\305\276\346G6\277N\265.\276\343\316\272\277t\364a>\201)|>p\223\251\277Qm2?\346\275)\277\354\235\233?"
+  }
+  input {
+    name: "input_1"
+    type {
+      tensor_type {
+        elem_type: 7
+        shape {
+          dim {
+            dim_param: "input_1_dim_0"
+          }
+        }
+      }
+    }
+  }
+  input {
+    name: "input_2"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_param: "input_2_dim_0"
+          }
+          dim {
+            dim_param: "input_2_dim_1"
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "10"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_param: "ConstantOfShape10_dim_0"
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 12
+}
+opset_import {
+  domain: "com.microsoft"
+  version: 1
+}
index fc191b5..e9aae2c 100644 (file)
@@ -147,10 +147,10 @@ graph {
         elem_type: 1
         shape {
           dim {
-            dim_value: 0
+            dim_param: "GenerateProposals4_dim_0"
           }
           dim {
-            dim_value: 5
+            dim_param: "GenerateProposals4_dim_1"
           }
         }
       }
@@ -163,7 +163,7 @@ graph {
         elem_type: 1
         shape {
           dim {
-            dim_value: 0
+            dim_param: "GenerateProposals5_dim_0"
           }
         }
       }
diff --git a/test/onnx/expect/TestOperators.test_dynamic_axes_add.expect b/test/onnx/expect/TestOperators.test_dynamic_axes_add.expect
new file mode 100644 (file)
index 0000000..83e6e74
--- /dev/null
@@ -0,0 +1,64 @@
+ir_version: 6
+producer_name: "pytorch"
+producer_version: "CURRENT_VERSION"
+graph {
+  node {
+    input: "input_1"
+    input: "input_2"
+    output: "2"
+    name: "Add_0"
+    op_type: "Add"
+  }
+  name: "torch-jit-export"
+  input {
+    name: "input_1"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_param: "input_1_dim_1"
+          }
+        }
+      }
+    }
+  }
+  input {
+    name: "input_2"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_param: "input_2_dim_1"
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "2"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_param: "Add2_dim_1"
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 12
+}
diff --git a/test/onnx/expect/TestOperators.test_dynamic_axes_matmul.expect b/test/onnx/expect/TestOperators.test_dynamic_axes_matmul.expect
new file mode 100644 (file)
index 0000000..038d3dd
--- /dev/null
@@ -0,0 +1,73 @@
+ir_version: 6
+producer_name: "pytorch"
+producer_version: "CURRENT_VERSION"
+graph {
+  node {
+    input: "input_1"
+    input: "input_2"
+    output: "2"
+    name: "MatMul_0"
+    op_type: "MatMul"
+  }
+  name: "torch-jit-export"
+  input {
+    name: "input_1"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_param: "input_1_dim_1"
+          }
+          dim {
+            dim_value: 4
+          }
+        }
+      }
+    }
+  }
+  input {
+    name: "input_2"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 4
+          }
+          dim {
+            dim_param: "input_2_dim_2"
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "2"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_param: "input_1_dim_1"
+          }
+          dim {
+            dim_param: "input_2_dim_2"
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 12
+}
diff --git a/test/onnx/expect/TestOperators.test_dynamic_axes_reduce_mean.expect b/test/onnx/expect/TestOperators.test_dynamic_axes_reduce_mean.expect
new file mode 100644 (file)
index 0000000..24de171
--- /dev/null
@@ -0,0 +1,60 @@
+ir_version: 6
+producer_name: "pytorch"
+producer_version: "CURRENT_VERSION"
+graph {
+  node {
+    input: "input"
+    output: "1"
+    name: "ReduceMean_0"
+    op_type: "ReduceMean"
+    attribute {
+      name: "axes"
+      ints: 1
+      type: INTS
+    }
+    attribute {
+      name: "keepdims"
+      i: 0
+      type: INT
+    }
+  }
+  name: "torch-jit-export"
+  input {
+    name: "input"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_param: "input_dim_1"
+          }
+          dim {
+            dim_param: "input_dim_2"
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "1"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_param: "input_dim_2"
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 12
+}
diff --git a/test/onnx/expect/TestOperators.test_dynamic_axes_unchange.expect b/test/onnx/expect/TestOperators.test_dynamic_axes_unchange.expect
new file mode 100644 (file)
index 0000000..e304a96
--- /dev/null
@@ -0,0 +1,76 @@
+ir_version: 6
+producer_name: "pytorch"
+producer_version: "CURRENT_VERSION"
+graph {
+  node {
+    input: "input"
+    output: "1"
+    name: "Transpose_0"
+    op_type: "Transpose"
+    attribute {
+      name: "perm"
+      ints: 1
+      ints: 0
+      type: INTS
+    }
+  }
+  node {
+    input: "1"
+    output: "2"
+    name: "Softmax_1"
+    op_type: "Softmax"
+    attribute {
+      name: "axis"
+      i: 1
+      type: INT
+    }
+  }
+  node {
+    input: "2"
+    output: "3"
+    name: "Transpose_2"
+    op_type: "Transpose"
+    attribute {
+      name: "perm"
+      ints: 1
+      ints: 0
+      type: INTS
+    }
+  }
+  name: "torch-jit-export"
+  input {
+    name: "input"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_param: "input_dim_1"
+          }
+        }
+      }
+    }
+  }
+  output {
+    name: "3"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_param: "input_dim_1"
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 12
+}
diff --git a/test/onnx/expect/TestOperators.test_lstm_none_sequence_lens.expect b/test/onnx/expect/TestOperators.test_lstm_none_sequence_lens.expect
new file mode 100644 (file)
index 0000000..b544aa6
--- /dev/null
@@ -0,0 +1,44 @@
+ir_version: 6
+producer_name: "pytorch"
+producer_version: "CURRENT_VERSION"
+graph {
+  node {
+    output: "7"
+    name: "Constant_0"
+    op_type: "Constant"
+    attribute {
+      name: "value"
+      t {
+        dims: 1
+        dims: 2
+        dims: 3
+        data_type: 1
+        raw_data: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?"
+      }
+      type: TENSOR
+    }
+  }
+  name: "torch-jit-export"
+  output {
+    name: "7"
+    type {
+      tensor_type {
+        elem_type: 1
+        shape {
+          dim {
+            dim_value: 1
+          }
+          dim {
+            dim_value: 2
+          }
+          dim {
+            dim_value: 3
+          }
+        }
+      }
+    }
+  }
+}
+opset_import {
+  version: 12
+}
index b9e391b..14d4794 100644 (file)
@@ -1,8 +1,11 @@
 
-from test_pytorch_common import TestCase, run_tests, flatten, skipIfNoLapack
+from test_pytorch_common import TestCase, run_tests, flatten, skipIfNoLapack, \
+    BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
 
 import torch
 import torch.onnx
+from torch.onnx.symbolic_helper import parse_args, _get_tensor_dim_size, _get_tensor_sizes
+from torch.onnx import register_custom_op_symbolic, unregister_custom_op_symbolic
 from torch.autograd import Variable, Function
 from torch.nn import Module, functional
 import torch.nn as nn
@@ -907,6 +910,130 @@ class TestOperators(TestCase):
         y = torch.empty(3, 2, 1, dtype=torch.long).random_(5)
         self.assertONNX(torch.nn.CrossEntropyLoss(), (x, y), opset_version=12)
 
+    def test_lstm_none_sequence_lens(self):
+        """Test symbolic shape inference for LSTM when the input sequence_lens = None."""
+        input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
+        h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
+        c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
+
+        class LSTMModel(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.rnn = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)
+
+            def forward(self, x, h0, c0):
+                a, b = self.rnn(x, (h0, c0))
+                return torch.ones(b[0].shape)
+
+        self.assertONNX(LSTMModel(),
+                        (input, h0, c0), input_names=["x", "y"],
+                        dynamic_axes={"x" : {0: 'batch'}}, opset_version=12)
+
+    def test_dynamic_axes_add(self):
+        m1 = torch.randn(2, 3, requires_grad=True)
+        m2 = torch.randn(2, 1, requires_grad=True)
+        self.assertONNX(lambda x, y: torch.add(x, y), (m1, m2), input_names=["input_1", "input_2"],
+                        dynamic_axes={"input_1": {1: "dim_1"}, "input_2": {1: "dim_2"}},
+                        opset_version=12)
+
+    def test_dynamic_axes_matmul(self):
+        m1 = torch.randn(2, 2, 4, requires_grad=True)
+        m2 = torch.randn(2, 4, 3, requires_grad=True)
+        self.assertONNX(lambda x, y: torch.matmul(x, y), (m1, m2), input_names=["input_1", "input_2"],
+                        dynamic_axes={"input_1": {1: "dim_0"}, "input_2": {2: "dim_1"}},
+                        opset_version=12)
+
+    def test_dynamic_axes_reduce_mean(self):
+        m1 = torch.randn(2, 3, 4, requires_grad=True)
+        self.assertONNX(lambda x: torch.mean(x, dim=1), (m1), input_names=["input"],
+                        dynamic_axes={"input": {1: "dim_1", 2: "dim_2"}},
+                        opset_version=12)
+
+    def test_dynamic_axes_unchange(self):
+        """Test ProcessUnchangeNode in symbolic shape inference."""
+        m1 = torch.randn(2, 3, requires_grad=True)
+        self.assertONNX(lambda x: torch.softmax(x, dim=0), (m1,), input_names=["input"],
+                        dynamic_axes={"input": {1: "dim_1"}},
+                        opset_version=12)
+
+    def test_aten_embedding_1(self):
+        _onnx_opset_version = 12
+
+        @parse_args('v', 'v', 'i', 'b', 'b')
+        def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
+            custom_attributes_json = (
+                '{'
+                f'"padding_idx":{str(padding_idx)},'
+                f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},'
+                f'"sparse":{str(sparse).lower()}'
+                '}'
+            )
+            output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding',
+                          custom_attributes_json_s=custom_attributes_json)
+            return output
+
+        register_custom_op_symbolic('::embedding', embedding, _onnx_opset_version)
+
+        class Model(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.emb = torch.nn.Embedding(4, 8)
+
+            def forward(self, x, y):
+                res = self.emb(x)
+                res = res + y
+                return torch.ones(res.shape[0])
+
+        model = Model()
+        x = torch.ones(32, dtype=torch.long)
+        y = torch.randn(1, 8)
+        self.assertONNX(model, (x, y), opset_version=_onnx_opset_version)
+
+        unregister_custom_op_symbolic('::embedding', _onnx_opset_version)
+
+    # This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding.
+    def test_aten_embedding_2(self):
+        _onnx_opset_version = 12
+
+        @parse_args('v', 'v', 'i', 'b', 'b')
+        def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
+            custom_attributes_json = (
+                '{'
+                f'"padding_idx":{str(padding_idx)},'
+                f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},'
+                f'"sparse":{str(sparse).lower()}'
+                '}'
+            )
+            output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding',
+                          custom_attributes_json_s=custom_attributes_json)
+
+            # do shape inference and set it via setType
+            indices_shape = _get_tensor_sizes(indices)
+            if indices_shape is not None and hasattr(weight.type(), 'with_sizes'):
+                output_type = weight.type().with_sizes(indices_shape + [_get_tensor_dim_size(weight, 1)])
+                output.setType(output_type)
+            return output
+
+        register_custom_op_symbolic('::embedding', embedding, _onnx_opset_version)
+
+        class Model(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.emb = torch.nn.Embedding(4, 8)
+
+            def forward(self, x, y):
+                res = self.emb(x)
+                res = res + y
+                return torch.ones(res.shape[0])
+
+        model = Model()
+        x = torch.ones(32, dtype=torch.long)
+        y = torch.randn(1, 8)
+        self.assertONNX(model, (x, y), opset_version=_onnx_opset_version, input_names=['input_1', 'input_2'],
+                        dynamic_axes={"input_1": {0: "dim_0"}, 'input_2': {0: "dim_1", 1: "dim_2"}})
+
+        unregister_custom_op_symbolic('::embedding', _onnx_opset_version)
+
 if __name__ == "__main__":
     no_onnx_dep_flag = "--no-onnx"
     _onnx_dep = no_onnx_dep_flag not in common.UNITTEST_ARGS
index 54a116b..60eac0d 100644 (file)
@@ -47,7 +47,7 @@ def convert_to_onnx(model, input=None, opset_version=9, example_outputs=None,
                     do_constant_folding=True, keep_initializers_as_inputs=True,
                     dynamic_axes=None, input_names=None, output_names=None,
                     fixed_batch_size=False, training=None,
-                    onnx_shape_inference=False):
+                    onnx_shape_inference=True):
     # export the model to ONNX
     f = io.BytesIO()
     input_copy = copy.deepcopy(input)
@@ -8582,8 +8582,13 @@ class TestONNXRuntime(unittest.TestCase):
         random_state = torch.rand((1, 1, 10, 30, 30))
         self.run_test(model, (random_data, empty_tensor),
                       input_names=["data", "state"],
-                      dynamic_axes={"state": [0, 1, 2, 3, 4]},
+                      dynamic_axes={"data": [0, 1, 2], "state": [0, 1, 2, 3, 4]},
                       test_with_inputs=[(random_data, random_state)])
+        self.run_test(model, (random_data, empty_tensor),
+                      input_names=["data", "state"],
+                      dynamic_axes={"state": [0, 1, 2, 3, 4]},
+                      test_with_inputs=[(random_data, random_state)],
+                      remained_onnx_input_idx=[1])
         self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
 
     @skipIfUnsupportedMinOpsetVersion(11)
index 0697f89..03e894c 100644 (file)
@@ -219,6 +219,14 @@ std::unordered_map<Value*, Value*> BlockToONNX(
   return {};
 }
 
+bool ConstantFoldCondition(torch::jit::Value* output) {
+  auto fold_condition = output->node()->kind() != c10::onnx::Constant &&
+      ConstantValueMap::HasValue(output->debugName());
+  auto reliable_value =
+      ConstantValueMap::GetTypeReliable(output->debugName()).value_or(false);
+  return fold_condition && reliable_value;
+}
+
 void NodeToONNX(
     Node* old_node,
     Block* new_block,
@@ -267,8 +275,7 @@ void NodeToONNX(
         //
         // If onnx shape inference is turned on, the new outputs will have
         // types inferred, and they will be merged with the old types.
-        if (outputs[i]->node()->kind() != c10::onnx::Constant &&
-            ConstantValueMap::HasValue(outputs[i]->debugName())) {
+        if (ConstantFoldCondition(outputs[i])) {
           // Create a const node if the node output value is in
           // ConstantValueMap.
           auto value =
@@ -286,8 +293,10 @@ void NodeToONNX(
           ONNXShapeTypeInference(const_node, empty_params_dict, opset_version);
           env[old] = const_node->output();
         } else {
-          outputs[i]->setType(
-              MergeInferredType(old->type(), outputs[i]->type()));
+          // ConstantValueMap has been set in shape inference,
+          // set_constant_value_map = false here to avoid redundancy.
+          MergeInferredTypeAndSetMap(
+              outputs[i], old->type(), outputs[i]->type(), false);
 
           // Copy over source location and scope information to all nodes
           // created by the symbolic
index 76c0674..cce5a43 100644 (file)
@@ -489,6 +489,12 @@ c10::optional<at::Tensor> runTorchBackendForOnnx(
   } else if (node->kind() == onnx::Equal) {
     updated_val = at::eq(inputTensorValues[0], inputTensorValues[1]);
     return c10::optional<at::Tensor>(updated_val);
+  } else if (node->kind() == onnx::Greater) {
+    updated_val = at::greater(inputTensorValues[0], inputTensorValues[1]);
+    return c10::optional<at::Tensor>(updated_val);
+  } else if (node->kind() == onnx::Less) {
+    updated_val = at::less(inputTensorValues[0], inputTensorValues[1]);
+    return c10::optional<at::Tensor>(updated_val);
   } else if (node->kind() == onnx::Neg) {
     updated_val = at::neg(inputTensorValues[0]);
     return c10::optional<at::Tensor>(updated_val);
index 8cbec27..69fe70a 100644 (file)
@@ -24,6 +24,7 @@ void ConstantValueMap::SetRank(
     const std::string& tensorName,
     size_t rankValue) {
   ConstantValueMap::getInstance().rankMap.emplace(tensorName, rankValue);
+  ConstantValueMap::getInstance().useInferredTypeMap.emplace(tensorName, true);
 }
 
 bool ConstantValueMap::HasRank(const std::string& tensorName) {
@@ -42,6 +43,7 @@ void ConstantValueMap::SetShape(
     const std::string& tensorName,
     const c10::SymbolicShape& shapeValue) {
   ConstantValueMap::getInstance().shapeMap.emplace(tensorName, shapeValue);
+  ConstantValueMap::getInstance().useInferredTypeMap.emplace(tensorName, true);
 }
 
 bool ConstantValueMap::HasShape(const std::string& tensorName) {
@@ -146,10 +148,50 @@ std::vector<int64_t> ConstantValueMap::GetValueInto1DInt64Vector(
   return value_vector;
 }
 
+void ConstantValueMap::SetTypeReliable(
+    const std::string& tensorName,
+    bool value) {
+  ConstantValueMap::getInstance().typeReliableMap.emplace(tensorName, value);
+}
+
+bool ConstantValueMap::HasTypeReliable(const std::string& tensorName) {
+  return ConstantValueMap::getInstance().typeReliableMap.find(tensorName) !=
+      ConstantValueMap::getInstance().typeReliableMap.end();
+}
+
+c10::optional<bool> ConstantValueMap::GetTypeReliable(
+    const std::string& tensorName) {
+  if (!HasTypeReliable(tensorName)) {
+    return c10::nullopt;
+  }
+  return ConstantValueMap::getInstance().typeReliableMap[tensorName];
+}
+
+void ConstantValueMap::SetUseInferredType(
+    const std::string& tensorName,
+    bool value) {
+  ConstantValueMap::getInstance().useInferredTypeMap.emplace(tensorName, value);
+}
+
+bool ConstantValueMap::HasUseInferredType(const std::string& tensorName) {
+  return ConstantValueMap::getInstance().useInferredTypeMap.find(tensorName) !=
+      ConstantValueMap::getInstance().useInferredTypeMap.end();
+}
+
+c10::optional<bool> ConstantValueMap::GetUseInferredType(
+    const std::string& tensorName) {
+  if (!HasUseInferredType(tensorName)) {
+    return c10::nullopt;
+  }
+  return ConstantValueMap::getInstance().useInferredTypeMap[tensorName];
+}
+
 void ConstantValueMap::ClearMaps() {
   ConstantValueMap::getInstance().rankMap.clear();
   ConstantValueMap::getInstance().shapeMap.clear();
   ConstantValueMap::getInstance().tensorValueMap.clear();
+  ConstantValueMap::getInstance().typeReliableMap.clear();
+  ConstantValueMap::getInstance().useInferredTypeMap.clear();
 }
 
 // For debug only.
@@ -179,6 +221,26 @@ void ConstantValueMap::PrintMaps() {
   for (const auto& x : ConstantValueMap::getInstance().tensorValueMap) {
     std::cout << "node " << x.first << ": " << x.second << std::endl;
   }
+  std::cout << std::endl;
+  std::cout << "Print TypeReliable Maps:" << std::endl;
+  size_t count = 0;
+  for (const auto& x : ConstantValueMap::getInstance().typeReliableMap) {
+    std::cout << "(node " << x.first << ": " << x.second << "), ";
+    count++;
+    if (count % 10 == 0) {
+      std::cout << std::endl;
+    }
+  }
+  std::cout << std::endl;
+  std::cout << "Print UseInferredType Maps:" << std::endl;
+  count = 0;
+  for (const auto& x : ConstantValueMap::getInstance().useInferredTypeMap) {
+    std::cout << "(node " << x.first << ": " << x.second << "), ";
+    count++;
+    if (count % 10 == 0) {
+      std::cout << std::endl;
+    }
+  }
 }
 
 } // namespace jit
index 97fa140..ab71557 100644 (file)
@@ -35,6 +35,16 @@ class ConstantValueMap {
   static std::vector<int64_t> GetValueInto1DInt64Vector(
       const std::string& value_name);
 
+  static void SetTypeReliable(const std::string& tensorName, bool reliable);
+  static bool HasTypeReliable(const std::string& tensorName);
+  static c10::optional<bool> GetTypeReliable(const std::string& tensorName);
+
+  static void SetUseInferredType(
+      const std::string& tensorName,
+      bool useInferredType);
+  static bool HasUseInferredType(const std::string& tensorName);
+  static c10::optional<bool> GetUseInferredType(const std::string& tensorName);
+
   static void PrintMaps();
   static void ClearMaps();
   ~ConstantValueMap() = default;
@@ -47,6 +57,11 @@ class ConstantValueMap {
   std::unordered_map<std::string, size_t> rankMap;
   std::unordered_map<std::string, c10::SymbolicShape> shapeMap;
   std::unordered_map<std::string, at::Tensor> tensorValueMap;
+  // This map indicates whether the current type is reliably estimated or not.
+  std::unordered_map<std::string, bool> typeReliableMap;
+  // This map indicates whether the current type is estimated through inference
+  // or tracer.
+  std::unordered_map<std::string, bool> useInferredTypeMap;
 };
 
 } // namespace jit
index 8ade722..5760c48 100644 (file)
@@ -16,6 +16,7 @@
 #include <onnx/shape_inference/implementation.h>
 #include <algorithm>
 #include <cmath>
+#include <unordered_set>
 
 namespace torch {
 namespace jit {
@@ -36,10 +37,13 @@ namespace jit {
 //  3. existing type: Scalar[], inferred type: Tensor
 //    ONNX represents list of scalars by 1-d Tensor. Return inferred type since
 //    it is more compatible with ONNX.
-TypePtr MergeInferredType(TypePtr existing_type, TypePtr inferred_type) {
+std::pair<TypePtr, bool> MergeInferredType(
+    TypePtr existing_type,
+    TypePtr inferred_type) {
   auto new_list_type = inferred_type->cast<ListType>();
+  auto use_inferred_type = false;
   if (new_list_type) {
-    return inferred_type;
+    return std::make_pair(inferred_type, true);
   }
   auto new_tensor_type = inferred_type->cast<TensorType>();
   auto old_tensor_type = existing_type->cast<TensorType>();
@@ -47,32 +51,49 @@ TypePtr MergeInferredType(TypePtr existing_type, TypePtr inferred_type) {
   if (new_tensor_type && old_tensor_type) {
     if (!old_tensor_type->device()) {
       // device not available means this is an invalid tensor type (most likely
-      // an empty one) -> return inferred type directly.
-      return new_tensor_type;
+      // an empty one) return inferred type directly.
+      return std::make_pair(new_tensor_type, true);
     }
     auto type = old_tensor_type;
     if (new_tensor_type->dim()) {
       type = type->withSymbolicShapes(new_tensor_type->symbolic_sizes());
+      use_inferred_type = true;
     }
     if (new_tensor_type->scalarType().has_value()) {
       type = type->withScalarType(new_tensor_type->scalarType());
+      use_inferred_type = true;
     }
-    return type;
+    return std::make_pair(type, use_inferred_type);
   }
 
   if (old_tensor_type) {
-    return existing_type;
+    return std::make_pair(existing_type, false);
   }
 
   auto old_list_type = existing_type->cast<ListType>();
   if (new_tensor_type && old_list_type) {
     if (new_tensor_type->sizes().isComplete()) {
-      return inferred_type;
+      return std::make_pair(inferred_type, true);
     }
-    return existing_type;
+    return std::make_pair(existing_type, false);
   }
 
-  return inferred_type;
+  return std::make_pair(inferred_type, true);
+}
+
+void MergeInferredTypeAndSetMap(
+    Value* dest_v,
+    TypePtr existing_type,
+    TypePtr inferred_type,
+    bool set_constant_value_map) {
+  TypePtr mergedType;
+  bool inferred;
+  std::tie(mergedType, inferred) =
+      MergeInferredType(existing_type, inferred_type);
+  dest_v->setType(mergedType);
+  if (set_constant_value_map) {
+    ConstantValueMap::SetUseInferredType(dest_v->debugName(), inferred);
+  }
 }
 
 namespace {
@@ -123,6 +144,7 @@ TensorTypePtr TorchTensorTypeFromONNX(
           // Assign a new Symbol, no need to keep track
           // of it because there won't be duplicates.
           sym = c10::ShapeSymbol::newSymbol();
+          symbol_map[sym.value()] = "";
         }
         sizes.emplace_back(sym.value());
       }
@@ -170,13 +192,13 @@ void UpdateTorchValueByOnnxValueInfo(
     const auto torch_tensor_type =
         TorchTensorTypeFromONNX(p_type.tensor_type(), symbol_map);
     if (torch_tensor_type) {
-      v->setType(MergeInferredType(v->type(), torch_tensor_type));
+      MergeInferredTypeAndSetMap(v, v->type(), torch_tensor_type);
     }
   } else if (p_type.has_sequence_type()) {
     const auto torch_list_type =
         TorchListTypeFromONNX(p_type.sequence_type(), symbol_map);
     if (torch_list_type) {
-      v->setType(MergeInferredType(v->type(), torch_list_type));
+      MergeInferredTypeAndSetMap(v, v->type(), torch_list_type);
     }
   }
 }
@@ -596,6 +618,16 @@ void UpdateShape(Value* value, const ::c10::SymbolicShape& shape) {
   }
 }
 
+void UpdateShapeConstantValueMap(
+    const Value* value,
+    const ::c10::SymbolicShape& shape) {
+  ConstantValueMap::SetShape(value->debugName(), shape);
+  if (shape.rank().has_value()) {
+    auto rank = shape.rank().value();
+    ConstantValueMap::SetRank(value->debugName(), rank);
+  }
+}
+
 c10::optional<std::vector<int64_t>> GetValueFromListConstructNode(
     Node* lc_node) {
   auto rank = lc_node->inputs().size();
@@ -618,6 +650,198 @@ c10::optional<std::vector<int64_t>> GetValueFromListConstructNode(
       : c10::nullopt;
 }
 
+void ProcessBroadCastNode(Node* n) {
+  TORCH_INTERNAL_ASSERT(n->inputs().size() == 2);
+  if (ConstantValueMap::HasShape(n->input(0)->debugName()) &&
+      ConstantValueMap::HasShape(n->input(1)->debugName())) {
+    auto input_shape_0 = ConstantValueMap::GetShape(n->input(0)->debugName());
+    auto input_shape_value_0 = input_shape_0.value().sizes();
+    auto input_shape_1 = ConstantValueMap::GetShape(n->input(1)->debugName());
+    auto input_shape_value_1 = input_shape_1.value().sizes();
+    size_t rank_0 = input_shape_value_0.value().size();
+    size_t rank_1 = input_shape_value_1.value().size();
+    size_t rank_max = std::max(rank_0, rank_1);
+    size_t rank_min = std::min(rank_0, rank_1);
+    std::vector<::c10::ShapeSymbol> final_shape;
+    final_shape.reserve(rank_max);
+    for (auto idx = 0; idx < rank_max; idx++) {
+      final_shape.emplace_back(::c10::ShapeSymbol::newSymbol());
+    }
+    for (auto idx = 0; idx < rank_min; idx++) {
+      auto is_static_0 =
+          input_shape_value_0.value()[rank_0 - 1 - idx].is_static();
+      auto is_static_1 =
+          input_shape_value_1.value()[rank_1 - 1 - idx].is_static();
+      if (is_static_0 && is_static_1) {
+        auto static_0_sz =
+            input_shape_value_0.value()[rank_0 - 1 - idx].static_size();
+        auto static_1_sz =
+            input_shape_value_1.value()[rank_1 - 1 - idx].static_size();
+        final_shape[rank_max - 1 - idx] = ::c10::ShapeSymbol::fromStaticSize(
+            std::max(static_0_sz, static_1_sz));
+      }
+    }
+
+    if (rank_0 < rank_1) {
+      for (auto idx = rank_min; idx < rank_max; idx++) {
+        auto shape_idx = rank_max - 1 - idx;
+        final_shape[shape_idx] = input_shape_value_1.value()[shape_idx];
+      }
+    } else {
+      for (auto idx = rank_min; idx < rank_max; idx++) {
+        auto shape_idx = rank_max - 1 - idx;
+        final_shape[shape_idx] = input_shape_value_0.value()[shape_idx];
+      }
+    }
+
+    UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
+  }
+}
+
+void ProcessConcatNode(Node* n) {
+  int axis = n->i(attr::axis);
+  if (ConstantValueMap::HasRank(n->input(0)->debugName())) {
+    auto rank = ConstantValueMap::GetRank(n->input(0)->debugName()).value();
+    size_t axis_adjust = 0;
+    if (axis >= 0) {
+      axis_adjust = static_cast<size_t>(axis);
+    } else {
+      axis_adjust = static_cast<size_t>(axis + static_cast<int>(rank));
+    }
+    std::vector<::c10::ShapeSymbol> final_shape;
+    final_shape.reserve(rank);
+    for (auto idx = 0; idx < rank; idx++) {
+      if (idx == axis_adjust) {
+        auto flag = true;
+        int64_t size_total = 0;
+        for (auto input_idx = 0; input_idx < n->inputs().size(); input_idx++) {
+          if (ConstantValueMap::HasShape(n->input(input_idx)->debugName())) {
+            auto input_shape =
+                ConstantValueMap::GetShape(n->input(input_idx)->debugName());
+            auto input_shape_value = input_shape.value().sizes();
+            auto shape_symbol = input_shape_value.value()[idx];
+            if (shape_symbol.is_static()) {
+              size_total += shape_symbol.static_size();
+            } else {
+              flag = false;
+              break;
+            }
+          }
+        }
+        if (flag) {
+          final_shape.emplace_back(
+              ::c10::ShapeSymbol::fromStaticSize(size_total));
+        } else {
+          final_shape.emplace_back(::c10::ShapeSymbol::newSymbol());
+        }
+      } else {
+        auto flag = false;
+        for (auto input_idx = 0; input_idx < n->inputs().size(); input_idx++) {
+          if (ConstantValueMap::HasShape(n->input(input_idx)->debugName())) {
+            auto input_shape =
+                ConstantValueMap::GetShape(n->input(input_idx)->debugName());
+            auto input_shape_value = input_shape.value().sizes();
+            auto shape_symbol = input_shape_value.value()[idx];
+            if (shape_symbol.is_static()) {
+              final_shape.emplace_back(::c10::ShapeSymbol::fromStaticSize(
+                  shape_symbol.static_size()));
+              flag = true;
+              break;
+            }
+          }
+        }
+        if (!flag) {
+          final_shape.emplace_back(::c10::ShapeSymbol::newSymbol());
+        }
+      }
+    }
+    UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
+  }
+}
+
+void ProcessMatMulNode(Node* n) {
+  if (ConstantValueMap::HasShape(n->input(0)->debugName()) &&
+      ConstantValueMap::HasShape(n->input(1)->debugName())) {
+    auto input_shape_0 =
+        ConstantValueMap::GetShape(n->input(0)->debugName()).value();
+    auto input_shape_value_0 = input_shape_0.sizes().value();
+    auto input_shape_1 =
+        ConstantValueMap::GetShape(n->input(1)->debugName()).value();
+    auto input_shape_value_1 = input_shape_1.sizes().value();
+    size_t rank_0 = input_shape_value_0.size();
+    size_t rank_1 = input_shape_value_1.size();
+    auto is_rank_0_1 = false;
+    if (rank_0 == 1) {
+      input_shape_value_0.insert(
+          input_shape_value_0.begin(), ::c10::ShapeSymbol::fromStaticSize(1));
+      rank_0 = 2;
+      is_rank_0_1 = true;
+    }
+    auto is_rank_1_1 = false;
+    if (rank_1 == 1) {
+      input_shape_value_1.emplace_back(::c10::ShapeSymbol::fromStaticSize(1));
+      rank_1 = 2;
+      is_rank_1_1 = true;
+    }
+    size_t rank = std::max(rank_0, rank_1);
+    std::vector<::c10::ShapeSymbol> final_shape;
+    final_shape.reserve(rank);
+    if (rank_0 >= rank_1) {
+      for (auto idx = 0; idx < rank_0 - 2; idx++) {
+        final_shape.emplace_back(input_shape_value_0[idx]);
+      }
+    } else {
+      for (auto idx = 0; idx < rank_1 - 2; idx++) {
+        final_shape.emplace_back(input_shape_value_1[idx]);
+      }
+    }
+    final_shape.emplace_back(input_shape_value_0[rank_0 - 2]);
+    final_shape.emplace_back(input_shape_value_1[rank_1 - 1]);
+    if (is_rank_0_1) {
+      final_shape.erase(final_shape.begin());
+    }
+    if (is_rank_1_1) {
+      final_shape.pop_back();
+    }
+    UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
+  }
+}
+
+void ProcessReduceNode(Node* n) {
+  if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
+    auto input_shape_0 = ConstantValueMap::GetShape(n->input(0)->debugName());
+    auto input_shape_value_0 = input_shape_0.value().sizes();
+    size_t rank_0 = input_shape_value_0.value().size();
+    std::vector<::c10::ShapeSymbol> final_shape;
+    if (!n->hasAttributeS("axes")) {
+      UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
+      return;
+    }
+    final_shape.reserve(rank_0);
+    std::vector<int64_t> axes_vector = n->is(attr::axes);
+    for (auto idx = 0; idx < axes_vector.size(); idx++) {
+      if (axes_vector[idx] < 0) {
+        axes_vector[idx] += rank_0;
+      }
+    }
+    int64_t keepdims = 0;
+    if (n->hasAttributeS("keepdims")) {
+      keepdims = n->i(attr::keepdims);
+    }
+    for (auto idx = 0; idx < rank_0; idx++) {
+      auto it = std::find(axes_vector.begin(), axes_vector.end(), idx);
+      if (it != axes_vector.end()) {
+        if (keepdims != 0) {
+          final_shape.emplace_back(::c10::ShapeSymbol::fromStaticSize(1));
+        }
+      } else {
+        final_shape.emplace_back(input_shape_value_0.value()[idx]);
+      }
+    }
+    UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
+  }
+}
+
 void ProcessReshapeNode(Node* n, int opset_version) {
   if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
     auto shape_temp =
@@ -625,9 +849,13 @@ void ProcessReshapeNode(Node* n, int opset_version) {
     auto shape_vector_0 =
         ConstantValueMap::GetShapeInto1DInt64VectorWithOneUnknown(
             n->input(0)->debugName());
+    std::vector<int64_t> shape_vector_0_value(0);
     if (shape_vector_0.has_value()) {
+      shape_vector_0_value = shape_vector_0.value();
+    }
+    if (shape_vector_0.has_value() || shape_temp.size() > 0) {
       auto final_shape = ComputeShapeFromReshape(
-          n, shape_vector_0.value(), shape_temp, opset_version);
+          n, shape_vector_0_value, shape_temp, opset_version);
       UpdateShapeFromVector(n->output(), final_shape);
       return;
     }
@@ -786,6 +1014,14 @@ void ProcessSliceNode(Node* n, int opset_version) {
   }
 }
 
+void ProcessUnchangeNode(Node* n) {
+  if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
+    auto shape_size_0 =
+        ConstantValueMap::GetShape(n->input(0)->debugName()).value();
+    UpdateShape(n->output(), shape_size_0);
+  }
+}
+
 void ProcessTimeSeriesNode(Node* n) {
   auto input0_shape = ConstantValueMap::GetShape(n->input(0)->debugName());
   auto input1_shape = ConstantValueMap::GetShape(n->input(1)->debugName());
@@ -870,6 +1106,20 @@ void ComputeConstant(Node* n, int opset_version) {
   }
 
   switch (n->kind()) {
+    case ::c10::onnx::Add:
+    case ::c10::onnx::Div:
+    case ::c10::onnx::Equal:
+    case ::c10::onnx::Greater:
+    case ::c10::onnx::GreaterOrEqual:
+    case ::c10::onnx::Less:
+    case ::c10::onnx::LessOrEqual:
+    case ::c10::onnx::Mod:
+    case ::c10::onnx::Mul:
+    case ::c10::onnx::Pow:
+    case ::c10::onnx::Sub: {
+      ProcessBroadCastNode(n);
+      break;
+    }
     case ::c10::onnx::Shape: {
       auto input_shape =
           ConstantValueMap::GetShapeInto1DInt64Vector(n->input()->debugName());
@@ -885,6 +1135,10 @@ void ComputeConstant(Node* n, int opset_version) {
         at::Tensor f_copy = at::empty({shape_value_size}, options);
         f_copy.copy_(f);
         ConstantValueMap::SetValue(n->output()->debugName(), f_copy);
+        std::vector<::c10::ShapeSymbol> final_shape_vector(
+            1, c10::ShapeSymbol::fromStaticSize(shape_value_size));
+        ::c10::SymbolicShape final_shape(final_shape_vector);
+        UpdateShape(n->output(), final_shape);
       }
       break;
     }
@@ -948,6 +1202,10 @@ void ComputeConstant(Node* n, int opset_version) {
       }
       break;
     }
+    case ::c10::onnx::Concat: {
+      ProcessConcatNode(n);
+      break;
+    }
     case ::c10::onnx::ConstantOfShape: {
       if (ConstantValueMap::HasValue(n->input()->debugName())) {
         auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
@@ -1025,6 +1283,15 @@ void ComputeConstant(Node* n, int opset_version) {
       }
       break;
     }
+    case ::c10::onnx::MatMul: {
+      ProcessMatMulNode(n);
+      break;
+    }
+    case ::c10::onnx::ReduceMean:
+    case ::c10::onnx::ReduceProd: {
+      ProcessReduceNode(n);
+      break;
+    }
     case ::c10::onnx::RNN:
     case ::c10::onnx::LSTM:
     case ::c10::onnx::GRU: {
@@ -1061,6 +1328,12 @@ void ComputeConstant(Node* n, int opset_version) {
       ProcessSliceNode(n, opset_version);
       break;
     }
+    case ::c10::onnx::Cast:
+    case ::c10::onnx::Relu:
+    case ::c10::onnx::Softmax: {
+      ProcessUnchangeNode(n);
+      break;
+    }
     case ::c10::onnx::Tile: {
       if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
         auto input0_shape_size =
@@ -1108,7 +1381,20 @@ bool IsListConstructIntType(const Value* v) {
 
 bool AllGraphInputsStatic(const Graph* g) {
   for (auto n : g->inputs()) {
-    if (!n->isCompleteTensor()) {
+    if (TensorTypePtr input_type = n->type()->cast<TensorType>()) {
+      if (input_type->dim()) {
+        auto shape = input_type->symbolic_sizes();
+        if (!ConstantValueMap::HasShape(n->debugName())) {
+          UpdateShapeConstantValueMap(n, shape);
+        }
+      }
+    }
+  }
+  for (auto n : g->inputs()) {
+    // Some inputs can be non-Tensor type, e.g.,
+    // __torch__.torch.classes.quantized.LinearPackedParamsBase
+    // so we only need check Tensor type here.
+    if (n->type()->cast<TensorType>() && !n->isCompleteTensor()) {
       return false;
     }
   }
@@ -1410,10 +1696,108 @@ void ONNXShapeTypeInference(
 
 } // namespace
 
+// For some operators, there are some inputs not related to shape inference.
+// For example, LSTM input 4 (sequence_lens) is optional,
+// and the shape inference can be done through other required inputs.
+// When we compute reliable, we don't need this input be reliable.
+static std::unordered_map<std::string, std::unordered_set<int64_t>>
+    non_required_shape_inference_idx_map = {{"onnx::LSTM", {4}}};
+
+std::pair<bool, bool> AreInputsReliableOrStatic(Node* n) {
+  auto reliable = true;
+  auto complete = true;
+  auto input_size = n->inputs().size();
+  std::unordered_set<int64_t> non_required_idx = {};
+  if (non_required_shape_inference_idx_map.find(n->kind().toDisplayString()) !=
+      non_required_shape_inference_idx_map.end()) {
+    non_required_idx =
+        non_required_shape_inference_idx_map[n->kind().toDisplayString()];
+  }
+  for (auto idx = 0; idx < input_size; idx++) {
+    if (!non_required_idx.empty() &&
+        non_required_idx.find(idx) != non_required_idx.end()) {
+      continue;
+    }
+    auto input = n->inputs()[idx];
+    reliable &=
+        ConstantValueMap::GetTypeReliable(input->debugName()).value_or(false);
+    if (auto pt = input->type()->cast<TensorType>()) {
+      if (!pt->sizes().isComplete()) {
+        complete = false;
+      }
+    }
+  }
+  return std::make_pair(reliable, complete);
+}
+
+// There is no need to put onnx type here, but we need this
+// for some legacy tests when onnx_shape_inference=False.
+static std::unordered_set<std::string> nodeTypeReliableForTracer = {
+    "prim::ListConstruct",
+    "onnx::Cast",
+    "onnx::Constant",
+    "onnx::Relu",
+    "com.microsoft::Gelu"};
+
+void UpdateReliable(
+    torch::jit::Value* output,
+    const std::pair<bool, bool>& inferred_type_reliable) {
+  auto inferred =
+      ConstantValueMap::GetUseInferredType(output->debugName()).value_or(false);
+  auto isTypeReliableForTracer =
+      nodeTypeReliableForTracer.find(
+          output->node()->kind().toDisplayString()) !=
+      nodeTypeReliableForTracer.end();
+  if (!inferred && !isTypeReliableForTracer &&
+      !output->node()->kind().is_onnx()) {
+    std::cerr
+        << "WARNING: The shape inference of "
+        << output->node()->kind().toDisplayString()
+        << " type is missing, so it may result in wrong shape inference for the exported graph. "
+        << "Please consider adding it in symbolic function." << std::endl;
+  }
+  auto reliable = false;
+  if (inferred) {
+    reliable = inferred_type_reliable.first;
+  } else {
+    if (inferred_type_reliable.second && isTypeReliableForTracer) {
+      reliable = true;
+    }
+  }
+  // Assume that the tracer can estimate rank correctly,
+  // then the output tensor of Shape should always be reliable.
+  if (output->node()->kind() == ::c10::onnx::Shape) {
+    reliable = true;
+  }
+  ConstantValueMap::SetTypeReliable(output->debugName(), reliable);
+  if (!reliable) {
+    if (auto output_tensor_type = output->type()->cast<TensorType>()) {
+      output->setType(output_tensor_type->withSymbolicShapes(
+          ::c10::SymbolicShape(output_tensor_type->dim())));
+    }
+  }
+}
+
+void UpdateReliable(Node* n) {
+  auto input_reliable = AreInputsReliableOrStatic(n);
+  for (auto output : n->outputs()) {
+    UpdateReliable(output, input_reliable);
+  }
+}
+
+void SetGraphInputTypeReliable(const Graph* g) {
+  for (auto graph_input : g->inputs()) {
+    if (!ConstantValueMap::HasTypeReliable(graph_input->debugName())) {
+      ConstantValueMap::SetTypeReliable(graph_input->debugName(), true);
+    }
+  }
+}
+
 void ONNXShapeTypeInference(
     Node* n,
     const ParamMap& params_dict,
     int opset_version) {
+  SetGraphInputTypeReliable(n->owningGraph());
   GRAPH_UPDATE(
       "Running ONNX shape inference for node: ", n->kind().toDisplayString());
   if (IsValidONNXNode(n)) {
@@ -1476,7 +1860,36 @@ void ONNXShapeTypeInference(
   SpecialPostProcess(n);
   if (IsValidONNXNode(n)) {
     ProcessConstantValueMap(n, opset_version);
+    if (n->kind() != prim::ListConstruct) {
+      for (auto input : n->inputs()) {
+        if (input->node()->kind() == prim::ListConstruct) {
+          UpdateReliable(input, AreInputsReliableOrStatic(input->node()));
+        }
+      }
+    }
   }
+  UpdateReliable(n);
+
+  // For the node type that does nott have ComputeConstant logic, it may have
+  // reliable shape but its shape is not in ConstantValueMap. So we need this
+  // logic to update ConstantValueMap.
+  for (auto node_output : n->outputs()) {
+    if (ConstantValueMap::HasTypeReliable(node_output->debugName())) {
+      auto reliable =
+          ConstantValueMap::GetTypeReliable(node_output->debugName())
+              .value_or(false);
+      if (reliable && !ConstantValueMap::HasShape(node_output->debugName())) {
+        // TODO: ListType case
+        if (auto output_tensor_type = node_output->type()->cast<TensorType>()) {
+          if (output_tensor_type->dim()) {
+            auto symbolic_sizes = output_tensor_type->symbolic_sizes();
+            UpdateShapeConstantValueMap(node_output, symbolic_sizes);
+          }
+        }
+      }
+    }
+  }
+
   GRAPH_DEBUG(
       "Torch graph after shape inference:", n->owningGraph()->toString());
 }
@@ -1549,8 +1962,8 @@ void ONNXUpdateTypeFromTensor(
     const at::Tensor& output,
     bool onnx_shape_inference) {
   if (onnx_shape_inference) {
-    graph_output->setType(
-        MergeInferredType(TensorType::create(output), graph_output->type()));
+    MergeInferredTypeAndSetMap(
+        graph_output, TensorType::create(output), graph_output->type());
   } else {
     graph_output->inferTypeFrom(output);
   }
@@ -1615,11 +2028,9 @@ size_t ONNXAssignOutputShape(
                              ->getElementType()
                              ->cast<TensorType>();
         elem_type = elem_type->withScalarType(var.scalar_type());
-        graph->outputs()
-            .at(outputs_index)
-            ->setType(MergeInferredType(
-                graph->outputs().at(outputs_index)->type(),
-                ListType::create(elem_type)));
+        auto graph_output = graph->outputs().at(outputs_index);
+        MergeInferredTypeAndSetMap(
+            graph_output, graph_output->type(), ListType::create(elem_type));
       } else {
         graph->outputs()
             .at(outputs_index)
@@ -1696,6 +2107,7 @@ void ONNXShapeTypeInference(
     const ParamMap& params_dict,
     int opset_version) {
   ConstantValueMap::ClearMaps();
+  SetGraphInputTypeReliable(graph.get());
   ONNXShapeTypeInference(graph->block(), params_dict, opset_version);
 }
 
index 69fbff1..f4347ca 100644 (file)
@@ -7,8 +7,11 @@
 namespace torch {
 namespace jit {
 
-TORCH_API TypePtr
-MergeInferredType(TypePtr existing_type, TypePtr inferred_type);
+void MergeInferredTypeAndSetMap(
+    Value* dest_v,
+    TypePtr existing_type,
+    TypePtr inferred_type,
+    bool set_constant_value_map = true);
 
 // Update graph input types with dynamic axes info.
 // Axes that are marked as dynamic will be assigned as dynamic ShapeSymbol.
@@ -49,5 +52,10 @@ TORCH_API void ONNXShapeTypeInference(
     const ParamMap& params_dict,
     int opset_version);
 
+std::pair<bool, bool> AreInputsReliableOrStatic(Node* n);
+void UpdateReliable(
+    torch::jit::Value* output,
+    const std::pair<bool, bool>& input_reliable);
+
 } // namespace jit
 } // namespace torch
index b726b2b..01c98b5 100644 (file)
@@ -379,3 +379,8 @@ def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
     """
     from torch.onnx import utils
     utils.register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version)
+
+
+def unregister_custom_op_symbolic(symbolic_name, opset_version):
+    from torch.onnx import utils
+    utils.unregister_custom_op_symbolic(symbolic_name, opset_version)
index 53440f1..09b8ae5 100644 (file)
@@ -432,18 +432,23 @@ def unbind(g, self, dim=0, _outputs=None):
 
 # Generate paddings in ONNX order based on pad in pytorch.
 # Args:
-#     dim: the dimension of the tensor.
+#     input: the input tensor.
 #     pad: the paddings in pytorch.
 #          The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end,
 #          where m is in range [0, n].
-def _prepare_onnx_paddings(g, dim, pad):
+def _prepare_onnx_paddings(g, input, pad):
     # The desired order of paddings is
     # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
     # n is the dimension of input.
     # Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning
     pad_len = torch.onnx.symbolic_opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0])))
     # Set extension = [0] * (dim * 2 - len(pad))
-    extension = g.op("Sub", g.op("Mul", g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int64)),
+    rank = sym_help._get_tensor_rank(input)
+    if rank is None:
+        rank = g.op("Size", g.op("Shape", input))
+    else:
+        rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64))
+    extension = g.op("Sub", g.op("Mul", rank,
                      g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))), pad_len)
     # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ]
     # Currently ONNX only supports int64 type for Pad
@@ -464,19 +469,19 @@ def constant_pad_nd(g, input, padding, value=None):
     mode = "constant"
     value = sym_help._maybe_get_scalar(value)
     value = sym_help._if_scalar_type_as(g, value, input)
-    pad = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding)
+    pad = _prepare_onnx_paddings(g, input, padding)
     return g.op("Pad", input, pad, value, mode_s=mode)
 
 
 def reflection_pad(g, input, padding):
     mode = "reflect"
-    paddings = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding)
+    paddings = _prepare_onnx_paddings(g, input, padding)
     return g.op("Pad", input, paddings, mode_s=mode)
 
 
 def replication_pad(g, input, padding):
     mode = "edge"
-    paddings = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding)
+    paddings = _prepare_onnx_paddings(g, input, padding)
     return g.op("Pad", input, paddings, mode_s=mode)
 
 
index fd3f6af..ebd379f 100644 (file)
@@ -93,6 +93,15 @@ def is_registered_op(opname, domain, version):
     global _registry
     return (domain, version) in _registry and opname in _registry[(domain, version)]
 
+def unregister_op(opname, domain, version):
+    global _registry
+    if is_registered_op(opname, domain, version):
+        del _registry[(domain, version)][opname]
+        if not _registry[(domain, version)]:
+            del _registry[(domain, version)]
+    else:
+        warnings.warn("The opname " + opname + " is not registered.")
+
 def get_op_supported_version(opname, domain, version):
     iter_version = version
     while iter_version <= _onnx_main_opset:
index 0b447a9..8ee0eca 100644 (file)
@@ -198,8 +198,6 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
     torch._C._jit_pass_onnx_scalar_type_analysis(graph, True, _export_onnx_opset_version)
     torch._C._jit_pass_lint(graph)
 
-    torch._C._jit_pass_onnx_fold_if(graph)
-
     torch._C._jit_pass_onnx_peephole(graph, _export_onnx_opset_version, fixed_batch_size)
     torch._C._jit_pass_lint(graph)
 
@@ -530,7 +528,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t
                             export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None,
                             google_printer=False, opset_version=None, _retain_param_name=True,
                             keep_initializers_as_inputs=None, custom_opsets=None, add_node_names=True,
-                            do_constant_folding=True):
+                            do_constant_folding=True, dynamic_axes=None):
     return _export_to_pretty_string(model, args, f, export_params, verbose, training,
                                     input_names, output_names, operator_export_type,
                                     export_type, example_outputs, google_printer,
@@ -538,7 +536,7 @@ def export_to_pretty_string(model, args, f, export_params=True, verbose=False, t
                                     do_constant_folding=do_constant_folding,
                                     add_node_names=add_node_names,
                                     keep_initializers_as_inputs=keep_initializers_as_inputs,
-                                    custom_opsets=custom_opsets)
+                                    custom_opsets=custom_opsets, dynamic_axes=dynamic_axes)
 
 
 def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=None,
@@ -547,7 +545,7 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False,
                              google_printer=False, opset_version=None, _retain_param_name=False,
                              do_constant_folding=True, keep_initializers_as_inputs=None,
                              fixed_batch_size=False, custom_opsets=None, add_node_names=True,
-                             onnx_shape_inference=True):
+                             onnx_shape_inference=True, dynamic_axes=None):
     from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version
     from torch.onnx.symbolic_helper import _set_operator_export_type
     if opset_version is None:
@@ -569,7 +567,7 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False,
                                                         output_names, operator_export_type,
                                                         example_outputs, _retain_param_name,
                                                         val_do_constant_folding, fixed_batch_size=fixed_batch_size,
-                                                        training=training)
+                                                        training=training, dynamic_axes=dynamic_axes)
 
         return graph._pretty_print_onnx(params_dict, opset_version, False,
                                         operator_export_type, google_printer,
@@ -1187,7 +1185,7 @@ def _node_getitem(self, k):
     return getattr(self, sel)(k)
 
 
-def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
+def get_ns_op_name_from_custom_op(symbolic_name):
     if not bool(re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name)):
         raise RuntimeError("Failed to register operator {}. \
                            The symbolic name must match the format Domain::Name, \
@@ -1199,6 +1197,15 @@ def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
     if ns in unaccepted_domain_names:
         raise RuntimeError("Failed to register operator {}. The domain {} is already a used domain."
                            .format(symbolic_name, ns))
+    return ns, op_name
+
+
+# When the user registers symbolic for custom/contrib ops,
+# it is highly recommended to add shape inference for that operator via setType API,
+# otherwise the exported graph may have incorrect shape inference in some extreme cases.
+# An example of setType is test_aten_embedding_2 in test_operators.py..
+def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
+    ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
     import torch.onnx.symbolic_registry as sym_registry
     from torch.onnx.symbolic_helper import _onnx_stable_opsets, _onnx_main_opset
 
@@ -1206,6 +1213,17 @@ def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
         if version >= opset_version:
             sym_registry.register_op(op_name, symbolic_fn, ns, version)
 
+
+def unregister_custom_op_symbolic(symbolic_name, opset_version):
+    ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
+    import torch.onnx.symbolic_registry as sym_registry
+    from torch.onnx.symbolic_helper import _onnx_stable_opsets, _onnx_main_opset
+
+    for version in _onnx_stable_opsets + [_onnx_main_opset]:
+        if version >= opset_version:
+            sym_registry.unregister_op(op_name, ns, version)
+
+
 # This helper function ensures dynamic axes argument is following the expected format
 def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
     if len(dynamic_axes) == 0: