From: Iurii Zdebskyi Date: Wed, 3 Apr 2019 17:53:11 +0000 (-0700) Subject: Added indexing for bool tensors and bool Indices (#18583) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~447 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5950c1e8c419cb3f51c0222def269ad94a564ff4;p=platform%2Fupstream%2Fpytorch.git Added indexing for bool tensors and bool Indices (#18583) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18583 ghimport-source-id: 2b1941449827f4ab632fa0f5c8cf0791a6be0845 Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18583 Added indexing for bool tensors and bool Indices** * #18505 Added numpy conversion * #18166 Bool Tensor for CUDA ----------- This PR enables bool tensor indexing and indexing with bool indices. This is a part of Bool Tensor feature implementation work. The whole plan looks like this: 1. Storage Implementation [Done] 2. Tensor Creation. a) CPU [Done] b) CUDA [In review] 3. Tensor Conversions. [In review] 4. Tensor Indexing. [This PR] 5. Tensor Operations. 6. Back compatibility related changes. TODO: as a follow up, we should move nonzero method from TH to Aten to make code cleaner. Change: ``` v = torch.tensor([True, False, True], dtype=torch.bool) boolIndices = torch.tensor([True, False, False], dtype=torch.bool) v[boolIndices] -> tensor([True], dtype=torch.bool) v = torch.randn(5, 7, 3) boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool) v[boolIndices] -> tensor([[[ 0.5885, -0.3322, 0.7388], [ 1.1182, 0.7808, -1.1492], [-0.7952, 0.5255, -0.0251], [ 0.7128, 0.8099, 1.2689], [-0.7018, -1.4733, -0.3732], [ 0.4503, 0.4986, -1.1605], [ 0.3348, -1.3767, -0.2976]], [[-2.0303, -0.4720, -0.1448], [-0.1914, -0.6821, 2.0061], [-1.0420, -0.1872, -0.3438], [ 1.7587, -0.4183, -0.7577], [ 1.0094, -0.1950, -0.2430], [ 0.1174, 0.3308, -0.5700], [ 0.1110, -0.2714, 1.3006]], [[-0.1946, -1.4747, -0.4650], [-1.0567, 1.0110, -0.2809], [ 0.3729, -0.5699, 0.0815], [-0.7733, -0.8316, 0.1674], [ 1.2000, -0.3745, -1.1679], [ 1.7105, 0.9851, -0.1907], [-1.1077, 0.2086, -0.0548]]]) ``` Differential Revision: D14673403 fbshipit-source-id: 2b88ec2c7eb26a4f5ef64f8707fb68068d476fc9 --- diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index c438886..a1bf2b3 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -111,6 +111,8 @@ [[ name: _th_nonzero cname: nonzero + cpu_bool: True + cuda_bool: True variants: - function return: argument 0 diff --git a/aten/src/ATen/native/Indexing.cpp b/aten/src/ATen/native/Indexing.cpp index c7d478c..34851f0 100644 --- a/aten/src/ATen/native/Indexing.cpp +++ b/aten/src/ATen/native/Indexing.cpp @@ -5,7 +5,7 @@ // index(Tensor self, indices) -> Tensor // index_put_(Tensor self, indices, value, accumulate=false) // -// The index is a TensorList containg kLong or kByte tensors or nulls. Byte +// The index is a TensorList containg kLong, kBool or kByte tensors or nulls. Byte // tensors (boolean masks) are expanded to long tensors via nonzero(). Null // tensors signify that the dimension is not indexed. // @@ -79,19 +79,19 @@ static void checkIndexTensorTypes(TensorList indices) { for (auto& tensor : indices) { if (tensor.defined()) { auto scalarType = tensor.scalar_type(); - if (scalarType != kLong && scalarType != kByte) { - AT_INDEX_ERROR("tensors used as indices must be long or byte tensors"); + if (scalarType != kLong && scalarType != kByte && scalarType != kBool) { + AT_INDEX_ERROR("tensors used as indices must be long, byte or bool tensors"); } } } } -static std::vector expandByteTensors(const Tensor & self, TensorList indices) { - // Expands byte tensors (masks) into the equivalent indexing by LongTensors +static std::vector expandTensors(const Tensor & self, TensorList indices) { + // Expands ByteTensor (masks) or BoolTensor (masks) into the equivalent indexing by LongTensors std::vector result; for (auto & index : indices) { - if (index.scalar_type() == kByte) { - // The sizes of the ByteTensor mask must match the sizes of the + if (index.scalar_type() == kByte || index.scalar_type() == kBool) { + // The sizes of the ByteTensor mask or bool tensor must match the sizes of the // corresponding dimensions in self for (int64_t j = 0; j < index.dim(); j++) { int64_t srcIdx = result.size() + j; @@ -244,8 +244,8 @@ static Tensor computeLinearIndex(const Tensor & src, TensorList indices) { static std::tuple makeLinearIndex(Tensor self, TensorList orig) { checkIndexTensorTypes(orig); - // first expand ByteTensor (boolean masks) into 1 or more LongTensors - auto indices = expandByteTensors(self, orig); + // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors + auto indices = expandTensors(self, orig); // next broadcast all index tensors together indices = expand_outplace(indices); // add missing null Tensors so that it matches self.dim() @@ -378,8 +378,8 @@ AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list) static AdvancedIndex make_info(Tensor self, TensorList orig) { checkIndexTensorTypes(orig); - // first expand ByteTensor (boolean masks) into 1 or more LongTensors - auto indices = expandByteTensors(self, orig); + // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors + auto indices = expandTensors(self, orig); // next broadcast all index tensors together try { indices = expand_outplace(indices); diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp index 698b40b..e11ce20 100644 --- a/aten/src/ATen/native/cpu/IndexKernel.cpp +++ b/aten/src/ATen/native/cpu/IndexKernel.cpp @@ -92,7 +92,7 @@ void cpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef } void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_cpu", [&] { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_cpu", [&] { cpu_index_kernel(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) { *(scalar_t*)dst = *(scalar_t*)(src + offset); }); @@ -101,7 +101,7 @@ void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef inde void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) { // NOTE: duplicate indices are only supported if accumulate is true. - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_put", [&] { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_put", [&] { if (accumulate) { // TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case, // this needs to be thread-safe. diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index d498179..337a15e 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -81,7 +81,7 @@ void index_put_kernel_impl(TensorIterator& iter, IntArrayRef index_size, IntArra } static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) { - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_cuda", [&] { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_cuda", [&] { using dtype = OpaqueType; index_kernel_impl(iter, index_size, index_stride); }); @@ -90,7 +90,7 @@ static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayR static void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) { AT_ASSERTM(!accumulate, "index_put does not support accumulate=true"); - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_put", [&] { + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_put", [&] { using dtype = OpaqueType; index_put_kernel_impl(iter, index_size, index_stride); }); diff --git a/aten/src/TH/THTensor.h b/aten/src/TH/THTensor.h index 072f9ca..c73415d 100644 --- a/aten/src/TH/THTensor.h +++ b/aten/src/TH/THTensor.h @@ -28,6 +28,9 @@ #include #include +#include +#include + /* fill and zero*/ #include #include diff --git a/aten/src/TH/THTensorEvenMoreMath.cpp b/aten/src/TH/THTensorEvenMoreMath.cpp index c5f3328..a0b9e19 100644 --- a/aten/src/TH/THTensorEvenMoreMath.cpp +++ b/aten/src/TH/THTensorEvenMoreMath.cpp @@ -5,3 +5,6 @@ #include #include + +#include +#include diff --git a/aten/src/TH/generic/THTensorEvenMoreMath.cpp b/aten/src/TH/generic/THTensorEvenMoreMath.cpp index 3b0e86f..af10962 100644 --- a/aten/src/TH/generic/THTensorEvenMoreMath.cpp +++ b/aten/src/TH/generic/THTensorEvenMoreMath.cpp @@ -4,6 +4,71 @@ #include +// Finds non-zero elements of a tensor and returns their subscripts +void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor) +{ + ptrdiff_t numel = 0; + int64_t *subscript_data; + int64_t i = 0; +#ifdef TH_REAL_IS_HALF +#define IS_NONZERO(val) ((val.x & 0x7fff) != 0) +#else +#define IS_NONZERO(val) ((val)!=0) +#endif + + /* First Pass to determine size of subscripts */ + TH_TENSOR_APPLY(scalar_t, tensor, + if IS_NONZERO(*tensor_data) { + ++numel; + }); +#ifdef DEBUG + THAssert(numel <= LONG_MAX); +#endif + THLongTensor_resize2d(subscript, numel, tensor->dim()); + if (numel <= 0) { + return; + } + int64_t dimensions = tensor->dim(); + // +1 faster than additional condition check inside loop + int64_t *sizes = new int64_t[dimensions+1]; + int64_t *idx = new int64_t[dimensions+1]; + int64_t *ii; + int64_t *ss; + std::fill(idx, idx+dimensions+1, 0); + for (i = 0; i < dimensions; ++i) { + sizes[dimensions - i - 1] = THTensor_(size)(tensor, i); // reverse order important + } + sizes[dimensions] = 0; + /* Second pass populates subscripts */ + subscript_data = THLongTensor_data(subscript); + auto subscript_strides = THTensor_stridesLegacyNoScalars(subscript); + subscript_strides[0] -= subscript_strides[1] * tensor->dim(); + TH_TENSOR_APPLY(scalar_t, tensor, + if IS_NONZERO(*tensor_data) { + ii = idx + dimensions; + for (int64_t dim = dimensions - 1; dim >= 0; dim--) { + --ii; + *subscript_data = *ii; + subscript_data += subscript_strides[1]; + } + subscript_data += subscript_strides[0]; + } + ii = idx; + ss = sizes; + ++(*ii); + while (*ii == *ss) { + *ii = 0; + ++ii; + ++ss; + ++(*ii); + } + ); + delete [] sizes; + delete [] idx; +} + +#if !defined(TH_REAL_IS_BOOL) /* non bool only part */ + void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value) { #ifdef _OPENMP @@ -91,69 +156,6 @@ void THTensor_(maskedSelect)(THTensor *tensor, THTensor *src, THByteTensor *mask }); } -// Finds non-zero elements of a tensor and returns their subscripts -void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor) -{ - ptrdiff_t numel = 0; - int64_t *subscript_data; - int64_t i = 0; -#ifdef TH_REAL_IS_HALF -#define IS_NONZERO(val) ((val.x & 0x7fff) != 0) -#else -#define IS_NONZERO(val) ((val)!=0) -#endif - - /* First Pass to determine size of subscripts */ - TH_TENSOR_APPLY(scalar_t, tensor, - if IS_NONZERO(*tensor_data) { - ++numel; - }); -#ifdef DEBUG - THAssert(numel <= LONG_MAX); -#endif - THLongTensor_resize2d(subscript, numel, tensor->dim()); - if (numel <= 0) { - return; - } - int64_t dimensions = tensor->dim(); - // +1 faster than additional condition check inside loop - int64_t *sizes = new int64_t[dimensions+1]; - int64_t *idx = new int64_t[dimensions+1]; - int64_t *ii; - int64_t *ss; - std::fill(idx, idx+dimensions+1, 0); - for (i = 0; i < dimensions; ++i) { - sizes[dimensions - i - 1] = THTensor_(size)(tensor, i); // reverse order important - } - sizes[dimensions] = 0; - /* Second pass populates subscripts */ - subscript_data = THLongTensor_data(subscript); - auto subscript_strides = THTensor_stridesLegacyNoScalars(subscript); - subscript_strides[0] -= subscript_strides[1] * tensor->dim(); - TH_TENSOR_APPLY(scalar_t, tensor, - if IS_NONZERO(*tensor_data) { - ii = idx + dimensions; - for (int64_t dim = dimensions - 1; dim >= 0; dim--) { - --ii; - *subscript_data = *ii; - subscript_data += subscript_strides[1]; - } - subscript_data += subscript_strides[0]; - } - ii = idx; - ss = sizes; - ++(*ii); - while (*ii == *ss) { - *ii = 0; - ++ii; - ++ss; - ++(*ii); - } - ); - delete [] sizes; - delete [] idx; -} - void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index) { ptrdiff_t i, numel; @@ -959,4 +961,6 @@ void THTensor_(bitand)(THTensor *r_, THTensor *t, scalar_t value) #endif } +#endif + #endif /* TH_GENERIC_FILE */ diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h index 3ab999b..a0766cb 100644 --- a/aten/src/TH/generic/THTensorMath.h +++ b/aten/src/TH/generic/THTensorMath.h @@ -2,12 +2,14 @@ #define TH_GENERIC_FILE "TH/generic/THTensorMath.h" #else +TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor); + +#if !defined(TH_REAL_IS_BOOL) /* non bool only part */ + TH_API void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value); TH_API void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src); TH_API void THTensor_(maskedSelect)(THTensor *tensor, THTensor* src, THByteTensor *mask); -TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor); - TH_API void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index); TH_API void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src); TH_API void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src); @@ -177,3 +179,4 @@ TH_API void THTensor_(dirichlet_grad)(THTensor *self, THTensor *x, THTensor *alp #endif #endif +#endif diff --git a/aten/src/THC/THCTensorMath.cu b/aten/src/THC/THCTensorMath.cu index d44645d..433fe3a 100644 --- a/aten/src/THC/THCTensorMath.cu +++ b/aten/src/THC/THCTensorMath.cu @@ -109,6 +109,15 @@ struct NonZeroOp } }; +template <> +struct NonZeroOp +{ + NonZeroOp() {} + __host__ __device__ bool operator()(bool lhs) const { + return lhs != false; + } +}; + #include #include diff --git a/aten/src/THC/generic/THCTensorMath.cu b/aten/src/THC/generic/THCTensorMath.cu index 286bbe7..b6c322d 100644 --- a/aten/src/THC/generic/THCTensorMath.cu +++ b/aten/src/THC/generic/THCTensorMath.cu @@ -242,8 +242,6 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result, } } -#if !defined(THC_REAL_IS_BOOL) /* non bool only part */ - void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self) { @@ -318,6 +316,8 @@ void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCudaCheck(cudaGetLastError()); } +#if !defined(THC_REAL_IS_BOOL) /* non bool only part */ + void THCTensor_(diag)(THCState *state, THCTensor *self_, THCTensor *src_, int64_t k){ THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_)); int nDimension = THCTensor_(nDimensionLegacyNoScalars)(state, src_); diff --git a/aten/src/THC/generic/THCTensorMath.h b/aten/src/THC/generic/THCTensorMath.h index f6c2d8a..a565cec 100644 --- a/aten/src/THC/generic/THCTensorMath.h +++ b/aten/src/THC/generic/THCTensorMath.h @@ -6,11 +6,11 @@ THC_API void THCTensor_(fill)(THCState *state, THCTensor *self, scalar_t value); THC_API void THCTensor_(zero)(THCState *state, THCTensor *self); THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THCTensor *tb, int dimension); THC_API void THCTensor_(catArray)(THCState *state, THCTensor *result, THCTensor **inputs, int numInputs, int dimension); +THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self); THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t); #if !defined(THC_REAL_IS_BOOL) /* non bool only part */ -THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self); THC_API void THCTensor_(triu)(THCState *state, THCTensor *self, THCTensor *src, int64_t k); THC_API void THCTensor_(diag)(THCState *state, THCTensor *self, THCTensor *src, int64_t k); THC_API accreal THCTensor_(trace)(THCState *state, THCTensor *self); diff --git a/test/test_indexing.py b/test/test_indexing.py index fcc65bc..9badf87 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -35,6 +35,32 @@ class TestIndexing(TestCase): self.assertEqual(v[0].tolist(), [0, 3, 0, 4]) self.assertEqual(v[1:].sum(), 0) + def test_bool_indices(self): + v = torch.randn(5, 7, 3) + boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool) + self.assertEqual(v[boolIndices].shape, (3, 7, 3)) + self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]])) + + v = torch.tensor([True, False, True], dtype=torch.bool) + boolIndices = torch.tensor([True, False, False], dtype=torch.bool) + uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8) + self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape) + self.assertEqual(v[boolIndices], v[uint8Indices]) + self.assertEqual(v[boolIndices], tensor([True], dtype=torch.bool)) + + def test_bool_indices_accumulate(self): + mask = torch.zeros(size=(10, ), dtype=torch.bool) + y = torch.ones(size=(10, 10)) + y.index_put_((mask, ), y[mask], accumulate=True) + self.assertEqual(y, torch.ones(size=(10, 10))) + + def test_multiple_bool_indices(self): + v = torch.randn(5, 7, 3) + # note: these broadcast together and are transposed to the first dim + mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool) + mask2 = torch.tensor([1, 1, 1], dtype=torch.bool) + self.assertEqual(v[mask1, :, mask2].shape, (3, 7)) + def test_byte_mask(self): v = torch.randn(5, 7, 3) mask = torch.ByteTensor([1, 0, 1, 1, 0])