Fixing ONNX export of logical ops to have correct output datatype (#15185)
authorSpandan Tiwari <sptiwari@microsoft.com>
Thu, 20 Dec 2018 20:24:42 +0000 (12:24 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 20 Dec 2018 20:37:27 +0000 (12:37 -0800)
Summary:
Currently PyTorch ONNX exporter exports the logical ops (`lt`, `gt`, `le`, `ge`, `eq`) with output type in corresponding ONNX ops as type `tensor(uint8)`. But ONNX spec allows for only `tensor(bool)`, which is why models that have these ops fail to load properly.

This issue is captured in https://github.com/pytorch/pytorch/issues/11339. Part of this issue, relating to the allowed input types, has been fixed in ONNX spec by houseroad. This PR fixes the other part pertaining to output type.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15185

Differential Revision: D13494873

Pulled By: houseroad

fbshipit-source-id: 069d2f956a5ae9bf0ac2540a32594a31b01adef8

test/expect/TestScript.test_listconstruct_erasure.expect
test/expect/TestScript.test_onnx_export_script_module_if.expect
test/expect/TestScript.test_onnx_export_speculate-f1.expect
test/expect/TestScript.test_onnx_export_speculate-f2.expect
test/onnx/expect/TestOperators.test_equal.expect
test/onnx/expect/TestOperators.test_ge.expect
test/onnx/expect/TestOperators.test_gt.expect
test/onnx/expect/TestOperators.test_le.expect
test/onnx/expect/TestOperators.test_lt.expect
torch/onnx/symbolic.py

index f6b5b1a..0944262 100644 (file)
@@ -6,13 +6,14 @@ ModelProto {
     GraphProto {
       name: "torch-jit-export"
       inputs: [{name: "0", type:Tensor dims: 3 4}]
-      outputs: [{name: "4", type:Tensor dims: 0}]
+      outputs: [{name: "5", type:Tensor dims: 0}]
       initializers: []
       nodes: [
         Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
         Node {type: "Less", inputs: [0,1], outputs: [2], attributes: []},
         Node {type: "Cast", inputs: [2], outputs: [3], attributes: [{ name: 'to', type: int, value: 2}]},
-        Node {type: "ATen", inputs: [0,3], outputs: [4], attributes: [{ name: 'operator', type: string, value: 'index'}]}
+        Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 2}]},
+        Node {type: "ATen", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}]}
       ]
     }
   opset_import: [OperatorSetIdProto { domain: }],
index 264b863..d436925 100644 (file)
@@ -6,20 +6,21 @@ ModelProto {
     GraphProto {
       name: "torch-jit-export"
       inputs: [{name: "x.1", type:Tensor dims: 1 2 3}]
-      outputs: [{name: "4", type:Tensor dims: 1 2 3}]
+      outputs: [{name: "5", type:Tensor dims: 1 2 3}]
       initializers: []
       nodes: [
         Node {type: "ReduceSum", inputs: [x.1], outputs: [1], attributes: [{ name: 'keepdims', type: int, value: 0}]},
         Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
         Node {type: "Greater", inputs: [1,2], outputs: [3], attributes: []},
-        Node {type: "If", inputs: [3], outputs: [4], attributes: [{ name: 'then_branch', type: graph, value:
+        Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 2}]},
+        Node {type: "If", inputs: [4], outputs: [5], attributes: [{ name: 'then_branch', type: graph, value:
             GraphProto {
               name: "torch-jit-export1"
               inputs: []
-              outputs: [{name: "5", type:Tensor dims: }]
+              outputs: [{name: "6", type:Tensor dims: }]
               initializers: []
               nodes: [
-                Node {type: "Neg", inputs: [x.1], outputs: [5], attributes: []}
+                Node {type: "Neg", inputs: [x.1], outputs: [6], attributes: []}
               ]
             }
 
index 4e8e515..dd094b6 100644 (file)
@@ -6,28 +6,29 @@ ModelProto {
     GraphProto {
       name: "torch-jit-export"
       inputs: [{name: "x.1", type:Tensor dims: 1 10}]
-      outputs: [{name: "8", type:Tensor dims: 10 1}]
+      outputs: [{name: "9", type:Tensor dims: 10 1}]
       initializers: []
       nodes: [
         Node {type: "Add", inputs: [x.1,x.1], outputs: [1], attributes: []},
         Node {type: "ReduceSum", inputs: [1], outputs: [2], attributes: [{ name: 'keepdims', type: int, value: 0}]},
         Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
         Node {type: "Greater", inputs: [2,3], outputs: [4], attributes: []},
-        Node {type: "Transpose", inputs: [1], outputs: [5], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
+        Node {type: "Cast", inputs: [4], outputs: [5], attributes: [{ name: 'to', type: int, value: 2}]},
         Node {type: "Transpose", inputs: [1], outputs: [6], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
         Node {type: "Transpose", inputs: [1], outputs: [7], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
-        Node {type: "If", inputs: [4], outputs: [8], attributes: [{ name: 'then_branch', type: graph, value:
+        Node {type: "Transpose", inputs: [1], outputs: [8], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
+        Node {type: "If", inputs: [5], outputs: [9], attributes: [{ name: 'then_branch', type: graph, value:
             GraphProto {
               name: "torch-jit-export1"
               inputs: []
-              outputs: [{name: "9", type:Tensor dims: }]
+              outputs: [{name: "10", type:Tensor dims: }]
               initializers: []
               nodes: [
-                Node {type: "If", inputs: [4], outputs: [9], attributes: [{ name: 'then_branch', type: graph, value:
+                Node {type: "If", inputs: [5], outputs: [10], attributes: [{ name: 'then_branch', type: graph, value:
                     GraphProto {
                       name: "torch-jit-export2"
                       inputs: []
-                      outputs: [{name: "5", type:Tensor dims: }]
+                      outputs: [{name: "6", type:Tensor dims: }]
                       initializers: []
                       nodes: [
                         
@@ -38,7 +39,7 @@ ModelProto {
                     GraphProto {
                       name: "torch-jit-export3"
                       inputs: []
-                      outputs: [{name: "6", type:Tensor dims: }]
+                      outputs: [{name: "7", type:Tensor dims: }]
                       initializers: []
                       nodes: [
                         
@@ -53,7 +54,7 @@ ModelProto {
             GraphProto {
               name: "torch-jit-export4"
               inputs: []
-              outputs: [{name: "7", type:Tensor dims: }]
+              outputs: [{name: "8", type:Tensor dims: }]
               initializers: []
               nodes: [
                 
index 2820ce5..3126f1d 100644 (file)
@@ -6,28 +6,29 @@ ModelProto {
     GraphProto {
       name: "torch-jit-export"
       inputs: [{name: "x.1", type:Tensor dims: 1 10},{name: "1", type:Tensor dims: 20 10},{name: "2", type:Tensor dims: 20}]
-      outputs: [{name: "7", type:Tensor dims: 1 20}]
+      outputs: [{name: "8", type:Tensor dims: 1 20}]
       initializers: [TensorProto shape: [20 10],TensorProto shape: [20]]
       nodes: [
         Node {type: "Add", inputs: [x.1,x.1], outputs: [3], attributes: []},
         Node {type: "ReduceSum", inputs: [3], outputs: [4], attributes: [{ name: 'keepdims', type: int, value: 0}]},
         Node {type: "Constant", inputs: [], outputs: [5], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
         Node {type: "Greater", inputs: [4,5], outputs: [6], attributes: []},
-        Node {type: "If", inputs: [6], outputs: [7], attributes: [{ name: 'then_branch', type: graph, value:
+        Node {type: "Cast", inputs: [6], outputs: [7], attributes: [{ name: 'to', type: int, value: 2}]},
+        Node {type: "If", inputs: [7], outputs: [8], attributes: [{ name: 'then_branch', type: graph, value:
             GraphProto {
               name: "torch-jit-export1"
               inputs: []
-              outputs: [{name: "8", type:Tensor dims: 1 20}]
+              outputs: [{name: "9", type:Tensor dims: 1 20}]
               initializers: []
               nodes: [
-                Node {type: "If", inputs: [6], outputs: [8], attributes: [{ name: 'then_branch', type: graph, value:
+                Node {type: "If", inputs: [7], outputs: [9], attributes: [{ name: 'then_branch', type: graph, value:
                     GraphProto {
                       name: "torch-jit-export2"
                       inputs: []
-                      outputs: [{name: "9", type:Tensor dims: 1 20}]
+                      outputs: [{name: "10", type:Tensor dims: 1 20}]
                       initializers: []
                       nodes: [
-                        Node {type: "Gemm", inputs: [3,1,2], outputs: [9], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
+                        Node {type: "Gemm", inputs: [3,1,2], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
                       ]
                     }
 
@@ -35,10 +36,10 @@ ModelProto {
                     GraphProto {
                       name: "torch-jit-export3"
                       inputs: []
-                      outputs: [{name: "10", type:Tensor dims: 1 20}]
+                      outputs: [{name: "11", type:Tensor dims: 1 20}]
                       initializers: []
                       nodes: [
-                        Node {type: "Gemm", inputs: [3,1,2], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
+                        Node {type: "Gemm", inputs: [3,1,2], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
                       ]
                     }
 
@@ -50,10 +51,10 @@ ModelProto {
             GraphProto {
               name: "torch-jit-export4"
               inputs: []
-              outputs: [{name: "11", type:Tensor dims: 1 20}]
+              outputs: [{name: "12", type:Tensor dims: 1 20}]
               initializers: []
               nodes: [
-                Node {type: "Gemm", inputs: [3,1,2], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
+                Node {type: "Gemm", inputs: [3,1,2], outputs: [12], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
               ]
             }
 
index c511737..53d4f84 100644 (file)
@@ -8,6 +8,16 @@ graph {
     output: "2"
     op_type: "Equal"
   }
+  node {
+    input: "2"
+    output: "3"
+    op_type: "Cast"
+    attribute {
+      name: "to"
+      i: 2
+      type: INT
+    }
+  }
   name: "torch-jit-export"
   input {
     name: "0"
@@ -48,7 +58,7 @@ graph {
     }
   }
   output {
-    name: "2"
+    name: "3"
     type {
       tensor_type {
         elem_type: 2
index a66c5f4..8a4b5d5 100644 (file)
@@ -13,6 +13,16 @@ graph {
     output: "3"
     op_type: "Not"
   }
+  node {
+    input: "3"
+    output: "4"
+    op_type: "Cast"
+    attribute {
+      name: "to"
+      i: 2
+      type: INT
+    }
+  }
   name: "torch-jit-export"
   input {
     name: "0"
@@ -47,7 +57,7 @@ graph {
     }
   }
   output {
-    name: "3"
+    name: "4"
     type {
       tensor_type {
         elem_type: 2
index 7680e0a..542bf81 100644 (file)
@@ -8,6 +8,16 @@ graph {
     output: "2"
     op_type: "Greater"
   }
+  node {
+    input: "2"
+    output: "3"
+    op_type: "Cast"
+    attribute {
+      name: "to"
+      i: 2
+      type: INT
+    }
+  }
   name: "torch-jit-export"
   input {
     name: "0"
@@ -48,7 +58,7 @@ graph {
     }
   }
   output {
-    name: "2"
+    name: "3"
     type {
       tensor_type {
         elem_type: 2
index 0b17740..923489e 100644 (file)
@@ -13,6 +13,16 @@ graph {
     output: "3"
     op_type: "Not"
   }
+  node {
+    input: "3"
+    output: "4"
+    op_type: "Cast"
+    attribute {
+      name: "to"
+      i: 2
+      type: INT
+    }
+  }
   name: "torch-jit-export"
   input {
     name: "0"
@@ -47,7 +57,7 @@ graph {
     }
   }
   output {
-    name: "3"
+    name: "4"
     type {
       tensor_type {
         elem_type: 2
index a5c6740..2befe25 100644 (file)
@@ -8,6 +8,16 @@ graph {
     output: "2"
     op_type: "Less"
   }
+  node {
+    input: "2"
+    output: "3"
+    op_type: "Cast"
+    attribute {
+      name: "to"
+      i: 2
+      type: INT
+    }
+  }
   name: "torch-jit-export"
   input {
     name: "0"
@@ -48,7 +58,7 @@ graph {
     }
   }
   output {
-    name: "2"
+    name: "3"
     type {
       tensor_type {
         elem_type: 2
index 0a259be..41338d9 100644 (file)
@@ -664,24 +664,50 @@ def upsample_bilinear2d(g, input, output_size, align_corners):
                 mode_s="linear")
 
 
+def wrap_logical_op_with_cast_to_uint8(func):
+    def wrap_with_cast(g, input, other):
+        return g.op("Cast", func(g, input, other), to_i=cast_pytorch_to_onnx['Byte'])
+    return wrap_with_cast
+
+
+def wrap_logical_op_with_negation(func):
+    def wrap_with_not(g, input, other):
+        return g.op("Not", func(g, input, other))
+    return wrap_with_not
+
+
+@wrap_logical_op_with_cast_to_uint8
 def gt(g, input, other):
+    return gt_impl(g, input, other)
+
+
+def gt_impl(g, input, other):
     other = _maybe_get_scalar(other)
     return g.op("Greater", input, _if_scalar_type_as(g, other, input))
 
 
+@wrap_logical_op_with_cast_to_uint8
 def lt(g, input, other):
+    return lt_impl(g, input, other)
+
+
+def lt_impl(g, input, other):
     other = _maybe_get_scalar(other)
     return g.op("Less", input, _if_scalar_type_as(g, other, input))
 
 
+@wrap_logical_op_with_cast_to_uint8
+@wrap_logical_op_with_negation
 def ge(g, input, other):
     other = _maybe_get_scalar(other)
-    return g.op("Not", lt(g, input, _if_scalar_type_as(g, other, input)))
+    return lt_impl(g, input, _if_scalar_type_as(g, other, input))
 
 
+@wrap_logical_op_with_cast_to_uint8
+@wrap_logical_op_with_negation
 def le(g, input, other):
     other = _maybe_get_scalar(other)
-    return g.op("Not", gt(g, input, _if_scalar_type_as(g, other, input)))
+    return gt_impl(g, input, _if_scalar_type_as(g, other, input))
 
 
 def where(g, condition, self, other):
@@ -915,6 +941,7 @@ def min(g, self, dim_or_y=None, keepdim=None):
                     outputs=2)
 
 
+@wrap_logical_op_with_cast_to_uint8
 def eq(g, self, other):
     return g.op("Equal", self, other)