self.assertEqual(t1.dtype, t2.dtype)
self.assertEqual(t1.device, t2.device)
+ # adapted from test in test_torch
+ def test_tensor_to(self):
+ template = dedent('''
+ def func(t):
+ cuda = "{cuda}"
+ device = "{device}"
+ non_blocking = {non_blocking}
+ return {to_str}
+ ''')
+
+ def s(t, to_str, non_blocking=None, device=None, cuda=None):
+ device = device if device is not None else str(t.device)
+ non_blocking = non_blocking if non_blocking is not None else False
+ cuda = "cuda" if cuda is None else cuda
+ code = template.format(to_str=to_str, device=device, non_blocking=non_blocking, cuda=cuda)
+ scope = {}
+ cu = torch.jit.CompilationUnit(code)
+ return cu.func(t)
+
+ def test_copy_behavior(t, non_blocking=False):
+ self.assertIs(t, s(t, 't.to(t, non_blocking=non_blocking)', non_blocking))
+ self.assertIs(t, s(t, 't.to(t.dtype, non_blocking=non_blocking)', non_blocking))
+ self.assertIs(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking)', non_blocking))
+ self.assertIsNot(t, s(t, 't.to(t, non_blocking=non_blocking, copy=True)', non_blocking))
+ self.assertIsNot(t, s(t, 't.to(t.dtype, non_blocking=non_blocking, copy=True)', non_blocking))
+ self.assertIsNot(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)', non_blocking))
+
+ devices = [t.device]
+ if t.device.type == 'cuda':
+ if t.device.index == -1:
+ devices.append('cuda:{}'.format(torch.cuda.current_device()))
+ elif t.device.index == torch.cuda.current_device():
+ devices.append('cuda')
+ for device in devices:
+ self.assertIs(t, s(t, 't.to(device, non_blocking=non_blocking)', non_blocking, device))
+ self.assertIs(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking)', non_blocking, device))
+ self.assertIsNot(t, s(t, 't.to(device, non_blocking=non_blocking, copy=True)', non_blocking, device))
+ self.assertIsNot(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking, copy=True)',
+ non_blocking, device))
+
+ t = torch.tensor(5)
+ test_copy_behavior(t)
+
+ self.assertEqual(t.device, s(t, "t.to('cpu')").device)
+ self.assertEqual(t.device, s(t, "t.to('cpu', dtype=torch.float32)").device)
+ self.assertIs(torch.float32, s(t, "t.to('cpu', dtype=torch.float32)").dtype)
+ self.assertEqual(t.device, s(t, "t.to(torch.float32)").device)
+ self.assertIs(torch.float32, s(t, "t.to(dtype=torch.float32)").dtype)
+ self.assertEqual(t.data_ptr(), s(t, "t.to('cpu')").data_ptr())
+ self.assertEqual(t.data_ptr(), s(t, "t.to(dtype=t.dtype, device=t.device, copy=False)").data_ptr())
+ self.assertEqual(t.data_ptr(), s(t, "t.to('cpu', copy=False)").data_ptr())
+ self.assertNotEqual(t.data_ptr(), s(t, "t.to('cpu', copy=True)").data_ptr())
+
+ a = torch.tensor(5)
+ if torch.cuda.is_available():
+ for non_blocking in [True, False]:
+ for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
+ b = torch.tensor(5., device=cuda)
+ test_copy_behavior(b, non_blocking)
+ self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
+ self.assertEqual(a.device, s(b, "t.to('cpu', non_blocking=non_blocking).device"))
+ self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
+ self.assertIs(torch.int32, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").dtype)
+ self.assertEqual(a.device, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").device)
+ self.assertIs(torch.int32, s(b, "t.to(dtype=torch.int32)").dtype)
+ self.assertEqual(b.device, s(b, "t.to(dtype=torch.int32)").device)
+
@unittest.skipIf(not RUN_CUDA, "No CUDA")
@skipIfRocm
def test_tensor_number_math_cuda(self):
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/jit_exception.h>
+#include <aten/src/ATen/Context.h>
#include <ATen/ExpandUtils.h>
#include <ATen/WrapDimUtils.h>
}
}
+// reference function THPVariable_to in python_variable_methods.cpp
+static at::Tensor to_dispatch(at::Tensor self, c10::optional<at::Device> device,
+ c10::optional<at::ScalarType> scalarType, bool non_blocking, bool copy) {
+ if (device && device->is_cuda()) {
+ at::globalContext().lazyInitCUDA();
+ }
+ if (!device && !scalarType && !copy) {
+ return self;
+ } else if (!device) {
+ return self.to(*scalarType, non_blocking, copy);
+ } else if (!scalarType) {
+ return self.to(*device, non_blocking, copy);
+ } else {
+ return self.to(*device, *scalarType, non_blocking, copy);
+ }
+}
+
RegisterOperators reg({
Operator(
prim::FusionGroup,
return 0;
};
}),
+ // reference function parse_to_conversion in python_arg_parsing.h
+ Operator(
+ "aten::to(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a)",
+ [](const Node* node) -> Operation {
+ return [](Stack& stack) {
+ bool non_blocking;
+ bool copy;
+ pop(stack, non_blocking, copy);
+ c10::optional<at::ScalarType> scalarType = pop(stack).toOptional<at::ScalarType>();
+ c10::optional<c10::Device> device = pop(stack).toOptional<c10::Device>();
+ at::Tensor self = pop(stack).toTensor();
+ push(stack, to_dispatch(self, device, scalarType, non_blocking, copy));
+ return 0;
+ };
+ }),
+ Operator(
+ "aten::to(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a)",
+ [](const Node* node) -> Operation {
+ return [](Stack& stack) {
+ bool non_blocking;
+ bool copy;
+ pop(stack, non_blocking, copy);
+ c10::optional<at::ScalarType> scalarType = pop(stack).toOptional<at::ScalarType>();
+ c10::optional<c10::Device> device = c10::nullopt;
+ at::Tensor self = pop(stack).toTensor();
+ push(stack, to_dispatch(self, device, scalarType, non_blocking, copy));
+ return 0;
+ };
+ }),
+ Operator(
+ "aten::to(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a)",
+ [](const Node* node) -> Operation {
+ return [](Stack& stack) {
+ at::Tensor self;
+ bool non_blocking;
+ bool copy;
+ pop(stack, self, non_blocking, copy);
+ c10::optional<c10::Device> device = c10::nullopt;
+ c10::optional<at::ScalarType> scalarType = c10::nullopt;
+ push(stack, to_dispatch(self, device, scalarType, non_blocking, copy));
+ return 0;
+ };
+ }),
Operator(
"aten::eq(Device a, Device b) -> bool",
[](const Node* node) -> Operation {