_(onnx, Gather) \
_(onnx, Gemm) \
_(onnx, LSTM) \
+ _(onnx, MatMul) \
_(onnx, Mul) \
_(onnx, Pow) \
_(onnx, RNN) \
_(onnx, Shape) \
_(onnx, Size) \
_(onnx, Slice) \
+ _(onnx, Softmax) \
_(onnx, Squeeze) \
_(onnx, Sub) \
_(onnx, Transpose) \
_(onnx, ReduceL2) \
_(onnx, Conv) \
_(onnx, BatchNormalization) \
+ _(onnx, ReduceMean) \
_(onnx, ReduceProd) \
+ _(onnx, Relu) \
_(onnx, Neg) \
_(onnx, NonZero) \
_(onnx, Range) \
In order to differentiate between different ``Function`` subclasses, the
symbolic function should use the ``name`` kwarg which gets set to the name of the class.
-:func:`register_custom_op_symbolic` does does not allow registration for ops in
+:func:`register_custom_op_symbolic` does not allow registration for ops in
the ``prim`` namespace, so for this use case, there's a back door: register the
symbolic for ``"::prim_PythonOp"``.
+Please also consider adding shape inference logic when you regiester a custom symbolic function
+via setType API. This can help the exporter to obtain correct shape inference.
+An example of setType is test_aten_embedding_2 in test_operators.py.
+Although it is not required to add shape inference logic,
+the exporter emits a warning message if it is not added.
+
The example below shows how you can access ``requires_grad`` via the ``Node`` object::
class MyClip(torch.autograd.Function):
--- /dev/null
+ir_version: 6
+producer_name: "pytorch"
+producer_version: "CURRENT_VERSION"
+graph {
+ node {
+ output: "3"
+ name: "Constant_0"
+ op_type: "Constant"
+ attribute {
+ name: "value"
+ t {
+ dims: 32
+ data_type: 1
+ raw_data: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?"
+ }
+ type: TENSOR
+ }
+ }
+ name: "torch-jit-export"
+ output {
+ name: "3"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 32
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 12
+}
--- /dev/null
+ir_version: 6
+producer_name: "pytorch"
+producer_version: "CURRENT_VERSION"
+graph {
+ node {
+ input: "emb.weight"
+ input: "input_1"
+ output: "3"
+ name: "ATenOp_0"
+ op_type: "ATenOp"
+ attribute {
+ name: "custom_attributes_json"
+ s: "{\"padding_idx\":-1,\"scale_grad_by_freq\":false,\"sparse\":false}"
+ type: STRING
+ }
+ attribute {
+ name: "name"
+ s: "aten::embedding"
+ type: STRING
+ }
+ domain: "com.microsoft"
+ }
+ node {
+ input: "3"
+ input: "input_2"
+ output: "4"
+ name: "Add_1"
+ op_type: "Add"
+ }
+ node {
+ input: "4"
+ output: "5"
+ name: "Shape_2"
+ op_type: "Shape"
+ }
+ node {
+ output: "6"
+ name: "Constant_3"
+ op_type: "Constant"
+ attribute {
+ name: "value"
+ t {
+ data_type: 7
+ raw_data: "\000\000\000\000\000\000\000\000"
+ }
+ type: TENSOR
+ }
+ }
+ node {
+ input: "5"
+ input: "6"
+ output: "7"
+ name: "Gather_4"
+ op_type: "Gather"
+ attribute {
+ name: "axis"
+ i: 0
+ type: INT
+ }
+ }
+ node {
+ input: "7"
+ output: "8"
+ name: "Unsqueeze_5"
+ op_type: "Unsqueeze"
+ attribute {
+ name: "axes"
+ ints: 0
+ type: INTS
+ }
+ }
+ node {
+ input: "8"
+ output: "9"
+ name: "Concat_6"
+ op_type: "Concat"
+ attribute {
+ name: "axis"
+ i: 0
+ type: INT
+ }
+ }
+ node {
+ input: "9"
+ output: "10"
+ name: "ConstantOfShape_7"
+ op_type: "ConstantOfShape"
+ attribute {
+ name: "value"
+ t {
+ dims: 1
+ data_type: 1
+ raw_data: "\000\000\200?"
+ }
+ type: TENSOR
+ }
+ }
+ name: "torch-jit-export"
+ initializer {
+ dims: 4
+ dims: 8
+ data_type: 1
+ name: "emb.weight"
+ raw_data: "\264\314\344\275\017A\376\276\313\374&>J\266a\277s\306\\=\212\032+?\211[t\275\344[\357\276Dk\\\276OKb?\234\'B\277A\334\274\2767N\257\276\320s\263\277\371+\244>:\314\202\277K\200L??\001\275\275\236u4\2774\032\315\277\214\004\224>Z\320\372>\267B\305\276\346G6\277N\265.\276\343\316\272\277t\364a>\201)|>p\223\251\277Qm2?\346\275)\277\354\235\233?"
+ }
+ input {
+ name: "input_1"
+ type {
+ tensor_type {
+ elem_type: 7
+ shape {
+ dim {
+ dim_param: "input_1_dim_0"
+ }
+ }
+ }
+ }
+ }
+ input {
+ name: "input_2"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_param: "input_2_dim_0"
+ }
+ dim {
+ dim_param: "input_2_dim_1"
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "10"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_param: "ConstantOfShape10_dim_0"
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 12
+}
+opset_import {
+ domain: "com.microsoft"
+ version: 1
+}
elem_type: 1
shape {
dim {
- dim_value: 0
+ dim_param: "GenerateProposals4_dim_0"
}
dim {
- dim_value: 5
+ dim_param: "GenerateProposals4_dim_1"
}
}
}
elem_type: 1
shape {
dim {
- dim_value: 0
+ dim_param: "GenerateProposals5_dim_0"
}
}
}
--- /dev/null
+ir_version: 6
+producer_name: "pytorch"
+producer_version: "CURRENT_VERSION"
+graph {
+ node {
+ input: "input_1"
+ input: "input_2"
+ output: "2"
+ name: "Add_0"
+ op_type: "Add"
+ }
+ name: "torch-jit-export"
+ input {
+ name: "input_1"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_param: "input_1_dim_1"
+ }
+ }
+ }
+ }
+ }
+ input {
+ name: "input_2"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_param: "input_2_dim_1"
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "2"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_param: "Add2_dim_1"
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 12
+}
--- /dev/null
+ir_version: 6
+producer_name: "pytorch"
+producer_version: "CURRENT_VERSION"
+graph {
+ node {
+ input: "input_1"
+ input: "input_2"
+ output: "2"
+ name: "MatMul_0"
+ op_type: "MatMul"
+ }
+ name: "torch-jit-export"
+ input {
+ name: "input_1"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_param: "input_1_dim_1"
+ }
+ dim {
+ dim_value: 4
+ }
+ }
+ }
+ }
+ }
+ input {
+ name: "input_2"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 4
+ }
+ dim {
+ dim_param: "input_2_dim_2"
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "2"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_param: "input_1_dim_1"
+ }
+ dim {
+ dim_param: "input_2_dim_2"
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 12
+}
--- /dev/null
+ir_version: 6
+producer_name: "pytorch"
+producer_version: "CURRENT_VERSION"
+graph {
+ node {
+ input: "input"
+ output: "1"
+ name: "ReduceMean_0"
+ op_type: "ReduceMean"
+ attribute {
+ name: "axes"
+ ints: 1
+ type: INTS
+ }
+ attribute {
+ name: "keepdims"
+ i: 0
+ type: INT
+ }
+ }
+ name: "torch-jit-export"
+ input {
+ name: "input"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_param: "input_dim_1"
+ }
+ dim {
+ dim_param: "input_dim_2"
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "1"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_param: "input_dim_2"
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 12
+}
--- /dev/null
+ir_version: 6
+producer_name: "pytorch"
+producer_version: "CURRENT_VERSION"
+graph {
+ node {
+ input: "input"
+ output: "1"
+ name: "Transpose_0"
+ op_type: "Transpose"
+ attribute {
+ name: "perm"
+ ints: 1
+ ints: 0
+ type: INTS
+ }
+ }
+ node {
+ input: "1"
+ output: "2"
+ name: "Softmax_1"
+ op_type: "Softmax"
+ attribute {
+ name: "axis"
+ i: 1
+ type: INT
+ }
+ }
+ node {
+ input: "2"
+ output: "3"
+ name: "Transpose_2"
+ op_type: "Transpose"
+ attribute {
+ name: "perm"
+ ints: 1
+ ints: 0
+ type: INTS
+ }
+ }
+ name: "torch-jit-export"
+ input {
+ name: "input"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_param: "input_dim_1"
+ }
+ }
+ }
+ }
+ }
+ output {
+ name: "3"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_param: "input_dim_1"
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 12
+}
--- /dev/null
+ir_version: 6
+producer_name: "pytorch"
+producer_version: "CURRENT_VERSION"
+graph {
+ node {
+ output: "7"
+ name: "Constant_0"
+ op_type: "Constant"
+ attribute {
+ name: "value"
+ t {
+ dims: 1
+ dims: 2
+ dims: 3
+ data_type: 1
+ raw_data: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?"
+ }
+ type: TENSOR
+ }
+ }
+ name: "torch-jit-export"
+ output {
+ name: "7"
+ type {
+ tensor_type {
+ elem_type: 1
+ shape {
+ dim {
+ dim_value: 1
+ }
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 3
+ }
+ }
+ }
+ }
+ }
+}
+opset_import {
+ version: 12
+}
-from test_pytorch_common import TestCase, run_tests, flatten, skipIfNoLapack
+from test_pytorch_common import TestCase, run_tests, flatten, skipIfNoLapack, \
+ BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
import torch
import torch.onnx
+from torch.onnx.symbolic_helper import parse_args, _get_tensor_dim_size, _get_tensor_sizes
+from torch.onnx import register_custom_op_symbolic, unregister_custom_op_symbolic
from torch.autograd import Variable, Function
from torch.nn import Module, functional
import torch.nn as nn
y = torch.empty(3, 2, 1, dtype=torch.long).random_(5)
self.assertONNX(torch.nn.CrossEntropyLoss(), (x, y), opset_version=12)
+ def test_lstm_none_sequence_lens(self):
+ """Test symbolic shape inference for LSTM when the input sequence_lens = None."""
+ input = torch.randn(RNN_SEQUENCE_LENGTH, BATCH_SIZE, RNN_INPUT_SIZE)
+ h0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
+ c0 = torch.randn(1, BATCH_SIZE, RNN_HIDDEN_SIZE)
+
+ class LSTMModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.rnn = torch.nn.LSTM(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE, 1, bidirectional=False)
+
+ def forward(self, x, h0, c0):
+ a, b = self.rnn(x, (h0, c0))
+ return torch.ones(b[0].shape)
+
+ self.assertONNX(LSTMModel(),
+ (input, h0, c0), input_names=["x", "y"],
+ dynamic_axes={"x" : {0: 'batch'}}, opset_version=12)
+
+ def test_dynamic_axes_add(self):
+ m1 = torch.randn(2, 3, requires_grad=True)
+ m2 = torch.randn(2, 1, requires_grad=True)
+ self.assertONNX(lambda x, y: torch.add(x, y), (m1, m2), input_names=["input_1", "input_2"],
+ dynamic_axes={"input_1": {1: "dim_1"}, "input_2": {1: "dim_2"}},
+ opset_version=12)
+
+ def test_dynamic_axes_matmul(self):
+ m1 = torch.randn(2, 2, 4, requires_grad=True)
+ m2 = torch.randn(2, 4, 3, requires_grad=True)
+ self.assertONNX(lambda x, y: torch.matmul(x, y), (m1, m2), input_names=["input_1", "input_2"],
+ dynamic_axes={"input_1": {1: "dim_0"}, "input_2": {2: "dim_1"}},
+ opset_version=12)
+
+ def test_dynamic_axes_reduce_mean(self):
+ m1 = torch.randn(2, 3, 4, requires_grad=True)
+ self.assertONNX(lambda x: torch.mean(x, dim=1), (m1), input_names=["input"],
+ dynamic_axes={"input": {1: "dim_1", 2: "dim_2"}},
+ opset_version=12)
+
+ def test_dynamic_axes_unchange(self):
+ """Test ProcessUnchangeNode in symbolic shape inference."""
+ m1 = torch.randn(2, 3, requires_grad=True)
+ self.assertONNX(lambda x: torch.softmax(x, dim=0), (m1,), input_names=["input"],
+ dynamic_axes={"input": {1: "dim_1"}},
+ opset_version=12)
+
+ def test_aten_embedding_1(self):
+ _onnx_opset_version = 12
+
+ @parse_args('v', 'v', 'i', 'b', 'b')
+ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
+ custom_attributes_json = (
+ '{'
+ f'"padding_idx":{str(padding_idx)},'
+ f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},'
+ f'"sparse":{str(sparse).lower()}'
+ '}'
+ )
+ output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding',
+ custom_attributes_json_s=custom_attributes_json)
+ return output
+
+ register_custom_op_symbolic('::embedding', embedding, _onnx_opset_version)
+
+ class Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.emb = torch.nn.Embedding(4, 8)
+
+ def forward(self, x, y):
+ res = self.emb(x)
+ res = res + y
+ return torch.ones(res.shape[0])
+
+ model = Model()
+ x = torch.ones(32, dtype=torch.long)
+ y = torch.randn(1, 8)
+ self.assertONNX(model, (x, y), opset_version=_onnx_opset_version)
+
+ unregister_custom_op_symbolic('::embedding', _onnx_opset_version)
+
+ # This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding.
+ def test_aten_embedding_2(self):
+ _onnx_opset_version = 12
+
+ @parse_args('v', 'v', 'i', 'b', 'b')
+ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
+ custom_attributes_json = (
+ '{'
+ f'"padding_idx":{str(padding_idx)},'
+ f'"scale_grad_by_freq":{str(scale_grad_by_freq).lower()},'
+ f'"sparse":{str(sparse).lower()}'
+ '}'
+ )
+ output = g.op("com.microsoft::ATenOp", weight, indices, name_s='aten::embedding',
+ custom_attributes_json_s=custom_attributes_json)
+
+ # do shape inference and set it via setType
+ indices_shape = _get_tensor_sizes(indices)
+ if indices_shape is not None and hasattr(weight.type(), 'with_sizes'):
+ output_type = weight.type().with_sizes(indices_shape + [_get_tensor_dim_size(weight, 1)])
+ output.setType(output_type)
+ return output
+
+ register_custom_op_symbolic('::embedding', embedding, _onnx_opset_version)
+
+ class Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.emb = torch.nn.Embedding(4, 8)
+
+ def forward(self, x, y):
+ res = self.emb(x)
+ res = res + y
+ return torch.ones(res.shape[0])
+
+ model = Model()
+ x = torch.ones(32, dtype=torch.long)
+ y = torch.randn(1, 8)
+ self.assertONNX(model, (x, y), opset_version=_onnx_opset_version, input_names=['input_1', 'input_2'],
+ dynamic_axes={"input_1": {0: "dim_0"}, 'input_2': {0: "dim_1", 1: "dim_2"}})
+
+ unregister_custom_op_symbolic('::embedding', _onnx_opset_version)
+
if __name__ == "__main__":
no_onnx_dep_flag = "--no-onnx"
_onnx_dep = no_onnx_dep_flag not in common.UNITTEST_ARGS
do_constant_folding=True, keep_initializers_as_inputs=True,
dynamic_axes=None, input_names=None, output_names=None,
fixed_batch_size=False, training=None,
- onnx_shape_inference=False):
+ onnx_shape_inference=True):
# export the model to ONNX
f = io.BytesIO()
input_copy = copy.deepcopy(input)
random_state = torch.rand((1, 1, 10, 30, 30))
self.run_test(model, (random_data, empty_tensor),
input_names=["data", "state"],
- dynamic_axes={"state": [0, 1, 2, 3, 4]},
+ dynamic_axes={"data": [0, 1, 2], "state": [0, 1, 2, 3, 4]},
test_with_inputs=[(random_data, random_state)])
+ self.run_test(model, (random_data, empty_tensor),
+ input_names=["data", "state"],
+ dynamic_axes={"state": [0, 1, 2, 3, 4]},
+ test_with_inputs=[(random_data, random_state)],
+ remained_onnx_input_idx=[1])
self.run_test(model, (random_data, empty_tensor), remained_onnx_input_idx=[])
@skipIfUnsupportedMinOpsetVersion(11)
return {};
}
+bool ConstantFoldCondition(torch::jit::Value* output) {
+ auto fold_condition = output->node()->kind() != c10::onnx::Constant &&
+ ConstantValueMap::HasValue(output->debugName());
+ auto reliable_value =
+ ConstantValueMap::GetTypeReliable(output->debugName()).value_or(false);
+ return fold_condition && reliable_value;
+}
+
void NodeToONNX(
Node* old_node,
Block* new_block,
//
// If onnx shape inference is turned on, the new outputs will have
// types inferred, and they will be merged with the old types.
- if (outputs[i]->node()->kind() != c10::onnx::Constant &&
- ConstantValueMap::HasValue(outputs[i]->debugName())) {
+ if (ConstantFoldCondition(outputs[i])) {
// Create a const node if the node output value is in
// ConstantValueMap.
auto value =
ONNXShapeTypeInference(const_node, empty_params_dict, opset_version);
env[old] = const_node->output();
} else {
- outputs[i]->setType(
- MergeInferredType(old->type(), outputs[i]->type()));
+ // ConstantValueMap has been set in shape inference,
+ // set_constant_value_map = false here to avoid redundancy.
+ MergeInferredTypeAndSetMap(
+ outputs[i], old->type(), outputs[i]->type(), false);
// Copy over source location and scope information to all nodes
// created by the symbolic
} else if (node->kind() == onnx::Equal) {
updated_val = at::eq(inputTensorValues[0], inputTensorValues[1]);
return c10::optional<at::Tensor>(updated_val);
+ } else if (node->kind() == onnx::Greater) {
+ updated_val = at::greater(inputTensorValues[0], inputTensorValues[1]);
+ return c10::optional<at::Tensor>(updated_val);
+ } else if (node->kind() == onnx::Less) {
+ updated_val = at::less(inputTensorValues[0], inputTensorValues[1]);
+ return c10::optional<at::Tensor>(updated_val);
} else if (node->kind() == onnx::Neg) {
updated_val = at::neg(inputTensorValues[0]);
return c10::optional<at::Tensor>(updated_val);
const std::string& tensorName,
size_t rankValue) {
ConstantValueMap::getInstance().rankMap.emplace(tensorName, rankValue);
+ ConstantValueMap::getInstance().useInferredTypeMap.emplace(tensorName, true);
}
bool ConstantValueMap::HasRank(const std::string& tensorName) {
const std::string& tensorName,
const c10::SymbolicShape& shapeValue) {
ConstantValueMap::getInstance().shapeMap.emplace(tensorName, shapeValue);
+ ConstantValueMap::getInstance().useInferredTypeMap.emplace(tensorName, true);
}
bool ConstantValueMap::HasShape(const std::string& tensorName) {
return value_vector;
}
+void ConstantValueMap::SetTypeReliable(
+ const std::string& tensorName,
+ bool value) {
+ ConstantValueMap::getInstance().typeReliableMap.emplace(tensorName, value);
+}
+
+bool ConstantValueMap::HasTypeReliable(const std::string& tensorName) {
+ return ConstantValueMap::getInstance().typeReliableMap.find(tensorName) !=
+ ConstantValueMap::getInstance().typeReliableMap.end();
+}
+
+c10::optional<bool> ConstantValueMap::GetTypeReliable(
+ const std::string& tensorName) {
+ if (!HasTypeReliable(tensorName)) {
+ return c10::nullopt;
+ }
+ return ConstantValueMap::getInstance().typeReliableMap[tensorName];
+}
+
+void ConstantValueMap::SetUseInferredType(
+ const std::string& tensorName,
+ bool value) {
+ ConstantValueMap::getInstance().useInferredTypeMap.emplace(tensorName, value);
+}
+
+bool ConstantValueMap::HasUseInferredType(const std::string& tensorName) {
+ return ConstantValueMap::getInstance().useInferredTypeMap.find(tensorName) !=
+ ConstantValueMap::getInstance().useInferredTypeMap.end();
+}
+
+c10::optional<bool> ConstantValueMap::GetUseInferredType(
+ const std::string& tensorName) {
+ if (!HasUseInferredType(tensorName)) {
+ return c10::nullopt;
+ }
+ return ConstantValueMap::getInstance().useInferredTypeMap[tensorName];
+}
+
void ConstantValueMap::ClearMaps() {
ConstantValueMap::getInstance().rankMap.clear();
ConstantValueMap::getInstance().shapeMap.clear();
ConstantValueMap::getInstance().tensorValueMap.clear();
+ ConstantValueMap::getInstance().typeReliableMap.clear();
+ ConstantValueMap::getInstance().useInferredTypeMap.clear();
}
// For debug only.
for (const auto& x : ConstantValueMap::getInstance().tensorValueMap) {
std::cout << "node " << x.first << ": " << x.second << std::endl;
}
+ std::cout << std::endl;
+ std::cout << "Print TypeReliable Maps:" << std::endl;
+ size_t count = 0;
+ for (const auto& x : ConstantValueMap::getInstance().typeReliableMap) {
+ std::cout << "(node " << x.first << ": " << x.second << "), ";
+ count++;
+ if (count % 10 == 0) {
+ std::cout << std::endl;
+ }
+ }
+ std::cout << std::endl;
+ std::cout << "Print UseInferredType Maps:" << std::endl;
+ count = 0;
+ for (const auto& x : ConstantValueMap::getInstance().useInferredTypeMap) {
+ std::cout << "(node " << x.first << ": " << x.second << "), ";
+ count++;
+ if (count % 10 == 0) {
+ std::cout << std::endl;
+ }
+ }
}
} // namespace jit
static std::vector<int64_t> GetValueInto1DInt64Vector(
const std::string& value_name);
+ static void SetTypeReliable(const std::string& tensorName, bool reliable);
+ static bool HasTypeReliable(const std::string& tensorName);
+ static c10::optional<bool> GetTypeReliable(const std::string& tensorName);
+
+ static void SetUseInferredType(
+ const std::string& tensorName,
+ bool useInferredType);
+ static bool HasUseInferredType(const std::string& tensorName);
+ static c10::optional<bool> GetUseInferredType(const std::string& tensorName);
+
static void PrintMaps();
static void ClearMaps();
~ConstantValueMap() = default;
std::unordered_map<std::string, size_t> rankMap;
std::unordered_map<std::string, c10::SymbolicShape> shapeMap;
std::unordered_map<std::string, at::Tensor> tensorValueMap;
+ // This map indicates whether the current type is reliably estimated or not.
+ std::unordered_map<std::string, bool> typeReliableMap;
+ // This map indicates whether the current type is estimated through inference
+ // or tracer.
+ std::unordered_map<std::string, bool> useInferredTypeMap;
};
} // namespace jit
#include <onnx/shape_inference/implementation.h>
#include <algorithm>
#include <cmath>
+#include <unordered_set>
namespace torch {
namespace jit {
// 3. existing type: Scalar[], inferred type: Tensor
// ONNX represents list of scalars by 1-d Tensor. Return inferred type since
// it is more compatible with ONNX.
-TypePtr MergeInferredType(TypePtr existing_type, TypePtr inferred_type) {
+std::pair<TypePtr, bool> MergeInferredType(
+ TypePtr existing_type,
+ TypePtr inferred_type) {
auto new_list_type = inferred_type->cast<ListType>();
+ auto use_inferred_type = false;
if (new_list_type) {
- return inferred_type;
+ return std::make_pair(inferred_type, true);
}
auto new_tensor_type = inferred_type->cast<TensorType>();
auto old_tensor_type = existing_type->cast<TensorType>();
if (new_tensor_type && old_tensor_type) {
if (!old_tensor_type->device()) {
// device not available means this is an invalid tensor type (most likely
- // an empty one) -> return inferred type directly.
- return new_tensor_type;
+ // an empty one) return inferred type directly.
+ return std::make_pair(new_tensor_type, true);
}
auto type = old_tensor_type;
if (new_tensor_type->dim()) {
type = type->withSymbolicShapes(new_tensor_type->symbolic_sizes());
+ use_inferred_type = true;
}
if (new_tensor_type->scalarType().has_value()) {
type = type->withScalarType(new_tensor_type->scalarType());
+ use_inferred_type = true;
}
- return type;
+ return std::make_pair(type, use_inferred_type);
}
if (old_tensor_type) {
- return existing_type;
+ return std::make_pair(existing_type, false);
}
auto old_list_type = existing_type->cast<ListType>();
if (new_tensor_type && old_list_type) {
if (new_tensor_type->sizes().isComplete()) {
- return inferred_type;
+ return std::make_pair(inferred_type, true);
}
- return existing_type;
+ return std::make_pair(existing_type, false);
}
- return inferred_type;
+ return std::make_pair(inferred_type, true);
+}
+
+void MergeInferredTypeAndSetMap(
+ Value* dest_v,
+ TypePtr existing_type,
+ TypePtr inferred_type,
+ bool set_constant_value_map) {
+ TypePtr mergedType;
+ bool inferred;
+ std::tie(mergedType, inferred) =
+ MergeInferredType(existing_type, inferred_type);
+ dest_v->setType(mergedType);
+ if (set_constant_value_map) {
+ ConstantValueMap::SetUseInferredType(dest_v->debugName(), inferred);
+ }
}
namespace {
// Assign a new Symbol, no need to keep track
// of it because there won't be duplicates.
sym = c10::ShapeSymbol::newSymbol();
+ symbol_map[sym.value()] = "";
}
sizes.emplace_back(sym.value());
}
const auto torch_tensor_type =
TorchTensorTypeFromONNX(p_type.tensor_type(), symbol_map);
if (torch_tensor_type) {
- v->setType(MergeInferredType(v->type(), torch_tensor_type));
+ MergeInferredTypeAndSetMap(v, v->type(), torch_tensor_type);
}
} else if (p_type.has_sequence_type()) {
const auto torch_list_type =
TorchListTypeFromONNX(p_type.sequence_type(), symbol_map);
if (torch_list_type) {
- v->setType(MergeInferredType(v->type(), torch_list_type));
+ MergeInferredTypeAndSetMap(v, v->type(), torch_list_type);
}
}
}
}
}
+void UpdateShapeConstantValueMap(
+ const Value* value,
+ const ::c10::SymbolicShape& shape) {
+ ConstantValueMap::SetShape(value->debugName(), shape);
+ if (shape.rank().has_value()) {
+ auto rank = shape.rank().value();
+ ConstantValueMap::SetRank(value->debugName(), rank);
+ }
+}
+
c10::optional<std::vector<int64_t>> GetValueFromListConstructNode(
Node* lc_node) {
auto rank = lc_node->inputs().size();
: c10::nullopt;
}
+void ProcessBroadCastNode(Node* n) {
+ TORCH_INTERNAL_ASSERT(n->inputs().size() == 2);
+ if (ConstantValueMap::HasShape(n->input(0)->debugName()) &&
+ ConstantValueMap::HasShape(n->input(1)->debugName())) {
+ auto input_shape_0 = ConstantValueMap::GetShape(n->input(0)->debugName());
+ auto input_shape_value_0 = input_shape_0.value().sizes();
+ auto input_shape_1 = ConstantValueMap::GetShape(n->input(1)->debugName());
+ auto input_shape_value_1 = input_shape_1.value().sizes();
+ size_t rank_0 = input_shape_value_0.value().size();
+ size_t rank_1 = input_shape_value_1.value().size();
+ size_t rank_max = std::max(rank_0, rank_1);
+ size_t rank_min = std::min(rank_0, rank_1);
+ std::vector<::c10::ShapeSymbol> final_shape;
+ final_shape.reserve(rank_max);
+ for (auto idx = 0; idx < rank_max; idx++) {
+ final_shape.emplace_back(::c10::ShapeSymbol::newSymbol());
+ }
+ for (auto idx = 0; idx < rank_min; idx++) {
+ auto is_static_0 =
+ input_shape_value_0.value()[rank_0 - 1 - idx].is_static();
+ auto is_static_1 =
+ input_shape_value_1.value()[rank_1 - 1 - idx].is_static();
+ if (is_static_0 && is_static_1) {
+ auto static_0_sz =
+ input_shape_value_0.value()[rank_0 - 1 - idx].static_size();
+ auto static_1_sz =
+ input_shape_value_1.value()[rank_1 - 1 - idx].static_size();
+ final_shape[rank_max - 1 - idx] = ::c10::ShapeSymbol::fromStaticSize(
+ std::max(static_0_sz, static_1_sz));
+ }
+ }
+
+ if (rank_0 < rank_1) {
+ for (auto idx = rank_min; idx < rank_max; idx++) {
+ auto shape_idx = rank_max - 1 - idx;
+ final_shape[shape_idx] = input_shape_value_1.value()[shape_idx];
+ }
+ } else {
+ for (auto idx = rank_min; idx < rank_max; idx++) {
+ auto shape_idx = rank_max - 1 - idx;
+ final_shape[shape_idx] = input_shape_value_0.value()[shape_idx];
+ }
+ }
+
+ UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
+ }
+}
+
+void ProcessConcatNode(Node* n) {
+ int axis = n->i(attr::axis);
+ if (ConstantValueMap::HasRank(n->input(0)->debugName())) {
+ auto rank = ConstantValueMap::GetRank(n->input(0)->debugName()).value();
+ size_t axis_adjust = 0;
+ if (axis >= 0) {
+ axis_adjust = static_cast<size_t>(axis);
+ } else {
+ axis_adjust = static_cast<size_t>(axis + static_cast<int>(rank));
+ }
+ std::vector<::c10::ShapeSymbol> final_shape;
+ final_shape.reserve(rank);
+ for (auto idx = 0; idx < rank; idx++) {
+ if (idx == axis_adjust) {
+ auto flag = true;
+ int64_t size_total = 0;
+ for (auto input_idx = 0; input_idx < n->inputs().size(); input_idx++) {
+ if (ConstantValueMap::HasShape(n->input(input_idx)->debugName())) {
+ auto input_shape =
+ ConstantValueMap::GetShape(n->input(input_idx)->debugName());
+ auto input_shape_value = input_shape.value().sizes();
+ auto shape_symbol = input_shape_value.value()[idx];
+ if (shape_symbol.is_static()) {
+ size_total += shape_symbol.static_size();
+ } else {
+ flag = false;
+ break;
+ }
+ }
+ }
+ if (flag) {
+ final_shape.emplace_back(
+ ::c10::ShapeSymbol::fromStaticSize(size_total));
+ } else {
+ final_shape.emplace_back(::c10::ShapeSymbol::newSymbol());
+ }
+ } else {
+ auto flag = false;
+ for (auto input_idx = 0; input_idx < n->inputs().size(); input_idx++) {
+ if (ConstantValueMap::HasShape(n->input(input_idx)->debugName())) {
+ auto input_shape =
+ ConstantValueMap::GetShape(n->input(input_idx)->debugName());
+ auto input_shape_value = input_shape.value().sizes();
+ auto shape_symbol = input_shape_value.value()[idx];
+ if (shape_symbol.is_static()) {
+ final_shape.emplace_back(::c10::ShapeSymbol::fromStaticSize(
+ shape_symbol.static_size()));
+ flag = true;
+ break;
+ }
+ }
+ }
+ if (!flag) {
+ final_shape.emplace_back(::c10::ShapeSymbol::newSymbol());
+ }
+ }
+ }
+ UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
+ }
+}
+
+void ProcessMatMulNode(Node* n) {
+ if (ConstantValueMap::HasShape(n->input(0)->debugName()) &&
+ ConstantValueMap::HasShape(n->input(1)->debugName())) {
+ auto input_shape_0 =
+ ConstantValueMap::GetShape(n->input(0)->debugName()).value();
+ auto input_shape_value_0 = input_shape_0.sizes().value();
+ auto input_shape_1 =
+ ConstantValueMap::GetShape(n->input(1)->debugName()).value();
+ auto input_shape_value_1 = input_shape_1.sizes().value();
+ size_t rank_0 = input_shape_value_0.size();
+ size_t rank_1 = input_shape_value_1.size();
+ auto is_rank_0_1 = false;
+ if (rank_0 == 1) {
+ input_shape_value_0.insert(
+ input_shape_value_0.begin(), ::c10::ShapeSymbol::fromStaticSize(1));
+ rank_0 = 2;
+ is_rank_0_1 = true;
+ }
+ auto is_rank_1_1 = false;
+ if (rank_1 == 1) {
+ input_shape_value_1.emplace_back(::c10::ShapeSymbol::fromStaticSize(1));
+ rank_1 = 2;
+ is_rank_1_1 = true;
+ }
+ size_t rank = std::max(rank_0, rank_1);
+ std::vector<::c10::ShapeSymbol> final_shape;
+ final_shape.reserve(rank);
+ if (rank_0 >= rank_1) {
+ for (auto idx = 0; idx < rank_0 - 2; idx++) {
+ final_shape.emplace_back(input_shape_value_0[idx]);
+ }
+ } else {
+ for (auto idx = 0; idx < rank_1 - 2; idx++) {
+ final_shape.emplace_back(input_shape_value_1[idx]);
+ }
+ }
+ final_shape.emplace_back(input_shape_value_0[rank_0 - 2]);
+ final_shape.emplace_back(input_shape_value_1[rank_1 - 1]);
+ if (is_rank_0_1) {
+ final_shape.erase(final_shape.begin());
+ }
+ if (is_rank_1_1) {
+ final_shape.pop_back();
+ }
+ UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
+ }
+}
+
+void ProcessReduceNode(Node* n) {
+ if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
+ auto input_shape_0 = ConstantValueMap::GetShape(n->input(0)->debugName());
+ auto input_shape_value_0 = input_shape_0.value().sizes();
+ size_t rank_0 = input_shape_value_0.value().size();
+ std::vector<::c10::ShapeSymbol> final_shape;
+ if (!n->hasAttributeS("axes")) {
+ UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
+ return;
+ }
+ final_shape.reserve(rank_0);
+ std::vector<int64_t> axes_vector = n->is(attr::axes);
+ for (auto idx = 0; idx < axes_vector.size(); idx++) {
+ if (axes_vector[idx] < 0) {
+ axes_vector[idx] += rank_0;
+ }
+ }
+ int64_t keepdims = 0;
+ if (n->hasAttributeS("keepdims")) {
+ keepdims = n->i(attr::keepdims);
+ }
+ for (auto idx = 0; idx < rank_0; idx++) {
+ auto it = std::find(axes_vector.begin(), axes_vector.end(), idx);
+ if (it != axes_vector.end()) {
+ if (keepdims != 0) {
+ final_shape.emplace_back(::c10::ShapeSymbol::fromStaticSize(1));
+ }
+ } else {
+ final_shape.emplace_back(input_shape_value_0.value()[idx]);
+ }
+ }
+ UpdateShape(n->output(0), c10::SymbolicShape(final_shape));
+ }
+}
+
void ProcessReshapeNode(Node* n, int opset_version) {
if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
auto shape_temp =
auto shape_vector_0 =
ConstantValueMap::GetShapeInto1DInt64VectorWithOneUnknown(
n->input(0)->debugName());
+ std::vector<int64_t> shape_vector_0_value(0);
if (shape_vector_0.has_value()) {
+ shape_vector_0_value = shape_vector_0.value();
+ }
+ if (shape_vector_0.has_value() || shape_temp.size() > 0) {
auto final_shape = ComputeShapeFromReshape(
- n, shape_vector_0.value(), shape_temp, opset_version);
+ n, shape_vector_0_value, shape_temp, opset_version);
UpdateShapeFromVector(n->output(), final_shape);
return;
}
}
}
+void ProcessUnchangeNode(Node* n) {
+ if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
+ auto shape_size_0 =
+ ConstantValueMap::GetShape(n->input(0)->debugName()).value();
+ UpdateShape(n->output(), shape_size_0);
+ }
+}
+
void ProcessTimeSeriesNode(Node* n) {
auto input0_shape = ConstantValueMap::GetShape(n->input(0)->debugName());
auto input1_shape = ConstantValueMap::GetShape(n->input(1)->debugName());
}
switch (n->kind()) {
+ case ::c10::onnx::Add:
+ case ::c10::onnx::Div:
+ case ::c10::onnx::Equal:
+ case ::c10::onnx::Greater:
+ case ::c10::onnx::GreaterOrEqual:
+ case ::c10::onnx::Less:
+ case ::c10::onnx::LessOrEqual:
+ case ::c10::onnx::Mod:
+ case ::c10::onnx::Mul:
+ case ::c10::onnx::Pow:
+ case ::c10::onnx::Sub: {
+ ProcessBroadCastNode(n);
+ break;
+ }
case ::c10::onnx::Shape: {
auto input_shape =
ConstantValueMap::GetShapeInto1DInt64Vector(n->input()->debugName());
at::Tensor f_copy = at::empty({shape_value_size}, options);
f_copy.copy_(f);
ConstantValueMap::SetValue(n->output()->debugName(), f_copy);
+ std::vector<::c10::ShapeSymbol> final_shape_vector(
+ 1, c10::ShapeSymbol::fromStaticSize(shape_value_size));
+ ::c10::SymbolicShape final_shape(final_shape_vector);
+ UpdateShape(n->output(), final_shape);
}
break;
}
}
break;
}
+ case ::c10::onnx::Concat: {
+ ProcessConcatNode(n);
+ break;
+ }
case ::c10::onnx::ConstantOfShape: {
if (ConstantValueMap::HasValue(n->input()->debugName())) {
auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
}
break;
}
+ case ::c10::onnx::MatMul: {
+ ProcessMatMulNode(n);
+ break;
+ }
+ case ::c10::onnx::ReduceMean:
+ case ::c10::onnx::ReduceProd: {
+ ProcessReduceNode(n);
+ break;
+ }
case ::c10::onnx::RNN:
case ::c10::onnx::LSTM:
case ::c10::onnx::GRU: {
ProcessSliceNode(n, opset_version);
break;
}
+ case ::c10::onnx::Cast:
+ case ::c10::onnx::Relu:
+ case ::c10::onnx::Softmax: {
+ ProcessUnchangeNode(n);
+ break;
+ }
case ::c10::onnx::Tile: {
if (ConstantValueMap::HasShape(n->input(0)->debugName())) {
auto input0_shape_size =
bool AllGraphInputsStatic(const Graph* g) {
for (auto n : g->inputs()) {
- if (!n->isCompleteTensor()) {
+ if (TensorTypePtr input_type = n->type()->cast<TensorType>()) {
+ if (input_type->dim()) {
+ auto shape = input_type->symbolic_sizes();
+ if (!ConstantValueMap::HasShape(n->debugName())) {
+ UpdateShapeConstantValueMap(n, shape);
+ }
+ }
+ }
+ }
+ for (auto n : g->inputs()) {
+ // Some inputs can be non-Tensor type, e.g.,
+ // __torch__.torch.classes.quantized.LinearPackedParamsBase
+ // so we only need check Tensor type here.
+ if (n->type()->cast<TensorType>() && !n->isCompleteTensor()) {
return false;
}
}
} // namespace
+// For some operators, there are some inputs not related to shape inference.
+// For example, LSTM input 4 (sequence_lens) is optional,
+// and the shape inference can be done through other required inputs.
+// When we compute reliable, we don't need this input be reliable.
+static std::unordered_map<std::string, std::unordered_set<int64_t>>
+ non_required_shape_inference_idx_map = {{"onnx::LSTM", {4}}};
+
+std::pair<bool, bool> AreInputsReliableOrStatic(Node* n) {
+ auto reliable = true;
+ auto complete = true;
+ auto input_size = n->inputs().size();
+ std::unordered_set<int64_t> non_required_idx = {};
+ if (non_required_shape_inference_idx_map.find(n->kind().toDisplayString()) !=
+ non_required_shape_inference_idx_map.end()) {
+ non_required_idx =
+ non_required_shape_inference_idx_map[n->kind().toDisplayString()];
+ }
+ for (auto idx = 0; idx < input_size; idx++) {
+ if (!non_required_idx.empty() &&
+ non_required_idx.find(idx) != non_required_idx.end()) {
+ continue;
+ }
+ auto input = n->inputs()[idx];
+ reliable &=
+ ConstantValueMap::GetTypeReliable(input->debugName()).value_or(false);
+ if (auto pt = input->type()->cast<TensorType>()) {
+ if (!pt->sizes().isComplete()) {
+ complete = false;
+ }
+ }
+ }
+ return std::make_pair(reliable, complete);
+}
+
+// There is no need to put onnx type here, but we need this
+// for some legacy tests when onnx_shape_inference=False.
+static std::unordered_set<std::string> nodeTypeReliableForTracer = {
+ "prim::ListConstruct",
+ "onnx::Cast",
+ "onnx::Constant",
+ "onnx::Relu",
+ "com.microsoft::Gelu"};
+
+void UpdateReliable(
+ torch::jit::Value* output,
+ const std::pair<bool, bool>& inferred_type_reliable) {
+ auto inferred =
+ ConstantValueMap::GetUseInferredType(output->debugName()).value_or(false);
+ auto isTypeReliableForTracer =
+ nodeTypeReliableForTracer.find(
+ output->node()->kind().toDisplayString()) !=
+ nodeTypeReliableForTracer.end();
+ if (!inferred && !isTypeReliableForTracer &&
+ !output->node()->kind().is_onnx()) {
+ std::cerr
+ << "WARNING: The shape inference of "
+ << output->node()->kind().toDisplayString()
+ << " type is missing, so it may result in wrong shape inference for the exported graph. "
+ << "Please consider adding it in symbolic function." << std::endl;
+ }
+ auto reliable = false;
+ if (inferred) {
+ reliable = inferred_type_reliable.first;
+ } else {
+ if (inferred_type_reliable.second && isTypeReliableForTracer) {
+ reliable = true;
+ }
+ }
+ // Assume that the tracer can estimate rank correctly,
+ // then the output tensor of Shape should always be reliable.
+ if (output->node()->kind() == ::c10::onnx::Shape) {
+ reliable = true;
+ }
+ ConstantValueMap::SetTypeReliable(output->debugName(), reliable);
+ if (!reliable) {
+ if (auto output_tensor_type = output->type()->cast<TensorType>()) {
+ output->setType(output_tensor_type->withSymbolicShapes(
+ ::c10::SymbolicShape(output_tensor_type->dim())));
+ }
+ }
+}
+
+void UpdateReliable(Node* n) {
+ auto input_reliable = AreInputsReliableOrStatic(n);
+ for (auto output : n->outputs()) {
+ UpdateReliable(output, input_reliable);
+ }
+}
+
+void SetGraphInputTypeReliable(const Graph* g) {
+ for (auto graph_input : g->inputs()) {
+ if (!ConstantValueMap::HasTypeReliable(graph_input->debugName())) {
+ ConstantValueMap::SetTypeReliable(graph_input->debugName(), true);
+ }
+ }
+}
+
void ONNXShapeTypeInference(
Node* n,
const ParamMap& params_dict,
int opset_version) {
+ SetGraphInputTypeReliable(n->owningGraph());
GRAPH_UPDATE(
"Running ONNX shape inference for node: ", n->kind().toDisplayString());
if (IsValidONNXNode(n)) {
SpecialPostProcess(n);
if (IsValidONNXNode(n)) {
ProcessConstantValueMap(n, opset_version);
+ if (n->kind() != prim::ListConstruct) {
+ for (auto input : n->inputs()) {
+ if (input->node()->kind() == prim::ListConstruct) {
+ UpdateReliable(input, AreInputsReliableOrStatic(input->node()));
+ }
+ }
+ }
}
+ UpdateReliable(n);
+
+ // For the node type that does nott have ComputeConstant logic, it may have
+ // reliable shape but its shape is not in ConstantValueMap. So we need this
+ // logic to update ConstantValueMap.
+ for (auto node_output : n->outputs()) {
+ if (ConstantValueMap::HasTypeReliable(node_output->debugName())) {
+ auto reliable =
+ ConstantValueMap::GetTypeReliable(node_output->debugName())
+ .value_or(false);
+ if (reliable && !ConstantValueMap::HasShape(node_output->debugName())) {
+ // TODO: ListType case
+ if (auto output_tensor_type = node_output->type()->cast<TensorType>()) {
+ if (output_tensor_type->dim()) {
+ auto symbolic_sizes = output_tensor_type->symbolic_sizes();
+ UpdateShapeConstantValueMap(node_output, symbolic_sizes);
+ }
+ }
+ }
+ }
+ }
+
GRAPH_DEBUG(
"Torch graph after shape inference:", n->owningGraph()->toString());
}
const at::Tensor& output,
bool onnx_shape_inference) {
if (onnx_shape_inference) {
- graph_output->setType(
- MergeInferredType(TensorType::create(output), graph_output->type()));
+ MergeInferredTypeAndSetMap(
+ graph_output, TensorType::create(output), graph_output->type());
} else {
graph_output->inferTypeFrom(output);
}
->getElementType()
->cast<TensorType>();
elem_type = elem_type->withScalarType(var.scalar_type());
- graph->outputs()
- .at(outputs_index)
- ->setType(MergeInferredType(
- graph->outputs().at(outputs_index)->type(),
- ListType::create(elem_type)));
+ auto graph_output = graph->outputs().at(outputs_index);
+ MergeInferredTypeAndSetMap(
+ graph_output, graph_output->type(), ListType::create(elem_type));
} else {
graph->outputs()
.at(outputs_index)
const ParamMap& params_dict,
int opset_version) {
ConstantValueMap::ClearMaps();
+ SetGraphInputTypeReliable(graph.get());
ONNXShapeTypeInference(graph->block(), params_dict, opset_version);
}
namespace torch {
namespace jit {
-TORCH_API TypePtr
-MergeInferredType(TypePtr existing_type, TypePtr inferred_type);
+void MergeInferredTypeAndSetMap(
+ Value* dest_v,
+ TypePtr existing_type,
+ TypePtr inferred_type,
+ bool set_constant_value_map = true);
// Update graph input types with dynamic axes info.
// Axes that are marked as dynamic will be assigned as dynamic ShapeSymbol.
const ParamMap& params_dict,
int opset_version);
+std::pair<bool, bool> AreInputsReliableOrStatic(Node* n);
+void UpdateReliable(
+ torch::jit::Value* output,
+ const std::pair<bool, bool>& input_reliable);
+
} // namespace jit
} // namespace torch
"""
from torch.onnx import utils
utils.register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version)
+
+
+def unregister_custom_op_symbolic(symbolic_name, opset_version):
+ from torch.onnx import utils
+ utils.unregister_custom_op_symbolic(symbolic_name, opset_version)
# Generate paddings in ONNX order based on pad in pytorch.
# Args:
-# dim: the dimension of the tensor.
+# input: the input tensor.
# pad: the paddings in pytorch.
# The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end,
# where m is in range [0, n].
-def _prepare_onnx_paddings(g, dim, pad):
+def _prepare_onnx_paddings(g, input, pad):
# The desired order of paddings is
# dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
# n is the dimension of input.
# Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning
pad_len = torch.onnx.symbolic_opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0])))
# Set extension = [0] * (dim * 2 - len(pad))
- extension = g.op("Sub", g.op("Mul", g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int64)),
+ rank = sym_help._get_tensor_rank(input)
+ if rank is None:
+ rank = g.op("Size", g.op("Shape", input))
+ else:
+ rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64))
+ extension = g.op("Sub", g.op("Mul", rank,
g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))), pad_len)
# Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ]
# Currently ONNX only supports int64 type for Pad
mode = "constant"
value = sym_help._maybe_get_scalar(value)
value = sym_help._if_scalar_type_as(g, value, input)
- pad = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding)
+ pad = _prepare_onnx_paddings(g, input, padding)
return g.op("Pad", input, pad, value, mode_s=mode)
def reflection_pad(g, input, padding):
mode = "reflect"
- paddings = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding)
+ paddings = _prepare_onnx_paddings(g, input, padding)
return g.op("Pad", input, paddings, mode_s=mode)
def replication_pad(g, input, padding):
mode = "edge"
- paddings = _prepare_onnx_paddings(g, sym_help._get_tensor_rank(input), padding)
+ paddings = _prepare_onnx_paddings(g, input, padding)
return g.op("Pad", input, paddings, mode_s=mode)
global _registry
return (domain, version) in _registry and opname in _registry[(domain, version)]
+def unregister_op(opname, domain, version):
+ global _registry
+ if is_registered_op(opname, domain, version):
+ del _registry[(domain, version)][opname]
+ if not _registry[(domain, version)]:
+ del _registry[(domain, version)]
+ else:
+ warnings.warn("The opname " + opname + " is not registered.")
+
def get_op_supported_version(opname, domain, version):
iter_version = version
while iter_version <= _onnx_main_opset:
torch._C._jit_pass_onnx_scalar_type_analysis(graph, True, _export_onnx_opset_version)
torch._C._jit_pass_lint(graph)
- torch._C._jit_pass_onnx_fold_if(graph)
-
torch._C._jit_pass_onnx_peephole(graph, _export_onnx_opset_version, fixed_batch_size)
torch._C._jit_pass_lint(graph)
export_type=ExportTypes.PROTOBUF_FILE, example_outputs=None,
google_printer=False, opset_version=None, _retain_param_name=True,
keep_initializers_as_inputs=None, custom_opsets=None, add_node_names=True,
- do_constant_folding=True):
+ do_constant_folding=True, dynamic_axes=None):
return _export_to_pretty_string(model, args, f, export_params, verbose, training,
input_names, output_names, operator_export_type,
export_type, example_outputs, google_printer,
do_constant_folding=do_constant_folding,
add_node_names=add_node_names,
keep_initializers_as_inputs=keep_initializers_as_inputs,
- custom_opsets=custom_opsets)
+ custom_opsets=custom_opsets, dynamic_axes=dynamic_axes)
def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, training=None,
google_printer=False, opset_version=None, _retain_param_name=False,
do_constant_folding=True, keep_initializers_as_inputs=None,
fixed_batch_size=False, custom_opsets=None, add_node_names=True,
- onnx_shape_inference=True):
+ onnx_shape_inference=True, dynamic_axes=None):
from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version
from torch.onnx.symbolic_helper import _set_operator_export_type
if opset_version is None:
output_names, operator_export_type,
example_outputs, _retain_param_name,
val_do_constant_folding, fixed_batch_size=fixed_batch_size,
- training=training)
+ training=training, dynamic_axes=dynamic_axes)
return graph._pretty_print_onnx(params_dict, opset_version, False,
operator_export_type, google_printer,
return getattr(self, sel)(k)
-def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
+def get_ns_op_name_from_custom_op(symbolic_name):
if not bool(re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name)):
raise RuntimeError("Failed to register operator {}. \
The symbolic name must match the format Domain::Name, \
if ns in unaccepted_domain_names:
raise RuntimeError("Failed to register operator {}. The domain {} is already a used domain."
.format(symbolic_name, ns))
+ return ns, op_name
+
+
+# When the user registers symbolic for custom/contrib ops,
+# it is highly recommended to add shape inference for that operator via setType API,
+# otherwise the exported graph may have incorrect shape inference in some extreme cases.
+# An example of setType is test_aten_embedding_2 in test_operators.py..
+def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
+ ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
import torch.onnx.symbolic_registry as sym_registry
from torch.onnx.symbolic_helper import _onnx_stable_opsets, _onnx_main_opset
if version >= opset_version:
sym_registry.register_op(op_name, symbolic_fn, ns, version)
+
+def unregister_custom_op_symbolic(symbolic_name, opset_version):
+ ns, op_name = get_ns_op_name_from_custom_op(symbolic_name)
+ import torch.onnx.symbolic_registry as sym_registry
+ from torch.onnx.symbolic_helper import _onnx_stable_opsets, _onnx_main_opset
+
+ for version in _onnx_stable_opsets + [_onnx_main_opset]:
+ if version >= opset_version:
+ sym_registry.unregister_op(op_name, ns, version)
+
+
# This helper function ensures dynamic axes argument is following the expected format
def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):
if len(dynamic_axes) == 0: