Bool tensor. Part 0: Boolean storage implementation (#16810)
authorIurii Zdebskyi <iuriiz@fb.com>
Tue, 19 Feb 2019 16:17:49 +0000 (08:17 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 19 Feb 2019 16:22:13 +0000 (08:22 -0800)
Summary:
This is the first commit from a series of planned changes in order to add boolean tensors to PyTorch. The whole plan looks like this:

0. Storage Implementation (this change)
1. Tensor Creation.
2. Tensor Conversions.
3. Tensor Indexing.
4. Tensor Operations.
5. Back compatibility related changes.

This feature was requested by the community:
https://github.com/pytorch/pytorch/issues/4764
https://github.com/pytorch/pytorch/issues/4219
https://github.com/pytorch/pytorch/issues/4288

**Change**:
Added boolean type to the Storage class for CPU and CUDA backends.

**Tested via**:
1. unit tests
2. running this:
-> import torch
-> torch.BoolStorage
<class 'torch.BoolStorage'>
-> torch.cuda.BoolStorage
<class 'torch.cuda.BoolStorage'>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16810

Reviewed By: gchanan

Differential Revision: D14087246

Pulled By: izdeby

fbshipit-source-id: 042642ced1cb0fd1bb6bff05f9ca871a5c54ee5e

61 files changed:
aten/src/ATen/DLConvertor.cpp
aten/src/ATen/core/Type.h
aten/src/ATen/gen.py
aten/src/TH/CMakeLists.txt
aten/src/TH/THFile.h
aten/src/TH/THGenerateBoolType.h [new file with mode: 0644]
aten/src/TH/THStorageFunctions.cpp
aten/src/TH/THStorageFunctions.h
aten/src/TH/THTensor.cpp
aten/src/TH/THTensor.h
aten/src/TH/THTensor.hpp
aten/src/TH/generic/THStorage.h
aten/src/TH/generic/THStorageCopy.cpp
aten/src/TH/generic/THStorageCopy.h
aten/src/TH/generic/THTensor.h
aten/src/THC/CMakeLists.txt
aten/src/THC/THCGenerateBoolType.h [new file with mode: 0644]
aten/src/THC/THCStorage.cpp
aten/src/THC/THCStorage.cu
aten/src/THC/THCStorage.h
aten/src/THC/THCStorageCopy.cpp
aten/src/THC/THCStorageCopy.cu
aten/src/THC/THCStorageCopy.h
aten/src/THC/THCTensor.cpp
aten/src/THC/THCTensor.cu
aten/src/THC/THCTensor.h
aten/src/THC/THCTensor.hpp
aten/src/THC/THCTensorCopy.cu
aten/src/THC/THCTensorCopy.h
aten/src/THC/generic/THCStorage.h
aten/src/THC/generic/THCStorageCopy.cpp
aten/src/THC/generic/THCStorageCopy.cu
aten/src/THC/generic/THCStorageCopy.h
aten/src/THC/generic/THCTensor.h
aten/src/THC/generic/THCTensorCopy.cu
c10/core/ScalarType.h
test/test_torch.py
torch/__init__.py
torch/_storage_docs.py
torch/csrc/DynamicTypes.cpp
torch/csrc/Module.cpp
torch/csrc/Storage.cpp
torch/csrc/Storage.h
torch/csrc/byte_order.cpp
torch/csrc/byte_order.h
torch/csrc/cuda/Module.cpp
torch/csrc/cuda/Storage.cpp
torch/csrc/cuda/Storage.h
torch/csrc/cuda/serialization.cpp
torch/csrc/cuda/serialization.h
torch/csrc/cuda/utils.cpp
torch/csrc/cuda/utils.h
torch/csrc/generic/Storage.cpp
torch/csrc/generic/StorageMethods.cpp
torch/csrc/serialization.cpp
torch/csrc/serialization.h
torch/csrc/utils.cpp
torch/csrc/utils.h
torch/csrc/utils/tensor_dtypes.cpp
torch/cuda/__init__.py
torch/storage.py

index 721516f..8c1bab7 100644 (file)
@@ -37,6 +37,9 @@ static DLDataType getDLDataType(const Type& type) {
     case ScalarType::Half:
       dtype.code = DLDataTypeCode::kDLFloat;
       break;
+    case ScalarType::Bool:
+      dtype.code = DLDataTypeCode::kDLUInt;
+      break;
     case ScalarType::ComplexHalf:
       throw std::logic_error("ComplexHalf is not supported by dlpack");
     case ScalarType::ComplexFloat:
index b90dd26..1b300f1 100644 (file)
@@ -44,6 +44,7 @@ struct Generator;
 static inline void noop_deleter(void*) {}
 
 enum class TypeID {
+  CPUBool,
   CPUByte,
   CPUChar,
   CPUDouble,
@@ -52,6 +53,7 @@ enum class TypeID {
   CPULong,
   CPUShort,
   CPUHalf,
+  SparseCPUBool,
   SparseCPUByte,
   SparseCPUChar,
   SparseCPUDouble,
@@ -59,6 +61,7 @@ enum class TypeID {
   SparseCPUInt,
   SparseCPULong,
   SparseCPUShort,
+  CUDABool,
   CUDAByte,
   CUDAChar,
   CUDADouble,
@@ -67,6 +70,7 @@ enum class TypeID {
   CUDALong,
   CUDAShort,
   CUDAHalf,
+  SparseCUDABool,
   SparseCUDAByte,
   SparseCUDAChar,
   SparseCUDADouble,
index 9f5fad1..e153041 100644 (file)
@@ -178,6 +178,7 @@ extension_backends = ['MSNPU', 'XLA']
 
 # scalar_name, c_type, accreal, th_scalar_type, is_floating_type
 scalar_types = [
+    ('Bool', 'uint8_t', 'BoolAccrealNotDefined', 'uint8_t', False),
     ('Byte', 'uint8_t', 'Long', 'uint8_t', False),
     ('Char', 'int8_t', 'Long', 'int8_t', False),
     ('Double', 'double', 'Double', 'double', True),
index 463792b..b4d72ac 100644 (file)
@@ -64,6 +64,7 @@ INSTALL(FILES
   THFilePrivate.h
   ${CMAKE_CURRENT_BINARY_DIR}/THGeneral.h
   THGenerateAllTypes.h
+  THGenerateBoolType.h
   THGenerateDoubleType.h
   THGenerateFloatType.h
   THGenerateHalfType.h
index bb0f19e..144cec4 100644 (file)
@@ -46,6 +46,7 @@ TH_API size_t THFile_readInt(THFile *self, THIntStorage *storage);
 TH_API size_t THFile_readLong(THFile *self, THLongStorage *storage);
 TH_API size_t THFile_readFloat(THFile *self, THFloatStorage *storage);
 TH_API size_t THFile_readDouble(THFile *self, THDoubleStorage *storage);
+TH_API size_t THFile_readBool(THFile *self, THBoolStorage *storage);
 
 TH_API size_t THFile_writeByte(THFile *self, THByteStorage *storage);
 TH_API size_t THFile_writeChar(THFile *self, THCharStorage *storage);
@@ -54,6 +55,7 @@ TH_API size_t THFile_writeInt(THFile *self, THIntStorage *storage);
 TH_API size_t THFile_writeLong(THFile *self, THLongStorage *storage);
 TH_API size_t THFile_writeFloat(THFile *self, THFloatStorage *storage);
 TH_API size_t THFile_writeDouble(THFile *self, THDoubleStorage *storage);
+TH_API size_t THFile_writeBool(THFile *self, THBoolStorage *storage);
 
 /* raw */
 TH_API size_t THFile_readByteRaw(THFile *self, uint8_t *data, size_t n);
diff --git a/aten/src/TH/THGenerateBoolType.h b/aten/src/TH/THGenerateBoolType.h
new file mode 100644 (file)
index 0000000..e18bd60
--- /dev/null
@@ -0,0 +1,22 @@
+#ifndef TH_GENERIC_FILE
+#error "You must define TH_GENERIC_FILE before including THGenerateBoolType.h"
+#endif
+
+// TODO: define accreal type once the correct value is known.
+#define scalar_t bool
+#define ureal bool
+#define Real Bool
+#define TH_CONVERT_ACCREAL_TO_REAL(_val) (scalar_t)(_val)
+#define TH_REAL_IS_BOOL
+#line 1 TH_GENERIC_FILE
+#include TH_GENERIC_FILE
+#undef scalar_t
+#undef ureal
+#undef Real
+#undef TH_REAL_IS_BOOL
+#undef TH_CONVERT_REAL_TO_ACCREAL
+#undef TH_CONVERT_ACCREAL_TO_REAL
+
+#ifndef THGenerateManyTypes
+#undef TH_GENERIC_FILE
+#endif
index 5fac244..2d99613 100644 (file)
@@ -9,12 +9,18 @@
 #include <TH/generic/THStorage.cpp>
 #include <TH/THGenerateHalfType.h>
 
+#include <TH/generic/THStorage.cpp>
+#include <TH/THGenerateBoolType.h>
+
 #include <TH/generic/THStorageCopy.cpp>
 #include <TH/THGenerateAllTypes.h>
 
 #include <TH/generic/THStorageCopy.cpp>
 #include <TH/THGenerateHalfType.h>
 
+#include <TH/generic/THStorageCopy.cpp>
+#include <TH/THGenerateBoolType.h>
+
 THStorage* THStorage_new(caffe2::TypeMeta data_type) {
   THStorage* storage = c10::make_intrusive<at::StorageImpl>(
       data_type,
index 40ca8cb..4a2c324 100644 (file)
 #include <TH/generic/THStorage.h>
 #include <TH/THGenerateHalfType.h>
 
+#include <TH/generic/THStorage.h>
+#include <TH/THGenerateBoolType.h>
+
 #include <TH/generic/THStorageCopy.h>
 #include <TH/THGenerateAllTypes.h>
 
 #include <TH/generic/THStorageCopy.h>
 #include <TH/THGenerateHalfType.h>
 
+#include <TH/generic/THStorageCopy.h>
+#include <TH/THGenerateBoolType.h>
+
 // This exists to have a data-type independent way of freeing (necessary for THPPointer).
 TH_API void THStorage_free(THStorage *storage);
index 2a249e8..bef67f9 100644 (file)
@@ -6,6 +6,9 @@
 #include <TH/generic/THTensor.cpp>
 #include <TH/THGenerateHalfType.h>
 
+#include <TH/generic/THTensor.cpp>
+#include <TH/THGenerateBoolType.h>
+
 #include <ATen/native/Resize.h>
 
 #include <numeric>
index 9dd6162..18b14cb 100644 (file)
@@ -13,6 +13,9 @@
 #include <TH/generic/THTensor.h>
 #include <TH/THGenerateHalfType.h>
 
+#include <TH/generic/THTensor.h>
+#include <TH/THGenerateBoolType.h>
+
 /* random numbers */
 #include <TH/THRandom.h>
 #include <TH/generic/THTensorRandom.h>
index 5e255a1..2a6a55d 100644 (file)
@@ -127,3 +127,6 @@ TH_CPP_API c10::optional<std::vector<int64_t>> THTensor_compute_stride(
 
 #include <TH/generic/THTensor.hpp>
 #include <TH/THGenerateHalfType.h>
+
+#include <TH/generic/THTensor.hpp>
+#include <TH/THGenerateBoolType.h>
index bb33bc8..2e432c1 100644 (file)
@@ -33,6 +33,7 @@
 #define THShortStorage THStorage
 #define THIntStorage THStorage
 #define THLongStorage THStorage
+#define THBoolStorage THStorage
 
 TH_API scalar_t* THStorage_(data)(const THStorage*);
 TH_API ptrdiff_t THStorage_(size)(const THStorage*);
index cdf8403..0ce5035 100644 (file)
@@ -37,5 +37,6 @@ IMPLEMENT_THStorage_COPY(Long)
 IMPLEMENT_THStorage_COPY(Float)
 IMPLEMENT_THStorage_COPY(Double)
 IMPLEMENT_THStorage_COPY(Half)
+IMPLEMENT_THStorage_COPY(Bool)
 
 #endif
index 8d86a23..0301fc6 100644 (file)
@@ -14,5 +14,6 @@ TH_API void THStorage_(copyLong)(THStorage *storage, struct THLongStorage *src);
 TH_API void THStorage_(copyFloat)(THStorage *storage, struct THFloatStorage *src);
 TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *src);
 TH_API void THStorage_(copyHalf)(THStorage *storage, struct THHalfStorage *src);
+TH_API void THStorage_(copyBool)(THStorage *storage, struct THBoolStorage *src);
 
 #endif
index 8d0c318..854757b 100644 (file)
@@ -18,6 +18,7 @@
 #define THShortTensor THTensor
 #define THIntTensor THTensor
 #define THLongTensor THTensor
+#define THBoolTensor THTensor
 
 /**** access methods ****/
 TH_API THStorage* THTensor_(storage)(const THTensor *self);
index fddfe7d..16fb799 100644 (file)
@@ -78,6 +78,7 @@ INSTALL(FILES
           THCDeviceTensorUtils.cuh
           THCDeviceTensorUtils-inl.cuh
           THCGenerateAllTypes.h
+          THCGenerateBoolType.h
           THCGenerateByteType.h
           THCGenerateCharType.h
           THCGenerateShortType.h
diff --git a/aten/src/THC/THCGenerateBoolType.h b/aten/src/THC/THCGenerateBoolType.h
new file mode 100644 (file)
index 0000000..b010649
--- /dev/null
@@ -0,0 +1,21 @@
+#ifndef THC_GENERIC_FILE
+#error "You must define THC_GENERIC_FILE before including THCGenerateBoolType.h"
+#endif
+
+// TODO: define accreal type once the correct value is known.
+#define scalar_t bool
+#define ureal bool
+#define Real Bool
+#define CReal CudaBool
+#define THC_REAL_IS_BOOL
+#line 1 THC_GENERIC_FILE
+#include THC_GENERIC_FILE
+#undef scalar_t
+#undef ureal
+#undef Real
+#undef CReal
+#undef THC_REAL_IS_BOOL
+
+#ifndef THCGenerateBoolType
+#undef THC_GENERIC_FILE
+#endif
index 6f93824..af71179 100644 (file)
@@ -8,6 +8,9 @@
 #include <THC/generic/THCStorage.cpp>
 #include <THC/THCGenerateAllTypes.h>
 
+#include <THC/generic/THCStorage.cpp>
+#include <THC/THCGenerateBoolType.h>
+
 #include <c10/util/intrusive_ptr.h>
 
 void THCStorage_resize(THCState *state, THCStorage *self, ptrdiff_t size)
index 7c0fca2..01d5462 100644 (file)
@@ -11,3 +11,6 @@
 
 #include <THC/generic/THCStorage.cu>
 #include <THC/THCGenerateAllTypes.h>
+
+#include <THC/generic/THCStorage.cu>
+#include <THC/THCGenerateBoolType.h>
index 0a1515e..19216ed 100644 (file)
@@ -9,4 +9,7 @@
 #include <THC/generic/THCStorage.h>
 #include <THC/THCGenerateAllTypes.h>
 
+#include <THC/generic/THCStorage.h>
+#include <THC/THCGenerateBoolType.h>
+
 #endif
index c25ea12..2c15088 100644 (file)
@@ -5,3 +5,6 @@
 
 #include <THC/generic/THCStorageCopy.cpp>
 #include <THC/THCGenerateAllTypes.h>
+
+#include <THC/generic/THCStorageCopy.cpp>
+#include <THC/THCGenerateBoolType.h>
index 2a695cf..9252e72 100644 (file)
@@ -8,3 +8,6 @@
 
 #include <THC/generic/THCStorageCopy.cu>
 #include <THC/THCGenerateAllTypes.h>
+
+#include <THC/generic/THCStorageCopy.cu>
+#include <THC/THCGenerateBoolType.h>
index b153a92..db97194 100644 (file)
@@ -8,4 +8,7 @@
 #include <THC/generic/THCStorageCopy.h>
 #include <THC/THCGenerateAllTypes.h>
 
+#include <THC/generic/THCStorageCopy.h>
+#include <THC/THCGenerateBoolType.h>
+
 #endif
index 47c30f2..8065d60 100644 (file)
@@ -7,6 +7,9 @@
 #include <THC/generic/THCTensor.cpp>
 #include <THC/THCGenerateAllTypes.h>
 
+#include <THC/generic/THCTensor.cpp>
+#include <THC/THCGenerateBoolType.h>
+
 #include <THC/THCTensorInfo.cuh>
 
 #include <ATen/native/cuda/Resize.cuh>
@@ -61,6 +64,8 @@ THCTensor *THCTensor_new(THCState *state, caffe2::TypeMeta type_meta) {
       return THCudaTensor_new(state);
     case at::ScalarType::Double:
       return THCudaDoubleTensor_new(state);
+    case at::ScalarType::Bool:
+      return THCudaBoolTensor_new(state);
     default:
       AT_ERROR("unexpected ScalarType: ", toString(scalar_type));
   }
index e8f253d..cc25d14 100644 (file)
@@ -3,3 +3,6 @@
 
 #include <THC/generic/THCTensor.cu>
 #include <THC/THCGenerateAllTypes.h>
+
+#include <THC/generic/THCTensor.cu>
+#include <THC/THCGenerateBoolType.h>
index c113c35..9670eb3 100644 (file)
@@ -17,4 +17,6 @@ typedef struct THC_CLASS THCDescBuff
 #include <THC/generic/THCTensor.h>
 #include <THC/THCGenerateAllTypes.h>
 
+#include <THC/generic/THCTensor.h>
+#include <THC/THCGenerateBoolType.h>
 #endif
index eaa3295..3162506 100644 (file)
@@ -56,3 +56,6 @@ THC_API bool THCTensor_maybeOverlappingIndices(THCState* state, const THCTensor*
 
 #include <THC/generic/THCTensor.hpp>
 #include <THC/THCGenerateAllTypes.h>
+
+#include <THC/generic/THCTensor.hpp>
+#include <THC/THCGenerateBoolType.h>
index 517d6fa..571d0e1 100644 (file)
@@ -5,16 +5,26 @@
 #include <type_traits>
 
 // Copy operator for the pointwise apply kernel
-template <typename TypeDst, typename TypeSrc>
+template <typename T>
 struct CopyOp {
-  __device__ __forceinline__ void operator()(TypeDst* dst, TypeSrc* src) {
+  __device__ __forceinline__ void operator()(T* dst, T* src) {
 #if __CUDA_ARCH__ >= 350
-    *dst = ScalarConvert<TypeSrc, TypeDst>::to(__ldg(src));
+    *dst = ScalarConvert<T, T>::to(__ldg(src));
 #else
-    *dst = ScalarConvert<TypeSrc, TypeDst>::to(*src);
+    *dst = ScalarConvert<T, T>::to(*src);
 #endif
   }
 };
 
+template <>
+struct CopyOp <bool> {
+  __device__ __forceinline__ void operator()(bool* dst, bool* src) {
+      *dst = ScalarConvert<bool, bool>::to(*src);
+  }
+};
+
 #include <THC/generic/THCTensorCopy.cu>
 #include <THC/THCGenerateAllTypes.h>
+
+#include <THC/generic/THCTensorCopy.cu>
+#include <THC/THCGenerateBoolType.h>
index 55b7c51..ec8ede7 100644 (file)
@@ -9,4 +9,7 @@
 #include <THC/generic/THCTensorCopy.h>
 #include <THC/THCGenerateAllTypes.h>
 
+#include <THC/generic/THCTensorCopy.h>
+#include <THC/THCGenerateBoolType.h>
+
 #endif
index 7efb5fe..5fdf41d 100644 (file)
@@ -14,6 +14,7 @@
 #define THCudaShortStorage  THCStorage
 #define THCudaIntStorage    THCStorage
 #define THCudaLongStorage   THCStorage
+#define THCudaBoolStorage   THCStorage
 
 THC_API scalar_t* THCStorage_(data)(THCState *state, const THCStorage*);
 THC_API ptrdiff_t THCStorage_(size)(THCState *state, const THCStorage*);
index cf68421..c132def 100644 (file)
@@ -33,6 +33,7 @@ TH_CUDA_STORAGE_IMPLEMENT_COPY(Long)
 TH_CUDA_STORAGE_IMPLEMENT_COPY(Float)
 TH_CUDA_STORAGE_IMPLEMENT_COPY(Half)
 TH_CUDA_STORAGE_IMPLEMENT_COPY(Double)
+TH_CUDA_STORAGE_IMPLEMENT_COPY(Bool)
 
 void THStorage_(copyCuda)(THCState *state, THStorage *self, struct THCStorage *src)
 {
@@ -65,6 +66,7 @@ TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Long)
 TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Float)
 TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Half)
 TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Double)
+TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Bool)
 
 #undef TH_CUDA_STORAGE_IMPLEMENT_COPY
 #undef TH_CUDA_STORAGE_IMPLEMENT_COPYTO
index 01a1a6d..d372563 100644 (file)
@@ -28,6 +28,7 @@ THC_CUDA_STORAGE_IMPLEMENT_COPY(Long,Long)
 THC_CUDA_STORAGE_IMPLEMENT_COPY(Float,)  // i.e. float
 THC_CUDA_STORAGE_IMPLEMENT_COPY(Double,Double)
 THC_CUDA_STORAGE_IMPLEMENT_COPY(Half,Half)
+THC_CUDA_STORAGE_IMPLEMENT_COPY(Bool,Bool)
 
 #undef THC_CUDA_STORAGE_IMPLEMENT_COPY
 
index dc8f4c9..2375e18 100644 (file)
@@ -14,6 +14,7 @@ THC_API void THCStorage_(copyLong)(THCState *state, THCStorage *storage, struct
 THC_API void THCStorage_(copyFloat)(THCState *state, THCStorage *storage, struct THFloatStorage *src);
 THC_API void THCStorage_(copyDouble)(THCState *state, THCStorage *storage, struct THDoubleStorage *src);
 THC_API void THCStorage_(copyHalf)(THCState *state, THCStorage *storage, struct THHalfStorage *src);
+THC_API void THCStorage_(copyBool)(THCState *state, THCStorage *storage, struct THBoolStorage *src);
 
 THC_API void THCStorage_(copyCudaByte)(THCState *state, THCStorage *storage, struct THCudaByteStorage *src);
 THC_API void THCStorage_(copyCudaChar)(THCState *state, THCStorage *storage, struct THCudaCharStorage *src);
@@ -23,6 +24,7 @@ THC_API void THCStorage_(copyCudaLong)(THCState *state, THCStorage *storage, str
 THC_API void THCStorage_(copyCudaFloat)(THCState *state, THCStorage *storage, struct THCudaStorage *src);
 THC_API void THCStorage_(copyCudaDouble)(THCState *state, THCStorage *storage, struct THCudaDoubleStorage *src);
 THC_API void THCStorage_(copyCudaHalf)(THCState *state, THCStorage *storage, struct THCudaHalfStorage *src);
+THC_API void THCStorage_(copyCudaBool)(THCState *state, THCStorage *storage, struct THCudaBoolStorage *src);
 
 THC_API void TH_CONCAT_2(THByteStorage_copyCuda  , Real)(THCState *state, THByteStorage *self, struct THCStorage *src);
 THC_API void TH_CONCAT_2(THCharStorage_copyCuda  , Real)(THCState *state, THCharStorage *self, struct THCStorage *src);
@@ -32,6 +34,7 @@ THC_API void TH_CONCAT_2(THLongStorage_copyCuda  , Real)(THCState *state, THLong
 THC_API void TH_CONCAT_2(THFloatStorage_copyCuda , Real)(THCState *state, THFloatStorage *self, struct THCStorage *src);
 THC_API void TH_CONCAT_2(THDoubleStorage_copyCuda, Real)(THCState *state, THDoubleStorage *self, struct THCStorage *src);
 THC_API void TH_CONCAT_2(THHalfStorage_copyCuda, Real)(THCState *state, THHalfStorage *self, struct THCStorage *src);
+THC_API void TH_CONCAT_2(THBoolStorage_copyCuda, Real)(THCState *state, THBoolStorage *self, struct THCStorage *src);
 
 THC_API void THStorage_(copyCuda)(THCState *state, THStorage *self, THCStorage *src);
 THC_API void THCStorage_(copyCuda)(THCState *state, THCStorage *self, THCStorage *src);
index 646cffa..76d1dd9 100644 (file)
@@ -14,6 +14,7 @@
 #define THCudaShortTensor THCTensor
 #define THCudaIntTensor THCTensor
 #define THCudaLongTensor THCTensor
+#define THCudaBoolTensor THCTensor
 
 /**** access methods ****/
 THC_API THCStorage* THCTensor_(storage)(THCState *state, const THCTensor *self);
index c972b0d..3e837e4 100644 (file)
@@ -52,7 +52,7 @@ void THCTensor_copyIgnoringOverlaps<scalar_t>(THCState* state, THCTensor* dst, T
   // FIXME: really, overlapping writes should be illegal/an error in Torch
   THC_pointwiseApply2<scalar_t, scalar_t>(
     state, dst, src,
-    CopyOp<scalar_t, scalar_t>(),
+    CopyOp<scalar_t>(),
     ReadOnly, /* ignore overwrites */
     ReadOnly);
 }
index 2b6aba4..e959e2e 100644 (file)
@@ -24,7 +24,8 @@ _(float,Float,d)   /* 6 */ \
 _(double,Double,d) /* 7 */ \
 _(at::ComplexHalf,ComplexHalf,z)        /* 8 */ \
 _(std::complex<float>,ComplexFloat,z)   /* 9 */ \
-_(std::complex<double>,ComplexDouble,z) /* 10 */
+_(std::complex<double>,ComplexDouble,z) /* 10 */ \
+_(bool,Bool,i) /* 11 */
 
 // If you want to support ComplexHalf for real, replace occurrences
 // of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX.  But
@@ -185,6 +186,7 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
   constexpr auto f2 = ScalarType::Half;
   constexpr auto f4 = ScalarType::Float;
   constexpr auto f8 = ScalarType::Double;
+  constexpr auto b1 = ScalarType::Bool;
   constexpr auto ud = ScalarType::Undefined;
   if (a == ud || b == ud) {
     return ScalarType::Undefined;
@@ -195,15 +197,16 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
   static constexpr ScalarType _promoteTypesLookup
       [static_cast<int>(ScalarType::NumOptions)]
       [static_cast<int>(ScalarType::NumOptions)] = {
-            /* u1  i1  i2  i4  i8  f2  f4  f8 */
-    /* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8 },
-    /* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8 },
-    /* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8 },
-    /* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8 },
-    /* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8 },
-    /* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8 },
-    /* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8 },
-    /* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8 },
+            /* u1  i1  i2  i4  i8  f2  f4  f8  b1 */
+    /* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, u1 },
+    /* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, i1 },
+    /* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8, i2 },
+    /* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8, i4 },
+    /* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8, i8 },
+    /* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8, f2 },
+    /* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, f4 },
+    /* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, f8 },
+    /* b1 */ { u1, i1, i2, i4, i8, f2, f4, f8, b1 },
   };
   return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
 }
index a5aadf8..28494ac 100644 (file)
@@ -7800,6 +7800,7 @@ class _TestTorchMixin(object):
         long = torch.LongStorage().element_size()
         float = torch.FloatStorage().element_size()
         double = torch.DoubleStorage().element_size()
+        bool = torch.BoolStorage().element_size()
 
         self.assertEqual(byte, torch.ByteTensor().element_size())
         self.assertEqual(char, torch.CharTensor().element_size())
@@ -7816,10 +7817,12 @@ class _TestTorchMixin(object):
         self.assertGreater(long, 0)
         self.assertGreater(float, 0)
         self.assertGreater(double, 0)
+        self.assertGreater(bool, 0)
 
         # These tests are portable, not necessarily strict for your system.
         self.assertEqual(byte, 1)
         self.assertEqual(char, 1)
+        self.assertEqual(bool, 1)
         self.assertGreaterEqual(short, 2)
         self.assertGreaterEqual(int, 2)
         self.assertGreaterEqual(int, short)
@@ -8887,6 +8890,75 @@ class _TestTorchMixin(object):
         self.assertEqual(floats.size(), 1)
         self.assertEqual(floats[0], 2.25)
 
+        f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40])
+        bools = torch.BoolStorage.from_buffer(f, 'big')
+        self.assertEqual(bools.size(), 8)
+        self.assertEqual(bools.tolist(), [False, True, True, True, True, True, True, True])
+        self.assertEqual(bools.type(), 'torch.BoolStorage')
+
+        f = bytearray(b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9')
+        bools = torch.BoolStorage.from_buffer(f, 'big')
+        self.assertEqual(bools.size(), 19)
+
+        f = bytearray(b'\0x4A')
+        bools = torch.BoolStorage.from_buffer(f, 'big')
+        self.assertEqual(bools.size(), 4)
+        self.assertEqual(bools.tolist(), [False, True, True, True])
+
+    def test_storage_casts(self):
+        storage = torch.IntStorage([-1, 0, 1, 2, 3, 4])
+        self.assertEqual(storage.size(), 6)
+        self.assertEqual(storage.tolist(), [-1, 0, 1, 2, 3, 4])
+        self.assertEqual(storage.type(), 'torch.IntStorage')
+
+        floatStorage = storage.float()
+        self.assertEqual(floatStorage.size(), 6)
+        self.assertEqual(floatStorage.tolist(), [-1, 0, 1, 2, 3, 4])
+        self.assertEqual(floatStorage.type(), 'torch.FloatStorage')
+        self.assertEqual(floatStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+
+        halfStorage = storage.half()
+        self.assertEqual(halfStorage.size(), 6)
+        self.assertEqual(halfStorage.tolist(), [-1, 0, 1, 2, 3, 4])
+        self.assertEqual(halfStorage.type(), 'torch.HalfStorage')
+        self.assertEqual(halfStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+
+        longStorage = storage.long()
+        self.assertEqual(longStorage.size(), 6)
+        self.assertEqual(longStorage.tolist(), [-1, 0, 1, 2, 3, 4])
+        self.assertEqual(longStorage.type(), 'torch.LongStorage')
+        self.assertEqual(longStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+
+        shortStorage = storage.short()
+        self.assertEqual(shortStorage.size(), 6)
+        self.assertEqual(shortStorage.tolist(), [-1, 0, 1, 2, 3, 4])
+        self.assertEqual(shortStorage.type(), 'torch.ShortStorage')
+        self.assertEqual(shortStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+
+        doubleStorage = storage.double()
+        self.assertEqual(doubleStorage.size(), 6)
+        self.assertEqual(doubleStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
+        self.assertEqual(doubleStorage.type(), 'torch.DoubleStorage')
+        self.assertEqual(doubleStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+
+        charStorage = storage.char()
+        self.assertEqual(charStorage.size(), 6)
+        self.assertEqual(charStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
+        self.assertEqual(charStorage.type(), 'torch.CharStorage')
+        self.assertEqual(charStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+
+        byteStorage = storage.byte()
+        self.assertEqual(byteStorage.size(), 6)
+        self.assertEqual(byteStorage.tolist(), [255, 0, 1, 2, 3, 4])
+        self.assertEqual(byteStorage.type(), 'torch.ByteStorage')
+        self.assertEqual(byteStorage.int().tolist(), [255, 0, 1, 2, 3, 4])
+
+        boolStorage = storage.bool()
+        self.assertEqual(boolStorage.size(), 6)
+        self.assertEqual(boolStorage.tolist(), [True, False, True, True, True, True])
+        self.assertEqual(boolStorage.type(), 'torch.BoolStorage')
+        self.assertEqual(boolStorage.int().tolist(), [1, 0, 1, 1, 1, 1])
+
     @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows")
     def test_from_file(self):
         size = 10000
@@ -8928,7 +9000,10 @@ class _TestTorchMixin(object):
         for t in torch._storage_classes:
             if t.is_cuda and not torch.cuda.is_available():
                 continue
-            obj = t(100).fill_(1)
+            if t == torch.BoolStorage or t == torch.cuda.BoolStorage:
+                obj = t(100).fill_(True)
+            else:
+                obj = t(100).fill_(1)
             obj.__repr__()
             str(obj)
 
index 321670b..3c456dd 100644 (file)
@@ -224,9 +224,12 @@ class ByteStorage(_C.ByteStorageBase, _StorageBase):
     pass
 
 
+class BoolStorage(_C.BoolStorageBase, _StorageBase):
+    pass
+
 _storage_classes = {
     DoubleStorage, FloatStorage, LongStorage, IntStorage, ShortStorage,
-    CharStorage, ByteStorage, HalfStorage
+    CharStorage, ByteStorage, HalfStorage, BoolStorage
 }
 
 # The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
@@ -275,6 +278,7 @@ del IntStorageBase
 del ShortStorageBase
 del CharStorageBase
 del ByteStorageBase
+del BoolStorageBase
 
 ################################################################################
 # Import most common subpackages
index bd82962..fd54032 100644 (file)
@@ -12,6 +12,7 @@ storage_classes = [
     'ShortStorageBase',
     'CharStorageBase',
     'ByteStorageBase',
+    'BoolStorageBase',
 ]
 
 
index 087f84d..ba217c9 100644 (file)
@@ -33,6 +33,7 @@ const std::unordered_map<std::string, at::ScalarType> attype_names = {
   {"Short", at::kShort},
   {"Int", at::kInt},
   {"Long", at::kLong},
+  {"Bool", at::kBool},
 };
 
 std::unordered_map<at::Type*, PyTypeObject*> attype_to_py_storage_type;
index cf193bb..0ddf9ac 100644 (file)
@@ -112,6 +112,7 @@ static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manag
   THPShortStorage_postInit(module);
   THPCharStorage_postInit(module);
   THPByteStorage_postInit(module);
+  THPBoolStorage_postInit(module);
   THPAutograd_initFunctions();
   Py_RETURN_NONE;
   END_HANDLE_TH_ERRORS
@@ -460,6 +461,7 @@ bool THCPIntStorage_init(PyObject *module);
 bool THCPShortStorage_init(PyObject *module);
 bool THCPCharStorage_init(PyObject *module);
 bool THCPByteStorage_init(PyObject *module);
+bool THCPBoolStorage_init(PyObject *module);
 
 void THCPStream_init(PyObject *module);
 void THCPEvent_init(PyObject *module);
@@ -490,6 +492,7 @@ bool THDPIntStorage_init(PyObject *module);
 bool THDPShortStorage_init(PyObject *module);
 bool THDPCharStorage_init(PyObject *module);
 bool THDPByteStorage_init(PyObject *module);
+bool THDPBoolStorage_init(PyObject *module);
 
 static std::vector<PyMethodDef> methods;
 
@@ -593,6 +596,7 @@ PyObject* initModule() {
   ASSERT_TRUE(THPShortStorage_init(module));
   ASSERT_TRUE(THPCharStorage_init(module));
   ASSERT_TRUE(THPByteStorage_init(module));
+  ASSERT_TRUE(THPBoolStorage_init(module));
 
 #ifdef USE_CUDA
   // This will only initialise base classes and attach them to library namespace
@@ -607,6 +611,7 @@ PyObject* initModule() {
   ASSERT_TRUE(THCPShortStorage_init(module));
   ASSERT_TRUE(THCPCharStorage_init(module));
   ASSERT_TRUE(THCPByteStorage_init(module));
+  ASSERT_TRUE(THCPBoolStorage_init(module));
 
   THCPStream_init(module);
   THCPEvent_init(module);
index 88afa0b..1d3d4c5 100644 (file)
@@ -23,6 +23,9 @@
 #include <torch/csrc/generic/Storage.cpp>
 #include <TH/THGenerateHalfType.h>
 
+#include <torch/csrc/generic/Storage.cpp>
+#include <TH/THGenerateBoolType.h>
+
 template<>
 void THPPointer<THStorage>::free() {
   if (ptr) {
index a8f7849..efd841f 100644 (file)
@@ -21,6 +21,8 @@
     PyObject_IsInstance(obj, THPCharStorageClass)
 #define THPByteStorage_Check(obj) \
     PyObject_IsInstance(obj, THPByteStorageClass)
+#define THPBoolStorage_Check(obj) \
+    PyObject_IsInstance(obj, THPBoolStorageClass)
 
 #define THPDoubleStorage_CData(obj)  (obj)->cdata
 #define THPFloatStorage_CData(obj)   (obj)->cdata
@@ -30,6 +32,7 @@
 #define THPShortStorage_CData(obj)   (obj)->cdata
 #define THPCharStorage_CData(obj)    (obj)->cdata
 #define THPByteStorage_CData(obj)    (obj)->cdata
+#define THPBoolStorage_CData(obj)    (obj)->cdata
 
 #ifdef _THP_CORE
 #define THPStorageType TH_CONCAT_3(THP,Real,StorageType)
@@ -42,4 +45,7 @@
 #include <torch/csrc/generic/Storage.h>
 #include <TH/THGenerateHalfType.h>
 
+#include <torch/csrc/generic/Storage.h>
+#include <TH/THGenerateBoolType.h>
+
 #endif
index 01f671f..0356791 100644 (file)
@@ -57,7 +57,7 @@ static inline void swapBytes64(void *ptr)
   uint64_t Byte5 = output & 0x0000FF0000000000;
   uint64_t Byte6 = output & 0x00FF000000000000;
   uint64_t Byte7 = output & 0xFF00000000000000;
-  output = (Byte0 << (7*8)) | (Byte1 << (5*8)) | (Byte2 << (3*8)) | (Byte3 << (1*8)) | 
+  output = (Byte0 << (7*8)) | (Byte1 << (5*8)) | (Byte2 << (3*8)) | (Byte3 << (1*8)) |
            (Byte7 >> (7*8)) | (Byte6 >> (5*8)) | (Byte5 >> (3*8)) | (Byte4 >> (1*8));
 #endif
   memcpy(ptr, &output, sizeof(uint64_t));
@@ -140,6 +140,13 @@ void THP_decodeHalfBuffer(THHalf* dst, const uint8_t* src, THPByteOrder order, s
   }
 }
 
+void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, THPByteOrder order, size_t len)
+{
+  for (size_t i = 0; i < len; i++) {
+    dst[i] = (int)src[i] != 0 ? true : false;
+  }
+}
+
 void THP_decodeFloatBuffer(float* dst, const uint8_t* src, THPByteOrder order, size_t len)
 {
   for (size_t i = 0; i < len; i++) {
index 0d0b80b..c9bb5a4 100644 (file)
@@ -18,6 +18,7 @@ void THP_decodeInt64Buffer(int64_t* dst, const uint8_t* src, THPByteOrder order,
 void THP_decodeHalfBuffer(THHalf* dst, const uint8_t* src, THPByteOrder order, size_t len);
 void THP_decodeFloatBuffer(float* dst, const uint8_t* src, THPByteOrder order, size_t len);
 void THP_decodeDoubleBuffer(double* dst, const uint8_t* src, THPByteOrder order, size_t len);
+void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, THPByteOrder order, size_t len);
 
 void THP_encodeInt16Buffer(uint8_t* dst, const int16_t* src, THPByteOrder order, size_t len);
 void THP_encodeInt32Buffer(uint8_t* dst, const int32_t* src, THPByteOrder order, size_t len);
index 8d1571a..7571f11 100644 (file)
@@ -375,6 +375,7 @@ static PyObject * THCPModule_initExtension(PyObject *self)
   THCPShortStorage_postInit(m);
   THCPCharStorage_postInit(m);
   THCPByteStorage_postInit(m);
+  THCPBoolStorage_postInit(m);
 
   bool has_magma = at::hasMAGMA();
   if (has_magma) {
index 9ac5435..6a103a7 100644 (file)
@@ -15,3 +15,6 @@
 
 #define THC_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
 #include <THC/THCGenerateAllTypes.h>
+
+#define THC_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
+#include <THC/THCGenerateBoolType.h>
index f3af761..19e62f1 100644 (file)
@@ -21,6 +21,8 @@
     PyObject_IsInstance(obj, THCPCharStorageClass)
 #define THCPByteStorage_Check(obj) \
     PyObject_IsInstance(obj, THCPByteStorageClass)
+#define THCPBoolStorage_Check(obj) \
+    PyObject_IsInstance(obj, THCPBoolStorageClass)
 
 #define THCPDoubleStorage_CData(obj)  (obj)->cdata
 #define THCPFloatStorage_CData(obj)   (obj)->cdata
@@ -29,6 +31,7 @@
 #define THCPShortStorage_CData(obj)   (obj)->cdata
 #define THCPCharStorage_CData(obj)    (obj)->cdata
 #define THCPByteStorage_CData(obj)    (obj)->cdata
+#define THCPBoolStorage_CData(obj)    (obj)->cdata
 
 #ifdef _THP_CORE
 #define THCPStorageType TH_CONCAT_3(THCP,Real,StorageType)
@@ -40,4 +43,7 @@
 #define THC_GENERIC_FILE "torch/csrc/generic/Storage.h"
 #include <THC/THCGenerateAllTypes.h>
 
+#define THC_GENERIC_FILE "torch/csrc/generic/Storage.h"
+#include <THC/THCGenerateBoolType.h>
+
 #endif
index 0d1f86e..e83ea6e 100644 (file)
@@ -10,3 +10,5 @@
 #define THC_GENERIC_FILE "torch/csrc/generic/serialization.cpp"
 #include <THC/THCGenerateAllTypes.h>
 
+#define THC_GENERIC_FILE "torch/csrc/generic/serialization.cpp"
+#include <THC/THCGenerateBoolType.h>
index 0779ac9..3e3eb2d 100644 (file)
@@ -6,4 +6,7 @@
 #define THC_GENERIC_FILE "torch/csrc/generic/serialization.h"
 #include <THC/THCGenerateAllTypes.h>
 
+#define THC_GENERIC_FILE "torch/csrc/generic/serialization.h"
+#include <THC/THCGenerateBoolType.h>
+
 #endif
index bf3ec82..8e29803 100644 (file)
@@ -8,6 +8,9 @@
 #define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
 #include <THC/THCGenerateAllTypes.h>
 
+#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
+#include <THC/THCGenerateBoolType.h>
+
 #ifdef USE_CUDA
 // NB: It's a list of *optional* CUDAStream; when nullopt, that means to use
 // whatever the current stream of the device the input is associated with was.
index 6dd50ed..209b453 100644 (file)
@@ -16,4 +16,6 @@
 #define THC_GENERIC_FILE "torch/csrc/generic/utils.h"
 #include <THC/THCGenerateAllTypes.h>
 
+#define THC_GENERIC_FILE "torch/csrc/generic/utils.h"
+#include <THC/THCGenerateBoolType.h>
 #endif
index e8fd9bb..bdd872e 100644 (file)
@@ -299,6 +299,7 @@ void THPStorage_(initCopyMethods)()
   THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPHalfStorageType, h, &THWStorage_(copyHalf));
   THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPFloatStorageType, h, &THWStorage_(copyFloat));
   THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPDoubleStorageType, h, &THWStorage_(copyDouble));
+  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPBoolStorageType, h, &THWStorage_(copyBool));
 #ifdef THC_GENERIC_FILE
   // copy from GPU types
   THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPByteStorageType, h, &THWStorage_(copyCudaByte));
@@ -309,6 +310,7 @@ void THPStorage_(initCopyMethods)()
   THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPFloatStorageType, h, &THWStorage_(copyCudaFloat));
   THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPDoubleStorageType, h, &THWStorage_(copyCudaDouble));
   THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPHalfStorageType, h, &THWStorage_(copyCudaHalf));
+  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPBoolStorageType, h, &THWStorage_(copyCudaBool));
   // add CPU <- GPU copies to base type
   /// #define THPCpuStorage TH_CONCAT_3(THP, Real, Storage)
   #define THCpuStorage_(name) TH_CONCAT_4(TH, Real, Storage_, name)
@@ -322,6 +324,7 @@ void THPStorage_(initCopyMethods)()
   THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPFloatStorageType, b, &THCpuStorage_(copyCudaFloat));
   THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPDoubleStorageType, b, &THCpuStorage_(copyCudaDouble));
   THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPHalfStorageType, b, &THCpuStorage_(copyCudaHalf));
+  THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPBoolStorageType, b, &THCpuStorage_(copyCudaBool));
   #undef THCpuStorage
   #undef THCpuStorage_
 #endif
index fb9247e..f20bbaa 100644 (file)
@@ -96,7 +96,7 @@ static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyO
   PyObject *obj = nullptr;
   const char* byte_order_str = nullptr;
   Py_ssize_t count = -1, offset = 0;
-  Py_buffer buffer;
+  Py_buffer buffer = {};
   static char *kwlist[] = {"buffer", "byte_order", "count", "offset", nullptr};
   const char* argtypes;
 #if defined(TH_REAL_IS_BYTE) || defined(TH_REAL_IS_CHAR)
@@ -160,6 +160,11 @@ static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyO
 
 #if defined(TH_REAL_IS_BYTE) || defined(TH_REAL_IS_CHAR)
   memcpy(THWStorage_(data)(storage), src + offset, count);
+#elif defined(TH_REAL_IS_BOOL)
+  // Because of ASAN checks, that are failing in the THStorage.cpp whenever
+  // we are trying to get a value which is not 0 or 1, we have to manually
+  // convert original values to boolean ones.
+  THP_decodeBoolBuffer(THWStorage_(data)(storage), src + offset, byte_order, count);
 #elif defined(TH_REAL_IS_SHORT)
   THP_decodeInt16Buffer(THWStorage_(data)(storage), src + offset, byte_order, count);
 #elif defined(TH_REAL_IS_INT)
index 1b045df..7308df3 100644 (file)
@@ -182,3 +182,6 @@ void doWrite(io fildes, void* raw_buf, size_t nbytes) {
 
 #include <torch/csrc/generic/serialization.cpp>
 #include <TH/THGenerateHalfType.h>
+
+#include <torch/csrc/generic/serialization.cpp>
+#include <TH/THGenerateBoolType.h>
index 4602eb6..d9f0a8c 100644 (file)
@@ -7,6 +7,9 @@
 #include <torch/csrc/generic/serialization.h>
 #include <TH/THGenerateHalfType.h>
 
+#include <torch/csrc/generic/serialization.h>
+#include <TH/THGenerateBoolType.h>
+
 template <class io>
 void doRead(io fildes, void* buf, size_t nbytes);
 
index 05a7de5..78cd96f 100644 (file)
@@ -17,6 +17,9 @@
 #include <torch/csrc/generic/utils.cpp>
 #include <TH/THGenerateHalfType.h>
 
+#include <torch/csrc/generic/utils.cpp>
+#include <TH/THGenerateBoolType.h>
+
 int THPUtils_getCallable(PyObject *arg, PyObject **result) {
   if (!PyCallable_Check(arg))
     return 0;
index 7088517..0297b7d 100644 (file)
     (throw std::runtime_error("Could not parse real"), 0))
 #endif
 
+#define THPUtils_unpackReal_BOOL(object)                                       \
+    (PyBool_Check(object) ? object :                                           \
+    (throw std::runtime_error("Could not parse real"), Py_False))
+
+#define THPUtils_checkReal_BOOL(object)                                        \
+    PyBool_Check(object)
+
 #define THPUtils_newReal_FLOAT(value) PyFloat_FromDouble(value)
 // TODO: handle int overflows for py2
 #define THPUtils_newReal_INT(value) PyInt_FromLong(value)
 
+#define THPUtils_newReal_BOOL(value) PyBool_FromLong(value)
+
 #define THPDoubleUtils_checkReal(object)      THPUtils_checkReal_FLOAT(object)
 #define THPDoubleUtils_unpackReal(object)     (double)THPUtils_unpackReal_FLOAT(object)
 #define THPDoubleUtils_newReal(value)         THPUtils_newReal_FLOAT(value)
 #define THPHalfUtils_newReal(value)           PyFloat_FromDouble(value)
 #define THPHalfUtils_newAccreal(value)        THPUtils_newReal_FLOAT(value)
 
+#define THPBoolUtils_checkReal(object)        THPUtils_checkReal_BOOL(object)
+#define THPBoolUtils_unpackReal(object)       THPUtils_unpackReal_BOOL(object)
+#define THPBoolUtils_newReal(value)           THPUtils_newReal_BOOL(value)
+#define THPBoolUtils_checkAccreal(object)     THPUtils_checkReal_BOOL(object)
+#define THPBoolUtils_unpackAccreal(object)    (int64_t)THPUtils_unpackReal_BOOL(object)
+#define THPBoolUtils_newAccreal(value)        THPUtils_newReal_BOOL(value)
 #define THPLongUtils_checkReal(object)        THPUtils_checkReal_INT(object)
 #define THPLongUtils_unpackReal(object)       (int64_t)THPUtils_unpackReal_INT(object)
 #define THPLongUtils_newReal(value)           THPUtils_newReal_INT(value)
@@ -123,6 +138,9 @@ struct THPUtils_typeTraits {};
 #include <torch/csrc/generic/utils.h>
 #include <TH/THGenerateHalfType.h>
 
+#include <torch/csrc/generic/utils.h>
+#include <TH/THGenerateBoolType.h>
+
 THLongStoragePtr THPUtils_unpackSize(PyObject *arg);
 bool THPUtils_tryUnpackLongs(PyObject *arg, THLongStoragePtr& result);
 std::vector<int64_t> THPUtils_unpackLongs(PyObject *arg);
index a59f691..de199d8 100644 (file)
@@ -36,6 +36,8 @@ static std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarTy
       return std::make_pair("complex64", "");
     case at::ScalarType::ComplexDouble:
       return std::make_pair("complex128", "");
+    case at::ScalarType::Bool:
+      return std::make_pair("bool", "");
     default:
       throw std::runtime_error("Unimplemented scalar type");
   }
index 13db6f2..63bc75e 100644 (file)
@@ -551,7 +551,7 @@ def _dummy_type(name):
 
 if not hasattr(torch._C, 'CudaDoubleStorageBase'):
     # Define dummy base classes
-    for t in ['Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Half']:
+    for t in ['Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Half', 'Bool']:
         storage_name = 'Cuda{0}StorageBase'.format(t)
         tensor_name = 'Cuda{0}TensorBase'.format(t)
 
@@ -613,6 +613,9 @@ class HalfStorage(_CudaBase, torch._C.CudaHalfStorageBase, _StorageBase):
     pass
 
 
+class BoolStorage(_CudaBase, torch._C.CudaBoolStorageBase, _StorageBase):
+    pass
+
 torch._storage_classes.add(DoubleStorage)
 torch._storage_classes.add(FloatStorage)
 torch._storage_classes.add(LongStorage)
@@ -621,6 +624,7 @@ torch._storage_classes.add(ShortStorage)
 torch._storage_classes.add(CharStorage)
 torch._storage_classes.add(ByteStorage)
 torch._storage_classes.add(HalfStorage)
+torch._storage_classes.add(BoolStorage)
 
 from . import sparse
 from . import profiler
index 22f64a8..68caff8 100644 (file)
@@ -83,6 +83,10 @@ class _StorageBase(object):
         """Casts this storage to byte type"""
         return self.type(type(self).__module__ + '.ByteStorage')
 
+    def bool(self):
+        """Casts this storage to bool type"""
+        return self.type(type(self).__module__ + '.BoolStorage')
+
     def pin_memory(self):
         """Copies the storage to pinned memory, if it's not already pinned."""
         if self.is_cuda: