From 2d752d25bccd58f1c48b795fd275fde3a10c015d Mon Sep 17 00:00:00 2001 From: masahi Date: Sat, 29 Aug 2020 15:14:02 +0900 Subject: [PATCH] [Torch] Add cast to double, fix flatten conversion (#6357) * support cast to double and fix flatten conversion * also support batch flatten, add test * add flatten test * clean up --- python/tvm/relay/frontend/pytorch.py | 15 ++++++++++++++- tests/python/frontend/pytorch/test_forward.py | 21 +++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 21cf9c3..108d1d8 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -996,12 +996,23 @@ def _transpose(prelude): return _op.transform.transpose(data, axes) return _impl + def _flatten(): def _impl(inputs, input_types): data = inputs[0] - return _op.nn.batch_flatten(data) + start_dim = inputs[1] if len(inputs) > 0 else 0 + end_dim = inputs[2] if len(inputs) > 1 else -1 + + if start_dim == 0 and end_dim == -1: + return _op.transform.reshape(data, (-1,)) + if start_dim == 1 and end_dim == -1: + return _op.nn.batch_flatten(data) + + raise NotImplementedError("Only support 1d flatten or batch flatten") + return _impl + def _dense(): def _impl(inputs, input_types): use_bias = isinstance(inputs[0], _expr.Expr) @@ -1509,11 +1520,13 @@ def _to(): # this happens when converting upsampling with scale factor cast_func = { 6: float, + 7: float, 3: int, 4: int } cast_func_expr = { 6: lambda x: _op.cast(x, "float32"), + 7: lambda x: _op.cast(x, "float64"), 3: lambda x: _op.cast(x, "int32"), 4: lambda x: _op.cast(x, "int64"), } diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 946712d..2e54ac4 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -881,6 +881,21 @@ def test_forward_reshape(): verify_model(Reshape1().float().eval(), input_data=input_data) verify_model(Reshape2().float().eval(), input_data=input_data) + +def test_flatten(): + class Flatten(Module): + def forward(self, x): + return torch.flatten(x) + + class BatchFlatten(Module): + def forward(self, x): + return torch.flatten(x, start_dim=1) + + inp = torch.rand((5, 2, 2)) + verify_model(Flatten(), input_data=inp) + verify_model(BatchFlatten(), input_data=inp) + + def test_forward_transpose(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1311,12 +1326,17 @@ def test_to(): def forward(self, x): return x.long() + class ToDouble(Module): + def forward(self, x): + return x.double() + verify_model(ToCPU().eval(), torch.rand((1, 3, 32, 32))) verify_model(ToFloat().eval(), torch.zeros((1, 3, 32, 32), dtype=torch.int)) verify_model(ToFloat().eval(), torch.tensor(2, dtype=torch.int)) verify_model(ToInt().eval(), torch.zeros((1, 3, 32, 32))) verify_model(ToInt().eval(), torch.tensor(0.8)) verify_model(ToLong().eval(), torch.tensor(0.8)) + verify_model(ToDouble().eval(), torch.tensor(0.8)) def test_adaptive_pool3d(): @@ -2901,6 +2921,7 @@ if __name__ == "__main__": test_forward_upsample3d() test_forward_nms() test_to() + test_flatten() test_type_as() test_forward_functional_pad() test_forward_zero_pad2d() -- 2.7.4