case ScalarType::Half:
dtype.code = DLDataTypeCode::kDLFloat;
break;
+ case ScalarType::Bool:
+ dtype.code = DLDataTypeCode::kDLUInt;
+ break;
case ScalarType::ComplexHalf:
throw std::logic_error("ComplexHalf is not supported by dlpack");
case ScalarType::ComplexFloat:
static inline void noop_deleter(void*) {}
enum class TypeID {
+ CPUBool,
CPUByte,
CPUChar,
CPUDouble,
CPULong,
CPUShort,
CPUHalf,
+ SparseCPUBool,
SparseCPUByte,
SparseCPUChar,
SparseCPUDouble,
SparseCPUInt,
SparseCPULong,
SparseCPUShort,
+ CUDABool,
CUDAByte,
CUDAChar,
CUDADouble,
CUDALong,
CUDAShort,
CUDAHalf,
+ SparseCUDABool,
SparseCUDAByte,
SparseCUDAChar,
SparseCUDADouble,
# scalar_name, c_type, accreal, th_scalar_type, is_floating_type
scalar_types = [
+ ('Bool', 'uint8_t', 'BoolAccrealNotDefined', 'uint8_t', False),
('Byte', 'uint8_t', 'Long', 'uint8_t', False),
('Char', 'int8_t', 'Long', 'int8_t', False),
('Double', 'double', 'Double', 'double', True),
THFilePrivate.h
${CMAKE_CURRENT_BINARY_DIR}/THGeneral.h
THGenerateAllTypes.h
+ THGenerateBoolType.h
THGenerateDoubleType.h
THGenerateFloatType.h
THGenerateHalfType.h
TH_API size_t THFile_readLong(THFile *self, THLongStorage *storage);
TH_API size_t THFile_readFloat(THFile *self, THFloatStorage *storage);
TH_API size_t THFile_readDouble(THFile *self, THDoubleStorage *storage);
+TH_API size_t THFile_readBool(THFile *self, THBoolStorage *storage);
TH_API size_t THFile_writeByte(THFile *self, THByteStorage *storage);
TH_API size_t THFile_writeChar(THFile *self, THCharStorage *storage);
TH_API size_t THFile_writeLong(THFile *self, THLongStorage *storage);
TH_API size_t THFile_writeFloat(THFile *self, THFloatStorage *storage);
TH_API size_t THFile_writeDouble(THFile *self, THDoubleStorage *storage);
+TH_API size_t THFile_writeBool(THFile *self, THBoolStorage *storage);
/* raw */
TH_API size_t THFile_readByteRaw(THFile *self, uint8_t *data, size_t n);
--- /dev/null
+#ifndef TH_GENERIC_FILE
+#error "You must define TH_GENERIC_FILE before including THGenerateBoolType.h"
+#endif
+
+// TODO: define accreal type once the correct value is known.
+#define scalar_t bool
+#define ureal bool
+#define Real Bool
+#define TH_CONVERT_ACCREAL_TO_REAL(_val) (scalar_t)(_val)
+#define TH_REAL_IS_BOOL
+#line 1 TH_GENERIC_FILE
+#include TH_GENERIC_FILE
+#undef scalar_t
+#undef ureal
+#undef Real
+#undef TH_REAL_IS_BOOL
+#undef TH_CONVERT_REAL_TO_ACCREAL
+#undef TH_CONVERT_ACCREAL_TO_REAL
+
+#ifndef THGenerateManyTypes
+#undef TH_GENERIC_FILE
+#endif
#include <TH/generic/THStorage.cpp>
#include <TH/THGenerateHalfType.h>
+#include <TH/generic/THStorage.cpp>
+#include <TH/THGenerateBoolType.h>
+
#include <TH/generic/THStorageCopy.cpp>
#include <TH/THGenerateAllTypes.h>
#include <TH/generic/THStorageCopy.cpp>
#include <TH/THGenerateHalfType.h>
+#include <TH/generic/THStorageCopy.cpp>
+#include <TH/THGenerateBoolType.h>
+
THStorage* THStorage_new(caffe2::TypeMeta data_type) {
THStorage* storage = c10::make_intrusive<at::StorageImpl>(
data_type,
#include <TH/generic/THStorage.h>
#include <TH/THGenerateHalfType.h>
+#include <TH/generic/THStorage.h>
+#include <TH/THGenerateBoolType.h>
+
#include <TH/generic/THStorageCopy.h>
#include <TH/THGenerateAllTypes.h>
#include <TH/generic/THStorageCopy.h>
#include <TH/THGenerateHalfType.h>
+#include <TH/generic/THStorageCopy.h>
+#include <TH/THGenerateBoolType.h>
+
// This exists to have a data-type independent way of freeing (necessary for THPPointer).
TH_API void THStorage_free(THStorage *storage);
#include <TH/generic/THTensor.cpp>
#include <TH/THGenerateHalfType.h>
+#include <TH/generic/THTensor.cpp>
+#include <TH/THGenerateBoolType.h>
+
#include <ATen/native/Resize.h>
#include <numeric>
#include <TH/generic/THTensor.h>
#include <TH/THGenerateHalfType.h>
+#include <TH/generic/THTensor.h>
+#include <TH/THGenerateBoolType.h>
+
/* random numbers */
#include <TH/THRandom.h>
#include <TH/generic/THTensorRandom.h>
#include <TH/generic/THTensor.hpp>
#include <TH/THGenerateHalfType.h>
+
+#include <TH/generic/THTensor.hpp>
+#include <TH/THGenerateBoolType.h>
#define THShortStorage THStorage
#define THIntStorage THStorage
#define THLongStorage THStorage
+#define THBoolStorage THStorage
TH_API scalar_t* THStorage_(data)(const THStorage*);
TH_API ptrdiff_t THStorage_(size)(const THStorage*);
IMPLEMENT_THStorage_COPY(Float)
IMPLEMENT_THStorage_COPY(Double)
IMPLEMENT_THStorage_COPY(Half)
+IMPLEMENT_THStorage_COPY(Bool)
#endif
TH_API void THStorage_(copyFloat)(THStorage *storage, struct THFloatStorage *src);
TH_API void THStorage_(copyDouble)(THStorage *storage, struct THDoubleStorage *src);
TH_API void THStorage_(copyHalf)(THStorage *storage, struct THHalfStorage *src);
+TH_API void THStorage_(copyBool)(THStorage *storage, struct THBoolStorage *src);
#endif
#define THShortTensor THTensor
#define THIntTensor THTensor
#define THLongTensor THTensor
+#define THBoolTensor THTensor
/**** access methods ****/
TH_API THStorage* THTensor_(storage)(const THTensor *self);
THCDeviceTensorUtils.cuh
THCDeviceTensorUtils-inl.cuh
THCGenerateAllTypes.h
+ THCGenerateBoolType.h
THCGenerateByteType.h
THCGenerateCharType.h
THCGenerateShortType.h
--- /dev/null
+#ifndef THC_GENERIC_FILE
+#error "You must define THC_GENERIC_FILE before including THCGenerateBoolType.h"
+#endif
+
+// TODO: define accreal type once the correct value is known.
+#define scalar_t bool
+#define ureal bool
+#define Real Bool
+#define CReal CudaBool
+#define THC_REAL_IS_BOOL
+#line 1 THC_GENERIC_FILE
+#include THC_GENERIC_FILE
+#undef scalar_t
+#undef ureal
+#undef Real
+#undef CReal
+#undef THC_REAL_IS_BOOL
+
+#ifndef THCGenerateBoolType
+#undef THC_GENERIC_FILE
+#endif
#include <THC/generic/THCStorage.cpp>
#include <THC/THCGenerateAllTypes.h>
+#include <THC/generic/THCStorage.cpp>
+#include <THC/THCGenerateBoolType.h>
+
#include <c10/util/intrusive_ptr.h>
void THCStorage_resize(THCState *state, THCStorage *self, ptrdiff_t size)
#include <THC/generic/THCStorage.cu>
#include <THC/THCGenerateAllTypes.h>
+
+#include <THC/generic/THCStorage.cu>
+#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCStorage.h>
#include <THC/THCGenerateAllTypes.h>
+#include <THC/generic/THCStorage.h>
+#include <THC/THCGenerateBoolType.h>
+
#endif
#include <THC/generic/THCStorageCopy.cpp>
#include <THC/THCGenerateAllTypes.h>
+
+#include <THC/generic/THCStorageCopy.cpp>
+#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCStorageCopy.cu>
#include <THC/THCGenerateAllTypes.h>
+
+#include <THC/generic/THCStorageCopy.cu>
+#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCStorageCopy.h>
#include <THC/THCGenerateAllTypes.h>
+#include <THC/generic/THCStorageCopy.h>
+#include <THC/THCGenerateBoolType.h>
+
#endif
#include <THC/generic/THCTensor.cpp>
#include <THC/THCGenerateAllTypes.h>
+#include <THC/generic/THCTensor.cpp>
+#include <THC/THCGenerateBoolType.h>
+
#include <THC/THCTensorInfo.cuh>
#include <ATen/native/cuda/Resize.cuh>
return THCudaTensor_new(state);
case at::ScalarType::Double:
return THCudaDoubleTensor_new(state);
+ case at::ScalarType::Bool:
+ return THCudaBoolTensor_new(state);
default:
AT_ERROR("unexpected ScalarType: ", toString(scalar_type));
}
#include <THC/generic/THCTensor.cu>
#include <THC/THCGenerateAllTypes.h>
+
+#include <THC/generic/THCTensor.cu>
+#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCTensor.h>
#include <THC/THCGenerateAllTypes.h>
+#include <THC/generic/THCTensor.h>
+#include <THC/THCGenerateBoolType.h>
#endif
#include <THC/generic/THCTensor.hpp>
#include <THC/THCGenerateAllTypes.h>
+
+#include <THC/generic/THCTensor.hpp>
+#include <THC/THCGenerateBoolType.h>
#include <type_traits>
// Copy operator for the pointwise apply kernel
-template <typename TypeDst, typename TypeSrc>
+template <typename T>
struct CopyOp {
- __device__ __forceinline__ void operator()(TypeDst* dst, TypeSrc* src) {
+ __device__ __forceinline__ void operator()(T* dst, T* src) {
#if __CUDA_ARCH__ >= 350
- *dst = ScalarConvert<TypeSrc, TypeDst>::to(__ldg(src));
+ *dst = ScalarConvert<T, T>::to(__ldg(src));
#else
- *dst = ScalarConvert<TypeSrc, TypeDst>::to(*src);
+ *dst = ScalarConvert<T, T>::to(*src);
#endif
}
};
+template <>
+struct CopyOp <bool> {
+ __device__ __forceinline__ void operator()(bool* dst, bool* src) {
+ *dst = ScalarConvert<bool, bool>::to(*src);
+ }
+};
+
#include <THC/generic/THCTensorCopy.cu>
#include <THC/THCGenerateAllTypes.h>
+
+#include <THC/generic/THCTensorCopy.cu>
+#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCTensorCopy.h>
#include <THC/THCGenerateAllTypes.h>
+#include <THC/generic/THCTensorCopy.h>
+#include <THC/THCGenerateBoolType.h>
+
#endif
#define THCudaShortStorage THCStorage
#define THCudaIntStorage THCStorage
#define THCudaLongStorage THCStorage
+#define THCudaBoolStorage THCStorage
THC_API scalar_t* THCStorage_(data)(THCState *state, const THCStorage*);
THC_API ptrdiff_t THCStorage_(size)(THCState *state, const THCStorage*);
TH_CUDA_STORAGE_IMPLEMENT_COPY(Float)
TH_CUDA_STORAGE_IMPLEMENT_COPY(Half)
TH_CUDA_STORAGE_IMPLEMENT_COPY(Double)
+TH_CUDA_STORAGE_IMPLEMENT_COPY(Bool)
void THStorage_(copyCuda)(THCState *state, THStorage *self, struct THCStorage *src)
{
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Float)
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Half)
TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Double)
+TH_CUDA_STORAGE_IMPLEMENT_COPYTO(Bool)
#undef TH_CUDA_STORAGE_IMPLEMENT_COPY
#undef TH_CUDA_STORAGE_IMPLEMENT_COPYTO
THC_CUDA_STORAGE_IMPLEMENT_COPY(Float,) // i.e. float
THC_CUDA_STORAGE_IMPLEMENT_COPY(Double,Double)
THC_CUDA_STORAGE_IMPLEMENT_COPY(Half,Half)
+THC_CUDA_STORAGE_IMPLEMENT_COPY(Bool,Bool)
#undef THC_CUDA_STORAGE_IMPLEMENT_COPY
THC_API void THCStorage_(copyFloat)(THCState *state, THCStorage *storage, struct THFloatStorage *src);
THC_API void THCStorage_(copyDouble)(THCState *state, THCStorage *storage, struct THDoubleStorage *src);
THC_API void THCStorage_(copyHalf)(THCState *state, THCStorage *storage, struct THHalfStorage *src);
+THC_API void THCStorage_(copyBool)(THCState *state, THCStorage *storage, struct THBoolStorage *src);
THC_API void THCStorage_(copyCudaByte)(THCState *state, THCStorage *storage, struct THCudaByteStorage *src);
THC_API void THCStorage_(copyCudaChar)(THCState *state, THCStorage *storage, struct THCudaCharStorage *src);
THC_API void THCStorage_(copyCudaFloat)(THCState *state, THCStorage *storage, struct THCudaStorage *src);
THC_API void THCStorage_(copyCudaDouble)(THCState *state, THCStorage *storage, struct THCudaDoubleStorage *src);
THC_API void THCStorage_(copyCudaHalf)(THCState *state, THCStorage *storage, struct THCudaHalfStorage *src);
+THC_API void THCStorage_(copyCudaBool)(THCState *state, THCStorage *storage, struct THCudaBoolStorage *src);
THC_API void TH_CONCAT_2(THByteStorage_copyCuda , Real)(THCState *state, THByteStorage *self, struct THCStorage *src);
THC_API void TH_CONCAT_2(THCharStorage_copyCuda , Real)(THCState *state, THCharStorage *self, struct THCStorage *src);
THC_API void TH_CONCAT_2(THFloatStorage_copyCuda , Real)(THCState *state, THFloatStorage *self, struct THCStorage *src);
THC_API void TH_CONCAT_2(THDoubleStorage_copyCuda, Real)(THCState *state, THDoubleStorage *self, struct THCStorage *src);
THC_API void TH_CONCAT_2(THHalfStorage_copyCuda, Real)(THCState *state, THHalfStorage *self, struct THCStorage *src);
+THC_API void TH_CONCAT_2(THBoolStorage_copyCuda, Real)(THCState *state, THBoolStorage *self, struct THCStorage *src);
THC_API void THStorage_(copyCuda)(THCState *state, THStorage *self, THCStorage *src);
THC_API void THCStorage_(copyCuda)(THCState *state, THCStorage *self, THCStorage *src);
#define THCudaShortTensor THCTensor
#define THCudaIntTensor THCTensor
#define THCudaLongTensor THCTensor
+#define THCudaBoolTensor THCTensor
/**** access methods ****/
THC_API THCStorage* THCTensor_(storage)(THCState *state, const THCTensor *self);
// FIXME: really, overlapping writes should be illegal/an error in Torch
THC_pointwiseApply2<scalar_t, scalar_t>(
state, dst, src,
- CopyOp<scalar_t, scalar_t>(),
+ CopyOp<scalar_t>(),
ReadOnly, /* ignore overwrites */
ReadOnly);
}
_(double,Double,d) /* 7 */ \
_(at::ComplexHalf,ComplexHalf,z) /* 8 */ \
_(std::complex<float>,ComplexFloat,z) /* 9 */ \
-_(std::complex<double>,ComplexDouble,z) /* 10 */
+_(std::complex<double>,ComplexDouble,z) /* 10 */ \
+_(bool,Bool,i) /* 11 */
// If you want to support ComplexHalf for real, replace occurrences
// of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But
constexpr auto f2 = ScalarType::Half;
constexpr auto f4 = ScalarType::Float;
constexpr auto f8 = ScalarType::Double;
+ constexpr auto b1 = ScalarType::Bool;
constexpr auto ud = ScalarType::Undefined;
if (a == ud || b == ud) {
return ScalarType::Undefined;
static constexpr ScalarType _promoteTypesLookup
[static_cast<int>(ScalarType::NumOptions)]
[static_cast<int>(ScalarType::NumOptions)] = {
- /* u1 i1 i2 i4 i8 f2 f4 f8 */
- /* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8 },
- /* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8 },
- /* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8 },
- /* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8 },
- /* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8 },
- /* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8 },
- /* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8 },
- /* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8 },
+ /* u1 i1 i2 i4 i8 f2 f4 f8 b1 */
+ /* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, u1 },
+ /* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, i1 },
+ /* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8, i2 },
+ /* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8, i4 },
+ /* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8, i8 },
+ /* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8, f2 },
+ /* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, f4 },
+ /* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, f8 },
+ /* b1 */ { u1, i1, i2, i4, i8, f2, f4, f8, b1 },
};
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
}
long = torch.LongStorage().element_size()
float = torch.FloatStorage().element_size()
double = torch.DoubleStorage().element_size()
+ bool = torch.BoolStorage().element_size()
self.assertEqual(byte, torch.ByteTensor().element_size())
self.assertEqual(char, torch.CharTensor().element_size())
self.assertGreater(long, 0)
self.assertGreater(float, 0)
self.assertGreater(double, 0)
+ self.assertGreater(bool, 0)
# These tests are portable, not necessarily strict for your system.
self.assertEqual(byte, 1)
self.assertEqual(char, 1)
+ self.assertEqual(bool, 1)
self.assertGreaterEqual(short, 2)
self.assertGreaterEqual(int, 2)
self.assertGreaterEqual(int, short)
self.assertEqual(floats.size(), 1)
self.assertEqual(floats[0], 2.25)
+ f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40])
+ bools = torch.BoolStorage.from_buffer(f, 'big')
+ self.assertEqual(bools.size(), 8)
+ self.assertEqual(bools.tolist(), [False, True, True, True, True, True, True, True])
+ self.assertEqual(bools.type(), 'torch.BoolStorage')
+
+ f = bytearray(b'\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9')
+ bools = torch.BoolStorage.from_buffer(f, 'big')
+ self.assertEqual(bools.size(), 19)
+
+ f = bytearray(b'\0x4A')
+ bools = torch.BoolStorage.from_buffer(f, 'big')
+ self.assertEqual(bools.size(), 4)
+ self.assertEqual(bools.tolist(), [False, True, True, True])
+
+ def test_storage_casts(self):
+ storage = torch.IntStorage([-1, 0, 1, 2, 3, 4])
+ self.assertEqual(storage.size(), 6)
+ self.assertEqual(storage.tolist(), [-1, 0, 1, 2, 3, 4])
+ self.assertEqual(storage.type(), 'torch.IntStorage')
+
+ floatStorage = storage.float()
+ self.assertEqual(floatStorage.size(), 6)
+ self.assertEqual(floatStorage.tolist(), [-1, 0, 1, 2, 3, 4])
+ self.assertEqual(floatStorage.type(), 'torch.FloatStorage')
+ self.assertEqual(floatStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+
+ halfStorage = storage.half()
+ self.assertEqual(halfStorage.size(), 6)
+ self.assertEqual(halfStorage.tolist(), [-1, 0, 1, 2, 3, 4])
+ self.assertEqual(halfStorage.type(), 'torch.HalfStorage')
+ self.assertEqual(halfStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+
+ longStorage = storage.long()
+ self.assertEqual(longStorage.size(), 6)
+ self.assertEqual(longStorage.tolist(), [-1, 0, 1, 2, 3, 4])
+ self.assertEqual(longStorage.type(), 'torch.LongStorage')
+ self.assertEqual(longStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+
+ shortStorage = storage.short()
+ self.assertEqual(shortStorage.size(), 6)
+ self.assertEqual(shortStorage.tolist(), [-1, 0, 1, 2, 3, 4])
+ self.assertEqual(shortStorage.type(), 'torch.ShortStorage')
+ self.assertEqual(shortStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+
+ doubleStorage = storage.double()
+ self.assertEqual(doubleStorage.size(), 6)
+ self.assertEqual(doubleStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
+ self.assertEqual(doubleStorage.type(), 'torch.DoubleStorage')
+ self.assertEqual(doubleStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+
+ charStorage = storage.char()
+ self.assertEqual(charStorage.size(), 6)
+ self.assertEqual(charStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
+ self.assertEqual(charStorage.type(), 'torch.CharStorage')
+ self.assertEqual(charStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+
+ byteStorage = storage.byte()
+ self.assertEqual(byteStorage.size(), 6)
+ self.assertEqual(byteStorage.tolist(), [255, 0, 1, 2, 3, 4])
+ self.assertEqual(byteStorage.type(), 'torch.ByteStorage')
+ self.assertEqual(byteStorage.int().tolist(), [255, 0, 1, 2, 3, 4])
+
+ boolStorage = storage.bool()
+ self.assertEqual(boolStorage.size(), 6)
+ self.assertEqual(boolStorage.tolist(), [True, False, True, True, True, True])
+ self.assertEqual(boolStorage.type(), 'torch.BoolStorage')
+ self.assertEqual(boolStorage.int().tolist(), [1, 0, 1, 1, 1, 1])
+
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows")
def test_from_file(self):
size = 10000
for t in torch._storage_classes:
if t.is_cuda and not torch.cuda.is_available():
continue
- obj = t(100).fill_(1)
+ if t == torch.BoolStorage or t == torch.cuda.BoolStorage:
+ obj = t(100).fill_(True)
+ else:
+ obj = t(100).fill_(1)
obj.__repr__()
str(obj)
pass
+class BoolStorage(_C.BoolStorageBase, _StorageBase):
+ pass
+
_storage_classes = {
DoubleStorage, FloatStorage, LongStorage, IntStorage, ShortStorage,
- CharStorage, ByteStorage, HalfStorage
+ CharStorage, ByteStorage, HalfStorage, BoolStorage
}
# The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings()
del ShortStorageBase
del CharStorageBase
del ByteStorageBase
+del BoolStorageBase
################################################################################
# Import most common subpackages
'ShortStorageBase',
'CharStorageBase',
'ByteStorageBase',
+ 'BoolStorageBase',
]
{"Short", at::kShort},
{"Int", at::kInt},
{"Long", at::kLong},
+ {"Bool", at::kBool},
};
std::unordered_map<at::Type*, PyTypeObject*> attype_to_py_storage_type;
THPShortStorage_postInit(module);
THPCharStorage_postInit(module);
THPByteStorage_postInit(module);
+ THPBoolStorage_postInit(module);
THPAutograd_initFunctions();
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
bool THCPShortStorage_init(PyObject *module);
bool THCPCharStorage_init(PyObject *module);
bool THCPByteStorage_init(PyObject *module);
+bool THCPBoolStorage_init(PyObject *module);
void THCPStream_init(PyObject *module);
void THCPEvent_init(PyObject *module);
bool THDPShortStorage_init(PyObject *module);
bool THDPCharStorage_init(PyObject *module);
bool THDPByteStorage_init(PyObject *module);
+bool THDPBoolStorage_init(PyObject *module);
static std::vector<PyMethodDef> methods;
ASSERT_TRUE(THPShortStorage_init(module));
ASSERT_TRUE(THPCharStorage_init(module));
ASSERT_TRUE(THPByteStorage_init(module));
+ ASSERT_TRUE(THPBoolStorage_init(module));
#ifdef USE_CUDA
// This will only initialise base classes and attach them to library namespace
ASSERT_TRUE(THCPShortStorage_init(module));
ASSERT_TRUE(THCPCharStorage_init(module));
ASSERT_TRUE(THCPByteStorage_init(module));
+ ASSERT_TRUE(THCPBoolStorage_init(module));
THCPStream_init(module);
THCPEvent_init(module);
#include <torch/csrc/generic/Storage.cpp>
#include <TH/THGenerateHalfType.h>
+#include <torch/csrc/generic/Storage.cpp>
+#include <TH/THGenerateBoolType.h>
+
template<>
void THPPointer<THStorage>::free() {
if (ptr) {
PyObject_IsInstance(obj, THPCharStorageClass)
#define THPByteStorage_Check(obj) \
PyObject_IsInstance(obj, THPByteStorageClass)
+#define THPBoolStorage_Check(obj) \
+ PyObject_IsInstance(obj, THPBoolStorageClass)
#define THPDoubleStorage_CData(obj) (obj)->cdata
#define THPFloatStorage_CData(obj) (obj)->cdata
#define THPShortStorage_CData(obj) (obj)->cdata
#define THPCharStorage_CData(obj) (obj)->cdata
#define THPByteStorage_CData(obj) (obj)->cdata
+#define THPBoolStorage_CData(obj) (obj)->cdata
#ifdef _THP_CORE
#define THPStorageType TH_CONCAT_3(THP,Real,StorageType)
#include <torch/csrc/generic/Storage.h>
#include <TH/THGenerateHalfType.h>
+#include <torch/csrc/generic/Storage.h>
+#include <TH/THGenerateBoolType.h>
+
#endif
uint64_t Byte5 = output & 0x0000FF0000000000;
uint64_t Byte6 = output & 0x00FF000000000000;
uint64_t Byte7 = output & 0xFF00000000000000;
- output = (Byte0 << (7*8)) | (Byte1 << (5*8)) | (Byte2 << (3*8)) | (Byte3 << (1*8)) |
+ output = (Byte0 << (7*8)) | (Byte1 << (5*8)) | (Byte2 << (3*8)) | (Byte3 << (1*8)) |
(Byte7 >> (7*8)) | (Byte6 >> (5*8)) | (Byte5 >> (3*8)) | (Byte4 >> (1*8));
#endif
memcpy(ptr, &output, sizeof(uint64_t));
}
}
+void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, THPByteOrder order, size_t len)
+{
+ for (size_t i = 0; i < len; i++) {
+ dst[i] = (int)src[i] != 0 ? true : false;
+ }
+}
+
void THP_decodeFloatBuffer(float* dst, const uint8_t* src, THPByteOrder order, size_t len)
{
for (size_t i = 0; i < len; i++) {
void THP_decodeHalfBuffer(THHalf* dst, const uint8_t* src, THPByteOrder order, size_t len);
void THP_decodeFloatBuffer(float* dst, const uint8_t* src, THPByteOrder order, size_t len);
void THP_decodeDoubleBuffer(double* dst, const uint8_t* src, THPByteOrder order, size_t len);
+void THP_decodeBoolBuffer(bool* dst, const uint8_t* src, THPByteOrder order, size_t len);
void THP_encodeInt16Buffer(uint8_t* dst, const int16_t* src, THPByteOrder order, size_t len);
void THP_encodeInt32Buffer(uint8_t* dst, const int32_t* src, THPByteOrder order, size_t len);
THCPShortStorage_postInit(m);
THCPCharStorage_postInit(m);
THCPByteStorage_postInit(m);
+ THCPBoolStorage_postInit(m);
bool has_magma = at::hasMAGMA();
if (has_magma) {
#define THC_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
#include <THC/THCGenerateAllTypes.h>
+
+#define THC_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
+#include <THC/THCGenerateBoolType.h>
PyObject_IsInstance(obj, THCPCharStorageClass)
#define THCPByteStorage_Check(obj) \
PyObject_IsInstance(obj, THCPByteStorageClass)
+#define THCPBoolStorage_Check(obj) \
+ PyObject_IsInstance(obj, THCPBoolStorageClass)
#define THCPDoubleStorage_CData(obj) (obj)->cdata
#define THCPFloatStorage_CData(obj) (obj)->cdata
#define THCPShortStorage_CData(obj) (obj)->cdata
#define THCPCharStorage_CData(obj) (obj)->cdata
#define THCPByteStorage_CData(obj) (obj)->cdata
+#define THCPBoolStorage_CData(obj) (obj)->cdata
#ifdef _THP_CORE
#define THCPStorageType TH_CONCAT_3(THCP,Real,StorageType)
#define THC_GENERIC_FILE "torch/csrc/generic/Storage.h"
#include <THC/THCGenerateAllTypes.h>
+#define THC_GENERIC_FILE "torch/csrc/generic/Storage.h"
+#include <THC/THCGenerateBoolType.h>
+
#endif
#define THC_GENERIC_FILE "torch/csrc/generic/serialization.cpp"
#include <THC/THCGenerateAllTypes.h>
+#define THC_GENERIC_FILE "torch/csrc/generic/serialization.cpp"
+#include <THC/THCGenerateBoolType.h>
#define THC_GENERIC_FILE "torch/csrc/generic/serialization.h"
#include <THC/THCGenerateAllTypes.h>
+#define THC_GENERIC_FILE "torch/csrc/generic/serialization.h"
+#include <THC/THCGenerateBoolType.h>
+
#endif
#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
#include <THC/THCGenerateAllTypes.h>
+#define THC_GENERIC_FILE "torch/csrc/generic/utils.cpp"
+#include <THC/THCGenerateBoolType.h>
+
#ifdef USE_CUDA
// NB: It's a list of *optional* CUDAStream; when nullopt, that means to use
// whatever the current stream of the device the input is associated with was.
#define THC_GENERIC_FILE "torch/csrc/generic/utils.h"
#include <THC/THCGenerateAllTypes.h>
+#define THC_GENERIC_FILE "torch/csrc/generic/utils.h"
+#include <THC/THCGenerateBoolType.h>
#endif
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPHalfStorageType, h, &THWStorage_(copyHalf));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPFloatStorageType, h, &THWStorage_(copyFloat));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPDoubleStorageType, h, &THWStorage_(copyDouble));
+ THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THPBoolStorageType, h, &THWStorage_(copyBool));
#ifdef THC_GENERIC_FILE
// copy from GPU types
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPByteStorageType, h, &THWStorage_(copyCudaByte));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPFloatStorageType, h, &THWStorage_(copyCudaFloat));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPDoubleStorageType, h, &THWStorage_(copyCudaDouble));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPHalfStorageType, h, &THWStorage_(copyCudaHalf));
+ THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPBoolStorageType, h, &THWStorage_(copyCudaBool));
// add CPU <- GPU copies to base type
/// #define THPCpuStorage TH_CONCAT_3(THP, Real, Storage)
#define THCpuStorage_(name) TH_CONCAT_4(TH, Real, Storage_, name)
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPFloatStorageType, b, &THCpuStorage_(copyCudaFloat));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPDoubleStorageType, b, &THCpuStorage_(copyCudaDouble));
THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPHalfStorageType, b, &THCpuStorage_(copyCudaHalf));
+ THPInsertStorageCopyFunction<THPStorage, THPStorage>(&THCPBoolStorageType, b, &THCpuStorage_(copyCudaBool));
#undef THCpuStorage
#undef THCpuStorage_
#endif
PyObject *obj = nullptr;
const char* byte_order_str = nullptr;
Py_ssize_t count = -1, offset = 0;
- Py_buffer buffer;
+ Py_buffer buffer = {};
static char *kwlist[] = {"buffer", "byte_order", "count", "offset", nullptr};
const char* argtypes;
#if defined(TH_REAL_IS_BYTE) || defined(TH_REAL_IS_CHAR)
#if defined(TH_REAL_IS_BYTE) || defined(TH_REAL_IS_CHAR)
memcpy(THWStorage_(data)(storage), src + offset, count);
+#elif defined(TH_REAL_IS_BOOL)
+ // Because of ASAN checks, that are failing in the THStorage.cpp whenever
+ // we are trying to get a value which is not 0 or 1, we have to manually
+ // convert original values to boolean ones.
+ THP_decodeBoolBuffer(THWStorage_(data)(storage), src + offset, byte_order, count);
#elif defined(TH_REAL_IS_SHORT)
THP_decodeInt16Buffer(THWStorage_(data)(storage), src + offset, byte_order, count);
#elif defined(TH_REAL_IS_INT)
#include <torch/csrc/generic/serialization.cpp>
#include <TH/THGenerateHalfType.h>
+
+#include <torch/csrc/generic/serialization.cpp>
+#include <TH/THGenerateBoolType.h>
#include <torch/csrc/generic/serialization.h>
#include <TH/THGenerateHalfType.h>
+#include <torch/csrc/generic/serialization.h>
+#include <TH/THGenerateBoolType.h>
+
template <class io>
void doRead(io fildes, void* buf, size_t nbytes);
#include <torch/csrc/generic/utils.cpp>
#include <TH/THGenerateHalfType.h>
+#include <torch/csrc/generic/utils.cpp>
+#include <TH/THGenerateBoolType.h>
+
int THPUtils_getCallable(PyObject *arg, PyObject **result) {
if (!PyCallable_Check(arg))
return 0;
(throw std::runtime_error("Could not parse real"), 0))
#endif
+#define THPUtils_unpackReal_BOOL(object) \
+ (PyBool_Check(object) ? object : \
+ (throw std::runtime_error("Could not parse real"), Py_False))
+
+#define THPUtils_checkReal_BOOL(object) \
+ PyBool_Check(object)
+
#define THPUtils_newReal_FLOAT(value) PyFloat_FromDouble(value)
// TODO: handle int overflows for py2
#define THPUtils_newReal_INT(value) PyInt_FromLong(value)
+#define THPUtils_newReal_BOOL(value) PyBool_FromLong(value)
+
#define THPDoubleUtils_checkReal(object) THPUtils_checkReal_FLOAT(object)
#define THPDoubleUtils_unpackReal(object) (double)THPUtils_unpackReal_FLOAT(object)
#define THPDoubleUtils_newReal(value) THPUtils_newReal_FLOAT(value)
#define THPHalfUtils_newReal(value) PyFloat_FromDouble(value)
#define THPHalfUtils_newAccreal(value) THPUtils_newReal_FLOAT(value)
+#define THPBoolUtils_checkReal(object) THPUtils_checkReal_BOOL(object)
+#define THPBoolUtils_unpackReal(object) THPUtils_unpackReal_BOOL(object)
+#define THPBoolUtils_newReal(value) THPUtils_newReal_BOOL(value)
+#define THPBoolUtils_checkAccreal(object) THPUtils_checkReal_BOOL(object)
+#define THPBoolUtils_unpackAccreal(object) (int64_t)THPUtils_unpackReal_BOOL(object)
+#define THPBoolUtils_newAccreal(value) THPUtils_newReal_BOOL(value)
#define THPLongUtils_checkReal(object) THPUtils_checkReal_INT(object)
#define THPLongUtils_unpackReal(object) (int64_t)THPUtils_unpackReal_INT(object)
#define THPLongUtils_newReal(value) THPUtils_newReal_INT(value)
#include <torch/csrc/generic/utils.h>
#include <TH/THGenerateHalfType.h>
+#include <torch/csrc/generic/utils.h>
+#include <TH/THGenerateBoolType.h>
+
THLongStoragePtr THPUtils_unpackSize(PyObject *arg);
bool THPUtils_tryUnpackLongs(PyObject *arg, THLongStoragePtr& result);
std::vector<int64_t> THPUtils_unpackLongs(PyObject *arg);
return std::make_pair("complex64", "");
case at::ScalarType::ComplexDouble:
return std::make_pair("complex128", "");
+ case at::ScalarType::Bool:
+ return std::make_pair("bool", "");
default:
throw std::runtime_error("Unimplemented scalar type");
}
if not hasattr(torch._C, 'CudaDoubleStorageBase'):
# Define dummy base classes
- for t in ['Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Half']:
+ for t in ['Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Half', 'Bool']:
storage_name = 'Cuda{0}StorageBase'.format(t)
tensor_name = 'Cuda{0}TensorBase'.format(t)
pass
+class BoolStorage(_CudaBase, torch._C.CudaBoolStorageBase, _StorageBase):
+ pass
+
torch._storage_classes.add(DoubleStorage)
torch._storage_classes.add(FloatStorage)
torch._storage_classes.add(LongStorage)
torch._storage_classes.add(CharStorage)
torch._storage_classes.add(ByteStorage)
torch._storage_classes.add(HalfStorage)
+torch._storage_classes.add(BoolStorage)
from . import sparse
from . import profiler
"""Casts this storage to byte type"""
return self.type(type(self).__module__ + '.ByteStorage')
+ def bool(self):
+ """Casts this storage to bool type"""
+ return self.type(type(self).__module__ + '.BoolStorage')
+
def pin_memory(self):
"""Copies the storage to pinned memory, if it's not already pinned."""
if self.is_cuda: