From b832b99afb241a8b6ea9fc34698d1f3bfd451f00 Mon Sep 17 00:00:00 2001 From: Iurii Zdebskyi Date: Tue, 2 Apr 2019 16:10:43 -0700 Subject: [PATCH] Bool Tensor for CUDA (#18166) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18166 ghimport-source-id: a8e2ba2d966e49747a55701c4f6863c5e24d6f14 Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18166 Bool Tensor for CUDA** * #18165 Resolved comments from Bool Tensor for CPU PR ------ This PR enables bool tensor creation and some basic operations for the CPU backend. This is a part of Bool Tensor feature implementation work. The whole plan looks like this: 1. Storage Implementation [Done] 2. Tensor Creation. a) CPU [Done] b) CUDA [This PR] 3. Tensor Conversions. 4. Tensor Indexing. 5. Tensor Operations. 6. Back compatibility related changes. Change: Enable bool tensor in CUDA with the following operations: torch.zeros torch.tensor torch.ones torch.rand/rand_like/randint/randint_like torch.full torch.full_like torch.empty torch.empty_like Tested via unit tests and local scripts. Differential Revision: D14605104 fbshipit-source-id: b7d7340a7d70edd03a109222d271e68becba762c --- aten/src/ATen/Declarations.cwrap | 12 +- aten/src/ATen/function_wrapper.py | 1 + aten/src/ATen/gen.py | 2 +- aten/src/ATen/native/TensorCompare.cpp | 38 +++-- aten/src/ATen/native/cpu/CopyKernel.cpp | 4 +- aten/src/ATen/native/cuda/CUDAScalar.cu | 4 +- aten/src/ATen/native/cuda/Copy.cu | 22 ++- aten/src/ATen/native/cuda/TensorCompare.cu | 43 ++++-- aten/src/ATen/native/native_functions.yaml | 13 ++ aten/src/ATen/native_parse.py | 1 + aten/src/ATen/preprocess_declarations.py | 7 +- aten/src/THC/THCTensorMath.cu | 3 + aten/src/THC/THCTensorMath.h | 3 + aten/src/THC/THCTensorRandom.cu | 3 + aten/src/THC/THCTensorRandom.h | 3 + aten/src/THC/generic/THCTensorMath.cu | 4 + aten/src/THC/generic/THCTensorMath.h | 10 +- c10/util/Half.h | 5 +- test/test_torch.py | 240 +++++++++++++++-------------- torch/testing/__init__.py | 16 +- 20 files changed, 268 insertions(+), 166 deletions(-) diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index b59bacf..c438886 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -4,6 +4,7 @@ variants: function cpu_half: True cpu_bool: True + cuda_bool: True device_guard: False return: argument 0 options: @@ -41,6 +42,7 @@ variants: function cpu_half: True cpu_bool: True + cuda_bool: True options: - arguments: - THTensor* self @@ -57,6 +59,7 @@ - function cpu_half: True cpu_bool: True + cuda_bool: True device_guard: False return: bool arguments: @@ -124,6 +127,7 @@ - function cpu_half: True cpu_bool: True + cuda_bool: True arguments: - THTensor* self ]] @@ -243,6 +247,7 @@ - function cpu_half: True cpu_bool: True + cuda_bool: True device_guard: False return: argument 0 arguments: @@ -1669,6 +1674,7 @@ return: self cpu_half: True cpu_bool: True + cuda_bool: True variants: - function arguments: @@ -2461,8 +2467,9 @@ - CPU - CUDA return: self - cpu_bool: True variants: function + cpu_bool: True + cuda_bool: True options: - cname: random arguments: @@ -2726,6 +2733,7 @@ return: THTensor* cpu_half: True cpu_bool: True + cuda_bool: True variants: - function options: @@ -2750,6 +2758,8 @@ cname: catArray variants: [function] cpu_half: True + cpu_bool: True + cuda_bool: True return: self arguments: - arg: THTensor* self diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index 61eff13..3694803 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -494,6 +494,7 @@ FunctionOption = TypedDict('FunctionOption', { 'cpu_half': bool, 'deprecated': bool, 'cpu_bool': bool, + 'cuda_bool': bool, # See Note [field_name versus name] 'field_name': str, 'formals_list': List[AtFormal], diff --git a/aten/src/ATen/gen.py b/aten/src/ATen/gen.py index 8b0cf0d..356e43f 100644 --- a/aten/src/ATen/gen.py +++ b/aten/src/ATen/gen.py @@ -180,7 +180,7 @@ extension_backends = ['MSNPU', 'XLA'] # scalar_name, c_type, accreal, is_floating_type scalar_types = [ - ('Bool', 'uint8_t', 'BoolAccrealNotDefined', False), + ('Bool', 'bool', 'BoolAccrealNotDefined', False), ('Byte', 'uint8_t', 'Long', False), ('Char', 'int8_t', 'Long', False), ('Double', 'double', 'Double', True), diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index ec7ddb5..9d0c2be 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -15,17 +15,31 @@ void where_cpu( const at::Tensor& condition, const at::Tensor& self, const at::Tensor& other) { - at::CPU_tensor_apply4( - ret, - condition, - self, - other, - [](scalar_t& ret_val, - const uint8_t& cond_val, - const scalar_t& self_val, - const scalar_t& other_val) { - ret_val = cond_val ? self_val : other_val; - }); + if (condition.scalar_type() == at::ScalarType::Byte) { + at::CPU_tensor_apply4( + ret, + condition, + self, + other, + [](scalar_t& ret_val, + const uint8_t& cond_val, + const scalar_t& self_val, + const scalar_t& other_val) { + ret_val = cond_val ? self_val : other_val; + }); + } else { + at::CPU_tensor_apply4( + ret, + condition, + self, + other, + [](scalar_t& ret_val, + const bool& cond_val, + const scalar_t& self_val, + const scalar_t& other_val) { + ret_val = cond_val ? self_val : other_val; + }); + } } } // namespace @@ -80,7 +94,7 @@ bool is_nonzero(const Tensor& self) { } Tensor where(const Tensor& condition, const Tensor& self, const Tensor& other) { - if (condition.scalar_type() != ScalarType::Byte) { + if (condition.scalar_type() != ScalarType::Byte && condition.scalar_type() != ScalarType::Bool) { AT_ERROR("Expected condition to have ScalarType Byte, but got ScalarType ", toString(condition.scalar_type())); } diff --git a/aten/src/ATen/native/cpu/CopyKernel.cpp b/aten/src/ATen/native/cpu/CopyKernel.cpp index d8cf81d..4f157bc 100644 --- a/aten/src/ATen/native/cpu/CopyKernel.cpp +++ b/aten/src/ATen/native/cpu/CopyKernel.cpp @@ -14,8 +14,8 @@ namespace { constexpr int64_t COPY_GRAIN_SIZE = 20000; static void copy_kernel_impl(Tensor& dst, const Tensor& src) { - AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, dst.scalar_type(), "copy_kernel_impl", [&]() { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::Bool, dst.scalar_type(), "copy_kernel_impl", [&]() { scalar_t* self_ptr = dst.data(); scalar_t* src_ptr = src.data(); diff --git a/aten/src/ATen/native/cuda/CUDAScalar.cu b/aten/src/ATen/native/cuda/CUDAScalar.cu index 68079ad..691d580 100644 --- a/aten/src/ATen/native/cuda/CUDAScalar.cu +++ b/aten/src/ATen/native/cuda/CUDAScalar.cu @@ -9,8 +9,8 @@ namespace native { Scalar _local_scalar_dense_cuda(const Tensor& self) { Scalar r; - AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, self.scalar_type(), "_local_scalar_dense_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "_local_scalar_dense_cuda", [&] { scalar_t value; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_CUDA_CHECK(cudaMemcpyAsync(&value, self.data(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream)); diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index 35dfb9b..9ddab33 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -25,7 +25,17 @@ struct CopyOp { #else dst_val = static_cast(static_cast>(src_val)); #endif - }); + }); + } +}; + +template +struct CopyOp { + static void apply(Tensor& dst, const Tensor& src) { + CUDA_tensor_apply2( + dst, src, [] __device__(dst_T & dst_val, const bool& src_val) { + dst_val = static_cast(static_cast>(src_val)); + }); } }; @@ -169,7 +179,7 @@ void copy_from_cpu(Tensor& dst, const Tensor& src) { cudaMemcpyHostToDevice, stream)); AT_CUDA_CHECK(cudaStreamSynchronize(stream)); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "copy_from_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, src.scalar_type(), "copy_from_cpu", [&]() { copy_device_to_device(dst, dst_contig); }); } @@ -202,7 +212,7 @@ void copy_from_cpu_async_(Tensor& dst, const Tensor& src) { CUDAGuard device_guard(dst.device()); CUDAStream stream = getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "copy_from_cpu_async", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, src.scalar_type(), "copy_from_cpu_async", [&]() { AT_CUDA_CHECK(cudaMemcpyAsync( dst.data(), src.data(), @@ -225,7 +235,7 @@ void copy_to_cpu_async_(Tensor& dst, const Tensor& src) { CUDAGuard device_guard(src.device()); CUDAStream stream = getCurrentCUDAStream(); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "copy_to_cpu_async", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, src.scalar_type(), "copy_to_cpu_async", [&]() { AT_CUDA_CHECK(cudaMemcpyAsync( dst.data(), src.data(), @@ -240,7 +250,7 @@ void copy_to_cpu_async_(Tensor& dst, const Tensor& src) { template void _copy__cuda(Tensor& dst, const Tensor& src, bool non_blocking) { AT_CHECK(dst.numel() == src.numel(), "sizes do not match"); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_copy__cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, src.scalar_type(), "_copy__cuda", [&]() { if (dst.is_cuda() && src.is_cuda()) { copy_device_to_device(dst, src); } else if (dst.is_cuda()) { @@ -279,7 +289,7 @@ namespace at { namespace native { Tensor& _s_copy__cuda(Tensor& self, const Tensor& src, bool non_blocking) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), "_copy__cuda", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "_copy__cuda", [&]() { ::_copy__cuda(self, src, non_blocking); }); return self; diff --git a/aten/src/ATen/native/cuda/TensorCompare.cu b/aten/src/ATen/native/cuda/TensorCompare.cu index d4ccc70..8cd8e17 100644 --- a/aten/src/ATen/native/cuda/TensorCompare.cu +++ b/aten/src/ATen/native/cuda/TensorCompare.cu @@ -10,20 +10,35 @@ void where_cuda( const at::Tensor& condition, const at::Tensor& self, const at::Tensor& other) { - // Yes this name is repetitive, but the CPU version is called - // CPU_tensor_apply4 and we don't have a CPU namespace or directory. - at::cuda::CUDA_tensor_apply4( - ret, - condition, - self, - other, - [] __device__( - scalar_t & ret_val, - const uint8_t& cond_val, - const scalar_t& self_val, - const scalar_t& other_val) { - ret_val = cond_val ? self_val : other_val; - }); + if (condition.scalar_type() == at::ScalarType::Byte) { + // Yes this name is repetitive, but the CPU version is called + // CPU_tensor_apply4 and we don't have a CPU namespace or directory. + at::cuda::CUDA_tensor_apply4( + ret, + condition, + self, + other, + [] __device__( + scalar_t & ret_val, + const uint8_t& cond_val, + const scalar_t& self_val, + const scalar_t& other_val) { + ret_val = cond_val ? self_val : other_val; + }); + } else { + at::cuda::CUDA_tensor_apply4( + ret, + condition, + self, + other, + [] __device__( + scalar_t & ret_val, + const bool& cond_val, + const scalar_t& self_val, + const scalar_t& other_val) { + ret_val = cond_val ? self_val : other_val; + }); + } } } // namespace diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 3821597..a25bb58 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -556,6 +556,7 @@ matches_jit_signature: True cpu_half: True cpu_bool: True + cuda_bool: True dispatch: CPU: _s_copy__cpu CUDA: _s_copy__cuda @@ -563,6 +564,8 @@ - func: _s_copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor matches_jit_signature: True cpu_half: True + cpu_bool: True + cuda_bool: True dispatch: CUDA: _s_copy_from_cuda @@ -570,6 +573,7 @@ matches_jit_signature: True cpu_half: True cpu_bool: True + cuda_bool: True dispatch: CPU: _copy_same_type__cpu @@ -843,6 +847,7 @@ matches_jit_signature: True cpu_half: True cpu_bool: True + cuda_bool: True dispatch: CPU: empty_cpu CUDA: empty_cuda @@ -853,6 +858,7 @@ matches_jit_signature: True variants: method cpu_bool: True + cuda_bool: True cpu_half: True device_guard: False dispatch: @@ -873,6 +879,8 @@ - func: empty_strided(int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor cpu_half: True + cpu_bool: True + cuda_bool: True matches_jit_signature: True dispatch: CPU: empty_strided_cpu @@ -2536,6 +2544,7 @@ variants: function, method cpu_half: True cpu_bool: True + cuda_bool: True dispatch: CPU: clone CUDA: clone @@ -2573,6 +2582,7 @@ variants: method, function cpu_half: True cpu_bool: True + cuda_bool: True dispatch: CPU: zero_ CUDA: zero_ @@ -3025,6 +3035,7 @@ matches_jit_signature: True cpu_half: True cpu_bool: True + cuda_bool: True dispatch: CPU: _local_scalar_dense_cpu CUDA: _local_scalar_dense_cuda @@ -3146,6 +3157,8 @@ - func: is_set_to(Tensor self, Tensor tensor) -> bool matches_jit_signature: True variants: method + cpu_bool: True + cuda_bool: True device_guard: False - func: masked_fill_(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!) diff --git a/aten/src/ATen/native_parse.py b/aten/src/ATen/native_parse.py index e93d32e..6217769 100644 --- a/aten/src/ATen/native_parse.py +++ b/aten/src/ATen/native_parse.py @@ -382,6 +382,7 @@ def run(paths): declaration['matches_jit_signature'] = func.get('matches_jit_signature', False) declaration['cpu_half'] = func.get('cpu_half', False) declaration['cpu_bool'] = func.get('cpu_bool', False) + declaration['cuda_bool'] = func.get('cuda_bool', False) declaration['deprecated'] = func.get('deprecated', False) declaration['device_guard'] = func.get('device_guard', True) declaration['arguments'] = func.get('arguments', arguments) diff --git a/aten/src/ATen/preprocess_declarations.py b/aten/src/ATen/preprocess_declarations.py index f167770..59dcf6b 100644 --- a/aten/src/ATen/preprocess_declarations.py +++ b/aten/src/ATen/preprocess_declarations.py @@ -57,15 +57,16 @@ def process_types_and_backends(option): if arg['type'] == 'THSTensor*': pairs.discard(('CUDA', 'Half')) - # special case remove Half and Bool for cpu unless it is explicitly enabled, + # special case remove Half for cpu unless it is explicitly enabled if not option.get('cpu_half', False): pairs.discard(('CPU', 'Half')) + # special cases remove bool for cpu and cuda unless it is explicitly enabled if not option.get('cpu_bool', False): pairs.discard(('CPU', 'Bool')) - # TODO: remove this hack once support for a bool tensor for CUDA is enabled - pairs.discard(('CUDA', 'Bool')) + if not option.get('cuda_bool', False): + pairs.discard(('CUDA', 'Bool')) # sort the result for easy reading option['backend_type_pairs'] = sorted([p for p in pairs]) diff --git a/aten/src/THC/THCTensorMath.cu b/aten/src/THC/THCTensorMath.cu index ed6fa56..d44645d 100644 --- a/aten/src/THC/THCTensorMath.cu +++ b/aten/src/THC/THCTensorMath.cu @@ -111,3 +111,6 @@ struct NonZeroOp #include #include + +#include +#include diff --git a/aten/src/THC/THCTensorMath.h b/aten/src/THC/THCTensorMath.h index 87ca4a2..078b1cd 100644 --- a/aten/src/THC/THCTensorMath.h +++ b/aten/src/THC/THCTensorMath.h @@ -7,6 +7,9 @@ #include #include +#include +#include + #include #include diff --git a/aten/src/THC/THCTensorRandom.cu b/aten/src/THC/THCTensorRandom.cu index 58bbabc..db7f644 100644 --- a/aten/src/THC/THCTensorRandom.cu +++ b/aten/src/THC/THCTensorRandom.cu @@ -169,5 +169,8 @@ GENERATE_KERNEL2(generate_cauchy, at::Half, double median, double sigma, float, #include #include +#include +#include + #undef GENERATE_KERNEL1 #undef GENERATE_KERNEL2 diff --git a/aten/src/THC/THCTensorRandom.h b/aten/src/THC/THCTensorRandom.h index 265b7ce..78d1add 100644 --- a/aten/src/THC/THCTensorRandom.h +++ b/aten/src/THC/THCTensorRandom.h @@ -6,6 +6,9 @@ #include #include +#include +#include + #include #include diff --git a/aten/src/THC/generic/THCTensorMath.cu b/aten/src/THC/generic/THCTensorMath.cu index d3dbb9a..286bbe7 100644 --- a/aten/src/THC/generic/THCTensorMath.cu +++ b/aten/src/THC/generic/THCTensorMath.cu @@ -242,6 +242,8 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result, } } +#if !defined(THC_REAL_IS_BOOL) /* non bool only part */ + void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self) { @@ -365,3 +367,5 @@ accreal THCTensor_(trace)(THCState *state, THCTensor *src_) { } #endif + +#endif diff --git a/aten/src/THC/generic/THCTensorMath.h b/aten/src/THC/generic/THCTensorMath.h index fd1d420..f6c2d8a 100644 --- a/aten/src/THC/generic/THCTensorMath.h +++ b/aten/src/THC/generic/THCTensorMath.h @@ -4,15 +4,17 @@ THC_API void THCTensor_(fill)(THCState *state, THCTensor *self, scalar_t value); THC_API void THCTensor_(zero)(THCState *state, THCTensor *self); - -THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t); THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THCTensor *tb, int dimension); THC_API void THCTensor_(catArray)(THCState *state, THCTensor *result, THCTensor **inputs, int numInputs, int dimension); -THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self); +THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t); + +#if !defined(THC_REAL_IS_BOOL) /* non bool only part */ +THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self); THC_API void THCTensor_(triu)(THCState *state, THCTensor *self, THCTensor *src, int64_t k); THC_API void THCTensor_(diag)(THCState *state, THCTensor *self, THCTensor *src, int64_t k); - THC_API accreal THCTensor_(trace)(THCState *state, THCTensor *self); #endif + +#endif diff --git a/c10/util/Half.h b/c10/util/Half.h index 3bb6f48..82df870 100644 --- a/c10/util/Half.h +++ b/c10/util/Half.h @@ -420,7 +420,7 @@ struct Converter< // In some versions of MSVC, there will be a compiler error when building. // C4146: unary minus operator applied to unsigned type, result still unsigned // C4804: unsafe use of type 'bool' in operation -// It can be addressed by disabling the following warning. +// It can be addressed by disabling the following warning. #ifdef _MSC_VER #pragma warning( push ) #pragma warning( disable : 4146 ) @@ -482,7 +482,8 @@ typename std::enable_if::value, bool>::type overflows( template To checked_convert(From f, const char* name) { - if (overflows(f)) { + // Converting to bool can't overflow so we exclude this case from checking. + if (!std::is_same::value && overflows(f)) { std::ostringstream oss; oss << "value cannot be converted to type " << name << " without overflow: " << f; diff --git a/test/test_torch.py b/test/test_torch.py index 7d44307..2f8d1fc 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -365,8 +365,7 @@ class _TestTorchMixin(object): self.assertEqual(res1, res2) def test_logical_any(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): x = torch.zeros([2, 3, 400], dtype=torch.uint8, device=device) self.assertEqual( @@ -405,8 +404,7 @@ class _TestTorchMixin(object): self.assertEqual(y, x.any(2, keepdim=True)) def test_logical_all(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): x = torch.ones([2, 3, 400], dtype=torch.uint8, device=device) self.assertEqual( @@ -1129,9 +1127,8 @@ class _TestTorchMixin(object): ('logsumexp', torch.logsumexp, -inf), ] - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] shape = (2, 0, 4) - for device in devices: + for device in torch.testing.get_all_device_types(): x = torch.randn(shape, device=device) for item in fns_to_test: @@ -1173,8 +1170,7 @@ class _TestTorchMixin(object): self.assertEqual(torch.ones((), device=device), xb.all()) def test_pairwise_distance_empty(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): shape = (2, 0) x = torch.randn(shape, device=device) y = torch.randn(shape, device=device) @@ -1189,8 +1185,7 @@ class _TestTorchMixin(object): self.assertEqual(torch.zeros((0, 1), device=device), torch.pairwise_distance(x, y, keepdim=True)) def test_pdist_empty(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): shape = (0, 2) x = torch.randn(shape, device=device) self.assertEqual(torch.empty(0, device=device), torch.pdist(x)) @@ -1213,8 +1208,7 @@ class _TestTorchMixin(object): self.assertEqual(expected.shape, actual.shape) self.assertTrue(torch.allclose(expected, actual)) - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): for shape in [(4, 5), (3, 2), (2, 1)]: for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]: for trans in [False, True]: @@ -1227,8 +1221,7 @@ class _TestTorchMixin(object): test_pdist_single((1000, 2), device, 2, dtype, False) def test_cdist_empty(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): x = torch.randn((0, 5), device=device) y = torch.randn((4, 5), device=device) self.assertEqual(torch.empty(0, 4, device=device), torch.cdist(x, y)) @@ -1246,8 +1239,7 @@ class _TestTorchMixin(object): self.assertEqual(torch.empty(2, 0, device=device), torch.cdist(x, y)) def test_cdist_norm(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): for r1 in [3, 4, 5, 6]: for m in [2, 3, 4, 10]: for r2 in [4, 6, 7, 8]: @@ -1259,8 +1251,7 @@ class _TestTorchMixin(object): self.assertTrue(torch.allclose(expected, actual)) def test_cdist_large(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): x = torch.randn(1000, 10, device=device) y = torch.randn(1000, 10, device=device) actual = torch.cdist(x, y, p=2) @@ -1268,8 +1259,7 @@ class _TestTorchMixin(object): self.assertTrue(torch.allclose(expected, actual)) def test_cdist_non_contiguous(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): x = torch.randn(5, 7, device=device).t() y = torch.randn(5, 3, device=device).t() actual = torch.cdist(x, y, p=2) @@ -1551,7 +1541,7 @@ class _TestTorchMixin(object): self._test_neg(self, lambda t: t) def test_threshold(self): - for dtype in torch.testing.get_all_dtypes(): + for dtype in torch.testing.get_all_math_dtypes('cpu'): if dtype != torch.uint8 and dtype != torch.float16: # 100 is wide enough to use AVX2 instructions for all types x = torch.randn(100).sign().to(dtype=dtype) @@ -1584,7 +1574,7 @@ class _TestTorchMixin(object): self.assertEqual(res1, res2) def test_floordiv(self): - for dtype in torch.testing.get_all_dtypes(): + for dtype in torch.testing.get_all_math_dtypes('cpu'): if dtype is torch.float16: continue x = torch.randn(100).mul(10).to(dtype) @@ -1594,7 +1584,7 @@ class _TestTorchMixin(object): self.assertEqual(y, z) def test_rdiv(self): - for dtype in torch.testing.get_all_dtypes(): + for dtype in torch.testing.get_all_math_dtypes('cpu'): if dtype is torch.float16: continue x = torch.rand(100).add(1).mul(4).to(dtype) @@ -2323,7 +2313,7 @@ class _TestTorchMixin(object): # 'out' is favored over dtype, check error self.assertRaises(RuntimeError, lambda: fn(x, out=out, dtype=other_dtype)) - for dtype in [dtype for dtype in torch.testing.get_all_dtypes() if dtype != torch.float16]: + for dtype in [dtype for dtype in torch.testing.get_all_math_dtypes('cpu') if dtype != torch.float16]: x = torch.ones(shape, dtype=dtype) expected_dtype = dtype if dtype.is_floating_point else torch.int64 self.assertIs(expected_dtype, fn(x).dtype) @@ -2722,10 +2712,10 @@ class _TestTorchMixin(object): self.assertEqual(b.device, a.to(b, non_blocking=non_blocking).device) def test_empty_full(self): - do_test_empty_full(self, torch.testing.get_all_dtypes(), torch.strided, torch.device('cpu')) + do_test_empty_full(self, torch.testing.get_all_math_dtypes('cpu'), torch.strided, torch.device('cpu')) if torch.cuda.device_count() > 0: - do_test_empty_full(self, torch.testing.get_all_dtypes(), torch.strided, None) - do_test_empty_full(self, torch.testing.get_all_dtypes(), torch.strided, torch.device('cuda:0')) + do_test_empty_full(self, torch.testing.get_all_math_dtypes('cpu'), torch.strided, None) + do_test_empty_full(self, torch.testing.get_all_math_dtypes('cpu'), torch.strided, torch.device('cuda:0')) def test_dtype_out_match(self): d = torch.autograd.Variable(torch.DoubleTensor(2, 3)) @@ -2923,60 +2913,94 @@ class _TestTorchMixin(object): self.assertTrue(x.is_cuda) torch.set_default_tensor_type(saved_type) - # This is a temporary test for a boolean tensors on CPU. Once the CUDA part - # will be done, these test cases will be moved down to test_tensor_factories_empty test - def test_tensor_factories_bool(self): - expectedShape = (1, 2) - test = torch.empty(expectedShape, dtype=torch.bool) - self.assertEqual(expectedShape, test.shape) - - test2 = torch.empty_like(test, dtype=torch.bool) - self.assertEqual(test.shape, test2.shape) - - test = torch.full(expectedShape, True, dtype=torch.bool) - self.assertEqual(test, torch.tensor([[True, True]], dtype=torch.bool)) - - test2 = torch.full_like(test, True, dtype=torch.bool) - self.assertEqual(test, test2) - - test = torch.zeros(expectedShape, dtype=torch.bool) - self.assertEqual(test, torch.tensor([[False, False]], dtype=torch.bool)) - - test2 = torch.zeros_like(test, dtype=torch.bool) - self.assertEqual(test, test2) - - test = torch.ones(expectedShape, dtype=torch.bool) - self.assertEqual(test, torch.tensor([[True, True]], dtype=torch.bool)) - - test2 = torch.ones_like(test, dtype=torch.bool) - self.assertEqual(test, test2) + def test_unfold_all_devices_and_dtypes(self): + for device in torch.testing.get_all_device_types(): + for dt in torch.testing.get_all_dtypes(): + if dt == torch.half and device == 'cpu': + # fix once random is implemented for Half on CPU + self.assertRaises(RuntimeError, lambda: torch.randint(5, (0, 1, 3, 0), dtype=dt, device=device)) + else: + x = torch.randint(5, (0, 1, 3, 0), dtype=dt, device=device) + self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape) - test = torch.randint(10, expectedShape, dtype=torch.bool) - self.assertEqual(expectedShape, test.shape) - self.assertEqual(torch.bool, test.dtype) + def test_copy_all_dtypes_and_devices(self): + from copy import copy + for device in torch.testing.get_all_device_types(): + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor([1, 2, 3, 4], dtype=dt, device=device) + x_clone = x.clone() + + y = copy(x) + y.fill_(1) + + # copy is a shallow copy, only copies the tensor view, + # not the data + self.assertEqual(x, y) + + def test_resize_all_dtypes_and_devices(self): + shape = (2, 2) + for device in torch.testing.get_all_device_types(): + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) + x.resize_(shape) + self.assertEqual(shape, x.shape) + + def test_fill_all_dtypes_and_devices(self): + for device in torch.testing.get_all_device_types(): + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor((1, 1), dtype=dt, device=device) + x.fill_(1) + + self.assertEqual(x, torch.tensor([1, 1], dtype=dt, device=device)) + self.assertEqual(dt, x.dtype) + + def test_clone_all_dtypes_and_devices(self): + for device in torch.testing.get_all_device_types(): + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor((1, 1), dtype=dt, device=device) + y = x.clone() + self.assertEqual(x, y) + + def test_cat_all_dtypes_and_devices(self): + for device in torch.testing.get_all_device_types(): + for dt in torch.testing.get_all_dtypes(): + x = torch.tensor([[1, 2], [3, 4]], dtype=dt, device=device) + expected1 = torch.tensor([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=dt, device=device) + self.assertEqual(torch.cat((x, x), 0), expected1) + + expected2 = torch.tensor([[1, 2, 1, 2], [3, 4, 3, 4]], dtype=dt, device=device) + self.assertEqual(torch.cat((x, x), 1), expected2) def test_tensor_factories_empty(self): # ensure we can create empty tensors from each factory function shapes = [(5, 0, 1), (0,), (0, 0, 1, 0, 2, 0, 0)] - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): for shape in shapes: - self.assertEqual(shape, torch.zeros(shape, device=device).shape) - self.assertEqual(shape, torch.zeros_like(torch.zeros(shape, device=device)).shape) - self.assertEqual(shape, torch.empty(shape, device=device).shape) - self.assertEqual(shape, torch.empty_like(torch.zeros(shape, device=device)).shape) - self.assertEqual(shape, torch.empty_strided(shape, (0,) * len(shape), device=device).shape) - self.assertEqual(shape, torch.full(shape, 3, device=device).shape) - self.assertEqual(shape, torch.full_like(torch.zeros(shape, device=device), 3).shape) - self.assertEqual(shape, torch.ones(shape, device=device).shape) - self.assertEqual(shape, torch.ones_like(torch.zeros(shape, device=device)).shape) - self.assertEqual(shape, torch.rand(shape, device=device).shape) - self.assertEqual(shape, torch.rand_like(torch.zeros(shape, device=device)).shape) - self.assertEqual(shape, torch.randn(shape, device=device).shape) - self.assertEqual(shape, torch.randn_like(torch.zeros(shape, device=device)).shape) - self.assertEqual(shape, torch.randint(6, shape, device=device).shape) - self.assertEqual(shape, torch.randint_like(torch.zeros(shape, device=device), 6).shape) + for dt in torch.testing.get_all_dtypes(): + self.assertEqual(shape, torch.zeros(shape, device=device, dtype=dt).shape) + self.assertEqual(shape, torch.zeros_like(torch.zeros(shape, device=device, dtype=dt)).shape) + self.assertEqual(shape, torch.full(shape, 3, device=device, dtype=dt).shape) + self.assertEqual(shape, torch.full_like(torch.zeros(shape, device=device, dtype=dt), 3).shape) + self.assertEqual(shape, torch.ones(shape, device=device, dtype=dt).shape) + self.assertEqual(shape, torch.ones_like(torch.zeros(shape, device=device, dtype=dt)).shape) + self.assertEqual(shape, torch.empty(shape, device=device, dtype=dt).shape) + self.assertEqual(shape, torch.empty_like(torch.zeros(shape, device=device, dtype=dt)).shape) + self.assertEqual(shape, torch.empty_strided(shape, (0,) * len(shape), device=device, dtype=dt).shape) + + if dt == torch.half and device == "cpu": + # update once random is implemented for half on CPU + self.assertRaises(RuntimeError, lambda: torch.randint(6, shape, device=device, dtype=dt).shape) + else: + self.assertEqual(shape, torch.randint(6, shape, device=device, dtype=dt).shape) + self.assertEqual(shape, torch.randint_like(torch.zeros(shape, device=device, dtype=dt), 6).shape) + + if dt != torch.double and dt != torch.float and dt != torch.half: + self.assertRaises(RuntimeError, lambda: torch.rand(shape, device=device, dtype=dt).shape) + + if dt == torch.double or dt == torch.float: + self.assertEqual(shape, torch.randn(shape, device=device, dtype=dt).shape) + self.assertEqual(shape, torch.randn_like(torch.zeros(shape, device=device, dtype=dt)).shape) self.assertEqual((0,), torch.arange(0, device=device).shape) self.assertEqual((0, 0), torch.eye(0, device=device).shape) @@ -3504,8 +3528,7 @@ class _TestTorchMixin(object): self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('inf'))) self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(float('inf'))) - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(-5, float('nan'), device=device)) # check with step size self.assertRaisesRegex(RuntimeError, msg, lambda: torch.arange(0, float('-inf'), -1, device=device)) @@ -3793,9 +3816,8 @@ class _TestTorchMixin(object): def test_empty_tensor_props(self): sizes = [(0,), (0, 3), (5, 0), (5, 0, 3, 0, 2), (0, 3, 0, 2), (0, 5, 0, 2, 0)] - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] for size in sizes: - for device in devices: + for device in torch.testing.get_all_device_types(): x = torch.empty(tuple(size), device=device) self.assertEqual(size, x.shape) self.assertTrue(x.is_contiguous()) @@ -4074,8 +4096,7 @@ class _TestTorchMixin(object): @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') def test_tensordot(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for d in devices: + for d in torch.testing.get_all_device_types(): a = torch.arange(60., device=d).reshape(3, 4, 5) b = torch.arange(24., device=d).reshape(4, 3, 2) c = torch.tensordot(a, b, dims=([1, 0], [0, 1])).cpu() @@ -4489,8 +4510,7 @@ class _TestTorchMixin(object): self.assertEqual(x.narrow(-2, -1, 1), torch.Tensor([[6, 7, 8]])) def test_narrow_empty(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): x = torch.randn(2, 3, 4, device=device) for d in range(x.dim()): y = x.narrow(d, x.size(d), 0) @@ -4547,8 +4567,7 @@ class _TestTorchMixin(object): @skipIfRocm def test_linspace(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): _from = random.random() to = _from + random.random() res1 = torch.linspace(_from, to, 137, device=device) @@ -7207,8 +7226,7 @@ class _TestTorchMixin(object): check(src.transpose(1, 2), idx) def test_take_empty(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): for input_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]: for indices_shape in [(0,), (0, 1, 2, 0)]: input = torch.empty(input_shape, device=device) @@ -7237,8 +7255,7 @@ class _TestTorchMixin(object): self.assertEqual(dst.tolist(), [[5, 7], [1, 1]]) def test_put_empty(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): for dst_shape in [(0,), (0, 1, 2, 0), (1, 2, 3)]: for indices_shape in [(0,), (0, 1, 2, 0)]: for accumulate in [False, True]: @@ -7766,8 +7783,7 @@ class _TestTorchMixin(object): self.assertRaises(RuntimeError, lambda: x.reshape(0, -1)) def test_tensor_shape_empty(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): x = torch.randn((0, 1, 3, 0), device=device) # flatten self.assertEqual((0,), torch.flatten(x, 0, 3).shape) @@ -7786,10 +7802,6 @@ class _TestTorchMixin(object): # select self.assertEqual((0, 1, 0), torch.select(x, 2, 2).shape) - # unfold - self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape) - y = torch.randn((0, 1, 3), device=device) - self.assertEqual((1, 1, 3, 0), y.unfold(0, 0, 4).shape) # repeat, permute self.assertEqual((9, 0, 5, 6, 0), x.repeat(9, 7, 5, 2, 3).shape) @@ -7832,8 +7844,7 @@ class _TestTorchMixin(object): # functions that operate over a dimension but don't reduce. @skipIfRocm def test_dim_function_empty(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): shape = (0, 1, 2, 0) x = torch.randn(shape, device=device) @@ -7964,8 +7975,7 @@ class _TestTorchMixin(object): @skipIfRocm def test_blas_empty(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): def fn(torchfn, *args): return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape @@ -8035,8 +8045,7 @@ class _TestTorchMixin(object): @skipIfRocm def test_blas_alpha_beta_empty(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): # ensure beta is respected value = 11 input = torch.full((2,), value, device=device) @@ -8066,8 +8075,7 @@ class _TestTorchMixin(object): # numpy/sci often has a direct wrapper (e.g. lu_factor) and a wrapper that "does the right thing" # (e.g. lu). We often name our functions identically to the lapack function, so it will take work # to name / migrate-to better wrappers. - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): # need to init cuda to check has_magma empty = torch.randn((0, 0), device=device) @@ -8203,6 +8211,10 @@ class _TestTorchMixin(object): "Tensors with no storages should not appear to be set " "to each other") + t1 = torch.tensor([True, True], dtype=torch.bool) + t2 = torch.tensor([0], dtype=torch.bool).set_(t1) + self.assertTrue(t1.is_set_to(t2)) + def test_tensor_set(self): t1 = torch.Tensor() t2 = torch.Tensor(3, 4, 9, 10).uniform_() @@ -8234,6 +8246,11 @@ class _TestTorchMixin(object): self.assertEqual(t1.size(), size) self.assertEqual(t1.stride(), stride) + t1 = torch.tensor([True, True], dtype=torch.bool) + t2 = torch.tensor([False, False], dtype=torch.bool) + t1.set_(t2) + self.assertEqual(t1.storage()._cdata, t2.storage()._cdata) + def test_equal(self): # Contiguous, 1D t1 = torch.Tensor((3, 4, 9, 10)) @@ -8448,8 +8465,7 @@ class _TestTorchMixin(object): self._test_flip(self, use_cuda=False) def test_roll(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): numbers = torch.arange(1, 9, device=device) single_roll = numbers.roll(1, 0) @@ -8616,8 +8632,7 @@ class _TestTorchMixin(object): self.assertNotEqual(tensor[dst1[i, 0], dst1[i, 1], dst1[i, 2]].item(), 0) def test_nonzero_empty(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): x = torch.randn(0, 2, 0, 5, 0, device=device) y = torch.nonzero(x) self.assertEqual(0, y.numel()) @@ -8670,16 +8685,6 @@ class _TestTorchMixin(object): self.assertEqual(torch.nn.Parameter, type(s2['weight'])) self.assertEqual(torch.nn.Parameter, type(s2['bias'])) - def test_copy(self): - from copy import copy - a = torch.randn(5, 5) - a_clone = a.clone() - b = copy(a) - b.fill_(1) - # copy is a shallow copy, only copies the tensor view, - # not the data - self.assertEqual(a, b) - def test_pickle(self): if sys.version_info[0] == 2: import cPickle as pickle @@ -9820,8 +9825,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.], self.assertEqual(torch.empty_like(a).type(), a.type()) def test_empty_strided(self): - devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] - for device in devices: + for device in torch.testing.get_all_device_types(): for shape in [(2, 3, 4), (0, 2, 0)]: # some of these cases are pretty strange, just verifying that if as_strided # allows them then empty_strided can as well. diff --git a/torch/testing/__init__.py b/torch/testing/__init__.py index eb2c920..894d823 100644 --- a/torch/testing/__init__.py +++ b/torch/testing/__init__.py @@ -87,10 +87,24 @@ def make_non_contiguous(tensor): def get_all_dtypes(): - return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, + return [torch.uint8, torch.bool, torch.int8, torch.int16, torch.int32, torch.int64, torch.float16, torch.float32, torch.float64] +def get_all_math_dtypes(device): + dtypes = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, + torch.float32, torch.float64] + + # torch.float16 is a math dtype on cuda but not cpu. + if device == 'cpu': + return dtypes + else: + return dtypes.append(torch.float16) + + +def get_all_device_types(): + return ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] + # 'dtype': (rtol, atol) _default_tolerances = { 'float64': (1e-5, 1e-8), # NumPy default -- 2.7.4