From 1a742075ee97b9603001188eeec9c30c3fe8a161 Mon Sep 17 00:00:00 2001 From: Iurii Zdebskyi Date: Tue, 26 Mar 2019 09:55:50 -0700 Subject: [PATCH] Resolving comments from Bool Tensor for CPU PR (#18165) MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18165 ghimport-source-id: 55cb3fb63a25c2faab1725b4ec14c688bf45bd38 Stack from [ghstack](https://github.com/ezyang/ghstack): * #18166 Bool Tensor for CUDA * **#18165 Resolved comments from Bool Tensor for CPU PR** ------- ------------ This is a follow up PR that resolves some additional feedback on one the of previous Bool Tensor PRs. gchanan, here is a list of almost all the comments from the original PR with respective fixes and replies: **[utils/python_scalars.h]** why is this converting from uint8_t and not bool? (comment?) When i was adding this, i was testing by creating a tensor and then calling its .tolist(). it worked for bool and uint8_t equally good so i left uint8_t as thought it makes more sense as we are calling PyBool_FromLong. �Changing it to bool. **[ATen/Dispatch.h]**better name?. fixed. **[test/test_torch.py]** what about other factories, such as full? (and more). There is a test that goes through the factory methods - test_tensor_factories_empty. i added some bool cases above it and added a comment that once CUDA will be done, i will unite them and it will iterate not just between CUDA and CPU but also all types. ��Adding all bool cases now. Will unite in CUDA PR. **[generic/THTensorMath.h]** any changes in this file actually needed? Bad merge. Fixed. **[TH/THTensor.h]** this generates code for random, clampedRandom, and cappedRandom -- do we have tests for all of these with bool? Added **[c10/core/ScalarType.h]** I'm not very confident about the lack of Bool here -- can you look at the call sites and see what makes sense to do here? Added bool to the macro and created a similar one without for a single case which fails the build with errors: _./torch/csrc/jit/symbolic_variable.h:79:20: error: ambiguous overload for ‘operator*’ (operand types are ‘const torch::jit::SymbolicVariable’ and ‘torch::jit::Value*’) return (*this) * insertConstant(rhs);_ Differential Revision: D14605105 fbshipit-source-id: abf82d50e8f8c50b386545ac068268651b28496d --- aten/src/ATen/Dispatch.h | 128 +++++++++++++--------- aten/src/ATen/native/Copy.cpp | 6 +- aten/src/ATen/native/TensorFactories.cpp | 2 +- aten/src/TH/generic/THTensorMath.cpp | 40 +++---- aten/src/TH/generic/THTensorMath.h | 6 - c10/core/Scalar.h | 9 ++ c10/core/ScalarType.h | 11 ++ caffe2/contrib/aten/aten_op_template.h | 8 +- caffe2/operators/experimental/c10/cpu/cast_cpu.cc | 2 +- test/common_utils.py | 10 +- test/test_torch.py | 25 ++++- torch/csrc/utils/python_scalars.h | 2 +- 12 files changed, 157 insertions(+), 92 deletions(-) diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index 374092a..b22a60d 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -12,6 +12,31 @@ namespace detail { +template +struct ScalarTypeToCType; + +template<> +struct ScalarTypeToCType { + using type = at::Half; + + // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType::type being used directly + // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail. + // For repro example, please see: https://github.com/izdeby/playground/blob/0b0c0e9373b32830442ec26b2bc3535b21fb8d95/C%2B%2B/CudaBugRepro.cu + // TODO: remove once the bug is fixed. + static at::Half t; +}; + +template<> +struct ScalarTypeToCType { + using type = bool; + + // This is a workaround for the CUDA bug which prevents ::detail::ScalarTypeToCType::type being used directly + // due to ambiguous reference which can't to be resolved. For some reason it cant pick between at::detail and at::cuda::detail. + // For repro example, please see: https://github.com/izdeby/playground/blob/0b0c0e9373b32830442ec26b2bc3535b21fb8d95/C%2B%2B/CudaBugRepro.cu + // TODO: remove once the bug is fixed. + static bool t; +}; + inline at::ScalarType scalar_type(at::ScalarType s) { return s; } @@ -101,8 +126,9 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} } \ }() -#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ +#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \ [&] { \ + detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \ const auto& the_type = TYPE; \ (void)the_type; \ at::ScalarType _st = ::detail::scalar_type(TYPE); \ @@ -114,14 +140,14 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ }() -#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \ +#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ [&] { \ - detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \ const auto& the_type = TYPE; \ (void)the_type; \ at::ScalarType _st = ::detail::scalar_type(TYPE); \ @@ -133,7 +159,6 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \ default: \ AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \ } \ @@ -200,53 +225,56 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} } \ }() +#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + [&] { \ + switch (TYPE) { \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE, decltype(::detail::ScalarTypeToCType::t), __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } \ + }() -template -struct MyTemplate; - -template<> -struct MyTemplate { - using type = at::Half; -}; - -template<> -struct MyTemplate { - using type = bool; -}; - -#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ - [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(SCALARTYPE, MyTemplate::type, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ - } \ +#define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ + [&] { \ + switch (TYPE) { \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE1, decltype(::detail::ScalarTypeToCType::t), __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE2, decltype(::detail::ScalarTypeToCType::t), __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } \ }() -#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ - [&] { \ - switch (TYPE) { \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(SCALARTYPE1, MyTemplate::type, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE(SCALARTYPE2, MyTemplate::type, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexFloat, std::complex, __VA_ARGS__) \ - AT_PRIVATE_CASE_TYPE( \ - at::ScalarType::ComplexDouble, std::complex, __VA_ARGS__) \ - default: \ - AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \ - } \ +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ + [&] { \ + switch (TYPE) { \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE1, decltype(::detail::ScalarTypeToCType::t), __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE(SCALARTYPE2, decltype(::detail::ScalarTypeToCType::t), __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + at::ScalarType::ComplexFloat, std::complex, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE( \ + at::ScalarType::ComplexDouble, std::complex, __VA_ARGS__) \ + default: \ + AT_ERROR(#NAME, " not implemented for '", TYPE, "'"); \ + } \ }() diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp index e96fa1c..848108a 100644 --- a/aten/src/ATen/native/Copy.cpp +++ b/aten/src/ATen/native/Copy.cpp @@ -20,7 +20,7 @@ void _copy__cpu(at::Tensor& self, const at::Tensor& src) { template void _copy__cpu(at::Tensor& self, const at::Tensor& src) { AT_CHECK(self.numel() == src.numel(), "sizes do not match"); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.scalar_type(), "_copy__cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, src.scalar_type(), "_copy__cpu", [&]() { _copy__cpu(self, src); }); } @@ -42,8 +42,8 @@ Tensor& _s_copy__cpu(Tensor& self, const Tensor& src, bool non_blocking) { _s_copy_from(src, self, non_blocking); return self; } - AT_DISPATCH_ALL_TYPES_AND( - at::ScalarType::Half, self.scalar_type(), "_copy__cpu", [&]() { ::_copy__cpu(self, src); }); + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, + self.scalar_type(), "_copy__cpu", [&]() { ::_copy__cpu(self, src); }); return self; } diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index b5d6ec8..f4f0f42 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -141,7 +141,7 @@ Tensor& empty_out(Tensor& result, IntArrayRef size) { return target_type.copy(self, non_blocking); \ } -AT_FORALL_SCALAR_TYPES(DEFINE_CAST_OP) +AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CAST_OP) #undef DEFINE_CAST_OP diff --git a/aten/src/TH/generic/THTensorMath.cpp b/aten/src/TH/generic/THTensorMath.cpp index 231e9c1..2e14cd0 100644 --- a/aten/src/TH/generic/THTensorMath.cpp +++ b/aten/src/TH/generic/THTensorMath.cpp @@ -210,6 +210,26 @@ void THTensor_(cmul)(THTensor *r_, THTensor *t, THTensor *src) } } +scalar_t THTensor_(powOne)(scalar_t x, scalar_t y) { +#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_HALF) + return powf(x, y); +#elif defined(TH_REAL_IS_DOUBLE) + return pow(x, y); +#else + THArgCheck(y >= 0, 1, + "Integers to negative integer powers are not allowed"); + scalar_t result = 1; + while (y) { + if (y & 1) { + result *= x; + } + y /= 2; + x *= x; + } + return result; +#endif +} + void THTensor_(pow)(THTensor *r_, THTensor *t, scalar_t value) { THTensor_(resizeAs)(r_, t); @@ -253,26 +273,6 @@ void THTensor_(pow)(THTensor *r_, THTensor *t, scalar_t value) #endif } -scalar_t THTensor_(powOne)(scalar_t x, scalar_t y) { -#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_HALF) - return powf(x, y); -#elif defined(TH_REAL_IS_DOUBLE) - return pow(x, y); -#else - THArgCheck(y >= 0, 1, - "Integers to negative integer powers are not allowed"); - scalar_t result = 1; - while (y) { - if (y & 1) { - result *= x; - } - y /= 2; - x *= x; - } - return result; -#endif -} - void THTensor_(cpow)(THTensor *r_, THTensor *t, THTensor *src) { THTensor_(resizeAs)(r_, t); diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h index 771d915..27083a0 100644 --- a/aten/src/TH/generic/THTensorMath.h +++ b/aten/src/TH/generic/THTensorMath.h @@ -87,11 +87,7 @@ TH_API void THTensor_(cmin)(THTensor *r, THTensor *t, THTensor *src); TH_API void THTensor_(cmaxValue)(THTensor *r, THTensor *t, scalar_t value); TH_API void THTensor_(cminValue)(THTensor *r, THTensor *t, scalar_t value); -TH_API void THTensor_(zerosLike)(THTensor *r_, THTensor *input); -TH_API void THTensor_(onesLike)(THTensor *r_, THTensor *input); TH_API void THTensor_(diag)(THTensor *r_, THTensor *t, int k); -TH_API void THTensor_(eye)(THTensor *r_, int64_t n, int64_t m); -TH_API void THTensor_(randperm)(THTensor *r_, THGenerator *_generator, int64_t n); TH_API void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimension, int descendingOrder); TH_API void THTensor_(topk)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int64_t k, int dim, int dir, int sorted); @@ -129,7 +125,6 @@ TH_API void THTensor_(eqTensorT)(THTensor *r_, THTensor *ta, THTensor *tb); TH_API void THTensor_(pow)(THTensor *r_, THTensor *t, scalar_t value); TH_API void THTensor_(tpow)(THTensor *r_, scalar_t value, THTensor *t); -TH_API scalar_t THTensor_(powOne)(scalar_t x, scalar_t y); TH_API void THTensor_(abs)(THTensor *r_, THTensor *t); #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) @@ -166,7 +161,6 @@ TH_API void THTensor_(round)(THTensor *r_, THTensor *t); TH_API void THTensor_(trunc)(THTensor *r_, THTensor *t); TH_API void THTensor_(frac)(THTensor *r_, THTensor *t); -TH_API void THTensor_(mean)(THTensor *r_, THTensor *t, int dimension, int keepdim); TH_API void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int biased, int keepdim); TH_API void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int biased, int keepdim); TH_API void THTensor_(norm)(THTensor *r_, THTensor *t, scalar_t value, int dimension, int keepdim); diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index eb4e08a..94fe216 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -35,6 +35,15 @@ class C10_API Scalar { #undef DEFINE_IMPLICIT_CTOR +// Value* is both implicitly convertible to SymbolicVariable and bool which +// causes ambiguosity error. Specialized constructor for bool resolves this problem. +template ::value, bool>::type* = nullptr> +Scalar(T vv) +: tag(Tag::HAS_i) { + v.i = convert(vv); +} + #define DEFINE_IMPLICIT_COMPLEX_CTOR(type, name, member) \ Scalar(type vv) : tag(Tag::HAS_##member) { \ v.member[0] = c10::convert(vv.real()); \ diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index ba55e9c..65ce4e8 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -53,6 +53,17 @@ _(at::Half,Half,d) \ _(float,Float,d) \ _(double,Double,d) +#define AT_FORALL_SCALAR_TYPES_AND_BOOL(_) \ +_(uint8_t,Byte,i) \ +_(int8_t,Char,i) \ +_(int16_t,Short,i) \ +_(int,Int,i) \ +_(int64_t,Long,i) \ +_(at::Half,Half,d) \ +_(float,Float,d) \ +_(double,Double,d) \ +_(bool,Bool,i) + #define AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(_) \ _(uint8_t,Byte,i) \ _(int8_t,Char,i) \ diff --git a/caffe2/contrib/aten/aten_op_template.h b/caffe2/contrib/aten/aten_op_template.h index 256b846..e2c0043 100644 --- a/caffe2/contrib/aten/aten_op_template.h +++ b/caffe2/contrib/aten/aten_op_template.h @@ -15,7 +15,7 @@ static std::unordered_map op_to_key = { namespace caffe2 { -using at::Half; // for AT_FORALL_SCALAR_TYPES +using at::Half; // for AT_FORALL_SCALAR_TYPES_AND_BOOL template class ATenOp : public Operator { @@ -47,7 +47,7 @@ private: case at::k##aten_name: \ return TypeMeta::Make(); switch(st) { - AT_FORALL_SCALAR_TYPES(DEFINE_CASE) + AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CASE) default: CAFFE_THROW("Unknown ATen Type"); } @@ -103,7 +103,7 @@ private: } } - // the AT_FORALL_SCALAR_TYPES macro just gives a 'i' or 'd' argument + // the AT_FORALL_SCALAR_TYPES_AND_BOOL macro just gives a 'i' or 'd' argument // for each type to specify if it is stored as a integer or a double. // We need this workaround here to extract the value in the scalar losslessly // because in some cases like 'sum' Torch promotes float to double @@ -123,7 +123,7 @@ private: auto value = extract_##native(scalar); \ assignToValue(dst, at::convert(value)); \ } break; - AT_FORALL_SCALAR_TYPES(DEFINE_CASE) + AT_FORALL_SCALAR_TYPES_AND_BOOL(DEFINE_CASE) #undef DEFINE_CASE default: CAFFE_THROW("Unknown ATen Type"); diff --git a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc index a2203c5..66ffcf5 100644 --- a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc @@ -80,7 +80,7 @@ void cast_op_cpu( int64_t to) { switch (input.scalar_type()) { #define CASE(ctype,name,_2) case ScalarType:: name : return cast_op_cpu_impl(input, output, to); - AT_FORALL_SCALAR_TYPES(CASE) + AT_FORALL_SCALAR_TYPES_AND_BOOL(CASE) #undef CASE default: throw std::runtime_error(string() + "Unsupported scalar type " + toString(input.scalar_type())); } diff --git a/test/common_utils.py b/test/common_utils.py index a55381e..43131c0 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -428,11 +428,15 @@ class TestCase(expecttest.TestCase): else: b = b.to(a) - if x.dtype == torch.bool and y.dtype == torch.bool: - self.assertEqual(x.tolist(), y.tolist()) - elif x.dtype == torch.bool or y.dtype == torch.bool: + if (a.dtype == torch.bool) != (b.dtype == torch.bool): raise TypeError("Was expecting both tensors to be bool type.") else: + if a.dtype == torch.bool and b.dtype == torch.bool: + # we want to respect precision but as bool doesn't support substraction, + # boolean tensor has to be converted to int + a = a.to(torch.int) + b = b.to(torch.int) + diff = a - b if a.is_floating_point(): # check that NaNs are in the same locations diff --git a/test/test_torch.py b/test/test_torch.py index 6b6c2a2..5728dd8 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -2905,16 +2905,35 @@ class _TestTorchMixin(object): # This is a temporary test for a boolean tensors on CPU. Once the CUDA part # will be done, these test cases will be moved down to test_tensor_factories_empty test - def test_tensor_factories_empty_bool(self): + def test_tensor_factories_bool(self): expectedShape = (1, 2) test = torch.empty(expectedShape, dtype=torch.bool) self.assertEqual(expectedShape, test.shape) - self.assertEqual(expectedShape, torch.empty_like(test).shape) + + test2 = torch.empty_like(test, dtype=torch.bool) + self.assertEqual(test.shape, test2.shape) test = torch.full(expectedShape, True, dtype=torch.bool) self.assertEqual(test, torch.tensor([[True, True]], dtype=torch.bool)) + + test2 = torch.full_like(test, True, dtype=torch.bool) + self.assertEqual(test, test2) + + test = torch.zeros(expectedShape, dtype=torch.bool) + self.assertEqual(test, torch.tensor([[False, False]], dtype=torch.bool)) + + test2 = torch.zeros_like(test, dtype=torch.bool) + self.assertEqual(test, test2) + + test = torch.ones(expectedShape, dtype=torch.bool) + self.assertEqual(test, torch.tensor([[True, True]], dtype=torch.bool)) + + test2 = torch.ones_like(test, dtype=torch.bool) + self.assertEqual(test, test2) + + test = torch.randint(10, expectedShape, dtype=torch.bool) self.assertEqual(expectedShape, test.shape) - self.assertEqual(expectedShape, torch.full_like(test, True).shape) + self.assertEqual(torch.bool, test.dtype) def test_tensor_factories_empty(self): # ensure we can create empty tensors from each factory function diff --git a/torch/csrc/utils/python_scalars.h b/torch/csrc/utils/python_scalars.h index ca0239d..c1bd36c 100644 --- a/torch/csrc/utils/python_scalars.h +++ b/torch/csrc/utils/python_scalars.h @@ -39,7 +39,7 @@ inline PyObject* load_scalar(void* data, at::ScalarType scalarType) { case at::kDouble: return PyFloat_FromDouble(*(double*)data); case at::kComplexFloat: return PyComplex_FromCComplex(*reinterpret_cast((std::complex*)data)); case at::kComplexDouble: return PyComplex_FromCComplex(*reinterpret_cast((std::complex*)data)); - case at::kBool: return PyBool_FromLong(*(uint8_t*)data); + case at::kBool: return PyBool_FromLong(*(bool*)data); default: throw std::runtime_error("invalid type"); } } -- 2.7.4