Add bindings for .cpu() & .cuda() to script (#15904)
authorElias Ellison <eellison@fb.com>
Fri, 11 Jan 2019 18:00:37 +0000 (10:00 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 11 Jan 2019 18:04:08 +0000 (10:04 -0800)
Summary:
Adding bindings for .cpu() and .cuda() to script.

It's worth noting that if the device remains unchanged, than the returned tensor aliases the input, but if it does change than they do not alias each other.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15904

Differential Revision: D13632879

Pulled By: eellison

fbshipit-source-id: 024a04f267909674aa1e510562efd9cb081f407c

test/test_jit.py
torch/csrc/jit/register_prim_ops.cpp

index 3cce42f..c6cbe30 100644 (file)
@@ -2914,6 +2914,25 @@ class TestScript(JitTestCase):
 
         self.checkScript(to_device, (torch.ones(3, 4),))
 
+    def test_tensor_to_cpu(self):
+        def to_cpu(x):
+            return x.cpu()
+
+        x = torch.ones(3, 4)
+        script_fn = torch.jit.script(to_cpu)
+        self.assertEqual(to_cpu(x).device, script_fn(x).device)
+        self.checkScript(to_cpu, (x,))
+
+    @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
+    def test_tensor_to_cuda(self):
+        def to_cuda(x):
+            return x.cuda()
+
+        x = torch.ones(3, 4)
+        script_fn = torch.jit.script(to_cuda)
+        self.assertEqual(to_cuda(x).device, script_fn(x).device)
+        self.checkScript(to_cuda, (x,))
+
     def test_generic_list_errors(self):
         with self.assertRaisesRegex(RuntimeError, "previously matched to type"):
             @torch.jit.script
index ee2ac40..e468fa9 100644 (file)
@@ -315,6 +315,26 @@ RegisterOperators reg({
           };
         }),
     Operator(
+        "aten::cpu(Tensor(a) self) -> Tensor(a)",
+        [](const Node* node) -> Operation {
+          return [](Stack& stack) {
+            at::Tensor a;
+            pop(stack, a);
+            push(stack, a.cpu());
+            return 0;
+          };
+        }),
+    Operator(
+        "aten::cuda(Tensor(a) self) -> Tensor(a)",
+        [](const Node* node) -> Operation {
+          return [](Stack& stack) {
+            at::Tensor a;
+            pop(stack, a);
+            push(stack, a.cuda());
+            return 0;
+          };
+        }),
+    Operator(
         "prim::Undefined() -> Tensor",
         [](const Node* node) {
           return [](Stack& stack) {