Fix pytorch frontend prim::Constant issue (#6051)
authorJinyu Xie <xjygr08@gmail.com>
Tue, 14 Jul 2020 06:51:54 +0000 (02:51 -0400)
committerGitHub <noreply@github.com>
Tue, 14 Jul 2020 06:51:54 +0000 (15:51 +0900)
python/tvm/relay/frontend/pytorch.py

index 0e04ffd..57b64ac 100644 (file)
@@ -2172,6 +2172,8 @@ def _get_constant(node):
             return node.f(attr_name)
         elif ty in ["TensorType", "CompleteTensorType"]:
             tensor = node.t(attr_name)
+            if tensor.is_cuda:
+                tensor = tensor.cpu()
             if len(tensor.shape) == 0:  # tensor(0.1)
                 # TODO(t-vi): When is this needed?
                 return tensor.item()