From b4d2df1fee35e9f2e8fb01297261e6c19d568e75 Mon Sep 17 00:00:00 2001 From: Iurii Zdebskyi Date: Thu, 4 Apr 2019 13:01:10 -0700 Subject: [PATCH] Added bool and half support for resize_as_ and view methods (#18821) Summary: Enabled **resize_as_** and **view** methods for bool and half tensors. tested via unit tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/18821 Reviewed By: ezyang Differential Revision: D14762852 Pulled By: izdeby fbshipit-source-id: 4312079fb4e893fea6f71ff4f163094b2674f1e8 --- aten/src/ATen/Declarations.cwrap | 6 ++++++ aten/src/ATen/native/native_functions.yaml | 6 ++++++ test/test_torch.py | 14 ++++++++++++++ 3 files changed, 26 insertions(+) diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index a1bf2b3..ec3a7e2 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -136,6 +136,9 @@ [[ name: _th_view cname: newView + cpu_half: True + cpu_bool: True + cuda_bool: True variants: - function device_guard: False @@ -148,6 +151,9 @@ [[ name: _th_resize_as_ cname: resizeAs + cpu_half: True + cpu_bool: True + cuda_bool: True variants: - function return: self diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e0c497a..7d57faa 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2567,6 +2567,9 @@ - func: resize_as_(Tensor(a!) self, Tensor the_template) -> Tensor(a!) matches_jit_signature: True + cpu_bool: True + cuda_bool: True + cpu_half: True variants: function, method dispatch: CPU: resize_as_ @@ -3219,6 +3222,9 @@ variants: function, method - func: view(Tensor(a) self, int[] size) -> Tensor(a) + cpu_half: True + cpu_bool: True + cuda_bool: True matches_jit_signature: True variants: method device_guard: False diff --git a/test/test_torch.py b/test/test_torch.py index 443b8db..bf486c6 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -3046,6 +3046,20 @@ class _TestTorchMixin(object): x.resize_(shape) self.assertEqual(shape, x.shape) + def test_resize_as_all_dtypes_and_devices(self): + for device in torch.testing.get_all_device_types(): + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) + y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device) + x.resize_as_(y) + self.assertEqual(y.shape, x.shape) + + def test_view_all_dtypes_and_devices(self): + for device in torch.testing.get_all_device_types(): + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) + self.assertEqual(x.view(6).shape, [6]) + def test_fill_all_dtypes_and_devices(self): for device in torch.testing.get_all_device_types(): for dt in torch.testing.get_all_dtypes(): -- 2.7.4