self.checkScript(to_device, (torch.ones(3, 4),))
+ def test_tensor_to_cpu(self):
+ def to_cpu(x):
+ return x.cpu()
+
+ x = torch.ones(3, 4)
+ script_fn = torch.jit.script(to_cpu)
+ self.assertEqual(to_cpu(x).device, script_fn(x).device)
+ self.checkScript(to_cpu, (x,))
+
+ @unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
+ def test_tensor_to_cuda(self):
+ def to_cuda(x):
+ return x.cuda()
+
+ x = torch.ones(3, 4)
+ script_fn = torch.jit.script(to_cuda)
+ self.assertEqual(to_cuda(x).device, script_fn(x).device)
+ self.checkScript(to_cuda, (x,))
+
def test_generic_list_errors(self):
with self.assertRaisesRegex(RuntimeError, "previously matched to type"):
@torch.jit.script
};
}),
Operator(
+ "aten::cpu(Tensor(a) self) -> Tensor(a)",
+ [](const Node* node) -> Operation {
+ return [](Stack& stack) {
+ at::Tensor a;
+ pop(stack, a);
+ push(stack, a.cpu());
+ return 0;
+ };
+ }),
+ Operator(
+ "aten::cuda(Tensor(a) self) -> Tensor(a)",
+ [](const Node* node) -> Operation {
+ return [](Stack& stack) {
+ at::Tensor a;
+ pop(stack, a);
+ push(stack, a.cuda());
+ return 0;
+ };
+ }),
+ Operator(
"prim::Undefined() -> Tensor",
[](const Node* node) {
return [](Stack& stack) {