Enable comp ops for bool tensor (#19109)
authorIurii Zdebskyi <iuriiz@fb.com>
Thu, 11 Apr 2019 21:25:21 +0000 (14:25 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 11 Apr 2019 21:37:10 +0000 (14:37 -0700)
Summary:
Enabled comparison ops for bool tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19109

Differential Revision: D14871187

Pulled By: izdeby

fbshipit-source-id: cf9951847d69124a93e5e21dd0a39c9568b1037d

12 files changed:
aten/src/ATen/Declarations.cwrap
aten/src/TH/THTensorMath.cpp
aten/src/TH/THTensorMoreMath.cpp
aten/src/TH/generic/THTensorMath.cpp
aten/src/TH/generic/THTensorMath.h
aten/src/TH/generic/THTensorMoreMath.cpp
aten/src/THC/CMakeLists.txt
aten/src/THC/THCNumerics.cuh
aten/src/THC/THCTensorMath.h
aten/src/THC/generated/THCTensorMathCompareBool.cu [new file with mode: 0644]
aten/src/THC/generated/THCTensorMathCompareTBool.cu [new file with mode: 0644]
test/test_torch.py

index 43fb088..96fda31 100644 (file)
 ]]
 [[
   name: _th_lt
+  cpu_bool: True
+  cuda_bool: True
   variants:
     - function
   return: argument 0
 ]]
 [[
   name: _th_lt_
+  cpu_bool: True
+  cuda_bool: True
   return: self
   variants: function
   options:
 ]]
 [[
   name: _th_gt
+  cpu_bool: True
+  cuda_bool: True
   variants:
     - function
   return: argument 0
 ]]
 [[
   name: _th_gt_
+  cpu_bool: True
+  cuda_bool: True
   return: self
   variants: function
   options:
 ]]
 [[
   name: _th_le
+  cpu_bool: True
+  cuda_bool: True
   variants:
     - function
   return: argument 0
 ]]
 [[
   name: _th_le_
+  cpu_bool: True
+  cuda_bool: True
   return: self
   variants: function
   options:
 ]]
 [[
   name: _th_ge
+  cpu_bool: True
+  cuda_bool: True
   variants:
     - function
   return: argument 0
 ]]
 [[
   name: _th_ge_
+  cpu_bool: True
+  cuda_bool: True
   return: self
   variants: function
   options:
 ]]
 [[
   name: _th_eq
+  cpu_bool: True
+  cuda_bool: True
   variants:
     - function
   return: argument 0
 ]]
 [[
   name: _th_eq_
+  cpu_bool: True
+  cuda_bool: True
   return: self
   variants: function
   options:
 ]]
 [[
   name: _th_ne
+  cpu_bool: True
+  cuda_bool: True
   variants:
     - function
   return: argument 0
 ]]
 [[
   name: _th_ne_
+  cpu_bool: True
+  cuda_bool: True
   return: self
   variants: function
   options:
index 4984772..e0cfbfa 100644 (file)
@@ -5,3 +5,6 @@
 
 #include <TH/generic/THTensorMath.cpp>
 #include <TH/THGenerateAllTypes.h>
+
+#include <TH/generic/THTensorMath.cpp>
+#include <TH/THGenerateBoolType.h>
index ba28390..ddbc7dc 100644 (file)
@@ -5,3 +5,6 @@
 
 #include <TH/generic/THTensorMoreMath.cpp>
 #include <TH/THGenerateAllTypes.h>
+
+#include <TH/generic/THTensorMoreMath.cpp>
+#include <TH/THGenerateBoolType.h>
index 2e14cd0..f9a31f2 100644 (file)
@@ -21,6 +21,7 @@
 // sense (rather than just having cut the file down the middle, which is
 // what I did when I split these up originally).
 
+#if !defined(TH_REAL_IS_BOOL) /* non bool only part */
 
 // Should wrap if the value (a) has a different sign than the divisor (b), but is not 0.
 static inline bool modulo_wrap(scalar_t a, scalar_t b) {
@@ -1197,4 +1198,6 @@ void THTensor_(addbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t al
   c10::raw::intrusive_ptr::decref(matrix2);
 }
 
+#endif /* !defined(TH_REAL_IS_BOOL) */
+
 #endif /* TH_GENERIC_FILE */
index a0766cb..55aaea8 100644 (file)
@@ -4,6 +4,34 @@
 
 TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor);
 
+TH_API void THTensor_(ltValue)(THByteTensor *r_, THTensor* t, scalar_t value);
+TH_API void THTensor_(leValue)(THByteTensor *r_, THTensor* t, scalar_t value);
+TH_API void THTensor_(gtValue)(THByteTensor *r_, THTensor* t, scalar_t value);
+TH_API void THTensor_(geValue)(THByteTensor *r_, THTensor* t, scalar_t value);
+TH_API void THTensor_(neValue)(THByteTensor *r_, THTensor* t, scalar_t value);
+TH_API void THTensor_(eqValue)(THByteTensor *r_, THTensor* t, scalar_t value);
+
+TH_API void THTensor_(ltValueT)(THTensor *r_, THTensor* t, scalar_t value);
+TH_API void THTensor_(leValueT)(THTensor *r_, THTensor* t, scalar_t value);
+TH_API void THTensor_(gtValueT)(THTensor *r_, THTensor* t, scalar_t value);
+TH_API void THTensor_(geValueT)(THTensor *r_, THTensor* t, scalar_t value);
+TH_API void THTensor_(neValueT)(THTensor *r_, THTensor* t, scalar_t value);
+TH_API void THTensor_(eqValueT)(THTensor *r_, THTensor* t, scalar_t value);
+
+TH_API void THTensor_(ltTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
+TH_API void THTensor_(leTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
+TH_API void THTensor_(gtTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
+TH_API void THTensor_(geTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
+TH_API void THTensor_(neTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
+TH_API void THTensor_(eqTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
+
+TH_API void THTensor_(ltTensorT)(THTensor *r_, THTensor *ta, THTensor *tb);
+TH_API void THTensor_(leTensorT)(THTensor *r_, THTensor *ta, THTensor *tb);
+TH_API void THTensor_(gtTensorT)(THTensor *r_, THTensor *ta, THTensor *tb);
+TH_API void THTensor_(geTensorT)(THTensor *r_, THTensor *ta, THTensor *tb);
+TH_API void THTensor_(neTensorT)(THTensor *r_, THTensor *ta, THTensor *tb);
+TH_API void THTensor_(eqTensorT)(THTensor *r_, THTensor *ta, THTensor *tb);
+
 #if !defined(TH_REAL_IS_BOOL) /* non bool only part */
 
 TH_API void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value);
@@ -96,34 +124,6 @@ TH_API void THTensor_(triu)(THTensor *r_, THTensor *t, int64_t k);
 
 TH_API int THTensor_(equal)(THTensor *ta, THTensor *tb);
 
-TH_API void THTensor_(ltValue)(THByteTensor *r_, THTensor* t, scalar_t value);
-TH_API void THTensor_(leValue)(THByteTensor *r_, THTensor* t, scalar_t value);
-TH_API void THTensor_(gtValue)(THByteTensor *r_, THTensor* t, scalar_t value);
-TH_API void THTensor_(geValue)(THByteTensor *r_, THTensor* t, scalar_t value);
-TH_API void THTensor_(neValue)(THByteTensor *r_, THTensor* t, scalar_t value);
-TH_API void THTensor_(eqValue)(THByteTensor *r_, THTensor* t, scalar_t value);
-
-TH_API void THTensor_(ltValueT)(THTensor *r_, THTensor* t, scalar_t value);
-TH_API void THTensor_(leValueT)(THTensor *r_, THTensor* t, scalar_t value);
-TH_API void THTensor_(gtValueT)(THTensor *r_, THTensor* t, scalar_t value);
-TH_API void THTensor_(geValueT)(THTensor *r_, THTensor* t, scalar_t value);
-TH_API void THTensor_(neValueT)(THTensor *r_, THTensor* t, scalar_t value);
-TH_API void THTensor_(eqValueT)(THTensor *r_, THTensor* t, scalar_t value);
-
-TH_API void THTensor_(ltTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
-TH_API void THTensor_(leTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
-TH_API void THTensor_(gtTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
-TH_API void THTensor_(geTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
-TH_API void THTensor_(neTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
-TH_API void THTensor_(eqTensor)(THByteTensor *r_, THTensor *ta, THTensor *tb);
-
-TH_API void THTensor_(ltTensorT)(THTensor *r_, THTensor *ta, THTensor *tb);
-TH_API void THTensor_(leTensorT)(THTensor *r_, THTensor *ta, THTensor *tb);
-TH_API void THTensor_(gtTensorT)(THTensor *r_, THTensor *ta, THTensor *tb);
-TH_API void THTensor_(geTensorT)(THTensor *r_, THTensor *ta, THTensor *tb);
-TH_API void THTensor_(neTensorT)(THTensor *r_, THTensor *ta, THTensor *tb);
-TH_API void THTensor_(eqTensorT)(THTensor *r_, THTensor *ta, THTensor *tb);
-
 TH_API void THTensor_(pow)(THTensor *r_, THTensor *t, scalar_t value);
 TH_API void THTensor_(tpow)(THTensor *r_, scalar_t value, THTensor *t);
 TH_API void THTensor_(abs)(THTensor *r_, THTensor *t);
index 48fa373..a3310e4 100644 (file)
@@ -5,6 +5,41 @@
 #include <TH/generic/THTensorApply.hpp>
 #include <TH/THGenerator.hpp>
 
+#define TENSOR_IMPLEMENT_LOGICAL(NAME,OP)                              \
+  void THTensor_(NAME##Value)(THByteTensor *r_, THTensor* t, scalar_t value) \
+  { \
+    THByteTensor_resizeNd(r_, t->dim(), THTensor_getSizePtr(t), NULL);         \
+    TH_TENSOR_APPLY2(unsigned char, r_, scalar_t, t,                   \
+                    *r__data = (*t_data OP value) ? 1 : 0;); \
+  }                                                                    \
+  void THTensor_(NAME##ValueT)(THTensor* r_, THTensor* t, scalar_t value)      \
+  {                                    \
+    THTensor_(resizeNd)(r_, t->dim(), THTensor_getSizePtr(t), NULL);           \
+    TH_TENSOR_APPLY2(scalar_t, r_, scalar_t, t,                                        \
+                    *r__data = (*t_data OP value) ? 1 : 0;); \
+  }                                                                    \
+  void THTensor_(NAME##Tensor)(THByteTensor *r_, THTensor *ta, THTensor *tb) \
+  {                                    \
+    THByteTensor_resizeNd(r_, ta->dim(), THTensor_getSizePtr(ta), NULL);               \
+    TH_TENSOR_APPLY3(unsigned char, r_, scalar_t, ta, scalar_t, tb,            \
+                    *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \
+  }                                                                    \
+  void THTensor_(NAME##TensorT)(THTensor *r_, THTensor *ta, THTensor *tb) \
+  {                            \
+    THTensor_(resizeNd)(r_, ta->dim(), THTensor_getSizePtr(ta), NULL);         \
+    TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, ta, scalar_t, tb,                 \
+                    *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \
+  }                                                                    \
+
+TENSOR_IMPLEMENT_LOGICAL(lt,<)
+TENSOR_IMPLEMENT_LOGICAL(gt,>)
+TENSOR_IMPLEMENT_LOGICAL(le,<=)
+TENSOR_IMPLEMENT_LOGICAL(ge,>=)
+TENSOR_IMPLEMENT_LOGICAL(eq,==)
+TENSOR_IMPLEMENT_LOGICAL(ne,!=)
+
+#if !defined(TH_REAL_IS_BOOL) /* non bool only part */
+
 void THTensor_(baddbmm)(THTensor *result, scalar_t beta, THTensor *t, scalar_t alpha, THTensor *batch1, THTensor *batch2)
 {
   int64_t batch;
@@ -999,41 +1034,6 @@ int THTensor_(equal)(THTensor *ta, THTensor* tb)
   return equal;
 }
 
-#define TENSOR_IMPLEMENT_LOGICAL(NAME,OP)                                \
-  void THTensor_(NAME##Value)(THByteTensor *r_, THTensor* t, scalar_t value) \
-  {                                                                        \
-    THByteTensor_resizeNd(r_, t->dim(), THTensor_getSizePtr(t), NULL);                \
-    TH_TENSOR_APPLY2(unsigned char, r_, scalar_t, t,                        \
-                     *r__data = (*t_data OP value) ? 1 : 0;); \
-  }                                                                        \
-  void THTensor_(NAME##ValueT)(THTensor* r_, THTensor* t, scalar_t value)        \
-  {                                                                        \
-    THTensor_(resizeNd)(r_, t->dim(), THTensor_getSizePtr(t), NULL);                \
-    TH_TENSOR_APPLY2(scalar_t, r_, scalar_t, t,                                        \
-                     *r__data = (*t_data OP value) ? 1 : 0;); \
-  }                                                                        \
-  void THTensor_(NAME##Tensor)(THByteTensor *r_, THTensor *ta, THTensor *tb) \
-  {                                                                        \
-    THByteTensor_resizeNd(r_, ta->dim(), THTensor_getSizePtr(ta), NULL);                \
-    TH_TENSOR_APPLY3(unsigned char, r_, scalar_t, ta, scalar_t, tb,                \
-                     *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \
-  }                                                                        \
-  void THTensor_(NAME##TensorT)(THTensor *r_, THTensor *ta, THTensor *tb) \
-  {                                                                        \
-    THTensor_(resizeNd)(r_, ta->dim(), THTensor_getSizePtr(ta), NULL);                \
-    TH_TENSOR_APPLY3(scalar_t, r_, scalar_t, ta, scalar_t, tb,                        \
-                     *r__data = (*ta_data OP *tb_data) ? 1 : 0;); \
-  }                                                                        \
-
-
-TENSOR_IMPLEMENT_LOGICAL(lt,<)
-TENSOR_IMPLEMENT_LOGICAL(gt,>)
-TENSOR_IMPLEMENT_LOGICAL(le,<=)
-TENSOR_IMPLEMENT_LOGICAL(ge,>=)
-TENSOR_IMPLEMENT_LOGICAL(eq,==)
-TENSOR_IMPLEMENT_LOGICAL(ne,!=)
-
-
 #ifdef _OPENMP
 
 #define LAB_IMPLEMENT_BASIC_FUNCTION_3_ARGS(NAME, CFUNC, OMP_THRESHOLD)             \
@@ -1681,4 +1681,6 @@ void THTensor_(dirichlet_grad)(THTensor *self, THTensor *x, THTensor *alpha, THT
 #endif /* floating point only part */
 #undef IS_NONZERO
 
+#endif /* !defined(TH_REAL_IS_BOOL) */
+
 #endif /* TH_GENERIC_FILE */
index 740e097..56e759d 100644 (file)
@@ -18,6 +18,14 @@ foreach(THC_TYPE Byte Char Short Int Long Half Float Double)
    endforeach()
 endforeach()
 
+foreach(THC_FILE TensorMathCompareT TensorMathCompare)
+   if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}Bool.cu")
+      FILE(WRITE "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}Bool.cu"
+        "#include <THC/THC${THC_FILE}.cuh>\n#include <THC/THCTensor.hpp>\n\n#include <THC/generic/THC${THC_FILE}.cu>\n#include <THC/THCGenerateBoolType.h>\n")
+   endif()
+   LIST(APPEND extra_src "${CMAKE_CURRENT_SOURCE_DIR}/generated/THC${THC_FILE}Bool.cu")
+endforeach()
+
 set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS}
   ${CMAKE_CURRENT_SOURCE_DIR}/THCCachingHostAllocator.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THCGeneral.cpp
index 2547277..a7f689b 100644 (file)
@@ -65,6 +65,16 @@ struct THCNumerics<uint8_t> {
 };
 
 template <>
+struct THCNumerics<bool> {
+  static inline __host__ __device__ bool lt(uint8_t a, uint8_t b) { return a < b; }
+  static inline __host__ __device__ bool le(uint8_t a, uint8_t b) { return a <= b; }
+  static inline __host__ __device__ bool gt(uint8_t a, uint8_t b) { return a > b; }
+  static inline __host__ __device__ bool ge(uint8_t a, uint8_t b) { return a >= b; }
+  static inline __host__ __device__ bool eq(uint8_t a, uint8_t b) { return a == b; }
+  static inline __host__ __device__ bool ne(uint8_t a, uint8_t b) { return a != b; }
+};
+
+template <>
 struct THCNumerics<int8_t> {
   static inline __host__ __device__ int8_t min() { return at::numeric_limits<int8_t>::lowest(); }
   static inline __host__ __device__ int8_t max() { return at::numeric_limits<int8_t>::max(); }
index 078b1cd..acdce2f 100644 (file)
 #include <THC/generic/THCTensorMathCompare.h>
 #include <THC/THCGenerateAllTypes.h>
 
+#include <THC/generic/THCTensorMathCompare.h>
+#include <THC/THCGenerateBoolType.h>
+
 #include <THC/generic/THCTensorMathCompareT.h>
 #include <THC/THCGenerateAllTypes.h>
 
+#include <THC/generic/THCTensorMathCompareT.h>
+#include <THC/THCGenerateBoolType.h>
+
 #include <THC/generic/THCTensorMathScan.h>
 #include <THC/THCGenerateAllTypes.h>
 
diff --git a/aten/src/THC/generated/THCTensorMathCompareBool.cu b/aten/src/THC/generated/THCTensorMathCompareBool.cu
new file mode 100644 (file)
index 0000000..25c1585
--- /dev/null
@@ -0,0 +1,5 @@
+#include <THC/THCTensorMathCompare.cuh>
+#include <THC/THCTensor.hpp>
+
+#include <THC/generic/THCTensorMathCompare.cu>
+#include <THC/THCGenerateBoolType.h>
diff --git a/aten/src/THC/generated/THCTensorMathCompareTBool.cu b/aten/src/THC/generated/THCTensorMathCompareTBool.cu
new file mode 100644 (file)
index 0000000..07418ba
--- /dev/null
@@ -0,0 +1,5 @@
+#include <THC/THCTensorMathCompareT.cuh>
+#include <THC/THCTensor.hpp>
+
+#include <THC/generic/THCTensorMathCompareT.cu>
+#include <THC/THCGenerateBoolType.h>
index ebfd34e..6d7651f 100644 (file)
@@ -3014,6 +3014,20 @@ class _TestTorchMixin(object):
         self.assertTrue(x.is_cuda)
         torch.set_default_tensor_type(saved_type)
 
+    def test_bool_tensor_comparison_ops(self):
+        a = torch.tensor([True, False, True, False, True, False], dtype=torch.bool)
+        b = torch.tensor([True, False, True, True, True, True], dtype=torch.bool)
+        for device in torch.testing.get_all_device_types():
+            self.assertEqual(a == b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.uint8))
+            self.assertEqual(a != b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.uint8))
+            self.assertEqual(a < b, torch.tensor([0, 0, 0, 1, 0, 1], dtype=torch.uint8))
+            self.assertEqual(a > b, torch.tensor([0, 0, 0, 0, 0, 0], dtype=torch.uint8))
+            self.assertEqual(a >= b, torch.tensor([1, 1, 1, 0, 1, 0], dtype=torch.uint8))
+            self.assertEqual(a <= b, torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.uint8))
+            self.assertEqual(a > False, torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.uint8))
+            self.assertEqual(a == torch.tensor(True, dtype=torch.bool), torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.uint8))
+            self.assertEqual(a == torch.tensor(0, dtype=torch.bool), torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.uint8))
+
     def test_bool_tensor_value_change(self):
         for device in torch.testing.get_all_device_types():
             x = torch.tensor([True, False], dtype=torch.bool)