Resolving comments from Bool Tensor for CPU PR (#18165)
authorIurii Zdebskyi <iuriiz@fb.com>
Tue, 26 Mar 2019 16:55:50 +0000 (09:55 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 26 Mar 2019 16:59:34 +0000 (09:59 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18165
ghimport-source-id: 55cb3fb63a25c2faab1725b4ec14c688bf45bd38

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18166 Bool Tensor for CUDA
* **#18165 Resolved comments from Bool Tensor for CPU PR**
-------
------------
This is a follow up PR that resolves some additional feedback on one the of previous Bool Tensor PRs.

gchanan, here is a list of almost all the comments from the original PR with respective fixes and replies:

**[utils/python_scalars.h]** why is this converting from uint8_t and not bool? (comment?)
When i was adding this, i was testing by creating a tensor and then calling its .tolist(). it worked for bool and uint8_t equally good so i left uint8_t as thought it makes more sense as we are calling PyBool_FromLong. �Changing it to bool.

**[ATen/Dispatch.h]**better name?.
fixed.

**[test/test_torch.py]** what about other factories, such as full? (and more).
There is a test that goes through the factory methods - test_tensor_factories_empty. i added some bool cases above it and added a comment that once CUDA will be done, i will unite them and it will iterate not just between CUDA and CPU but also all types. ��Adding all bool cases now. Will unite in CUDA PR.

**[generic/THTensorMath.h]** any changes in this file actually needed?
Bad merge. Fixed.

**[TH/THTensor.h]** this generates code for random, clampedRandom, and cappedRandom -- do we have tests for all of these with bool?
Added

**[c10/core/ScalarType.h]** I'm not very confident about the lack of Bool here -- can you look at the call sites and see what makes sense to do here?
Added bool to the macro and created a similar one without for a single case which fails the build with errors:

_./torch/csrc/jit/symbolic_variable.h:79:20: error: ambiguous overload for ‘operator*’ (operand types are ‘const torch::jit::SymbolicVariable’ and ‘torch::jit::Value*’)
return (*this) * insertConstant(rhs);_

Differential Revision: D14605105

fbshipit-source-id: abf82d50e8f8c50b386545ac068268651b28496d

12 files changed:
aten/src/ATen/Dispatch.h
aten/src/ATen/native/Copy.cpp
aten/src/ATen/native/TensorFactories.cpp
aten/src/TH/generic/THTensorMath.cpp
aten/src/TH/generic/THTensorMath.h
c10/core/Scalar.h
c10/core/ScalarType.h
caffe2/contrib/aten/aten_op_template.h
caffe2/operators/experimental/c10/cpu/cast_cpu.cc
test/common_utils.py
test/test_torch.py
torch/csrc/utils/python_scalars.h

index 374092a..b22a60d 100644 (file)
 
 namespace detail {
 
+template <at::ScalarType N>
+struct ScalarTypeToCType;
+
+template<>
+struct ScalarTypeToCType<at::ScalarType::Half> {
+  using type = at::Half;
+
+  // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType<T>::type being used directly
+  // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail.
+  // For repro example, please see: https://github.com/izdeby/playground/blob/0b0c0e9373b32830442ec26b2bc3535b21fb8d95/C%2B%2B/CudaBugRepro.cu
+  // TODO: remove once the bug is fixed.
+  static at::Half t;
+};
+
+template<>
+struct ScalarTypeToCType<at::ScalarType::Bool> {
+  using type = bool;
+
+  // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType<T>::type being used directly
+  // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail.
+  // For repro example, please see: https://github.com/izdeby/playground/blob/0b0c0e9373b32830442ec26b2bc3535b21fb8d95/C%2B%2B/CudaBugRepro.cu
+  // TODO: remove once the bug is fixed.
+  static bool t;
+};
+
 inline at::ScalarType scalar_type(at::ScalarType s) {
   return s;
 }
@@ -101,8 +126,9 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
     }                                                                        \
   }()
 
-#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...)                               \
+#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...)                      \
   [&] {                                                                      \
+    detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF();                     \
     const auto& the_type = TYPE;                                             \
     (void)the_type;                                                          \
     at::ScalarType _st = ::detail::scalar_type(TYPE);                        \
@@ -114,14 +140,14 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)        \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)       \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)      \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__)      \
       default:                                                               \
         AT_ERROR(#NAME, " not implemented for '", toString(_st), "'");       \
     }                                                                        \
   }()
 
-#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...)                      \
+#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...)                               \
   [&] {                                                                      \
-    detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF();                     \
     const auto& the_type = TYPE;                                             \
     (void)the_type;                                                          \
     at::ScalarType _st = ::detail::scalar_type(TYPE);                        \
@@ -133,7 +159,6 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)        \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)       \
       AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)      \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__)      \
       default:                                                               \
         AT_ERROR(#NAME, " not implemented for '", toString(_st), "'");       \
     }                                                                        \
@@ -200,53 +225,56 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
     }                                                                        \
   }()
 
+#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...)                                            \
+  [&] {                                                                                                   \
+    switch (TYPE) {                                                                                       \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)                                    \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)                                     \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)                                   \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)                                     \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)                                     \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)                                    \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)                                   \
+      AT_PRIVATE_CASE_TYPE(SCALARTYPE, decltype(::detail::ScalarTypeToCType<SCALARTYPE>::t), __VA_ARGS__) \
+      default:                                                                                            \
+        AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");                                   \
+    }                                                                                                     \
+  }()
 
-template <at::ScalarType N>
-struct MyTemplate;
-
-template<>
-struct MyTemplate<at::ScalarType::Half> {
-  using type = at::Half;
-};
-
-template<>
-struct MyTemplate<at::ScalarType::Bool> {
-  using type = bool;
-};
-
-#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...)                    \
-  [&] {                                                                           \
-    switch (TYPE) {                                                               \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)            \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)             \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)           \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)             \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)             \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)            \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)           \
-      AT_PRIVATE_CASE_TYPE(SCALARTYPE, MyTemplate<SCALARTYPE>::type, __VA_ARGS__) \
-      default:                                                                    \
-        AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");           \
-    }                                                                             \
+#define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...)                               \
+  [&] {                                                                                                     \
+    switch (TYPE) {                                                                                         \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)                                      \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)                                       \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)                                     \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)                                       \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)                                       \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)                                      \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)                                     \
+      AT_PRIVATE_CASE_TYPE(SCALARTYPE1, decltype(::detail::ScalarTypeToCType<SCALARTYPE1>::t), __VA_ARGS__) \
+      AT_PRIVATE_CASE_TYPE(SCALARTYPE2, decltype(::detail::ScalarTypeToCType<SCALARTYPE2>::t), __VA_ARGS__) \
+      default:                                                                                              \
+        AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");                                     \
+    }                                                                                                       \
   }()
 
-#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...)        \
-  [&] {                                                                                         \
-    switch (TYPE) {                                                                             \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)                          \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)                           \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)                         \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)                           \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)                           \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)                          \
-      AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)                         \
-      AT_PRIVATE_CASE_TYPE(SCALARTYPE1, MyTemplate<SCALARTYPE1>::type, __VA_ARGS__)             \
-      AT_PRIVATE_CASE_TYPE(SCALARTYPE2, MyTemplate<SCALARTYPE2>::type, __VA_ARGS__)             \
-      AT_PRIVATE_CASE_TYPE(                                                                     \
-          at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__)                       \
-      AT_PRIVATE_CASE_TYPE(                                                                     \
-          at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__)                     \
-      default:                                                                                  \
-        AT_ERROR(#NAME, " not implemented for '", TYPE, "'");                                   \
-    }                                                                                           \
+#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...)                    \
+  [&] {                                                                                                     \
+    switch (TYPE) {                                                                                         \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__)                                      \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__)                                       \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__)                                     \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__)                                       \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__)                                       \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__)                                      \
+      AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__)                                     \
+      AT_PRIVATE_CASE_TYPE(SCALARTYPE1, decltype(::detail::ScalarTypeToCType<SCALARTYPE1>::t), __VA_ARGS__) \
+      AT_PRIVATE_CASE_TYPE(SCALARTYPE2, decltype(::detail::ScalarTypeToCType<SCALARTYPE2>::t), __VA_ARGS__) \
+      AT_PRIVATE_CASE_TYPE(                                                                                 \
+          at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__)                                   \
+      AT_PRIVATE_CASE_TYPE(                                                                                 \
+          at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__)                                 \
+      default:                                                                                              \
+        AT_ERROR(#NAME, " not implemented for '", TYPE, "'");                                               \
+    }                                                                                                       \
   }()
index e96fa1c..848108a 100644 (file)
@@ -20,7 +20,7 @@ void _copy__cpu(at::Tensor& self, const at::Tensor& src) {
 template <typename self_T>
 void _copy__cpu(at::Tensor& self, const at::Tensor& src) {
   AT_CHECK(self.numel() == src.numel(), "sizes do not match");
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_copy__cpu", [&]() {
+  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, src.scalar_type(), "_copy__cpu", [&]() {
     _copy__cpu<self_T, scalar_t>(self, src);
   });
 }
@@ -42,8 +42,8 @@ Tensor& _s_copy__cpu(Tensor& self, const Tensor& src, bool non_blocking) {
     _s_copy_from(src, self, non_blocking);
     return self;
   }
-  AT_DISPATCH_ALL_TYPES_AND(
-    at::ScalarType::Half, self.scalar_type(), "_copy__cpu", [&]() { ::_copy__cpu<scalar_t>(self, src); });
+  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool,
+      self.scalar_type(), "_copy__cpu", [&]() { ::_copy__cpu<scalar_t>(self, src); });
   return self;
 }
 
index b5d6ec8..f4f0f42 100644 (file)
@@ -141,7 +141,7 @@ Tensor& empty_out(Tensor& result, IntArrayRef size) {
     return target_type.copy(self, non_blocking);                 \
   }
 
-AT_FORALL_SCALAR_TYPES(DEFINE_CAST_OP)
+AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CAST_OP)
 
 #undef DEFINE_CAST_OP
 
index 231e9c1..2e14cd0 100644 (file)
@@ -210,6 +210,26 @@ void THTensor_(cmul)(THTensor *r_, THTensor *t, THTensor *src)
   }
 }
 
+scalar_t THTensor_(powOne)(scalar_t x, scalar_t y) {
+#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_HALF)
+  return powf(x, y);
+#elif defined(TH_REAL_IS_DOUBLE)
+  return pow(x, y);
+#else
+  THArgCheck(y >= 0, 1,
+      "Integers to negative integer powers are not allowed");
+  scalar_t result = 1;
+  while (y) {
+    if (y & 1) {
+       result *= x;
+    }
+    y /= 2;
+    x *= x;
+  }
+  return result;
+#endif
+}
+
 void THTensor_(pow)(THTensor *r_, THTensor *t, scalar_t value)
 {
   THTensor_(resizeAs)(r_, t);
@@ -253,26 +273,6 @@ void THTensor_(pow)(THTensor *r_, THTensor *t, scalar_t value)
 #endif
 }
 
-scalar_t THTensor_(powOne)(scalar_t x, scalar_t y) {
-#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_HALF)
-  return powf(x, y);
-#elif defined(TH_REAL_IS_DOUBLE)
-  return pow(x, y);
-#else
-  THArgCheck(y >= 0, 1,
-      "Integers to negative integer powers are not allowed");
-  scalar_t result = 1;
-  while (y) {
-    if (y & 1) {
-       result *= x;
-    }
-    y /= 2;
-    x *= x;
-  }
-  return result;
-#endif
-}
-
 void THTensor_(cpow)(THTensor *r_, THTensor *t, THTensor *src)
 {
   THTensor_(resizeAs)(r_, t);
index 771d915..27083a0 100644 (file)
@@ -87,11 +87,7 @@ TH_API void THTensor_(cmin)(THTensor *r, THTensor *t, THTensor *src);
 TH_API void THTensor_(cmaxValue)(THTensor *r, THTensor *t, scalar_t value);
 TH_API void THTensor_(cminValue)(THTensor *r, THTensor *t, scalar_t value);
 
-TH_API void THTensor_(zerosLike)(THTensor *r_, THTensor *input);
-TH_API void THTensor_(onesLike)(THTensor *r_, THTensor *input);
 TH_API void THTensor_(diag)(THTensor *r_, THTensor *t, int k);
-TH_API void THTensor_(eye)(THTensor *r_, int64_t n, int64_t m);
-TH_API void THTensor_(randperm)(THTensor *r_, THGenerator *_generator, int64_t n);
 
 TH_API void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimension, int descendingOrder);
 TH_API void THTensor_(topk)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int64_t k, int dim, int dir, int sorted);
@@ -129,7 +125,6 @@ TH_API void THTensor_(eqTensorT)(THTensor *r_, THTensor *ta, THTensor *tb);
 
 TH_API void THTensor_(pow)(THTensor *r_, THTensor *t, scalar_t value);
 TH_API void THTensor_(tpow)(THTensor *r_, scalar_t value, THTensor *t);
-TH_API scalar_t THTensor_(powOne)(scalar_t x, scalar_t y);
 TH_API void THTensor_(abs)(THTensor *r_, THTensor *t);
 
 #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
@@ -166,7 +161,6 @@ TH_API void THTensor_(round)(THTensor *r_, THTensor *t);
 TH_API void THTensor_(trunc)(THTensor *r_, THTensor *t);
 TH_API void THTensor_(frac)(THTensor *r_, THTensor *t);
 
-TH_API void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension, int keepdim);
 TH_API void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int biased, int keepdim);
 TH_API void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int biased, int keepdim);
 TH_API void THTensor_(norm)(THTensor *r_, THTensor *t, scalar_t value, int dimension, int keepdim);
index eb4e08a..94fe216 100644 (file)
@@ -35,6 +35,15 @@ class C10_API Scalar {
 
 #undef DEFINE_IMPLICIT_CTOR
 
+// Value* is both implicitly convertible to SymbolicVariable and bool which
+// causes ambiguosity error. Specialized constructor for bool resolves this problem.
+template <typename T,
+       typename std::enable_if<std::is_same<T, bool>::value, bool>::type* = nullptr>
+Scalar(T vv)
+: tag(Tag::HAS_i) {
+  v.i = convert<decltype(v.i), bool>(vv);
+}
+
 #define DEFINE_IMPLICIT_COMPLEX_CTOR(type, name, member) \
   Scalar(type vv) : tag(Tag::HAS_##member) {             \
     v.member[0] = c10::convert<double>(vv.real());       \
index ba55e9c..65ce4e8 100644 (file)
@@ -53,6 +53,17 @@ _(at::Half,Half,d) \
 _(float,Float,d)   \
 _(double,Double,d)
 
+#define AT_FORALL_SCALAR_TYPES_AND_BOOL(_) \
+_(uint8_t,Byte,i)  \
+_(int8_t,Char,i)   \
+_(int16_t,Short,i) \
+_(int,Int,i)       \
+_(int64_t,Long,i)  \
+_(at::Half,Half,d) \
+_(float,Float,d)   \
+_(double,Double,d) \
+_(bool,Bool,i)
+
 #define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(_) \
 _(uint8_t,Byte,i) \
 _(int8_t,Char,i) \
index 256b846..e2c0043 100644 (file)
@@ -15,7 +15,7 @@ static std::unordered_map<std::string, int> op_to_key = {
 
 namespace caffe2 {
 
-using at::Half; // for AT_FORALL_SCALAR_TYPES
+using at::Half; // for AT_FORALL_SCALAR_TYPES_AND_BOOL
 
 template <class Context>
 class ATenOp : public Operator<Context> {
@@ -47,7 +47,7 @@ private:
       case at::k##aten_name: \
         return TypeMeta::Make<ctype>();
     switch(st) {
-      AT_FORALL_SCALAR_TYPES(DEFINE_CASE)
+      AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CASE)
       default:
         CAFFE_THROW("Unknown ATen Type");
     }
@@ -103,7 +103,7 @@ private:
     }
   }
 
-  // the AT_FORALL_SCALAR_TYPES macro just gives a 'i' or 'd' argument
+  // the AT_FORALL_SCALAR_TYPES_AND_BOOL macro just gives a 'i' or 'd' argument
   // for each type to specify if it is stored as a integer or a double.
   // We need this workaround here to extract the value in the scalar losslessly
   // because in some cases like 'sum' Torch promotes float to double
@@ -123,7 +123,7 @@ private:
           auto value = extract_##native(scalar); \
           assignToValue<ctype>(dst, at::convert<ctype,decltype(value)>(value)); \
         } break;
-      AT_FORALL_SCALAR_TYPES(DEFINE_CASE)
+      AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CASE)
       #undef DEFINE_CASE
       default:
         CAFFE_THROW("Unknown ATen Type");
index a2203c5..66ffcf5 100644 (file)
@@ -80,7 +80,7 @@ void cast_op_cpu(
     int64_t to) {
   switch (input.scalar_type()) {
 #define CASE(ctype,name,_2) case ScalarType:: name : return cast_op_cpu_impl<ctype>(input, output, to);
-    AT_FORALL_SCALAR_TYPES(CASE)
+    AT_FORALL_SCALAR_TYPES_AND_BOOL(CASE)
 #undef CASE
     default: throw std::runtime_error(string() + "Unsupported scalar type " + toString(input.scalar_type()));
   }
index a55381e..43131c0 100644 (file)
@@ -428,11 +428,15 @@ class TestCase(expecttest.TestCase):
                     else:
                         b = b.to(a)
 
-                    if x.dtype == torch.bool and y.dtype == torch.bool:
-                        self.assertEqual(x.tolist(), y.tolist())
-                    elif x.dtype == torch.bool or y.dtype == torch.bool:
+                    if (a.dtype == torch.bool) != (b.dtype == torch.bool):
                         raise TypeError("Was expecting both tensors to be bool type.")
                     else:
+                        if a.dtype == torch.bool and b.dtype == torch.bool:
+                            # we want to respect precision but as bool doesn't support substraction,
+                            # boolean tensor has to be converted to int
+                            a = a.to(torch.int)
+                            b = b.to(torch.int)
+
                         diff = a - b
                         if a.is_floating_point():
                             # check that NaNs are in the same locations
index 6b6c2a2..5728dd8 100644 (file)
@@ -2905,16 +2905,35 @@ class _TestTorchMixin(object):
 
     # This is a temporary test for a boolean tensors on CPU. Once the CUDA part
     # will be done, these test cases will be moved down to test_tensor_factories_empty test
-    def test_tensor_factories_empty_bool(self):
+    def test_tensor_factories_bool(self):
         expectedShape = (1, 2)
         test = torch.empty(expectedShape, dtype=torch.bool)
         self.assertEqual(expectedShape, test.shape)
-        self.assertEqual(expectedShape, torch.empty_like(test).shape)
+
+        test2 = torch.empty_like(test, dtype=torch.bool)
+        self.assertEqual(test.shape, test2.shape)
 
         test = torch.full(expectedShape, True, dtype=torch.bool)
         self.assertEqual(test, torch.tensor([[True, True]], dtype=torch.bool))
+
+        test2 = torch.full_like(test, True, dtype=torch.bool)
+        self.assertEqual(test, test2)
+
+        test = torch.zeros(expectedShape, dtype=torch.bool)
+        self.assertEqual(test, torch.tensor([[False, False]], dtype=torch.bool))
+
+        test2 = torch.zeros_like(test, dtype=torch.bool)
+        self.assertEqual(test, test2)
+
+        test = torch.ones(expectedShape, dtype=torch.bool)
+        self.assertEqual(test, torch.tensor([[True, True]], dtype=torch.bool))
+
+        test2 = torch.ones_like(test, dtype=torch.bool)
+        self.assertEqual(test, test2)
+
+        test = torch.randint(10, expectedShape, dtype=torch.bool)
         self.assertEqual(expectedShape, test.shape)
-        self.assertEqual(expectedShape, torch.full_like(test, True).shape)
+        self.assertEqual(torch.bool, test.dtype)
 
     def test_tensor_factories_empty(self):
         # ensure we can create empty tensors from each factory function
index ca0239d..c1bd36c 100644 (file)
@@ -39,7 +39,7 @@ inline PyObject* load_scalar(void* data, at::ScalarType scalarType) {
     case at::kDouble: return PyFloat_FromDouble(*(double*)data);
     case at::kComplexFloat: return PyComplex_FromCComplex(*reinterpret_cast<Py_complex *>((std::complex<float>*)data));
     case at::kComplexDouble: return PyComplex_FromCComplex(*reinterpret_cast<Py_complex *>((std::complex<double>*)data));
-    case at::kBool: return PyBool_FromLong(*(uint8_t*)data);
+    case at::kBool: return PyBool_FromLong(*(bool*)data);
     default: throw std::runtime_error("invalid type");
   }
 }