GraphProto {
name: "torch-jit-export"
inputs: [{name: "0", type:Tensor dims: 3 4}]
- outputs: [{name: "4", type:Tensor dims: 0}]
+ outputs: [{name: "5", type:Tensor dims: 0}]
initializers: []
nodes: [
Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Less", inputs: [0,1], outputs: [2], attributes: []},
Node {type: "Cast", inputs: [2], outputs: [3], attributes: [{ name: 'to', type: int, value: 2}]},
- Node {type: "ATen", inputs: [0,3], outputs: [4], attributes: [{ name: 'operator', type: string, value: 'index'}]}
+ Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 2}]},
+ Node {type: "ATen", inputs: [0,4], outputs: [5], attributes: [{ name: 'operator', type: string, value: 'index'}]}
]
}
opset_import: [OperatorSetIdProto { domain: }],
GraphProto {
name: "torch-jit-export"
inputs: [{name: "x.1", type:Tensor dims: 1 2 3}]
- outputs: [{name: "4", type:Tensor dims: 1 2 3}]
+ outputs: [{name: "5", type:Tensor dims: 1 2 3}]
initializers: []
nodes: [
Node {type: "ReduceSum", inputs: [x.1], outputs: [1], attributes: [{ name: 'keepdims', type: int, value: 0}]},
Node {type: "Constant", inputs: [], outputs: [2], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Greater", inputs: [1,2], outputs: [3], attributes: []},
- Node {type: "If", inputs: [3], outputs: [4], attributes: [{ name: 'then_branch', type: graph, value:
+ Node {type: "Cast", inputs: [3], outputs: [4], attributes: [{ name: 'to', type: int, value: 2}]},
+ Node {type: "If", inputs: [4], outputs: [5], attributes: [{ name: 'then_branch', type: graph, value:
GraphProto {
name: "torch-jit-export1"
inputs: []
- outputs: [{name: "5", type:Tensor dims: }]
+ outputs: [{name: "6", type:Tensor dims: }]
initializers: []
nodes: [
- Node {type: "Neg", inputs: [x.1], outputs: [5], attributes: []}
+ Node {type: "Neg", inputs: [x.1], outputs: [6], attributes: []}
]
}
GraphProto {
name: "torch-jit-export"
inputs: [{name: "x.1", type:Tensor dims: 1 10}]
- outputs: [{name: "8", type:Tensor dims: 10 1}]
+ outputs: [{name: "9", type:Tensor dims: 10 1}]
initializers: []
nodes: [
Node {type: "Add", inputs: [x.1,x.1], outputs: [1], attributes: []},
Node {type: "ReduceSum", inputs: [1], outputs: [2], attributes: [{ name: 'keepdims', type: int, value: 0}]},
Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Greater", inputs: [2,3], outputs: [4], attributes: []},
- Node {type: "Transpose", inputs: [1], outputs: [5], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
+ Node {type: "Cast", inputs: [4], outputs: [5], attributes: [{ name: 'to', type: int, value: 2}]},
Node {type: "Transpose", inputs: [1], outputs: [6], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
Node {type: "Transpose", inputs: [1], outputs: [7], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
- Node {type: "If", inputs: [4], outputs: [8], attributes: [{ name: 'then_branch', type: graph, value:
+ Node {type: "Transpose", inputs: [1], outputs: [8], attributes: [{ name: 'perm', type: ints, values: [1 0]}]},
+ Node {type: "If", inputs: [5], outputs: [9], attributes: [{ name: 'then_branch', type: graph, value:
GraphProto {
name: "torch-jit-export1"
inputs: []
- outputs: [{name: "9", type:Tensor dims: }]
+ outputs: [{name: "10", type:Tensor dims: }]
initializers: []
nodes: [
- Node {type: "If", inputs: [4], outputs: [9], attributes: [{ name: 'then_branch', type: graph, value:
+ Node {type: "If", inputs: [5], outputs: [10], attributes: [{ name: 'then_branch', type: graph, value:
GraphProto {
name: "torch-jit-export2"
inputs: []
- outputs: [{name: "5", type:Tensor dims: }]
+ outputs: [{name: "6", type:Tensor dims: }]
initializers: []
nodes: [
GraphProto {
name: "torch-jit-export3"
inputs: []
- outputs: [{name: "6", type:Tensor dims: }]
+ outputs: [{name: "7", type:Tensor dims: }]
initializers: []
nodes: [
GraphProto {
name: "torch-jit-export4"
inputs: []
- outputs: [{name: "7", type:Tensor dims: }]
+ outputs: [{name: "8", type:Tensor dims: }]
initializers: []
nodes: [
GraphProto {
name: "torch-jit-export"
inputs: [{name: "x.1", type:Tensor dims: 1 10},{name: "1", type:Tensor dims: 20 10},{name: "2", type:Tensor dims: 20}]
- outputs: [{name: "7", type:Tensor dims: 1 20}]
+ outputs: [{name: "8", type:Tensor dims: 1 20}]
initializers: [TensorProto shape: [20 10],TensorProto shape: [20]]
nodes: [
Node {type: "Add", inputs: [x.1,x.1], outputs: [3], attributes: []},
Node {type: "ReduceSum", inputs: [3], outputs: [4], attributes: [{ name: 'keepdims', type: int, value: 0}]},
Node {type: "Constant", inputs: [], outputs: [5], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
Node {type: "Greater", inputs: [4,5], outputs: [6], attributes: []},
- Node {type: "If", inputs: [6], outputs: [7], attributes: [{ name: 'then_branch', type: graph, value:
+ Node {type: "Cast", inputs: [6], outputs: [7], attributes: [{ name: 'to', type: int, value: 2}]},
+ Node {type: "If", inputs: [7], outputs: [8], attributes: [{ name: 'then_branch', type: graph, value:
GraphProto {
name: "torch-jit-export1"
inputs: []
- outputs: [{name: "8", type:Tensor dims: 1 20}]
+ outputs: [{name: "9", type:Tensor dims: 1 20}]
initializers: []
nodes: [
- Node {type: "If", inputs: [6], outputs: [8], attributes: [{ name: 'then_branch', type: graph, value:
+ Node {type: "If", inputs: [7], outputs: [9], attributes: [{ name: 'then_branch', type: graph, value:
GraphProto {
name: "torch-jit-export2"
inputs: []
- outputs: [{name: "9", type:Tensor dims: 1 20}]
+ outputs: [{name: "10", type:Tensor dims: 1 20}]
initializers: []
nodes: [
- Node {type: "Gemm", inputs: [3,1,2], outputs: [9], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
+ Node {type: "Gemm", inputs: [3,1,2], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
]
}
GraphProto {
name: "torch-jit-export3"
inputs: []
- outputs: [{name: "10", type:Tensor dims: 1 20}]
+ outputs: [{name: "11", type:Tensor dims: 1 20}]
initializers: []
nodes: [
- Node {type: "Gemm", inputs: [3,1,2], outputs: [10], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
+ Node {type: "Gemm", inputs: [3,1,2], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
]
}
GraphProto {
name: "torch-jit-export4"
inputs: []
- outputs: [{name: "11", type:Tensor dims: 1 20}]
+ outputs: [{name: "12", type:Tensor dims: 1 20}]
initializers: []
nodes: [
- Node {type: "Gemm", inputs: [3,1,2], outputs: [11], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
+ Node {type: "Gemm", inputs: [3,1,2], outputs: [12], attributes: [{ name: 'alpha', type: float, value: 1},{ name: 'beta', type: float, value: 1},{ name: 'transB', type: int, value: 1}]}
]
}
output: "2"
op_type: "Equal"
}
+ node {
+ input: "2"
+ output: "3"
+ op_type: "Cast"
+ attribute {
+ name: "to"
+ i: 2
+ type: INT
+ }
+ }
name: "torch-jit-export"
input {
name: "0"
}
}
output {
- name: "2"
+ name: "3"
type {
tensor_type {
elem_type: 2
output: "3"
op_type: "Not"
}
+ node {
+ input: "3"
+ output: "4"
+ op_type: "Cast"
+ attribute {
+ name: "to"
+ i: 2
+ type: INT
+ }
+ }
name: "torch-jit-export"
input {
name: "0"
}
}
output {
- name: "3"
+ name: "4"
type {
tensor_type {
elem_type: 2
output: "2"
op_type: "Greater"
}
+ node {
+ input: "2"
+ output: "3"
+ op_type: "Cast"
+ attribute {
+ name: "to"
+ i: 2
+ type: INT
+ }
+ }
name: "torch-jit-export"
input {
name: "0"
}
}
output {
- name: "2"
+ name: "3"
type {
tensor_type {
elem_type: 2
output: "3"
op_type: "Not"
}
+ node {
+ input: "3"
+ output: "4"
+ op_type: "Cast"
+ attribute {
+ name: "to"
+ i: 2
+ type: INT
+ }
+ }
name: "torch-jit-export"
input {
name: "0"
}
}
output {
- name: "3"
+ name: "4"
type {
tensor_type {
elem_type: 2
output: "2"
op_type: "Less"
}
+ node {
+ input: "2"
+ output: "3"
+ op_type: "Cast"
+ attribute {
+ name: "to"
+ i: 2
+ type: INT
+ }
+ }
name: "torch-jit-export"
input {
name: "0"
}
}
output {
- name: "2"
+ name: "3"
type {
tensor_type {
elem_type: 2
mode_s="linear")
+def wrap_logical_op_with_cast_to_uint8(func):
+ def wrap_with_cast(g, input, other):
+ return g.op("Cast", func(g, input, other), to_i=cast_pytorch_to_onnx['Byte'])
+ return wrap_with_cast
+
+
+def wrap_logical_op_with_negation(func):
+ def wrap_with_not(g, input, other):
+ return g.op("Not", func(g, input, other))
+ return wrap_with_not
+
+
+@wrap_logical_op_with_cast_to_uint8
def gt(g, input, other):
+ return gt_impl(g, input, other)
+
+
+def gt_impl(g, input, other):
other = _maybe_get_scalar(other)
return g.op("Greater", input, _if_scalar_type_as(g, other, input))
+@wrap_logical_op_with_cast_to_uint8
def lt(g, input, other):
+ return lt_impl(g, input, other)
+
+
+def lt_impl(g, input, other):
other = _maybe_get_scalar(other)
return g.op("Less", input, _if_scalar_type_as(g, other, input))
+@wrap_logical_op_with_cast_to_uint8
+@wrap_logical_op_with_negation
def ge(g, input, other):
other = _maybe_get_scalar(other)
- return g.op("Not", lt(g, input, _if_scalar_type_as(g, other, input)))
+ return lt_impl(g, input, _if_scalar_type_as(g, other, input))
+@wrap_logical_op_with_cast_to_uint8
+@wrap_logical_op_with_negation
def le(g, input, other):
other = _maybe_get_scalar(other)
- return g.op("Not", gt(g, input, _if_scalar_type_as(g, other, input)))
+ return gt_impl(g, input, _if_scalar_type_as(g, other, input))
def where(g, condition, self, other):
outputs=2)
+@wrap_logical_op_with_cast_to_uint8
def eq(g, self, other):
return g.op("Equal", self, other)