From: masahi Date: Mon, 7 Sep 2020 15:02:30 +0000 (+0900) Subject: reshape with non constant shape argument (#6411) X-Git-Tag: upstream/0.7.0~153 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=656b4c5b80d52e3dd68e61dcd2cc4b5c384f4143;p=platform%2Fupstream%2Ftvm.git reshape with non constant shape argument (#6411) --- diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index fe5a45c..99d2dae 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1131,10 +1131,13 @@ def _view(): def _reshape(): def _impl(inputs, input_types): data = inputs[0] - if isinstance(inputs[1], list): + if _is_int_seq(inputs[1]): new_shape = inputs[1] else: - new_shape = _infer_shape(inputs[1]) + assert isinstance(inputs[1], list) + infer_res = [_infer_value(_wrap_const(size), {}) for size in inputs[1]] + new_shape = [np.asscalar(res.asnumpy().astype(np.int)) + for res in infer_res] return _op.transform.reshape(data, new_shape) return _impl diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index cfe9507..b35c7d6 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -946,6 +946,7 @@ def test_forward_reshape(): torch.set_grad_enabled(False) input_shape = [2, 1, 10, 1, 10] new_shape = [2, 1, 10, 10] + class Reshape1(Module): def forward(self, *args): return args[0].reshape(new_shape) @@ -954,9 +955,15 @@ def test_forward_reshape(): def forward(self, *args): return args[0].reshape([-1]) + class Reshape3(torch.nn.Module): + def forward(self, x): + x_shape = x.shape + return x.reshape((x_shape[0] * x_shape[1], x_shape[2])) + input_data = torch.rand(input_shape).float() - verify_model(Reshape1().float().eval(), input_data=input_data) - verify_model(Reshape2().float().eval(), input_data=input_data) + verify_model(Reshape1(), input_data=input_data) + verify_model(Reshape2(), input_data=input_data) + verify_model(Reshape3(), input_data=torch.randn(2, 3, 4)) @tvm.testing.uses_gpu