Bool Tensor for CUDA (#18166)
authorIurii Zdebskyi <iuriiz@fb.com>
Tue, 2 Apr 2019 23:10:43 +0000 (16:10 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 2 Apr 2019 23:17:05 +0000 (16:17 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18166
ghimport-source-id: a8e2ba2d966e49747a55701c4f6863c5e24d6f14

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18166 Bool Tensor for CUDA**
* #18165 Resolved comments from Bool Tensor for CPU PR
------

This PR enables bool tensor creation and some basic operations for the CPU backend. This is a part of Bool Tensor feature implementation work. The whole plan looks like this:
1. Storage Implementation [Done]
2. Tensor Creation.
a) CPU [Done]
b) CUDA [This PR]
3. Tensor Conversions.
4. Tensor Indexing.
5. Tensor Operations.
6. Back compatibility related changes.

Change:
Enable bool tensor in CUDA with the following operations:

    torch.zeros
    torch.tensor
    torch.ones
    torch.rand/rand_like/randint/randint_like
    torch.full
    torch.full_like
    torch.empty
    torch.empty_like

Tested via unit tests and local scripts.

Differential Revision: D14605104

fbshipit-source-id: b7d7340a7d70edd03a109222d271e68becba762c

20 files changed:
aten/src/ATen/Declarations.cwrap
aten/src/ATen/function_wrapper.py
aten/src/ATen/gen.py
aten/src/ATen/native/TensorCompare.cpp
aten/src/ATen/native/cpu/CopyKernel.cpp
aten/src/ATen/native/cuda/CUDAScalar.cu
aten/src/ATen/native/cuda/Copy.cu
aten/src/ATen/native/cuda/TensorCompare.cu
aten/src/ATen/native/native_functions.yaml
aten/src/ATen/native_parse.py
aten/src/ATen/preprocess_declarations.py
aten/src/THC/THCTensorMath.cu
aten/src/THC/THCTensorMath.h
aten/src/THC/THCTensorRandom.cu
aten/src/THC/THCTensorRandom.h
aten/src/THC/generic/THCTensorMath.cu
aten/src/THC/generic/THCTensorMath.h
c10/util/Half.h
test/test_torch.py
torch/testing/__init__.py

index b59bacf..c438886 100644 (file)
@@ -4,6 +4,7 @@
   variants: function
   cpu_half: True
   cpu_bool: True
+  cuda_bool: True
   device_guard: False
   return: argument 0
   options:
@@ -41,6 +42,7 @@
   variants: function
   cpu_half: True
   cpu_bool: True
+  cuda_bool: True
   options:
     - arguments:
       - THTensor* self
@@ -57,6 +59,7 @@
     - function
   cpu_half: True
   cpu_bool: True
+  cuda_bool: True
   device_guard: False
   return: bool
   arguments:
     - function
   cpu_half: True
   cpu_bool: True
+  cuda_bool: True
   arguments:
     - THTensor* self
 ]]
     - function
   cpu_half: True
   cpu_bool: True
+  cuda_bool: True
   device_guard: False
   return: argument 0
   arguments:
   return: self
   cpu_half: True
   cpu_bool: True
+  cuda_bool: True
   variants:
     - function
   arguments:
     - CPU
     - CUDA
   return: self
-  cpu_bool: True
   variants: function
+  cpu_bool: True
+  cuda_bool: True
   options:
     - cname: random
       arguments:
   return: THTensor*
   cpu_half: True
   cpu_bool: True
+  cuda_bool: True
   variants:
     - function
   options:
   cname: catArray
   variants: [function]
   cpu_half: True
+  cpu_bool: True
+  cuda_bool: True
   return: self
   arguments:
     - arg: THTensor* self
index 61eff13..3694803 100644 (file)
@@ -494,6 +494,7 @@ FunctionOption = TypedDict('FunctionOption', {
     'cpu_half': bool,
     'deprecated': bool,
     'cpu_bool': bool,
+    'cuda_bool': bool,
     # See Note [field_name versus name]
     'field_name': str,
     'formals_list': List[AtFormal],
index 8b0cf0d..356e43f 100644 (file)
@@ -180,7 +180,7 @@ extension_backends = ['MSNPU', 'XLA']
 
 # scalar_name, c_type, accreal, is_floating_type
 scalar_types = [
-    ('Bool', 'uint8_t', 'BoolAccrealNotDefined', False),
+    ('Bool', 'bool', 'BoolAccrealNotDefined', False),
     ('Byte', 'uint8_t', 'Long', False),
     ('Char', 'int8_t', 'Long', False),
     ('Double', 'double', 'Double', True),
index ec7ddb5..9d0c2be 100644 (file)
@@ -15,17 +15,31 @@ void where_cpu(
     const at::Tensor& condition,
     const at::Tensor& self,
     const at::Tensor& other) {
-  at::CPU_tensor_apply4<scalar_t, uint8_t, scalar_t, scalar_t>(
-      ret,
-      condition,
-      self,
-      other,
-      [](scalar_t& ret_val,
-         const uint8_t& cond_val,
-         const scalar_t& self_val,
-         const scalar_t& other_val) {
-        ret_val = cond_val ? self_val : other_val;
-      });
+  if (condition.scalar_type() == at::ScalarType::Byte) {
+    at::CPU_tensor_apply4<scalar_t, uint8_t, scalar_t, scalar_t>(
+        ret,
+        condition,
+        self,
+        other,
+        [](scalar_t& ret_val,
+           const uint8_t& cond_val,
+           const scalar_t& self_val,
+           const scalar_t& other_val) {
+          ret_val = cond_val ? self_val : other_val;
+        });
+    } else {
+      at::CPU_tensor_apply4<scalar_t, bool, scalar_t, scalar_t>(
+          ret,
+          condition,
+          self,
+          other,
+          [](scalar_t& ret_val,
+             const bool& cond_val,
+             const scalar_t& self_val,
+             const scalar_t& other_val) {
+            ret_val = cond_val ? self_val : other_val;
+          });
+    }
 }
 } // namespace
 
@@ -80,7 +94,7 @@ bool is_nonzero(const Tensor& self) {
 }
 
 Tensor where(const Tensor& condition, const Tensor& self, const Tensor& other) {
-  if (condition.scalar_type() != ScalarType::Byte) {
+  if (condition.scalar_type() != ScalarType::Byte && condition.scalar_type() != ScalarType::Bool) {
     AT_ERROR("Expected condition to have ScalarType Byte, but got ScalarType ",
                   toString(condition.scalar_type()));
   }
index d8cf81d..4f157bc 100644 (file)
@@ -14,8 +14,8 @@ namespace {
 constexpr int64_t COPY_GRAIN_SIZE = 20000;
 
 static void copy_kernel_impl(Tensor& dst, const Tensor& src) {
-  AT_DISPATCH_ALL_TYPES_AND(
-    at::ScalarType::Half, dst.scalar_type(), "copy_kernel_impl", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND2(
+    at::ScalarType::Half, at::ScalarType::Bool, dst.scalar_type(), "copy_kernel_impl", [&]() {
       scalar_t* self_ptr = dst.data<scalar_t>();
       scalar_t* src_ptr = src.data<scalar_t>();
 
index 68079ad..691d580 100644 (file)
@@ -9,8 +9,8 @@ namespace native {
 
 Scalar _local_scalar_dense_cuda(const Tensor& self) {
   Scalar r;
-  AT_DISPATCH_ALL_TYPES_AND(
-    at::ScalarType::Half, self.scalar_type(), "_local_scalar_dense_cuda", [&] {
+  AT_DISPATCH_ALL_TYPES_AND2(
+    at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "_local_scalar_dense_cuda", [&] {
         scalar_t value;
         cudaStream_t stream = at::cuda::getCurrentCUDAStream();
         AT_CUDA_CHECK(cudaMemcpyAsync(&value, self.data<scalar_t>(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream));
index 35dfb9b..9ddab33 100644 (file)
@@ -25,7 +25,17 @@ struct CopyOp {
 #else
           dst_val = static_cast<dst_T>(static_cast<native::inter_copy_type_t<dst_T>>(src_val));
 #endif
-        });
+      });
+  }
+};
+
+template<typename dst_T>
+struct CopyOp<dst_T, bool> {
+  static void apply(Tensor& dst, const Tensor& src) {
+    CUDA_tensor_apply2<dst_T, bool>(
+      dst, src, [] __device__(dst_T & dst_val, const bool& src_val) {
+        dst_val = static_cast<dst_T>(static_cast<native::inter_copy_type_t<dst_T>>(src_val));
+      });
   }
 };
 
@@ -169,7 +179,7 @@ void copy_from_cpu(Tensor& dst, const Tensor& src) {
       cudaMemcpyHostToDevice,
       stream));
   AT_CUDA_CHECK(cudaStreamSynchronize(stream));
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "copy_from_cpu", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, src.scalar_type(), "copy_from_cpu", [&]() {
     copy_device_to_device<scalar_t, scalar_t>(dst, dst_contig);
   });
 }
@@ -202,7 +212,7 @@ void copy_from_cpu_async_(Tensor& dst, const Tensor& src) {
   CUDAGuard device_guard(dst.device());
   CUDAStream stream = getCurrentCUDAStream();
 
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "copy_from_cpu_async", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, src.scalar_type(), "copy_from_cpu_async", [&]() {
     AT_CUDA_CHECK(cudaMemcpyAsync(
         dst.data<scalar_t>(),
         src.data<scalar_t>(),
@@ -225,7 +235,7 @@ void copy_to_cpu_async_(Tensor& dst, const Tensor& src) {
   CUDAGuard device_guard(src.device());
   CUDAStream stream = getCurrentCUDAStream();
 
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "copy_to_cpu_async", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, src.scalar_type(), "copy_to_cpu_async", [&]() {
     AT_CUDA_CHECK(cudaMemcpyAsync(
         dst.data<scalar_t>(),
         src.data<scalar_t>(),
@@ -240,7 +250,7 @@ void copy_to_cpu_async_(Tensor& dst, const Tensor& src) {
 template <typename dst_T>
 void _copy__cuda(Tensor& dst, const Tensor& src, bool non_blocking) {
   AT_CHECK(dst.numel() == src.numel(), "sizes do not match");
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_copy__cuda", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, src.scalar_type(), "_copy__cuda", [&]() {
     if (dst.is_cuda() && src.is_cuda()) {
       copy_device_to_device<dst_T, scalar_t>(dst, src);
     } else if (dst.is_cuda()) {
@@ -279,7 +289,7 @@ namespace at {
 namespace native {
 
 Tensor& _s_copy__cuda(Tensor& self, const Tensor& src, bool non_blocking) {
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "_copy__cuda", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "_copy__cuda", [&]() {
     ::_copy__cuda<scalar_t>(self, src, non_blocking);
   });
   return self;
index d4ccc70..8cd8e17 100644 (file)
@@ -10,20 +10,35 @@ void where_cuda(
     const at::Tensor& condition,
     const at::Tensor& self,
     const at::Tensor& other) {
-  // Yes this name is repetitive, but the CPU version is called
-  // CPU_tensor_apply4 and we don't have a CPU namespace or directory.
-  at::cuda::CUDA_tensor_apply4<scalar_t, uint8_t, scalar_t, scalar_t>(
-      ret,
-      condition,
-      self,
-      other,
-      [] __device__(
-          scalar_t & ret_val,
-          const uint8_t& cond_val,
-          const scalar_t& self_val,
-          const scalar_t& other_val) {
-        ret_val = cond_val ? self_val : other_val;
-      });
+  if (condition.scalar_type() == at::ScalarType::Byte) {
+    // Yes this name is repetitive, but the CPU version is called
+    // CPU_tensor_apply4 and we don't have a CPU namespace or directory.
+    at::cuda::CUDA_tensor_apply4<scalar_t, uint8_t, scalar_t, scalar_t>(
+        ret,
+        condition,
+        self,
+        other,
+        [] __device__(
+            scalar_t & ret_val,
+            const uint8_t& cond_val,
+            const scalar_t& self_val,
+            const scalar_t& other_val) {
+          ret_val = cond_val ? self_val : other_val;
+        });
+   } else {
+     at::cuda::CUDA_tensor_apply4<scalar_t, bool, scalar_t, scalar_t>(
+         ret,
+         condition,
+         self,
+         other,
+         [] __device__(
+             scalar_t & ret_val,
+             const bool& cond_val,
+             const scalar_t& self_val,
+             const scalar_t& other_val) {
+           ret_val = cond_val ? self_val : other_val;
+         });
+   }
 }
 } // namespace
 
index 3821597..a25bb58 100644 (file)
   matches_jit_signature: True
   cpu_half: True
   cpu_bool: True
+  cuda_bool: True
   dispatch:
     CPU: _s_copy__cpu
     CUDA: _s_copy__cuda
 - func: _s_copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor
   matches_jit_signature: True
   cpu_half: True
+  cpu_bool: True
+  cuda_bool: True
   dispatch:
     CUDA: _s_copy_from_cuda
 
   matches_jit_signature: True
   cpu_half: True
   cpu_bool: True
+  cuda_bool: True
   dispatch:
     CPU: _copy_same_type__cpu
 
   matches_jit_signature: True
   cpu_half: True
   cpu_bool: True
+  cuda_bool: True
   dispatch:
     CPU: empty_cpu
     CUDA: empty_cuda
   matches_jit_signature: True
   variants: method
   cpu_bool: True
+  cuda_bool: True
   cpu_half: True
   device_guard: False
   dispatch:
 
 - func: empty_strided(int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
   cpu_half: True
+  cpu_bool: True
+  cuda_bool: True
   matches_jit_signature: True
   dispatch:
     CPU: empty_strided_cpu
   variants: function, method
   cpu_half: True
   cpu_bool: True
+  cuda_bool: True
   dispatch:
     CPU: clone
     CUDA: clone
   variants: method, function
   cpu_half: True
   cpu_bool: True
+  cuda_bool: True
   dispatch:
     CPU: zero_
     CUDA: zero_
   matches_jit_signature: True
   cpu_half: True
   cpu_bool: True
+  cuda_bool: True
   dispatch:
     CPU: _local_scalar_dense_cpu
     CUDA: _local_scalar_dense_cuda
 - func: is_set_to(Tensor self, Tensor tensor) -> bool
   matches_jit_signature: True
   variants: method
+  cpu_bool: True
+  cuda_bool: True
   device_guard: False
 
 - func: masked_fill_(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!)
index e93d32e..6217769 100644 (file)
@@ -382,6 +382,7 @@ def run(paths):
                 declaration['matches_jit_signature'] = func.get('matches_jit_signature', False)
                 declaration['cpu_half'] = func.get('cpu_half', False)
                 declaration['cpu_bool'] = func.get('cpu_bool', False)
+                declaration['cuda_bool'] = func.get('cuda_bool', False)
                 declaration['deprecated'] = func.get('deprecated', False)
                 declaration['device_guard'] = func.get('device_guard', True)
                 declaration['arguments'] = func.get('arguments', arguments)
index f167770..59dcf6b 100644 (file)
@@ -57,15 +57,16 @@ def process_types_and_backends(option):
         if arg['type'] == 'THSTensor*':
             pairs.discard(('CUDA', 'Half'))
 
-    # special case remove Half and Bool for cpu unless it is explicitly enabled,
+    # special case remove Half for cpu unless it is explicitly enabled
     if not option.get('cpu_half', False):
         pairs.discard(('CPU', 'Half'))
 
+    # special cases remove bool for cpu and cuda unless it is explicitly enabled
     if not option.get('cpu_bool', False):
         pairs.discard(('CPU', 'Bool'))
 
-    # TODO: remove this hack once support for a bool tensor for CUDA is enabled
-    pairs.discard(('CUDA', 'Bool'))
+    if not option.get('cuda_bool', False):
+        pairs.discard(('CUDA', 'Bool'))
 
     # sort the result for easy reading
     option['backend_type_pairs'] = sorted([p for p in pairs])
index ed6fa56..d44645d 100644 (file)
@@ -111,3 +111,6 @@ struct NonZeroOp
 
 #include <THC/generic/THCTensorMath.cu>
 #include <THC/THCGenerateAllTypes.h>
+
+#include <THC/generic/THCTensorMath.cu>
+#include <THC/THCGenerateBoolType.h>
index 87ca4a2..078b1cd 100644 (file)
@@ -7,6 +7,9 @@
 #include <THC/generic/THCTensorMath.h>
 #include <THC/THCGenerateAllTypes.h>
 
+#include <THC/generic/THCTensorMath.h>
+#include <THC/THCGenerateBoolType.h>
+
 #include <THC/generic/THCTensorMathBlas.h>
 #include <THC/THCGenerateAllTypes.h>
 
index 58bbabc..db7f644 100644 (file)
@@ -169,5 +169,8 @@ GENERATE_KERNEL2(generate_cauchy, at::Half, double median, double sigma, float,
 #include <THC/generic/THCTensorRandom.cu>
 #include <THC/THCGenerateAllTypes.h>
 
+#include <THC/generic/THCTensorRandom.cu>
+#include <THC/THCGenerateBoolType.h>
+
 #undef GENERATE_KERNEL1
 #undef GENERATE_KERNEL2
index 265b7ce..78d1add 100644 (file)
@@ -6,6 +6,9 @@
 #include <THC/generic/THCTensorRandom.h>
 #include <THC/THCGenerateAllTypes.h>
 
+#include <THC/generic/THCTensorRandom.h>
+#include <THC/THCGenerateBoolType.h>
+
 #include <curand.h>
 #include <curand_kernel.h>
 
index d3dbb9a..286bbe7 100644 (file)
@@ -242,6 +242,8 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
   }
 }
 
+#if !defined(THC_REAL_IS_BOOL) /* non bool only part */
+
 void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor,
                           THCTensor *self)
 {
@@ -365,3 +367,5 @@ accreal THCTensor_(trace)(THCState *state, THCTensor *src_) {
 }
 
 #endif
+
+#endif
index fd1d420..f6c2d8a 100644 (file)
@@ -4,15 +4,17 @@
 
 THC_API void THCTensor_(fill)(THCState *state, THCTensor *self, scalar_t value);
 THC_API void THCTensor_(zero)(THCState *state, THCTensor *self);
-
-THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t);
 THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THCTensor *tb, int dimension);
 THC_API void THCTensor_(catArray)(THCState *state, THCTensor *result, THCTensor **inputs, int numInputs, int dimension);
-THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self);
+THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t);
+
+#if !defined(THC_REAL_IS_BOOL) /* non bool only part */
 
+THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self);
 THC_API void THCTensor_(triu)(THCState *state, THCTensor *self, THCTensor *src, int64_t k);
 THC_API void THCTensor_(diag)(THCState *state, THCTensor *self, THCTensor *src, int64_t k);
-
 THC_API accreal THCTensor_(trace)(THCState *state, THCTensor *self);
 
 #endif
+
+#endif
index 3bb6f48..82df870 100644 (file)
@@ -420,7 +420,7 @@ struct Converter<
 // In some versions of MSVC, there will be a compiler error when building.
 // C4146: unary minus operator applied to unsigned type, result still unsigned
 // C4804: unsafe use of type 'bool' in operation
-// It can be addressed by disabling the following warning. 
+// It can be addressed by disabling the following warning.
 #ifdef _MSC_VER
 #pragma warning( push )
 #pragma warning( disable : 4146 )
@@ -482,7 +482,8 @@ typename std::enable_if<is_complex_t<From>::value, bool>::type overflows(
 
 template <typename To, typename From>
 To checked_convert(From f, const char* name) {
-  if (overflows<To, From>(f)) {
+  // Converting to bool can't overflow so we exclude this case from checking.
+  if (!std::is_same<To, bool>::value && overflows<To, From>(f)) {
     std::ostringstream oss;
     oss << "value cannot be converted to type " << name
         << " without overflow: " << f;
index 7d44307..2f8d1fc 100644 (file)
@@ -365,8 +365,7 @@ class _TestTorchMixin(object):
             self.assertEqual(res1, res2)
 
     def test_logical_any(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             x = torch.zeros([2, 3, 400], dtype=torch.uint8, device=device)
 
             self.assertEqual(
@@ -405,8 +404,7 @@ class _TestTorchMixin(object):
             self.assertEqual(y, x.any(2, keepdim=True))
 
     def test_logical_all(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             x = torch.ones([2, 3, 400], dtype=torch.uint8, device=device)
 
             self.assertEqual(
@@ -1129,9 +1127,8 @@ class _TestTorchMixin(object):
             ('logsumexp', torch.logsumexp, -inf),
         ]
 
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
         shape = (2, 0, 4)
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             x = torch.randn(shape, device=device)
 
             for item in fns_to_test:
@@ -1173,8 +1170,7 @@ class _TestTorchMixin(object):
             self.assertEqual(torch.ones((), device=device), xb.all())
 
     def test_pairwise_distance_empty(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             shape = (2, 0)
             x = torch.randn(shape, device=device)
             y = torch.randn(shape, device=device)
@@ -1189,8 +1185,7 @@ class _TestTorchMixin(object):
             self.assertEqual(torch.zeros((0, 1), device=device), torch.pairwise_distance(x, y, keepdim=True))
 
     def test_pdist_empty(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             shape = (0, 2)
             x = torch.randn(shape, device=device)
             self.assertEqual(torch.empty(0, device=device), torch.pdist(x))
@@ -1213,8 +1208,7 @@ class _TestTorchMixin(object):
             self.assertEqual(expected.shape, actual.shape)
             self.assertTrue(torch.allclose(expected, actual))
 
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             for shape in [(4, 5), (3, 2), (2, 1)]:
                 for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
                     for trans in [False, True]:
@@ -1227,8 +1221,7 @@ class _TestTorchMixin(object):
                 test_pdist_single((1000, 2), device, 2, dtype, False)
 
     def test_cdist_empty(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             x = torch.randn((0, 5), device=device)
             y = torch.randn((4, 5), device=device)
             self.assertEqual(torch.empty(0, 4, device=device), torch.cdist(x, y))
@@ -1246,8 +1239,7 @@ class _TestTorchMixin(object):
             self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y))
 
     def test_cdist_norm(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             for r1 in [3, 4, 5, 6]:
                 for m in [2, 3, 4, 10]:
                     for r2 in [4, 6, 7, 8]:
@@ -1259,8 +1251,7 @@ class _TestTorchMixin(object):
                             self.assertTrue(torch.allclose(expected, actual))
 
     def test_cdist_large(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             x = torch.randn(1000, 10, device=device)
             y = torch.randn(1000, 10, device=device)
             actual = torch.cdist(x, y, p=2)
@@ -1268,8 +1259,7 @@ class _TestTorchMixin(object):
             self.assertTrue(torch.allclose(expected, actual))
 
     def test_cdist_non_contiguous(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             x = torch.randn(5, 7, device=device).t()
             y = torch.randn(5, 3, device=device).t()
             actual = torch.cdist(x, y, p=2)
@@ -1551,7 +1541,7 @@ class _TestTorchMixin(object):
         self._test_neg(self, lambda t: t)
 
     def test_threshold(self):
-        for dtype in torch.testing.get_all_dtypes():
+        for dtype in torch.testing.get_all_math_dtypes('cpu'):
             if dtype != torch.uint8 and dtype != torch.float16:
                 # 100 is wide enough to use AVX2 instructions for all types
                 x = torch.randn(100).sign().to(dtype=dtype)
@@ -1584,7 +1574,7 @@ class _TestTorchMixin(object):
         self.assertEqual(res1, res2)
 
     def test_floordiv(self):
-        for dtype in torch.testing.get_all_dtypes():
+        for dtype in torch.testing.get_all_math_dtypes('cpu'):
             if dtype is torch.float16:
                 continue
             x = torch.randn(100).mul(10).to(dtype)
@@ -1594,7 +1584,7 @@ class _TestTorchMixin(object):
             self.assertEqual(y, z)
 
     def test_rdiv(self):
-        for dtype in torch.testing.get_all_dtypes():
+        for dtype in torch.testing.get_all_math_dtypes('cpu'):
             if dtype is torch.float16:
                 continue
             x = torch.rand(100).add(1).mul(4).to(dtype)
@@ -2323,7 +2313,7 @@ class _TestTorchMixin(object):
             # 'out' is favored over dtype, check error
             self.assertRaises(RuntimeError, lambda: fn(x, out=out, dtype=other_dtype))
 
-        for dtype in [dtype for dtype in torch.testing.get_all_dtypes() if dtype != torch.float16]:
+        for dtype in [dtype for dtype in torch.testing.get_all_math_dtypes('cpu') if dtype != torch.float16]:
             x = torch.ones(shape, dtype=dtype)
             expected_dtype = dtype if dtype.is_floating_point else torch.int64
             self.assertIs(expected_dtype, fn(x).dtype)
@@ -2722,10 +2712,10 @@ class _TestTorchMixin(object):
                     self.assertEqual(b.device, a.to(b, non_blocking=non_blocking).device)
 
     def test_empty_full(self):
-        do_test_empty_full(self, torch.testing.get_all_dtypes(), torch.strided, torch.device('cpu'))
+        do_test_empty_full(self, torch.testing.get_all_math_dtypes('cpu'), torch.strided, torch.device('cpu'))
         if torch.cuda.device_count() > 0:
-            do_test_empty_full(self, torch.testing.get_all_dtypes(), torch.strided, None)
-            do_test_empty_full(self, torch.testing.get_all_dtypes(), torch.strided, torch.device('cuda:0'))
+            do_test_empty_full(self, torch.testing.get_all_math_dtypes('cpu'), torch.strided, None)
+            do_test_empty_full(self, torch.testing.get_all_math_dtypes('cpu'), torch.strided, torch.device('cuda:0'))
 
     def test_dtype_out_match(self):
         d = torch.autograd.Variable(torch.DoubleTensor(2, 3))
@@ -2923,60 +2913,94 @@ class _TestTorchMixin(object):
         self.assertTrue(x.is_cuda)
         torch.set_default_tensor_type(saved_type)
 
-    # This is a temporary test for a boolean tensors on CPU. Once the CUDA part
-    # will be done, these test cases will be moved down to test_tensor_factories_empty test
-    def test_tensor_factories_bool(self):
-        expectedShape = (1, 2)
-        test = torch.empty(expectedShape, dtype=torch.bool)
-        self.assertEqual(expectedShape, test.shape)
-
-        test2 = torch.empty_like(test, dtype=torch.bool)
-        self.assertEqual(test.shape, test2.shape)
-
-        test = torch.full(expectedShape, True, dtype=torch.bool)
-        self.assertEqual(test, torch.tensor([[True, True]], dtype=torch.bool))
-
-        test2 = torch.full_like(test, True, dtype=torch.bool)
-        self.assertEqual(test, test2)
-
-        test = torch.zeros(expectedShape, dtype=torch.bool)
-        self.assertEqual(test, torch.tensor([[False, False]], dtype=torch.bool))
-
-        test2 = torch.zeros_like(test, dtype=torch.bool)
-        self.assertEqual(test, test2)
-
-        test = torch.ones(expectedShape, dtype=torch.bool)
-        self.assertEqual(test, torch.tensor([[True, True]], dtype=torch.bool))
-
-        test2 = torch.ones_like(test, dtype=torch.bool)
-        self.assertEqual(test, test2)
+    def test_unfold_all_devices_and_dtypes(self):
+        for device in torch.testing.get_all_device_types():
+            for dt in torch.testing.get_all_dtypes():
+                if dt == torch.half and device == 'cpu':
+                    # fix once random is implemented for Half on CPU
+                    self.assertRaises(RuntimeError, lambda: torch.randint(5, (0, 1, 3, 0), dtype=dt, device=device))
+                else:
+                    x = torch.randint(5, (0, 1, 3, 0), dtype=dt, device=device)
+                    self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
 
-        test = torch.randint(10, expectedShape, dtype=torch.bool)
-        self.assertEqual(expectedShape, test.shape)
-        self.assertEqual(torch.bool, test.dtype)
+    def test_copy_all_dtypes_and_devices(self):
+        from copy import copy
+        for device in torch.testing.get_all_device_types():
+            for dt in torch.testing.get_all_dtypes():
+                x = torch.tensor([1, 2, 3, 4], dtype=dt, device=device)
+                x_clone = x.clone()
+
+                y = copy(x)
+                y.fill_(1)
+
+                # copy is a shallow copy, only copies the tensor view,
+                # not the data
+                self.assertEqual(x, y)
+
+    def test_resize_all_dtypes_and_devices(self):
+        shape = (2, 2)
+        for device in torch.testing.get_all_device_types():
+            for dt in torch.testing.get_all_dtypes():
+                x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
+                x.resize_(shape)
+                self.assertEqual(shape, x.shape)
+
+    def test_fill_all_dtypes_and_devices(self):
+        for device in torch.testing.get_all_device_types():
+            for dt in torch.testing.get_all_dtypes():
+                x = torch.tensor((1, 1), dtype=dt, device=device)
+                x.fill_(1)
+
+                self.assertEqual(x, torch.tensor([1, 1], dtype=dt, device=device))
+                self.assertEqual(dt, x.dtype)
+
+    def test_clone_all_dtypes_and_devices(self):
+        for device in torch.testing.get_all_device_types():
+            for dt in torch.testing.get_all_dtypes():
+                x = torch.tensor((1, 1), dtype=dt, device=device)
+                y = x.clone()
+                self.assertEqual(x, y)
+
+    def test_cat_all_dtypes_and_devices(self):
+        for device in torch.testing.get_all_device_types():
+            for dt in torch.testing.get_all_dtypes():
+                x = torch.tensor([[1, 2], [3, 4]], dtype=dt, device=device)
+                expected1 = torch.tensor([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=dt, device=device)
+                self.assertEqual(torch.cat((x, x), 0), expected1)
+
+                expected2 = torch.tensor([[1, 2, 1, 2], [3, 4, 3, 4]], dtype=dt, device=device)
+                self.assertEqual(torch.cat((x, x), 1), expected2)
 
     def test_tensor_factories_empty(self):
         # ensure we can create empty tensors from each factory function
         shapes = [(5, 0, 1), (0,), (0, 0, 1, 0, 2, 0, 0)]
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
 
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             for shape in shapes:
-                self.assertEqual(shape, torch.zeros(shape, device=device).shape)
-                self.assertEqual(shape, torch.zeros_like(torch.zeros(shape, device=device)).shape)
-                self.assertEqual(shape, torch.empty(shape, device=device).shape)
-                self.assertEqual(shape, torch.empty_like(torch.zeros(shape, device=device)).shape)
-                self.assertEqual(shape, torch.empty_strided(shape, (0,) * len(shape), device=device).shape)
-                self.assertEqual(shape, torch.full(shape, 3, device=device).shape)
-                self.assertEqual(shape, torch.full_like(torch.zeros(shape, device=device), 3).shape)
-                self.assertEqual(shape, torch.ones(shape, device=device).shape)
-                self.assertEqual(shape, torch.ones_like(torch.zeros(shape, device=device)).shape)
-                self.assertEqual(shape, torch.rand(shape, device=device).shape)
-                self.assertEqual(shape, torch.rand_like(torch.zeros(shape, device=device)).shape)
-                self.assertEqual(shape, torch.randn(shape, device=device).shape)
-                self.assertEqual(shape, torch.randn_like(torch.zeros(shape, device=device)).shape)
-                self.assertEqual(shape, torch.randint(6, shape, device=device).shape)
-                self.assertEqual(shape, torch.randint_like(torch.zeros(shape, device=device), 6).shape)
+                for dt in torch.testing.get_all_dtypes():
+                    self.assertEqual(shape, torch.zeros(shape, device=device, dtype=dt).shape)
+                    self.assertEqual(shape, torch.zeros_like(torch.zeros(shape, device=device, dtype=dt)).shape)
+                    self.assertEqual(shape, torch.full(shape, 3, device=device, dtype=dt).shape)
+                    self.assertEqual(shape, torch.full_like(torch.zeros(shape, device=device, dtype=dt), 3).shape)
+                    self.assertEqual(shape, torch.ones(shape, device=device, dtype=dt).shape)
+                    self.assertEqual(shape, torch.ones_like(torch.zeros(shape, device=device, dtype=dt)).shape)
+                    self.assertEqual(shape, torch.empty(shape, device=device, dtype=dt).shape)
+                    self.assertEqual(shape, torch.empty_like(torch.zeros(shape, device=device, dtype=dt)).shape)
+                    self.assertEqual(shape, torch.empty_strided(shape, (0,) * len(shape), device=device, dtype=dt).shape)
+
+                    if dt == torch.half and device == "cpu":
+                        # update once random is implemented for half on CPU
+                        self.assertRaises(RuntimeError, lambda: torch.randint(6, shape, device=device, dtype=dt).shape)
+                    else:
+                        self.assertEqual(shape, torch.randint(6, shape, device=device, dtype=dt).shape)
+                        self.assertEqual(shape, torch.randint_like(torch.zeros(shape, device=device, dtype=dt), 6).shape)
+
+                    if dt != torch.double and dt != torch.float and dt != torch.half:
+                        self.assertRaises(RuntimeError, lambda: torch.rand(shape, device=device, dtype=dt).shape)
+
+                    if dt == torch.double or dt == torch.float:
+                        self.assertEqual(shape, torch.randn(shape, device=device, dtype=dt).shape)
+                        self.assertEqual(shape, torch.randn_like(torch.zeros(shape, device=device, dtype=dt)).shape)
 
             self.assertEqual((0,), torch.arange(0, device=device).shape)
             self.assertEqual((0, 0), torch.eye(0, device=device).shape)
@@ -3504,8 +3528,7 @@ class _TestTorchMixin(object):
         self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('inf')))
         self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('inf')))
 
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(-5, float('nan'), device=device))
             # check with step size
             self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('-inf'), -1, device=device))
@@ -3793,9 +3816,8 @@ class _TestTorchMixin(object):
 
     def test_empty_tensor_props(self):
         sizes = [(0,), (0, 3), (5, 0), (5, 0, 3, 0, 2), (0, 3, 0, 2), (0, 5, 0, 2, 0)]
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
         for size in sizes:
-            for device in devices:
+            for device in torch.testing.get_all_device_types():
                 x = torch.empty(tuple(size), device=device)
                 self.assertEqual(size, x.shape)
                 self.assertTrue(x.is_contiguous())
@@ -4074,8 +4096,7 @@ class _TestTorchMixin(object):
 
     @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
     def test_tensordot(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for d in devices:
+        for d in torch.testing.get_all_device_types():
             a = torch.arange(60., device=d).reshape(3, 4, 5)
             b = torch.arange(24., device=d).reshape(4, 3, 2)
             c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu()
@@ -4489,8 +4510,7 @@ class _TestTorchMixin(object):
         self.assertEqual(x.narrow(-2, -1, 1), torch.Tensor([[6, 7, 8]]))
 
     def test_narrow_empty(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             x = torch.randn(2, 3, 4, device=device)
             for d in range(x.dim()):
                 y = x.narrow(d, x.size(d), 0)
@@ -4547,8 +4567,7 @@ class _TestTorchMixin(object):
 
     @skipIfRocm
     def test_linspace(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             _from = random.random()
             to = _from + random.random()
             res1 = torch.linspace(_from, to, 137, device=device)
@@ -7207,8 +7226,7 @@ class _TestTorchMixin(object):
         check(src.transpose(1, 2), idx)
 
     def test_take_empty(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             for input_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]:
                 for indices_shape in [(0,), (0, 1, 2, 0)]:
                     input = torch.empty(input_shape, device=device)
@@ -7237,8 +7255,7 @@ class _TestTorchMixin(object):
         self.assertEqual(dst.tolist(), [[5, 7], [1, 1]])
 
     def test_put_empty(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             for dst_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]:
                 for indices_shape in [(0,), (0, 1, 2, 0)]:
                     for accumulate in [False, True]:
@@ -7766,8 +7783,7 @@ class _TestTorchMixin(object):
         self.assertRaises(RuntimeError, lambda: x.reshape(0, -1))
 
     def test_tensor_shape_empty(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             x = torch.randn((0, 1, 3, 0), device=device)
             # flatten
             self.assertEqual((0,), torch.flatten(x, 0, 3).shape)
@@ -7786,10 +7802,6 @@ class _TestTorchMixin(object):
 
             # select
             self.assertEqual((0, 1, 0), torch.select(x, 2, 2).shape)
-            # unfold
-            self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
-            y = torch.randn((0, 1, 3), device=device)
-            self.assertEqual((1, 1, 3, 0), y.unfold(0, 0, 4).shape)
 
             # repeat, permute
             self.assertEqual((9, 0, 5, 6, 0), x.repeat(9, 7, 5, 2, 3).shape)
@@ -7832,8 +7844,7 @@ class _TestTorchMixin(object):
     # functions that operate over a dimension but don't reduce.
     @skipIfRocm
     def test_dim_function_empty(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             shape = (0, 1, 2, 0)
             x = torch.randn(shape, device=device)
 
@@ -7964,8 +7975,7 @@ class _TestTorchMixin(object):
 
     @skipIfRocm
     def test_blas_empty(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
 
             def fn(torchfn, *args):
                 return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape
@@ -8035,8 +8045,7 @@ class _TestTorchMixin(object):
 
     @skipIfRocm
     def test_blas_alpha_beta_empty(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             # ensure beta is respected
             value = 11
             input = torch.full((2,), value, device=device)
@@ -8066,8 +8075,7 @@ class _TestTorchMixin(object):
         # numpy/sci often has a direct wrapper (e.g. lu_factor) and a wrapper that "does the right thing"
         # (e.g. lu).  We often name our functions identically to the lapack function, so it will take work
         # to name / migrate-to better wrappers.
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
 
             # need to init cuda to check has_magma
             empty = torch.randn((0, 0), device=device)
@@ -8203,6 +8211,10 @@ class _TestTorchMixin(object):
                          "Tensors with no storages should not appear to be set "
                          "to each other")
 
+        t1 = torch.tensor([True, True], dtype=torch.bool)
+        t2 = torch.tensor([0], dtype=torch.bool).set_(t1)
+        self.assertTrue(t1.is_set_to(t2))
+
     def test_tensor_set(self):
         t1 = torch.Tensor()
         t2 = torch.Tensor(3, 4, 9, 10).uniform_()
@@ -8234,6 +8246,11 @@ class _TestTorchMixin(object):
         self.assertEqual(t1.size(), size)
         self.assertEqual(t1.stride(), stride)
 
+        t1 = torch.tensor([True, True], dtype=torch.bool)
+        t2 = torch.tensor([False, False], dtype=torch.bool)
+        t1.set_(t2)
+        self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
+
     def test_equal(self):
         # Contiguous, 1D
         t1 = torch.Tensor((3, 4, 9, 10))
@@ -8448,8 +8465,7 @@ class _TestTorchMixin(object):
         self._test_flip(self, use_cuda=False)
 
     def test_roll(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             numbers = torch.arange(1, 9, device=device)
 
             single_roll = numbers.roll(1, 0)
@@ -8616,8 +8632,7 @@ class _TestTorchMixin(object):
                         self.assertNotEqual(tensor[dst1[i, 0], dst1[i, 1], dst1[i, 2]].item(), 0)
 
     def test_nonzero_empty(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             x = torch.randn(0, 2, 0, 5, 0, device=device)
             y = torch.nonzero(x)
             self.assertEqual(0, y.numel())
@@ -8670,16 +8685,6 @@ class _TestTorchMixin(object):
         self.assertEqual(torch.nn.Parameter, type(s2['weight']))
         self.assertEqual(torch.nn.Parameter, type(s2['bias']))
 
-    def test_copy(self):
-        from copy import copy
-        a = torch.randn(5, 5)
-        a_clone = a.clone()
-        b = copy(a)
-        b.fill_(1)
-        # copy is a shallow copy, only copies the tensor view,
-        # not the data
-        self.assertEqual(a, b)
-
     def test_pickle(self):
         if sys.version_info[0] == 2:
             import cPickle as pickle
@@ -9820,8 +9825,7 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             self.assertEqual(torch.empty_like(a).type(), a.type())
 
     def test_empty_strided(self):
-        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
-        for device in devices:
+        for device in torch.testing.get_all_device_types():
             for shape in [(2, 3, 4), (0, 2, 0)]:
                 # some of these cases are pretty strange, just verifying that if as_strided
                 # allows them then empty_strided can as well.
index eb2c920..894d823 100644 (file)
@@ -87,10 +87,24 @@ def make_non_contiguous(tensor):
 
 
 def get_all_dtypes():
-    return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64,
+    return [torch.uint8, torch.bool, torch.int8, torch.int16, torch.int32, torch.int64,
             torch.float16, torch.float32, torch.float64]
 
 
+def get_all_math_dtypes(device):
+    dtypes = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64,
+              torch.float32, torch.float64]
+
+    # torch.float16 is a math dtype on cuda but not cpu.
+    if device == 'cpu':
+        return dtypes
+    else:
+        return dtypes.append(torch.float16)
+
+
+def get_all_device_types():
+    return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
+
 # 'dtype': (rtol, atol)
 _default_tolerances = {
     'float64': (1e-5, 1e-8),  # NumPy default