add tensor.to to script (#15976)
authorElias Ellison <eellison@fb.com>
Mon, 14 Jan 2019 23:44:50 +0000 (15:44 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 14 Jan 2019 23:49:31 +0000 (15:49 -0800)
Summary:
Previously it only worked with keyword arguments. Now it is fully compatible.

Fix for: https://github.com/pytorch/pytorch/issues/15478
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15976

Differential Revision: D13643979

Pulled By: eellison

fbshipit-source-id: 6a47bce7db362da80452adffebd2732f8e62a240

test/test_jit.py
torch/csrc/jit/register_prim_ops.cpp

index fe56294..f28b61a 100644 (file)
@@ -4644,6 +4644,73 @@ a")
                 self.assertEqual(t1.dtype, t2.dtype)
                 self.assertEqual(t1.device, t2.device)
 
+    # adapted from test in test_torch
+    def test_tensor_to(self):
+        template = dedent('''
+        def func(t):
+            cuda = "{cuda}"
+            device = "{device}"
+            non_blocking = {non_blocking}
+            return {to_str}
+        ''')
+
+        def s(t, to_str, non_blocking=None, device=None, cuda=None):
+            device = device if device is not None else str(t.device)
+            non_blocking = non_blocking if non_blocking is not None else False
+            cuda = "cuda" if cuda is None else cuda
+            code = template.format(to_str=to_str, device=device, non_blocking=non_blocking, cuda=cuda)
+            scope = {}
+            cu = torch.jit.CompilationUnit(code)
+            return cu.func(t)
+
+        def test_copy_behavior(t, non_blocking=False):
+            self.assertIs(t, s(t, 't.to(t, non_blocking=non_blocking)', non_blocking))
+            self.assertIs(t, s(t, 't.to(t.dtype, non_blocking=non_blocking)', non_blocking))
+            self.assertIs(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking)', non_blocking))
+            self.assertIsNot(t, s(t, 't.to(t, non_blocking=non_blocking, copy=True)', non_blocking))
+            self.assertIsNot(t, s(t, 't.to(t.dtype, non_blocking=non_blocking, copy=True)', non_blocking))
+            self.assertIsNot(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)', non_blocking))
+
+            devices = [t.device]
+            if t.device.type == 'cuda':
+                if t.device.index == -1:
+                    devices.append('cuda:{}'.format(torch.cuda.current_device()))
+                elif t.device.index == torch.cuda.current_device():
+                    devices.append('cuda')
+            for device in devices:
+                self.assertIs(t, s(t, 't.to(device, non_blocking=non_blocking)', non_blocking, device))
+                self.assertIs(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking)', non_blocking, device))
+                self.assertIsNot(t, s(t, 't.to(device, non_blocking=non_blocking, copy=True)', non_blocking, device))
+                self.assertIsNot(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking, copy=True)',
+                                      non_blocking, device))
+
+        t = torch.tensor(5)
+        test_copy_behavior(t)
+
+        self.assertEqual(t.device, s(t, "t.to('cpu')").device)
+        self.assertEqual(t.device, s(t, "t.to('cpu', dtype=torch.float32)").device)
+        self.assertIs(torch.float32, s(t, "t.to('cpu', dtype=torch.float32)").dtype)
+        self.assertEqual(t.device, s(t, "t.to(torch.float32)").device)
+        self.assertIs(torch.float32, s(t, "t.to(dtype=torch.float32)").dtype)
+        self.assertEqual(t.data_ptr(), s(t, "t.to('cpu')").data_ptr())
+        self.assertEqual(t.data_ptr(), s(t, "t.to(dtype=t.dtype, device=t.device, copy=False)").data_ptr())
+        self.assertEqual(t.data_ptr(), s(t, "t.to('cpu', copy=False)").data_ptr())
+        self.assertNotEqual(t.data_ptr(), s(t, "t.to('cpu', copy=True)").data_ptr())
+
+        a = torch.tensor(5)
+        if torch.cuda.is_available():
+            for non_blocking in [True, False]:
+                for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
+                    b = torch.tensor(5., device=cuda)
+                    test_copy_behavior(b, non_blocking)
+                    self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
+                    self.assertEqual(a.device, s(b, "t.to('cpu', non_blocking=non_blocking).device"))
+                    self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
+                    self.assertIs(torch.int32, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").dtype)
+                    self.assertEqual(a.device, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").device)
+                    self.assertIs(torch.int32, s(b, "t.to(dtype=torch.int32)").dtype)
+                    self.assertEqual(b.device, s(b, "t.to(dtype=torch.int32)").device)
+
     @unittest.skipIf(not RUN_CUDA, "No CUDA")
     @skipIfRocm
     def test_tensor_number_math_cuda(self):
index e468fa9..ea6e5c6 100644 (file)
@@ -9,6 +9,7 @@
 #include <torch/csrc/jit/ir.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/script/jit_exception.h>
+#include <aten/src/ATen/Context.h>
 
 #include <ATen/ExpandUtils.h>
 #include <ATen/WrapDimUtils.h>
@@ -86,6 +87,23 @@ static int64_t floordiv(int64_t a, int64_t b) {
   }
 }
 
+// reference function THPVariable_to in python_variable_methods.cpp
+static at::Tensor to_dispatch(at::Tensor self, c10::optional<at::Device> device,
+      c10::optional<at::ScalarType> scalarType, bool non_blocking, bool copy) {
+  if (device && device->is_cuda()) {
+    at::globalContext().lazyInitCUDA();
+  }
+  if (!device && !scalarType && !copy) {
+    return self;
+  } else if (!device) {
+    return self.to(*scalarType, non_blocking, copy);
+  } else if (!scalarType) {
+    return self.to(*device, non_blocking, copy);
+  } else {
+    return self.to(*device, *scalarType, non_blocking, copy);
+  }
+}
+
 RegisterOperators reg({
     Operator(
         prim::FusionGroup,
@@ -256,6 +274,49 @@ RegisterOperators reg({
             return 0;
           };
         }),
+    // reference function parse_to_conversion in python_arg_parsing.h
+    Operator(
+        "aten::to(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a)",
+        [](const Node* node) -> Operation {
+          return [](Stack& stack) {
+            bool non_blocking;
+            bool copy;
+            pop(stack, non_blocking, copy);
+            c10::optional<at::ScalarType> scalarType = pop(stack).toOptional<at::ScalarType>();
+            c10::optional<c10::Device> device = pop(stack).toOptional<c10::Device>();
+            at::Tensor self = pop(stack).toTensor();
+            push(stack, to_dispatch(self, device, scalarType, non_blocking, copy));
+            return 0;
+          };
+        }),
+    Operator(
+        "aten::to(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a)",
+        [](const Node* node) -> Operation {
+          return [](Stack& stack) {
+            bool non_blocking;
+            bool copy;
+            pop(stack, non_blocking, copy);
+            c10::optional<at::ScalarType> scalarType = pop(stack).toOptional<at::ScalarType>();
+            c10::optional<c10::Device> device = c10::nullopt;
+            at::Tensor self = pop(stack).toTensor();
+            push(stack, to_dispatch(self, device, scalarType, non_blocking, copy));
+            return 0;
+          };
+        }),
+    Operator(
+        "aten::to(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a)",
+        [](const Node* node) -> Operation {
+          return [](Stack& stack) {
+            at::Tensor self;
+            bool non_blocking;
+            bool copy;
+            pop(stack, self, non_blocking, copy);
+            c10::optional<c10::Device> device = c10::nullopt;
+            c10::optional<at::ScalarType> scalarType = c10::nullopt;
+            push(stack, to_dispatch(self, device, scalarType, non_blocking, copy));
+            return 0;
+          };
+        }),
     Operator(
         "aten::eq(Device a, Device b) -> bool",
         [](const Node* node) -> Operation {