From 59f5cbe921cf329febcd9d6eff2df94d80f1c523 Mon Sep 17 00:00:00 2001 From: Rand Xie Date: Fri, 12 Jun 2020 21:52:45 -0700 Subject: [PATCH] support aten::type_as in the pytorch frontend (#5787) * support aten::type_as in the pytorch frontend * use _convert_data_type to convert torch type to tvm type and add more types in the type_as test --- python/tvm/relay/frontend/pytorch.py | 9 +++++++ tests/python/frontend/pytorch/test_forward.py | 37 +++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a9f4a7b..d2451cd 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1645,6 +1645,14 @@ def _list_len(prelude): return _impl +def _type_as(): + def _impl(inputs, input_types): + assert len(inputs) == 2 + assert len(input_types) == 2 + return _op.cast(inputs[0], _convert_data_type(input_types[1])) + return _impl + + def _add(prelude): # add_ is overloaded for tensor add and list concat def _impl(inputs, input_types): @@ -1953,6 +1961,7 @@ def _get_convert_map(prelude): "aten::stack" : _tensor_array_stack(prelude), "aten::__getitem__" : _list_getitem(prelude), "aten::len" : _list_len(prelude), + "aten::type_as" : _type_as(), } return convert_map diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 86fb409..f8fb57f 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -27,6 +27,7 @@ import torchvision from tvm import relay from tvm.contrib import graph_runtime +from tvm.contrib.nvcc import have_fp16 from tvm.relay.testing.config import ctx_list @@ -837,6 +838,41 @@ def test_forward_size(): input_data = torch.rand(input_shape).float() verify_model(Size1().float().eval(), input_data=input_data) + +def test_type_as(): + torch.set_grad_enabled(False) + input_shape = [1, 3] + + def _create_module(dtype): + class TypeAs(Module): + def forward(self, *args): + expected_type_tensor = torch.zeros(1, 3, dtype=dtype) + return args[0].type_as(expected_type_tensor) + + return TypeAs() + + input_data = torch.randn(input_shape).float() + verify_model(_create_module(torch.float64), input_data=input_data) + verify_model(_create_module(torch.float32), input_data=input_data) + verify_model(_create_module(torch.int64), input_data=input_data) + verify_model(_create_module(torch.int32), input_data=input_data) + verify_model(_create_module(torch.int16), input_data=input_data) + verify_model(_create_module(torch.int8), input_data=input_data) + + if torch.cuda.is_available(): + check_fp16 = False + try: + # Only check half precision on supported hardwares. + if have_fp16(tvm.gpu(0).compute_version): + check_fp16 = True + except Exception as e: + # If GPU is not enabled in TVM, skip the fp16 test. + pass + + if check_fp16: + verify_model(_create_module(torch.float16), input_data=input_data) + + def test_forward_view(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -2575,6 +2611,7 @@ if __name__ == "__main__": test_upsample() test_forward_upsample3d() test_to() + test_type_as() test_forward_functional_pad() test_forward_zero_pad2d() test_forward_constant_pad1d() -- 2.7.4