From f0f9277c3cda0b022af0dfaf21d043641bea80d2 Mon Sep 17 00:00:00 2001 From: Spandan Tiwari Date: Thu, 20 Dec 2018 12:24:42 -0800 Subject: [PATCH] Fixing ONNX export of logical ops to have correct output datatype (#15185) Summary: Currently PyTorch ONNX exporter exports the logical ops (`lt`, `gt`, `le`, `ge`, `eq`) with output type in corresponding ONNX ops as type `tensor(uint8)`. But ONNX spec allows for only `tensor(bool)`, which is why models that have these ops fail to load properly. This issue is captured in https://github.com/pytorch/pytorch/issues/11339. Part of this issue, relating to the allowed input types, has been fixed in ONNX spec by houseroad. This PR fixes the other part pertaining to output type. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15185 Differential Revision: D13494873 Pulled By: houseroad fbshipit-source-id: 069d2f956a5ae9bf0ac2540a32594a31b01adef8 --- .../TestScript.test_listconstruct_erasure.expect | 5 ++-- ...Script.test_onnx_export_script_module_if.expect | 9 ++++--- ...TestScript.test_onnx_export_speculate-f1.expect | 17 ++++++------ ...TestScript.test_onnx_export_speculate-f2.expect | 21 ++++++++------- test/onnx/expect/TestOperators.test_equal.expect | 12 ++++++++- test/onnx/expect/TestOperators.test_ge.expect | 12 ++++++++- test/onnx/expect/TestOperators.test_gt.expect | 12 ++++++++- test/onnx/expect/TestOperators.test_le.expect | 12 ++++++++- test/onnx/expect/TestOperators.test_lt.expect | 12 ++++++++- torch/onnx/symbolic.py | 31 ++++++++++++++++++++-- 10 files changed, 112 insertions(+), 31 deletions(-) diff --git a/test/expect/TestScript.test_listconstruct_erasure.expect b/test/expect/TestScript.test_listconstruct_erasure.expect index f6b5b1a..0944262 100644 --- a/test/expect/TestScript.test_listconstruct_erasure.expect +++ b/test/expect/TestScript.test_listconstruct_erasure.expect @@ -6,13 +6,14 @@ ModelProto { GraphProto { name: "torch-jit-export" inputs: [{name: "0", type:Tensor dims: 3 4}] - outputs: [{name: "4", type:Tensor dims: 0}] + outputs: [{name: "5", type:Tensor dims: 0}] initializers: [] nodes: [ Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, Node {type: "Less", inputs: [0,1], outputs: [2], attributes: []}, Node {type: "Cast", inputs: [2], outputs: [3], attributes: [{ name: 'to', type: int, value: 2}]}, - Node {type: "ATen", inputs: [0,3], outputs: [4], attributes: [{ name: 'operator', type: string, value: 'index'}]} + Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 2}]}, + Node {type: "ATen", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}]} ] } opset_import: [OperatorSetIdProto { domain: }], diff --git a/test/expect/TestScript.test_onnx_export_script_module_if.expect b/test/expect/TestScript.test_onnx_export_script_module_if.expect index 264b863..d436925 100644 --- a/test/expect/TestScript.test_onnx_export_script_module_if.expect +++ b/test/expect/TestScript.test_onnx_export_script_module_if.expect @@ -6,20 +6,21 @@ ModelProto { GraphProto { name: "torch-jit-export" inputs: [{name: "x.1", type:Tensor dims: 1 2 3}] - outputs: [{name: "4", type:Tensor dims: 1 2 3}] + outputs: [{name: "5", type:Tensor dims: 1 2 3}] initializers: [] nodes: [ Node {type: "ReduceSum", inputs: [x.1], outputs: [1], attributes: [{ name: 'keepdims', type: int, value: 0}]}, Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, Node {type: "Greater", inputs: [1,2], outputs: [3], attributes: []}, - Node {type: "If", inputs: [3], outputs: [4], attributes: [{ name: 'then_branch', type: graph, value: + Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 2}]}, + Node {type: "If", inputs: [4], outputs: [5], attributes: [{ name: 'then_branch', type: graph, value: GraphProto { name: "torch-jit-export1" inputs: [] - outputs: [{name: "5", type:Tensor dims: }] + outputs: [{name: "6", type:Tensor dims: }] initializers: [] nodes: [ - Node {type: "Neg", inputs: [x.1], outputs: [5], attributes: []} + Node {type: "Neg", inputs: [x.1], outputs: [6], attributes: []} ] } diff --git a/test/expect/TestScript.test_onnx_export_speculate-f1.expect b/test/expect/TestScript.test_onnx_export_speculate-f1.expect index 4e8e515..dd094b6 100644 --- a/test/expect/TestScript.test_onnx_export_speculate-f1.expect +++ b/test/expect/TestScript.test_onnx_export_speculate-f1.expect @@ -6,28 +6,29 @@ ModelProto { GraphProto { name: "torch-jit-export" inputs: [{name: "x.1", type:Tensor dims: 1 10}] - outputs: [{name: "8", type:Tensor dims: 10 1}] + outputs: [{name: "9", type:Tensor dims: 10 1}] initializers: [] nodes: [ Node {type: "Add", inputs: [x.1,x.1], outputs: [1], attributes: []}, Node {type: "ReduceSum", inputs: [1], outputs: [2], attributes: [{ name: 'keepdims', type: int, value: 0}]}, Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, Node {type: "Greater", inputs: [2,3], outputs: [4], attributes: []}, - Node {type: "Transpose", inputs: [1], outputs: [5], attributes: [{ name: 'perm', type: ints, values: [1 0]}]}, + Node {type: "Cast", inputs: [4], outputs: [5], attributes: [{ name: 'to', type: int, value: 2}]}, Node {type: "Transpose", inputs: [1], outputs: [6], attributes: [{ name: 'perm', type: ints, values: [1 0]}]}, Node {type: "Transpose", inputs: [1], outputs: [7], attributes: [{ name: 'perm', type: ints, values: [1 0]}]}, - Node {type: "If", inputs: [4], outputs: [8], attributes: [{ name: 'then_branch', type: graph, value: + Node {type: "Transpose", inputs: [1], outputs: [8], attributes: [{ name: 'perm', type: ints, values: [1 0]}]}, + Node {type: "If", inputs: [5], outputs: [9], attributes: [{ name: 'then_branch', type: graph, value: GraphProto { name: "torch-jit-export1" inputs: [] - outputs: [{name: "9", type:Tensor dims: }] + outputs: [{name: "10", type:Tensor dims: }] initializers: [] nodes: [ - Node {type: "If", inputs: [4], outputs: [9], attributes: [{ name: 'then_branch', type: graph, value: + Node {type: "If", inputs: [5], outputs: [10], attributes: [{ name: 'then_branch', type: graph, value: GraphProto { name: "torch-jit-export2" inputs: [] - outputs: [{name: "5", type:Tensor dims: }] + outputs: [{name: "6", type:Tensor dims: }] initializers: [] nodes: [ @@ -38,7 +39,7 @@ ModelProto { GraphProto { name: "torch-jit-export3" inputs: [] - outputs: [{name: "6", type:Tensor dims: }] + outputs: [{name: "7", type:Tensor dims: }] initializers: [] nodes: [ @@ -53,7 +54,7 @@ ModelProto { GraphProto { name: "torch-jit-export4" inputs: [] - outputs: [{name: "7", type:Tensor dims: }] + outputs: [{name: "8", type:Tensor dims: }] initializers: [] nodes: [ diff --git a/test/expect/TestScript.test_onnx_export_speculate-f2.expect b/test/expect/TestScript.test_onnx_export_speculate-f2.expect index 2820ce5..3126f1d 100644 --- a/test/expect/TestScript.test_onnx_export_speculate-f2.expect +++ b/test/expect/TestScript.test_onnx_export_speculate-f2.expect @@ -6,28 +6,29 @@ ModelProto { GraphProto { name: "torch-jit-export" inputs: [{name: "x.1", type:Tensor dims: 1 10},{name: "1", type:Tensor dims: 20 10},{name: "2", type:Tensor dims: 20}] - outputs: [{name: "7", type:Tensor dims: 1 20}] + outputs: [{name: "8", type:Tensor dims: 1 20}] initializers: [TensorProto shape: [20 10],TensorProto shape: [20]] nodes: [ Node {type: "Add", inputs: [x.1,x.1], outputs: [3], attributes: []}, Node {type: "ReduceSum", inputs: [3], outputs: [4], attributes: [{ name: 'keepdims', type: int, value: 0}]}, Node {type: "Constant", inputs: [], outputs: [5], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]}, Node {type: "Greater", inputs: [4,5], outputs: [6], attributes: []}, - Node {type: "If", inputs: [6], outputs: [7], attributes: [{ name: 'then_branch', type: graph, value: + Node {type: "Cast", inputs: [6], outputs: [7], attributes: [{ name: 'to', type: int, value: 2}]}, + Node {type: "If", inputs: [7], outputs: [8], attributes: [{ name: 'then_branch', type: graph, value: GraphProto { name: "torch-jit-export1" inputs: [] - outputs: [{name: "8", type:Tensor dims: 1 20}] + outputs: [{name: "9", type:Tensor dims: 1 20}] initializers: [] nodes: [ - Node {type: "If", inputs: [6], outputs: [8], attributes: [{ name: 'then_branch', type: graph, value: + Node {type: "If", inputs: [7], outputs: [9], attributes: [{ name: 'then_branch', type: graph, value: GraphProto { name: "torch-jit-export2" inputs: [] - outputs: [{name: "9", type:Tensor dims: 1 20}] + outputs: [{name: "10", type:Tensor dims: 1 20}] initializers: [] nodes: [ - Node {type: "Gemm", inputs: [3,1,2], outputs: [9], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} + Node {type: "Gemm", inputs: [3,1,2], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} ] } @@ -35,10 +36,10 @@ ModelProto { GraphProto { name: "torch-jit-export3" inputs: [] - outputs: [{name: "10", type:Tensor dims: 1 20}] + outputs: [{name: "11", type:Tensor dims: 1 20}] initializers: [] nodes: [ - Node {type: "Gemm", inputs: [3,1,2], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} + Node {type: "Gemm", inputs: [3,1,2], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} ] } @@ -50,10 +51,10 @@ ModelProto { GraphProto { name: "torch-jit-export4" inputs: [] - outputs: [{name: "11", type:Tensor dims: 1 20}] + outputs: [{name: "12", type:Tensor dims: 1 20}] initializers: [] nodes: [ - Node {type: "Gemm", inputs: [3,1,2], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} + Node {type: "Gemm", inputs: [3,1,2], outputs: [12], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]} ] } diff --git a/test/onnx/expect/TestOperators.test_equal.expect b/test/onnx/expect/TestOperators.test_equal.expect index c511737..53d4f84 100644 --- a/test/onnx/expect/TestOperators.test_equal.expect +++ b/test/onnx/expect/TestOperators.test_equal.expect @@ -8,6 +8,16 @@ graph { output: "2" op_type: "Equal" } + node { + input: "2" + output: "3" + op_type: "Cast" + attribute { + name: "to" + i: 2 + type: INT + } + } name: "torch-jit-export" input { name: "0" @@ -48,7 +58,7 @@ graph { } } output { - name: "2" + name: "3" type { tensor_type { elem_type: 2 diff --git a/test/onnx/expect/TestOperators.test_ge.expect b/test/onnx/expect/TestOperators.test_ge.expect index a66c5f4..8a4b5d5 100644 --- a/test/onnx/expect/TestOperators.test_ge.expect +++ b/test/onnx/expect/TestOperators.test_ge.expect @@ -13,6 +13,16 @@ graph { output: "3" op_type: "Not" } + node { + input: "3" + output: "4" + op_type: "Cast" + attribute { + name: "to" + i: 2 + type: INT + } + } name: "torch-jit-export" input { name: "0" @@ -47,7 +57,7 @@ graph { } } output { - name: "3" + name: "4" type { tensor_type { elem_type: 2 diff --git a/test/onnx/expect/TestOperators.test_gt.expect b/test/onnx/expect/TestOperators.test_gt.expect index 7680e0a..542bf81 100644 --- a/test/onnx/expect/TestOperators.test_gt.expect +++ b/test/onnx/expect/TestOperators.test_gt.expect @@ -8,6 +8,16 @@ graph { output: "2" op_type: "Greater" } + node { + input: "2" + output: "3" + op_type: "Cast" + attribute { + name: "to" + i: 2 + type: INT + } + } name: "torch-jit-export" input { name: "0" @@ -48,7 +58,7 @@ graph { } } output { - name: "2" + name: "3" type { tensor_type { elem_type: 2 diff --git a/test/onnx/expect/TestOperators.test_le.expect b/test/onnx/expect/TestOperators.test_le.expect index 0b17740..923489e 100644 --- a/test/onnx/expect/TestOperators.test_le.expect +++ b/test/onnx/expect/TestOperators.test_le.expect @@ -13,6 +13,16 @@ graph { output: "3" op_type: "Not" } + node { + input: "3" + output: "4" + op_type: "Cast" + attribute { + name: "to" + i: 2 + type: INT + } + } name: "torch-jit-export" input { name: "0" @@ -47,7 +57,7 @@ graph { } } output { - name: "3" + name: "4" type { tensor_type { elem_type: 2 diff --git a/test/onnx/expect/TestOperators.test_lt.expect b/test/onnx/expect/TestOperators.test_lt.expect index a5c6740..2befe25 100644 --- a/test/onnx/expect/TestOperators.test_lt.expect +++ b/test/onnx/expect/TestOperators.test_lt.expect @@ -8,6 +8,16 @@ graph { output: "2" op_type: "Less" } + node { + input: "2" + output: "3" + op_type: "Cast" + attribute { + name: "to" + i: 2 + type: INT + } + } name: "torch-jit-export" input { name: "0" @@ -48,7 +58,7 @@ graph { } } output { - name: "2" + name: "3" type { tensor_type { elem_type: 2 diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 0a259be..41338d9 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -664,24 +664,50 @@ def upsample_bilinear2d(g, input, output_size, align_corners): mode_s="linear") +def wrap_logical_op_with_cast_to_uint8(func): + def wrap_with_cast(g, input, other): + return g.op("Cast", func(g, input, other), to_i=cast_pytorch_to_onnx['Byte']) + return wrap_with_cast + + +def wrap_logical_op_with_negation(func): + def wrap_with_not(g, input, other): + return g.op("Not", func(g, input, other)) + return wrap_with_not + + +@wrap_logical_op_with_cast_to_uint8 def gt(g, input, other): + return gt_impl(g, input, other) + + +def gt_impl(g, input, other): other = _maybe_get_scalar(other) return g.op("Greater", input, _if_scalar_type_as(g, other, input)) +@wrap_logical_op_with_cast_to_uint8 def lt(g, input, other): + return lt_impl(g, input, other) + + +def lt_impl(g, input, other): other = _maybe_get_scalar(other) return g.op("Less", input, _if_scalar_type_as(g, other, input)) +@wrap_logical_op_with_cast_to_uint8 +@wrap_logical_op_with_negation def ge(g, input, other): other = _maybe_get_scalar(other) - return g.op("Not", lt(g, input, _if_scalar_type_as(g, other, input))) + return lt_impl(g, input, _if_scalar_type_as(g, other, input)) +@wrap_logical_op_with_cast_to_uint8 +@wrap_logical_op_with_negation def le(g, input, other): other = _maybe_get_scalar(other) - return g.op("Not", gt(g, input, _if_scalar_type_as(g, other, input))) + return gt_impl(g, input, _if_scalar_type_as(g, other, input)) def where(g, condition, self, other): @@ -915,6 +941,7 @@ def min(g, self, dim_or_y=None, keepdim=None): outputs=2) +@wrap_logical_op_with_cast_to_uint8 def eq(g, self, other): return g.op("Equal", self, other) -- 2.7.4