Added indexing for bool tensors and bool Indices (#18583)
authorIurii Zdebskyi <iuriiz@fb.com>
Wed, 3 Apr 2019 17:53:11 +0000 (10:53 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 19:47:26 +0000 (12:47 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18583
ghimport-source-id: 2b1941449827f4ab632fa0f5c8cf0791a6be0845

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18583 Added indexing for bool tensors and bool Indices**
* #18505 Added numpy conversion
* #18166 Bool Tensor for CUDA

-----------
This PR enables bool tensor indexing and indexing with bool indices. This is a part of Bool Tensor feature implementation work. The whole plan looks like this:
1. Storage Implementation [Done]
2. Tensor Creation.
    a) CPU [Done]
    b) CUDA [In review]
3. Tensor Conversions. [In review]
4. Tensor Indexing. [This PR]
5. Tensor Operations.
6. Back compatibility related changes.

TODO:
as a follow up, we should move nonzero method from TH to Aten to make code cleaner.

Change:
```
v = torch.tensor([True, False, True], dtype=torch.bool)
boolIndices = torch.tensor([True, False, False], dtype=torch.bool)
v[boolIndices]
-> tensor([True], dtype=torch.bool)

v = torch.randn(5, 7, 3)
boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool)
v[boolIndices]
->
tensor([[[ 0.5885, -0.3322,  0.7388],
         [ 1.1182,  0.7808, -1.1492],
         [-0.7952,  0.5255, -0.0251],
         [ 0.7128,  0.8099,  1.2689],
         [-0.7018, -1.4733, -0.3732],
         [ 0.4503,  0.4986, -1.1605],
         [ 0.3348, -1.3767, -0.2976]],

        [[-2.0303, -0.4720, -0.1448],
         [-0.1914, -0.6821,  2.0061],
         [-1.0420, -0.1872, -0.3438],
         [ 1.7587, -0.4183, -0.7577],
         [ 1.0094, -0.1950, -0.2430],
         [ 0.1174,  0.3308, -0.5700],
         [ 0.1110, -0.2714,  1.3006]],

        [[-0.1946, -1.4747, -0.4650],
         [-1.0567,  1.0110, -0.2809],
         [ 0.3729, -0.5699,  0.0815],
         [-0.7733, -0.8316,  0.1674],
         [ 1.2000, -0.3745, -1.1679],
         [ 1.7105,  0.9851, -0.1907],
         [-1.1077,  0.2086, -0.0548]]])
```

Differential Revision: D14673403

fbshipit-source-id: 2b88ec2c7eb26a4f5ef64f8707fb68068d476fc9

12 files changed:
aten/src/ATen/Declarations.cwrap
aten/src/ATen/native/Indexing.cpp
aten/src/ATen/native/cpu/IndexKernel.cpp
aten/src/ATen/native/cuda/IndexKernel.cu
aten/src/TH/THTensor.h
aten/src/TH/THTensorEvenMoreMath.cpp
aten/src/TH/generic/THTensorEvenMoreMath.cpp
aten/src/TH/generic/THTensorMath.h
aten/src/THC/THCTensorMath.cu
aten/src/THC/generic/THCTensorMath.cu
aten/src/THC/generic/THCTensorMath.h
test/test_indexing.py

index c438886..a1bf2b3 100644 (file)
 [[
   name: _th_nonzero
   cname: nonzero
+  cpu_bool: True
+  cuda_bool: True
   variants:
     - function
   return: argument 0
index c7d478c..34851f0 100644 (file)
@@ -5,7 +5,7 @@
 //  index(Tensor self, indices) -> Tensor
 //  index_put_(Tensor self, indices, value, accumulate=false)
 //
-// The index is a TensorList containg kLong or kByte tensors or nulls. Byte
+// The index is a TensorList containg kLong, kBool or kByte tensors or nulls. Byte
 // tensors (boolean masks) are expanded to long tensors via nonzero(). Null
 // tensors signify that the dimension is not indexed.
 //
@@ -79,19 +79,19 @@ static void checkIndexTensorTypes(TensorList indices) {
   for (auto& tensor : indices) {
     if (tensor.defined()) {
       auto scalarType = tensor.scalar_type();
-      if (scalarType != kLong && scalarType != kByte) {
-          AT_INDEX_ERROR("tensors used as indices must be long or byte tensors");
+      if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
+          AT_INDEX_ERROR("tensors used as indices must be long, byte or bool tensors");
       }
     }
   }
 }
 
-static std::vector<Tensor> expandByteTensors(const Tensor & self, TensorList indices) {
-  // Expands byte tensors (masks) into the equivalent indexing by LongTensors
+static std::vector<Tensor> expandTensors(const Tensor & self, TensorList indices) {
+  // Expands ByteTensor (masks) or BoolTensor (masks) into the equivalent indexing by LongTensors
   std::vector<Tensor> result;
   for (auto & index : indices) {
-    if (index.scalar_type() == kByte) {
-      // The sizes of the ByteTensor mask must match the sizes of the
+    if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
+      // The sizes of the ByteTensor mask or bool tensor must match the sizes of the
       // corresponding dimensions in self
       for (int64_t j = 0; j < index.dim(); j++) {
         int64_t srcIdx = result.size() + j;
@@ -244,8 +244,8 @@ static Tensor computeLinearIndex(const Tensor & src, TensorList indices) {
 
 static std::tuple<Tensor, Tensor> makeLinearIndex(Tensor self, TensorList orig) {
   checkIndexTensorTypes(orig);
-  // first expand ByteTensor (boolean masks) into 1 or more LongTensors
-  auto indices = expandByteTensors(self, orig);
+  // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
+  auto indices = expandTensors(self, orig);
   // next broadcast all index tensors together
   indices = expand_outplace(indices);
   // add missing null Tensors so that it matches self.dim()
@@ -378,8 +378,8 @@ AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
 
 static AdvancedIndex make_info(Tensor self, TensorList orig) {
   checkIndexTensorTypes(orig);
-  // first expand ByteTensor (boolean masks) into 1 or more LongTensors
-  auto indices = expandByteTensors(self, orig);
+  // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
+  auto indices = expandTensors(self, orig);
   // next broadcast all index tensors together
   try {
     indices = expand_outplace(indices);
index 698b40b..e11ce20 100644 (file)
@@ -92,7 +92,7 @@ void cpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
 }
 
 void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_cpu", [&] {
+  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_cpu", [&] {
     cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
       *(scalar_t*)dst = *(scalar_t*)(src + offset);
     });
@@ -101,7 +101,7 @@ void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef inde
 
 void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
   // NOTE: duplicate indices are only supported if accumulate is true.
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_put", [&] {
+  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_put", [&] {
     if (accumulate) {
       // TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case,
       // this needs to be thread-safe.
index d498179..337a15e 100644 (file)
@@ -81,7 +81,7 @@ void index_put_kernel_impl(TensorIterator& iter, IntArrayRef index_size, IntArra
 }
 
 static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_cuda", [&] {
+  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_cuda", [&] {
     using dtype = OpaqueType<sizeof(scalar_t)>;
     index_kernel_impl<dtype>(iter, index_size, index_stride);
   });
@@ -90,7 +90,7 @@ static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayR
 
 static void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
   AT_ASSERTM(!accumulate, "index_put does not support accumulate=true");
-  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "index_put", [&] {
+  AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "index_put", [&] {
     using dtype = OpaqueType<sizeof(scalar_t)>;
     index_put_kernel_impl<dtype>(iter, index_size, index_stride);
   });
index 072f9ca..c73415d 100644 (file)
@@ -28,6 +28,9 @@
 #include <TH/generic/THTensorMath.h>
 #include <TH/THGenerateAllTypes.h>
 
+#include <TH/generic/THTensorMath.h>
+#include <TH/THGenerateBoolType.h>
+
 /* fill and zero*/
 #include <TH/generic/THTensorFill.h>
 #include <TH/THGenerateAllTypes.h>
index c5f3328..a0b9e19 100644 (file)
@@ -5,3 +5,6 @@
 
 #include <TH/generic/THTensorEvenMoreMath.cpp>
 #include <TH/THGenerateAllTypes.h>
+
+#include <TH/generic/THTensorEvenMoreMath.cpp>
+#include <TH/THGenerateBoolType.h>
index 3b0e86f..af10962 100644 (file)
@@ -4,6 +4,71 @@
 
 #include <TH/generic/THTensorApply.hpp>
 
+// Finds non-zero elements of a tensor and returns their subscripts
+void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
+{
+  ptrdiff_t numel = 0;
+  int64_t *subscript_data;
+  int64_t i = 0;
+#ifdef TH_REAL_IS_HALF
+#define IS_NONZERO(val) ((val.x & 0x7fff) != 0)
+#else
+#define IS_NONZERO(val) ((val)!=0)
+#endif
+
+  /* First Pass to determine size of subscripts */
+  TH_TENSOR_APPLY(scalar_t, tensor,
+                  if IS_NONZERO(*tensor_data) {
+                    ++numel;
+                  });
+#ifdef DEBUG
+  THAssert(numel <= LONG_MAX);
+#endif
+  THLongTensor_resize2d(subscript, numel, tensor->dim());
+  if (numel <= 0) {
+    return;
+  }
+  int64_t dimensions = tensor->dim();
+  // +1 faster than additional condition check inside loop
+  int64_t *sizes = new int64_t[dimensions+1];
+  int64_t *idx = new int64_t[dimensions+1];
+  int64_t *ii;
+  int64_t *ss;
+  std::fill(idx, idx+dimensions+1, 0);
+  for (i = 0; i < dimensions; ++i) {
+    sizes[dimensions - i - 1] = THTensor_(size)(tensor, i); // reverse order important
+  }
+  sizes[dimensions] = 0;
+  /* Second pass populates subscripts */
+  subscript_data = THLongTensor_data(subscript);
+  auto subscript_strides = THTensor_stridesLegacyNoScalars(subscript);
+  subscript_strides[0] -= subscript_strides[1] * tensor->dim();
+  TH_TENSOR_APPLY(scalar_t, tensor,
+                  if IS_NONZERO(*tensor_data) {
+                    ii = idx + dimensions;
+                    for (int64_t dim = dimensions - 1; dim >= 0; dim--) {
+                      --ii;
+                      *subscript_data = *ii;
+                      subscript_data += subscript_strides[1];
+                    }
+                    subscript_data += subscript_strides[0];
+                  }
+                  ii = idx;
+                  ss = sizes;
+                  ++(*ii);
+                  while (*ii == *ss) {
+                    *ii = 0;
+                    ++ii;
+                    ++ss;
+                    ++(*ii);
+                  }
+                );
+  delete [] sizes;
+  delete [] idx;
+}
+
+#if !defined(TH_REAL_IS_BOOL) /* non bool only part */
+
 void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value)
 {
 #ifdef _OPENMP
@@ -91,69 +156,6 @@ void THTensor_(maskedSelect)(THTensor *tensor, THTensor *src, THByteTensor *mask
                    });
 }
 
-// Finds non-zero elements of a tensor and returns their subscripts
-void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor)
-{
-  ptrdiff_t numel = 0;
-  int64_t *subscript_data;
-  int64_t i = 0;
-#ifdef TH_REAL_IS_HALF
-#define IS_NONZERO(val) ((val.x & 0x7fff) != 0)
-#else
-#define IS_NONZERO(val) ((val)!=0)
-#endif
-
-  /* First Pass to determine size of subscripts */
-  TH_TENSOR_APPLY(scalar_t, tensor,
-                  if IS_NONZERO(*tensor_data) {
-                    ++numel;
-                  });
-#ifdef DEBUG
-  THAssert(numel <= LONG_MAX);
-#endif
-  THLongTensor_resize2d(subscript, numel, tensor->dim());
-  if (numel <= 0) {
-    return;
-  }
-  int64_t dimensions = tensor->dim();
-  // +1 faster than additional condition check inside loop
-  int64_t *sizes = new int64_t[dimensions+1];
-  int64_t *idx = new int64_t[dimensions+1];
-  int64_t *ii;
-  int64_t *ss;
-  std::fill(idx, idx+dimensions+1, 0);
-  for (i = 0; i < dimensions; ++i) {
-    sizes[dimensions - i - 1] = THTensor_(size)(tensor, i); // reverse order important
-  }
-  sizes[dimensions] = 0;
-  /* Second pass populates subscripts */
-  subscript_data = THLongTensor_data(subscript);
-  auto subscript_strides = THTensor_stridesLegacyNoScalars(subscript);
-  subscript_strides[0] -= subscript_strides[1] * tensor->dim();
-  TH_TENSOR_APPLY(scalar_t, tensor,
-                  if IS_NONZERO(*tensor_data) {
-                    ii = idx + dimensions;
-                    for (int64_t dim = dimensions - 1; dim >= 0; dim--) {
-                      --ii;
-                      *subscript_data = *ii;
-                      subscript_data += subscript_strides[1];
-                    }
-                    subscript_data += subscript_strides[0];
-                  }
-                  ii = idx;
-                  ss = sizes;
-                  ++(*ii);
-                  while (*ii == *ss) {
-                    *ii = 0;
-                    ++ii;
-                    ++ss;
-                    ++(*ii);
-                  }
-                );
-  delete [] sizes;
-  delete [] idx;
-}
-
 void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index)
 {
   ptrdiff_t i, numel;
@@ -959,4 +961,6 @@ void THTensor_(bitand)(THTensor *r_, THTensor *t, scalar_t value)
 #endif
 }
 
+#endif
+
 #endif /* TH_GENERIC_FILE */
index 3ab999b..a0766cb 100644 (file)
@@ -2,12 +2,14 @@
 #define TH_GENERIC_FILE "TH/generic/THTensorMath.h"
 #else
 
+TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor);
+
+#if !defined(TH_REAL_IS_BOOL) /* non bool only part */
+
 TH_API void THTensor_(maskedFill)(THTensor *tensor, THByteTensor *mask, scalar_t value);
 TH_API void THTensor_(maskedCopy)(THTensor *tensor, THByteTensor *mask, THTensor* src);
 TH_API void THTensor_(maskedSelect)(THTensor *tensor, THTensor* src, THByteTensor *mask);
 
-TH_API void THTensor_(nonzero)(THLongTensor *subscript, THTensor *tensor);
-
 TH_API void THTensor_(indexSelect)(THTensor *tensor, THTensor *src, int dim, THLongTensor *index);
 TH_API void THTensor_(indexCopy)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
 TH_API void THTensor_(indexAdd)(THTensor *tensor, int dim, THLongTensor *index, THTensor *src);
@@ -177,3 +179,4 @@ TH_API void THTensor_(dirichlet_grad)(THTensor *self, THTensor *x, THTensor *alp
 #endif
 
 #endif
+#endif
index d44645d..433fe3a 100644 (file)
@@ -109,6 +109,15 @@ struct NonZeroOp
   }
 };
 
+template <>
+struct NonZeroOp<bool>
+{
+  NonZeroOp() {}
+  __host__ __device__ bool operator()(bool lhs) const {
+    return lhs != false;
+  }
+};
+
 #include <THC/generic/THCTensorMath.cu>
 #include <THC/THCGenerateAllTypes.h>
 
index 286bbe7..b6c322d 100644 (file)
@@ -242,8 +242,6 @@ void THCTensor_(catArray)(THCState *state, THCTensor *result,
   }
 }
 
-#if !defined(THC_REAL_IS_BOOL) /* non bool only part */
-
 void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor,
                           THCTensor *self)
 {
@@ -318,6 +316,8 @@ void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor,
   THCudaCheck(cudaGetLastError());
 }
 
+#if !defined(THC_REAL_IS_BOOL) /* non bool only part */
+
 void THCTensor_(diag)(THCState *state, THCTensor *self_, THCTensor *src_, int64_t k){
   THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_));
   int nDimension = THCTensor_(nDimensionLegacyNoScalars)(state, src_);
index f6c2d8a..a565cec 100644 (file)
@@ -6,11 +6,11 @@ THC_API void THCTensor_(fill)(THCState *state, THCTensor *self, scalar_t value);
 THC_API void THCTensor_(zero)(THCState *state, THCTensor *self);
 THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THCTensor *tb, int dimension);
 THC_API void THCTensor_(catArray)(THCState *state, THCTensor *result, THCTensor **inputs, int numInputs, int dimension);
+THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self);
 THC_API ptrdiff_t THCTensor_(numel)(THCState *state, THCTensor *t);
 
 #if !defined(THC_REAL_IS_BOOL) /* non bool only part */
 
-THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self);
 THC_API void THCTensor_(triu)(THCState *state, THCTensor *self, THCTensor *src, int64_t k);
 THC_API void THCTensor_(diag)(THCState *state, THCTensor *self, THCTensor *src, int64_t k);
 THC_API accreal THCTensor_(trace)(THCState *state, THCTensor *self);
index fcc65bc..9badf87 100644 (file)
@@ -35,6 +35,32 @@ class TestIndexing(TestCase):
         self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
         self.assertEqual(v[1:].sum(), 0)
 
+    def test_bool_indices(self):
+        v = torch.randn(5, 7, 3)
+        boolIndices = torch.tensor([True, False, True, True, False], dtype=torch.bool)
+        self.assertEqual(v[boolIndices].shape, (3, 7, 3))
+        self.assertEqual(v[boolIndices], torch.stack([v[0], v[2], v[3]]))
+
+        v = torch.tensor([True, False, True], dtype=torch.bool)
+        boolIndices = torch.tensor([True, False, False], dtype=torch.bool)
+        uint8Indices = torch.tensor([1, 0, 0], dtype=torch.uint8)
+        self.assertEqual(v[boolIndices].shape, v[uint8Indices].shape)
+        self.assertEqual(v[boolIndices], v[uint8Indices])
+        self.assertEqual(v[boolIndices], tensor([True], dtype=torch.bool))
+
+    def test_bool_indices_accumulate(self):
+        mask = torch.zeros(size=(10, ), dtype=torch.bool)
+        y = torch.ones(size=(10, 10))
+        y.index_put_((mask, ), y[mask], accumulate=True)
+        self.assertEqual(y, torch.ones(size=(10, 10)))
+
+    def test_multiple_bool_indices(self):
+        v = torch.randn(5, 7, 3)
+        # note: these broadcast together and are transposed to the first dim
+        mask1 = torch.tensor([1, 0, 1, 1, 0], dtype=torch.bool)
+        mask2 = torch.tensor([1, 1, 1], dtype=torch.bool)
+        self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
+
     def test_byte_mask(self):
         v = torch.randn(5, 7, 3)
         mask = torch.ByteTensor([1, 0, 1, 1, 0])