x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.run_model_test(FlattenModel(), train=False, input=x, batch_size=BATCH_SIZE)
+ def test_reshape(self):
+ class ReshapeModel(torch.nn.Module):
+ def forward(self, input):
+ return input.reshape(1, 1)
+
+ x = torch.randn(1, requires_grad=True)
+ self.run_model_test(ReshapeModel(), train=False, input=x, batch_size=BATCH_SIZE)
+
+ def test_reshape_as(self):
+ class ReshapeAsModel(torch.nn.Module):
+ def forward(self, input):
+ y = torch.randn(3, 1, 2, 1, requires_grad=False)
+ return input.reshape_as(y)
+
+ x = torch.randn(2, 3, requires_grad=True)
+ self.run_model_test(ReshapeAsModel(), train=False, input=x, batch_size=BATCH_SIZE)
+
# a bit of metaprogramming to set up all the rnn tests
return g.op('Reshape', input, shape)
+def reshape(g, self, shape):
+ return view(g, self, shape)
+
+
+def reshape_as(g, self, other):
+ shape = g.op('Shape', other)
+ return reshape(g, self, shape)
+
+
def add(g, self, other, alpha=None):
# default alpha arg is to allow no-alpha add (aten add st overload no alpha)
if alpha and _scalar(_maybe_get_scalar(alpha)) != 1: