Added bool and half support for resize_as_ and view methods (#18821)
authorIurii Zdebskyi <iuriiz@fb.com>
Thu, 4 Apr 2019 20:01:10 +0000 (13:01 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 4 Apr 2019 20:09:10 +0000 (13:09 -0700)
Summary:
Enabled **resize_as_** and **view** methods for bool and half tensors.
tested via unit tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18821

Reviewed By: ezyang

Differential Revision: D14762852

Pulled By: izdeby

fbshipit-source-id: 4312079fb4e893fea6f71ff4f163094b2674f1e8

aten/src/ATen/Declarations.cwrap
aten/src/ATen/native/native_functions.yaml
test/test_torch.py

index a1bf2b3..ec3a7e2 100644 (file)
 [[
   name: _th_view
   cname: newView
+  cpu_half: True
+  cpu_bool: True
+  cuda_bool: True
   variants:
     - function
   device_guard: False
 [[
   name: _th_resize_as_
   cname: resizeAs
+  cpu_half: True
+  cpu_bool: True
+  cuda_bool: True
   variants:
     - function
   return: self
index e0c497a..7d57faa 100644 (file)
 
 - func: resize_as_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)
   matches_jit_signature: True
+  cpu_bool: True
+  cuda_bool: True
+  cpu_half: True
   variants: function, method
   dispatch:
     CPU: resize_as_
   variants: function, method
 
 - func: view(Tensor(a) self, int[] size) -> Tensor(a)
+  cpu_half: True
+  cpu_bool: True
+  cuda_bool: True
   matches_jit_signature: True
   variants: method
   device_guard: False
index 443b8db..bf486c6 100644 (file)
@@ -3046,6 +3046,20 @@ class _TestTorchMixin(object):
                 x.resize_(shape)
                 self.assertEqual(shape, x.shape)
 
+    def test_resize_as_all_dtypes_and_devices(self):
+        for device in torch.testing.get_all_device_types():
+            for dt in torch.testing.get_all_dtypes():
+                x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
+                y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device)
+                x.resize_as_(y)
+                self.assertEqual(y.shape, x.shape)
+
+    def test_view_all_dtypes_and_devices(self):
+        for device in torch.testing.get_all_device_types():
+            for dt in torch.testing.get_all_dtypes():
+                x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
+                self.assertEqual(x.view(6).shape, [6])
+
     def test_fill_all_dtypes_and_devices(self):
         for device in torch.testing.get_all_device_types():
             for dt in torch.testing.get_all_dtypes():