From bf2a30cb22afe88b696b41edad6c2c147f9425f2 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Mon, 25 Mar 2019 20:36:44 -0700 Subject: [PATCH] Support dim=None for argmax and argmin (#18264) Summary: Fixes: https://github.com/pytorch/pytorch/issues/18263 cc: houseroad Pull Request resolved: https://github.com/pytorch/pytorch/pull/18264 Reviewed By: ezyang Differential Revision: D14559234 Pulled By: houseroad fbshipit-source-id: c5b8623752d6c6af41c6d715fd9585a65294868d --- test/onnx/test_pytorch_onnx_caffe2.py | 16 ++++++++++++++++ torch/onnx/symbolic.py | 18 ++++++++++++++---- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py index 9c26392..ecc3e4f 100644 --- a/test/onnx/test_pytorch_onnx_caffe2.py +++ b/test/onnx/test_pytorch_onnx_caffe2.py @@ -1121,6 +1121,14 @@ class TestCaffe2Backend(unittest.TestCase): x = torch.randn(4, 4, requires_grad=True) self.run_model_test(ArgmaxModel(), train=False, input=x, batch_size=BATCH_SIZE) + def test_argmax_none_dim(self): + class ArgmaxModel(torch.nn.Module): + def forward(self, input): + return torch.argmax(input) + + x = torch.randn(4, 4, requires_grad=True) + self.run_model_test(ArgmaxModel(), train=False, input=x, batch_size=BATCH_SIZE) + def test_argmin(self): class ArgminModel(torch.nn.Module): def forward(self, input): @@ -1129,6 +1137,14 @@ class TestCaffe2Backend(unittest.TestCase): x = torch.randn(4, 4, requires_grad=True) self.run_model_test(ArgminModel(), train=False, input=x, batch_size=BATCH_SIZE) + def test_argmin_none_dim(self): + class ArgminModel(torch.nn.Module): + def forward(self, input): + return torch.argmin(input) + + x = torch.randn(4, 4, requires_grad=True) + self.run_model_test(ArgminModel(), train=False, input=x, batch_size=BATCH_SIZE) + def test_reshape(self): class ReshapeModel(torch.nn.Module): def forward(self, input): diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py index 2821b41..878f8a3 100644 --- a/torch/onnx/symbolic.py +++ b/torch/onnx/symbolic.py @@ -1690,11 +1690,21 @@ def narrow(g, input, dim, start, length): return g.op("Slice", input, axes_i=[dim], starts_i=[start], ends_i=[start + length]) -@parse_args('v', 'i', 'i') def argmax(g, input, dim, keepdim): - return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim) + if dim.node().mustBeNone(): + flattened = reshape(g, input, (-1,)) + return g.op('ArgMax', flattened, axis_i=0, keepdims_i=False) + else: + dim = _parse_arg(dim, 'i') + keepdim = _parse_arg(keepdim, 'i') + return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim) -@parse_args('v', 'i', 'i') def argmin(g, input, dim, keepdim): - return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim) + if dim.node().mustBeNone(): + flattened = reshape(g, input, (-1,)) + return g.op('ArgMin', flattened, axis_i=0, keepdims_i=False) + else: + dim = _parse_arg(dim, 'i') + keepdim = _parse_arg(keepdim, 'i') + return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim) -- 2.7.4