x = torch.randn(4, 4, requires_grad=True)
self.run_model_test(ArgmaxModel(), train=False, input=x, batch_size=BATCH_SIZE)
+ def test_argmax_none_dim(self):
+ class ArgmaxModel(torch.nn.Module):
+ def forward(self, input):
+ return torch.argmax(input)
+
+ x = torch.randn(4, 4, requires_grad=True)
+ self.run_model_test(ArgmaxModel(), train=False, input=x, batch_size=BATCH_SIZE)
+
def test_argmin(self):
class ArgminModel(torch.nn.Module):
def forward(self, input):
x = torch.randn(4, 4, requires_grad=True)
self.run_model_test(ArgminModel(), train=False, input=x, batch_size=BATCH_SIZE)
+ def test_argmin_none_dim(self):
+ class ArgminModel(torch.nn.Module):
+ def forward(self, input):
+ return torch.argmin(input)
+
+ x = torch.randn(4, 4, requires_grad=True)
+ self.run_model_test(ArgminModel(), train=False, input=x, batch_size=BATCH_SIZE)
+
def test_reshape(self):
class ReshapeModel(torch.nn.Module):
def forward(self, input):
return g.op("Slice", input, axes_i=[dim], starts_i=[start], ends_i=[start + length])
-@parse_args('v', 'i', 'i')
def argmax(g, input, dim, keepdim):
- return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim)
+ if dim.node().mustBeNone():
+ flattened = reshape(g, input, (-1,))
+ return g.op('ArgMax', flattened, axis_i=0, keepdims_i=False)
+ else:
+ dim = _parse_arg(dim, 'i')
+ keepdim = _parse_arg(keepdim, 'i')
+ return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim)
-@parse_args('v', 'i', 'i')
def argmin(g, input, dim, keepdim):
- return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim)
+ if dim.node().mustBeNone():
+ flattened = reshape(g, input, (-1,))
+ return g.op('ArgMin', flattened, axis_i=0, keepdims_i=False)
+ else:
+ dim = _parse_arg(dim, 'i')
+ keepdim = _parse_arg(keepdim, 'i')
+ return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim)