Support dim=None for argmax and argmin (#18264)
authorXiang Gao <qasdfgtyuiop@gmail.com>
Tue, 26 Mar 2019 03:36:44 +0000 (20:36 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 26 Mar 2019 03:43:34 +0000 (20:43 -0700)
Summary:
Fixes: https://github.com/pytorch/pytorch/issues/18263
cc: houseroad
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18264

Reviewed By: ezyang

Differential Revision: D14559234

Pulled By: houseroad

fbshipit-source-id: c5b8623752d6c6af41c6d715fd9585a65294868d

test/onnx/test_pytorch_onnx_caffe2.py
torch/onnx/symbolic.py

index 9c26392..ecc3e4f 100644 (file)
@@ -1121,6 +1121,14 @@ class TestCaffe2Backend(unittest.TestCase):
         x = torch.randn(4, 4, requires_grad=True)
         self.run_model_test(ArgmaxModel(), train=False, input=x, batch_size=BATCH_SIZE)
 
+    def test_argmax_none_dim(self):
+        class ArgmaxModel(torch.nn.Module):
+            def forward(self, input):
+                return torch.argmax(input)
+
+        x = torch.randn(4, 4, requires_grad=True)
+        self.run_model_test(ArgmaxModel(), train=False, input=x, batch_size=BATCH_SIZE)
+
     def test_argmin(self):
         class ArgminModel(torch.nn.Module):
             def forward(self, input):
@@ -1129,6 +1137,14 @@ class TestCaffe2Backend(unittest.TestCase):
         x = torch.randn(4, 4, requires_grad=True)
         self.run_model_test(ArgminModel(), train=False, input=x, batch_size=BATCH_SIZE)
 
+    def test_argmin_none_dim(self):
+        class ArgminModel(torch.nn.Module):
+            def forward(self, input):
+                return torch.argmin(input)
+
+        x = torch.randn(4, 4, requires_grad=True)
+        self.run_model_test(ArgminModel(), train=False, input=x, batch_size=BATCH_SIZE)
+
     def test_reshape(self):
         class ReshapeModel(torch.nn.Module):
             def forward(self, input):
index 2821b41..878f8a3 100644 (file)
@@ -1690,11 +1690,21 @@ def narrow(g, input, dim, start, length):
     return g.op("Slice", input, axes_i=[dim], starts_i=[start], ends_i=[start + length])
 
 
-@parse_args('v', 'i', 'i')
 def argmax(g, input, dim, keepdim):
-    return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim)
+    if dim.node().mustBeNone():
+        flattened = reshape(g, input, (-1,))
+        return g.op('ArgMax', flattened, axis_i=0, keepdims_i=False)
+    else:
+        dim = _parse_arg(dim, 'i')
+        keepdim = _parse_arg(keepdim, 'i')
+        return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim)
 
 
-@parse_args('v', 'i', 'i')
 def argmin(g, input, dim, keepdim):
-    return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim)
+    if dim.node().mustBeNone():
+        flattened = reshape(g, input, (-1,))
+        return g.op('ArgMin', flattened, axis_i=0, keepdims_i=False)
+    else:
+        dim = _parse_arg(dim, 'i')
+        keepdim = _parse_arg(keepdim, 'i')
+        return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim)