support aten::type_as in the pytorch frontend (#5787)
authorRand Xie <randxiexyy29@gmail.com>
Sat, 13 Jun 2020 04:52:45 +0000 (21:52 -0700)
committerGitHub <noreply@github.com>
Sat, 13 Jun 2020 04:52:45 +0000 (13:52 +0900)
* support aten::type_as in the pytorch frontend

* use _convert_data_type to convert torch type to tvm type and add more types in the type_as test

python/tvm/relay/frontend/pytorch.py
tests/python/frontend/pytorch/test_forward.py

index a9f4a7b..d2451cd 100644 (file)
@@ -1645,6 +1645,14 @@ def _list_len(prelude):
     return _impl
 
 
+def _type_as():
+    def _impl(inputs, input_types):
+        assert len(inputs) == 2
+        assert len(input_types) == 2
+        return _op.cast(inputs[0], _convert_data_type(input_types[1]))
+    return _impl
+
+
 def _add(prelude):
     # add_ is overloaded for tensor add and list concat
     def _impl(inputs, input_types):
@@ -1953,6 +1961,7 @@ def _get_convert_map(prelude):
         "aten::stack"                           : _tensor_array_stack(prelude),
         "aten::__getitem__"                     : _list_getitem(prelude),
         "aten::len"                             : _list_len(prelude),
+        "aten::type_as"                         : _type_as(),
     }
     return convert_map
 
index 86fb409..f8fb57f 100644 (file)
@@ -27,6 +27,7 @@ import torchvision
 
 from tvm import relay
 from tvm.contrib import graph_runtime
+from tvm.contrib.nvcc import have_fp16
 from tvm.relay.testing.config import ctx_list
 
 
@@ -837,6 +838,41 @@ def test_forward_size():
     input_data = torch.rand(input_shape).float()
     verify_model(Size1().float().eval(), input_data=input_data)
 
+
+def test_type_as():
+    torch.set_grad_enabled(False)
+    input_shape = [1, 3]
+    
+    def _create_module(dtype):
+        class TypeAs(Module):
+            def forward(self, *args):
+                expected_type_tensor = torch.zeros(1, 3, dtype=dtype)
+                return args[0].type_as(expected_type_tensor)
+        
+        return TypeAs()
+
+    input_data = torch.randn(input_shape).float()
+    verify_model(_create_module(torch.float64), input_data=input_data)
+    verify_model(_create_module(torch.float32), input_data=input_data)
+    verify_model(_create_module(torch.int64), input_data=input_data)
+    verify_model(_create_module(torch.int32), input_data=input_data)
+    verify_model(_create_module(torch.int16), input_data=input_data)
+    verify_model(_create_module(torch.int8), input_data=input_data)
+
+    if torch.cuda.is_available():
+        check_fp16 = False
+        try:
+            # Only check half precision on supported hardwares.
+            if have_fp16(tvm.gpu(0).compute_version):
+                check_fp16 = True
+        except Exception as e:
+            # If GPU is not enabled in TVM, skip the fp16 test.
+            pass
+        
+        if check_fp16:
+            verify_model(_create_module(torch.float16), input_data=input_data)
+
+
 def test_forward_view():
     torch.set_grad_enabled(False)
     input_shape = [1, 3, 10, 10]
@@ -2575,6 +2611,7 @@ if __name__ == "__main__":
     test_upsample()
     test_forward_upsample3d()
     test_to()
+    test_type_as()
     test_forward_functional_pad()
     test_forward_zero_pad2d()
     test_forward_constant_pad1d()