From f4c504593cc73c4c3939f8d3e8e012c4d47bfd8d Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Tue, 18 Dec 2018 11:28:04 -0800 Subject: [PATCH] Fix the (reduce)min and (reduce)max ONNX exporting (#15241) Summary: max and reducemax are smashed together, we need to support one input case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15241 Reviewed By: yinghai Differential Revision: D13473312 Pulled By: houseroad fbshipit-source-id: 9b8c847286a2631b006ca900271bc0d26574101a --- .../expect/TestOperators.test_reducemax.expect | 51 ++++++++++++++++++++++ .../expect/TestOperators.test_reducemin.expect | 51 ++++++++++++++++++++++ test/onnx/test_operators.py | 8 ++++ torch/onnx/symbolic.py | 8 +++- 4 files changed, 116 insertions(+), 2 deletions(-) create mode 100644 test/onnx/expect/TestOperators.test_reducemax.expect create mode 100644 test/onnx/expect/TestOperators.test_reducemin.expect diff --git a/test/onnx/expect/TestOperators.test_reducemax.expect b/test/onnx/expect/TestOperators.test_reducemax.expect new file mode 100644 index 0000000..207a230 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_reducemax.expect @@ -0,0 +1,51 @@ +ir_version: 3 +producer_name: "pytorch" +producer_version: "0.4" +graph { + node { + input: "0" + output: "1" + op_type: "ReduceMax" + attribute { + name: "keepdims" + i: 0 + type: INT + } + } + name: "torch-jit-export" + input { + name: "0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "1" + type { + tensor_type { + elem_type: 1 + shape { + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/test/onnx/expect/TestOperators.test_reducemin.expect b/test/onnx/expect/TestOperators.test_reducemin.expect new file mode 100644 index 0000000..44ec6b5 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_reducemin.expect @@ -0,0 +1,51 @@ +ir_version: 3 +producer_name: "pytorch" +producer_version: "0.4" +graph { + node { + input: "0" + output: "1" + op_type: "ReduceMin" + attribute { + name: "keepdims" + i: 0 + type: INT + } + } + name: "torch-jit-export" + input { + name: "0" + type { + tensor_type { + elem_type: 1 + shape { + dim { + dim_value: 1 + } + dim { + dim_value: 2 + } + dim { + dim_value: 3 + } + dim { + dim_value: 4 + } + } + } + } + } + output { + name: "1" + type { + tensor_type { + elem_type: 1 + shape { + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 4913585..625cce2 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -496,6 +496,14 @@ class TestOperators(TestCase): y = torch.randn(1, 4, requires_grad=False).int() self.assertONNX(lambda x, y: torch.ne(x, y), (x, y)) + def test_reducemax(self): + x = torch.randn(1, 2, 3, 4) + self.assertONNX(lambda x: torch.max(x), x) + + def test_reducemin(self): + x = torch.randn(1, 2, 3, 4) + self.assertONNX(lambda x: torch.min(x), x) + if __name__ == '__main__': no_onnx_dep_flag = '--no-onnx' _onnx_dep = no_onnx_dep_flag not in common.UNITTEST_ARGS diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index e5e2536..0a259be 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -881,7 +881,9 @@ def clamp_max(g, self, max): # torch.max (same for torch.min) actually has two interfaces smashed together: # torch.max(x, dim, keepdim) and torch.max(x, y) -def max(g, self, dim_or_y, keepdim=None): +def max(g, self, dim_or_y=None, keepdim=None): + if dim_or_y is None and keepdim is None: + return g.op("ReduceMax", self, keepdims_i=0) if keepdim is None: return g.op("Max", self, dim_or_y) else: @@ -896,7 +898,9 @@ def max(g, self, dim_or_y, keepdim=None): outputs=2) -def min(g, self, dim_or_y, keepdim=None): +def min(g, self, dim_or_y=None, keepdim=None): + if dim_or_y is None and keepdim is None: + return g.op("ReduceMin", self, keepdims_i=0) if keepdim is None: return g.op("Min", self, dim_or_y) else: -- 2.7.4