def _reshape():
def _impl(inputs, input_types):
data = inputs[0]
- if isinstance(inputs[1], list):
+ if _is_int_seq(inputs[1]):
new_shape = inputs[1]
else:
- new_shape = _infer_shape(inputs[1])
+ assert isinstance(inputs[1], list)
+ infer_res = [_infer_value(_wrap_const(size), {}) for size in inputs[1]]
+ new_shape = [np.asscalar(res.asnumpy().astype(np.int))
+ for res in infer_res]
return _op.transform.reshape(data, new_shape)
return _impl
torch.set_grad_enabled(False)
input_shape = [2, 1, 10, 1, 10]
new_shape = [2, 1, 10, 10]
+
class Reshape1(Module):
def forward(self, *args):
return args[0].reshape(new_shape)
def forward(self, *args):
return args[0].reshape([-1])
+ class Reshape3(torch.nn.Module):
+ def forward(self, x):
+ x_shape = x.shape
+ return x.reshape((x_shape[0] * x_shape[1], x_shape[2]))
+
input_data = torch.rand(input_shape).float()
- verify_model(Reshape1().float().eval(), input_data=input_data)
- verify_model(Reshape2().float().eval(), input_data=input_data)
+ verify_model(Reshape1(), input_data=input_data)
+ verify_model(Reshape2(), input_data=input_data)
+ verify_model(Reshape3(), input_data=torch.randn(2, 3, 4))
@tvm.testing.uses_gpu