From 507fe66beaad18b30eda8c0de77afd075c55e602 Mon Sep 17 00:00:00 2001 From: Iurii Zdebskyi Date: Thu, 11 Apr 2019 14:25:21 -0700 Subject: [PATCH] Enable comp ops for bool tensor (#19109) Summary: Enabled comparison ops for bool tensors Pull Request resolved: https://github.com/pytorch/pytorch/pull/19109 Differential Revision: D14871187 Pulled By: izdeby fbshipit-source-id: cf9951847d69124a93e5e21dd0a39c9568b1037d --- aten/src/ATen/Declarations.cwrap | 24 ++++++++ aten/src/TH/THTensorMath.cpp | 3 + aten/src/TH/THTensorMoreMath.cpp | 3 + aten/src/TH/generic/THTensorMath.cpp | 3 + aten/src/TH/generic/THTensorMath.h | 56 ++++++++--------- aten/src/TH/generic/THTensorMoreMath.cpp | 72 +++++++++++----------- aten/src/THC/CMakeLists.txt | 8 +++ aten/src/THC/THCNumerics.cuh | 10 +++ aten/src/THC/THCTensorMath.h | 6 ++ aten/src/THC/generated/THCTensorMathCompareBool.cu | 5 ++ .../src/THC/generated/THCTensorMathCompareTBool.cu | 5 ++ test/test_torch.py | 14 +++++ 12 files changed, 146 insertions(+), 63 deletions(-) create mode 100644 aten/src/THC/generated/THCTensorMathCompareBool.cu create mode 100644 aten/src/THC/generated/THCTensorMathCompareTBool.cu diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index 43fb088..96fda31 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -526,6 +526,8 @@ ]] [[ name: _th_lt + cpu_bool: True + cuda_bool: True variants: - function return: argument 0 @@ -546,6 +548,8 @@ ]] [[ name: _th_lt_ + cpu_bool: True + cuda_bool: True return: self variants: function options: @@ -563,6 +567,8 @@ ]] [[ name: _th_gt + cpu_bool: True + cuda_bool: True variants: - function return: argument 0 @@ -583,6 +589,8 @@ ]] [[ name: _th_gt_ + cpu_bool: True + cuda_bool: True return: self variants: function options: @@ -600,6 +608,8 @@ ]] [[ name: _th_le + cpu_bool: True + cuda_bool: True variants: - function return: argument 0 @@ -620,6 +630,8 @@ ]] [[ name: _th_le_ + cpu_bool: True + cuda_bool: True return: self variants: function options: @@ -637,6 +649,8 @@ ]] [[ name: _th_ge + cpu_bool: True + cuda_bool: True variants: - function return: argument 0 @@ -657,6 +671,8 @@ ]] [[ name: _th_ge_ + cpu_bool: True + cuda_bool: True return: self variants: function options: @@ -674,6 +690,8 @@ ]] [[ name: _th_eq + cpu_bool: True + cuda_bool: True variants: - function return: argument 0 @@ -694,6 +712,8 @@ ]] [[ name: _th_eq_ + cpu_bool: True + cuda_bool: True return: self variants: function options: @@ -711,6 +731,8 @@ ]] [[ name: _th_ne + cpu_bool: True + cuda_bool: True variants: - function return: argument 0 @@ -731,6 +753,8 @@ ]] [[ name: _th_ne_ + cpu_bool: True + cuda_bool: True return: self variants: function options: diff --git a/aten/src/TH/THTensorMath.cpp b/aten/src/TH/THTensorMath.cpp index 4984772..e0cfbfa 100644 --- a/aten/src/TH/THTensorMath.cpp +++ b/aten/src/TH/THTensorMath.cpp @@ -5,3 +5,6 @@ #include #include + +#include +#include diff --git a/aten/src/TH/THTensorMoreMath.cpp b/aten/src/TH/THTensorMoreMath.cpp index ba28390..ddbc7dc 100644 --- a/aten/src/TH/THTensorMoreMath.cpp +++ b/aten/src/TH/THTensorMoreMath.cpp @@ -5,3 +5,6 @@ #include #include + +#include +#include diff --git a/aten/src/TH/generic/THTensorMath.cpp b/aten/src/TH/generic/THTensorMath.cpp index 2e14cd0..f9a31f2 100644 --- a/aten/src/TH/generic/THTensorMath.cpp +++ b/aten/src/TH/generic/THTensorMath.cpp @@ -21,6 +21,7 @@ // sense (rather than just having cut the file down the middle, which is // what I did when I split these up originally). +#if !defined(TH_REAL_IS_BOOL) /* non bool only part */ // Should wrap if the value (a) has a different sign than the divisor (b), but is not 0. static inline bool modulo_wrap(scalar_t a, scalar_t b) { @@ -1197,4 +1198,6 @@ void THTensor_(addbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t al c10::raw::intrusive_ptr::decref(matrix2); } +#endif /* !defined(TH_REAL_IS_BOOL) */ + #endif /* TH_GENERIC_FILE */ diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h index a0766cb..55aaea8 100644 --- a/aten/src/TH/generic/THTensorMath.h +++ b/aten/src/TH/generic/THTensorMath.h @@ -4,6 +4,34 @@ TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor); +TH_API void THTensor_(ltValue)(THByteTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(leValue)(THByteTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(gtValue)(THByteTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(geValue)(THByteTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(neValue)(THByteTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(eqValue)(THByteTensor *r_, THTensor* t, scalar_t value); + +TH_API void THTensor_(ltValueT)(THTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(leValueT)(THTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(gtValueT)(THTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(geValueT)(THTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(neValueT)(THTensor *r_, THTensor* t, scalar_t value); +TH_API void THTensor_(eqValueT)(THTensor *r_, THTensor* t, scalar_t value); + +TH_API void THTensor_(ltTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(leTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(gtTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(geTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(neTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(eqTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); + +TH_API void THTensor_(ltTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(leTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(gtTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(geTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(neTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); +TH_API void THTensor_(eqTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); + #if !defined(TH_REAL_IS_BOOL) /* non bool only part */ TH_API void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value); @@ -96,34 +124,6 @@ TH_API void THTensor_(triu)(THTensor *r_, THTensor *t, int64_t k); TH_API int THTensor_(equal)(THTensor *ta, THTensor *tb); -TH_API void THTensor_(ltValue)(THByteTensor *r_, THTensor* t, scalar_t value); -TH_API void THTensor_(leValue)(THByteTensor *r_, THTensor* t, scalar_t value); -TH_API void THTensor_(gtValue)(THByteTensor *r_, THTensor* t, scalar_t value); -TH_API void THTensor_(geValue)(THByteTensor *r_, THTensor* t, scalar_t value); -TH_API void THTensor_(neValue)(THByteTensor *r_, THTensor* t, scalar_t value); -TH_API void THTensor_(eqValue)(THByteTensor *r_, THTensor* t, scalar_t value); - -TH_API void THTensor_(ltValueT)(THTensor *r_, THTensor* t, scalar_t value); -TH_API void THTensor_(leValueT)(THTensor *r_, THTensor* t, scalar_t value); -TH_API void THTensor_(gtValueT)(THTensor *r_, THTensor* t, scalar_t value); -TH_API void THTensor_(geValueT)(THTensor *r_, THTensor* t, scalar_t value); -TH_API void THTensor_(neValueT)(THTensor *r_, THTensor* t, scalar_t value); -TH_API void THTensor_(eqValueT)(THTensor *r_, THTensor* t, scalar_t value); - -TH_API void THTensor_(ltTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); -TH_API void THTensor_(leTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); -TH_API void THTensor_(gtTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); -TH_API void THTensor_(geTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); -TH_API void THTensor_(neTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); -TH_API void THTensor_(eqTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb); - -TH_API void THTensor_(ltTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); -TH_API void THTensor_(leTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); -TH_API void THTensor_(gtTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); -TH_API void THTensor_(geTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); -TH_API void THTensor_(neTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); -TH_API void THTensor_(eqTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); - TH_API void THTensor_(pow)(THTensor *r_, THTensor *t, scalar_t value); TH_API void THTensor_(tpow)(THTensor *r_, scalar_t value, THTensor *t); TH_API void THTensor_(abs)(THTensor *r_, THTensor *t); diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp index 48fa373..a3310e4 100644 --- a/aten/src/TH/generic/THTensorMoreMath.cpp +++ b/aten/src/TH/generic/THTensorMoreMath.cpp @@ -5,6 +5,41 @@ #include #include +#define TENSOR_IMPLEMENT_LOGICAL(NAME,OP) \ + void THTensor_(NAME##Value)(THByteTensor *r_, THTensor* t, scalar_t value) \ + { \ + THByteTensor_resizeNd(r_, t->dim(), THTensor_getSizePtr(t), NULL); \ + TH_TENSOR_APPLY2(unsigned char, r_, scalar_t, t, \ + *r__data = (*t_data OP value) ? 1 : 0;); \ + } \ + void THTensor_(NAME##ValueT)(THTensor* r_, THTensor* t, scalar_t value) \ + { \ + THTensor_(resizeNd)(r_, t->dim(), THTensor_getSizePtr(t), NULL); \ + TH_TENSOR_APPLY2(scalar_t, r_, scalar_t, t, \ + *r__data = (*t_data OP value) ? 1 : 0;); \ + } \ + void THTensor_(NAME##Tensor)(THByteTensor *r_, THTensor *ta, THTensor *tb) \ + { \ + THByteTensor_resizeNd(r_, ta->dim(), THTensor_getSizePtr(ta), NULL); \ + TH_TENSOR_APPLY3(unsigned char, r_, scalar_t, ta, scalar_t, tb, \ + *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \ + } \ + void THTensor_(NAME##TensorT)(THTensor *r_, THTensor *ta, THTensor *tb) \ + { \ + THTensor_(resizeNd)(r_, ta->dim(), THTensor_getSizePtr(ta), NULL); \ + TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, ta, scalar_t, tb, \ + *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \ + } \ + +TENSOR_IMPLEMENT_LOGICAL(lt,<) +TENSOR_IMPLEMENT_LOGICAL(gt,>) +TENSOR_IMPLEMENT_LOGICAL(le,<=) +TENSOR_IMPLEMENT_LOGICAL(ge,>=) +TENSOR_IMPLEMENT_LOGICAL(eq,==) +TENSOR_IMPLEMENT_LOGICAL(ne,!=) + +#if !defined(TH_REAL_IS_BOOL) /* non bool only part */ + void THTensor_(baddbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *batch1, THTensor *batch2) { int64_t batch; @@ -999,41 +1034,6 @@ int THTensor_(equal)(THTensor *ta, THTensor* tb) return equal; } -#define TENSOR_IMPLEMENT_LOGICAL(NAME,OP) \ - void THTensor_(NAME##Value)(THByteTensor *r_, THTensor* t, scalar_t value) \ - { \ - THByteTensor_resizeNd(r_, t->dim(), THTensor_getSizePtr(t), NULL); \ - TH_TENSOR_APPLY2(unsigned char, r_, scalar_t, t, \ - *r__data = (*t_data OP value) ? 1 : 0;); \ - } \ - void THTensor_(NAME##ValueT)(THTensor* r_, THTensor* t, scalar_t value) \ - { \ - THTensor_(resizeNd)(r_, t->dim(), THTensor_getSizePtr(t), NULL); \ - TH_TENSOR_APPLY2(scalar_t, r_, scalar_t, t, \ - *r__data = (*t_data OP value) ? 1 : 0;); \ - } \ - void THTensor_(NAME##Tensor)(THByteTensor *r_, THTensor *ta, THTensor *tb) \ - { \ - THByteTensor_resizeNd(r_, ta->dim(), THTensor_getSizePtr(ta), NULL); \ - TH_TENSOR_APPLY3(unsigned char, r_, scalar_t, ta, scalar_t, tb, \ - *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \ - } \ - void THTensor_(NAME##TensorT)(THTensor *r_, THTensor *ta, THTensor *tb) \ - { \ - THTensor_(resizeNd)(r_, ta->dim(), THTensor_getSizePtr(ta), NULL); \ - TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, ta, scalar_t, tb, \ - *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \ - } \ - - -TENSOR_IMPLEMENT_LOGICAL(lt,<) -TENSOR_IMPLEMENT_LOGICAL(gt,>) -TENSOR_IMPLEMENT_LOGICAL(le,<=) -TENSOR_IMPLEMENT_LOGICAL(ge,>=) -TENSOR_IMPLEMENT_LOGICAL(eq,==) -TENSOR_IMPLEMENT_LOGICAL(ne,!=) - - #ifdef _OPENMP #define LAB_IMPLEMENT_BASIC_FUNCTION_3_ARGS(NAME, CFUNC, OMP_THRESHOLD) \ @@ -1681,4 +1681,6 @@ void THTensor_(dirichlet_grad)(THTensor *self, THTensor *x, THTensor *alpha, THT #endif /* floating point only part */ #undef IS_NONZERO +#endif /* !defined(TH_REAL_IS_BOOL) */ + #endif /* TH_GENERIC_FILE */ diff --git a/aten/src/THC/CMakeLists.txt b/aten/src/THC/CMakeLists.txt index 740e097..56e759d 100644 --- a/aten/src/THC/CMakeLists.txt +++ b/aten/src/THC/CMakeLists.txt @@ -18,6 +18,14 @@ foreach(THC_TYPE Byte Char Short Int Long Half Float Double) endforeach() endforeach() +foreach(THC_FILE TensorMathCompareT TensorMathCompare) + if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}Bool.cu") + FILE(WRITE "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}Bool.cu" + "#include \n#include \n\n#include \n#include \n") + endif() + LIST(APPEND extra_src "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}Bool.cu") +endforeach() + set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/THCCachingHostAllocator.cpp ${CMAKE_CURRENT_SOURCE_DIR}/THCGeneral.cpp diff --git a/aten/src/THC/THCNumerics.cuh b/aten/src/THC/THCNumerics.cuh index 2547277..a7f689b 100644 --- a/aten/src/THC/THCNumerics.cuh +++ b/aten/src/THC/THCNumerics.cuh @@ -65,6 +65,16 @@ struct THCNumerics { }; template <> +struct THCNumerics { + static inline __host__ __device__ bool lt(uint8_t a, uint8_t b) { return a < b; } + static inline __host__ __device__ bool le(uint8_t a, uint8_t b) { return a <= b; } + static inline __host__ __device__ bool gt(uint8_t a, uint8_t b) { return a > b; } + static inline __host__ __device__ bool ge(uint8_t a, uint8_t b) { return a >= b; } + static inline __host__ __device__ bool eq(uint8_t a, uint8_t b) { return a == b; } + static inline __host__ __device__ bool ne(uint8_t a, uint8_t b) { return a != b; } +}; + +template <> struct THCNumerics { static inline __host__ __device__ int8_t min() { return at::numeric_limits::lowest(); } static inline __host__ __device__ int8_t max() { return at::numeric_limits::max(); } diff --git a/aten/src/THC/THCTensorMath.h b/aten/src/THC/THCTensorMath.h index 078b1cd..acdce2f 100644 --- a/aten/src/THC/THCTensorMath.h +++ b/aten/src/THC/THCTensorMath.h @@ -28,9 +28,15 @@ #include #include +#include +#include + #include #include +#include +#include + #include #include diff --git a/aten/src/THC/generated/THCTensorMathCompareBool.cu b/aten/src/THC/generated/THCTensorMathCompareBool.cu new file mode 100644 index 0000000..25c1585 --- /dev/null +++ b/aten/src/THC/generated/THCTensorMathCompareBool.cu @@ -0,0 +1,5 @@ +#include +#include + +#include +#include diff --git a/aten/src/THC/generated/THCTensorMathCompareTBool.cu b/aten/src/THC/generated/THCTensorMathCompareTBool.cu new file mode 100644 index 0000000..07418ba --- /dev/null +++ b/aten/src/THC/generated/THCTensorMathCompareTBool.cu @@ -0,0 +1,5 @@ +#include +#include + +#include +#include diff --git a/test/test_torch.py b/test/test_torch.py index ebfd34e..6d7651f 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -3014,6 +3014,20 @@ class _TestTorchMixin(object): self.assertTrue(x.is_cuda) torch.set_default_tensor_type(saved_type) + def test_bool_tensor_comparison_ops(self): + a = torch.tensor([True, False, True, False, True, False], dtype=torch.bool) + b = torch.tensor([True, False, True, True, True, True], dtype=torch.bool) + for device in torch.testing.get_all_device_types(): + self.assertEqual(a == b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.uint8)) + self.assertEqual(a != b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.uint8)) + self.assertEqual(a < b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.uint8)) + self.assertEqual(a > b, torch.tensor([0, 0, 0, 0, 0, 0], dtype=torch.uint8)) + self.assertEqual(a >= b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.uint8)) + self.assertEqual(a <= b, torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.uint8)) + self.assertEqual(a > False, torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.uint8)) + self.assertEqual(a == torch.tensor(True, dtype=torch.bool), torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.uint8)) + self.assertEqual(a == torch.tensor(0, dtype=torch.bool), torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.uint8)) + def test_bool_tensor_value_change(self): for device in torch.testing.get_all_device_types(): x = torch.tensor([True, False], dtype=torch.bool) -- 2.7.4