allow numpy-like boolean-list indexing in pytorch (#14932)
authorrory <ruochunz@student.unimelb.edu.au>
Thu, 20 Dec 2018 23:18:39 +0000 (15:18 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 20 Dec 2018 23:33:06 +0000 (15:33 -0800)
Summary:
Suggested fix to issue #6773, the fix allows numpy-like boolean-list indexing in pytorch
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14932

Differential Revision: D13398795

Pulled By: ezyang

fbshipit-source-id: 67f8daf9829db2550ff76d2bde673be6dd2708cd

test/test_indexing.py
torch/csrc/autograd/python_variable_indexing.cpp
torch/csrc/utils/tensor_new.cpp
torch/csrc/utils/tensor_new.h

index 1cf9e25..b2bb1e6 100644 (file)
@@ -506,6 +506,19 @@ class NumpyTests(TestCase):
         self.assertEqual((1, 2, 3), a[True, True].shape)
         self.assertEqual((1, 2, 3), a[true, true].shape)
 
+    def test_boolean_list_indexing(self):
+        # Indexing a 2-dimensional array with
+        # boolean lists
+        a = tensor([[1, 2, 3],
+                    [4, 5, 6],
+                    [7, 8, 9]])
+        b = [True, False, False]
+        c = [True, True, False]
+        self.assertEqual(a[b], tensor([[1, 2, 3]]))
+        self.assertEqual(a[b, b], tensor([1]))
+        self.assertEqual(a[c], tensor([[1, 2, 3], [4, 5, 6]]))
+        self.assertEqual(a[c, c], tensor([1, 5]))
+
     def test_everything_returns_views(self):
         # Before `...` would return a itself.
         a = tensor([5])
index 0ab7f8b..e83764c 100644 (file)
@@ -107,7 +107,7 @@ static Variable applySelect(const Variable& self, int64_t dim, int64_t index) {
 
 static Variable sequenceToVariable(const at::Type& type, PyObject* seq) {
   auto& idx_type = type.toScalarType(kLong);
-  return torch::utils::legacy_new_from_data(idx_type, c10::nullopt, seq);
+  return torch::utils::indexing_tensor_from_data(idx_type, c10::nullopt, seq);
 }
 
 static Variable valueToTensor(const at::Type & type, PyObject* value) {
index 6721ab7..186a7f6 100644 (file)
@@ -249,7 +249,7 @@ Tensor legacy_new_from_sequence(
   if (!PySequence_Check(data)) {
     throw TypeError("new(): data must be a sequence (got %s)", Py_TYPE(data)->tp_name);
   }
-  return legacy_new_from_data(type, std::move(device), data);
+  return internal_new_from_data(type, std::move(device), data, false, false, false);
 }
 
 void check_legacy_ctor_device(const Type& type, c10::optional<Device> device) {
@@ -449,11 +449,19 @@ Tensor legacy_tensor_new(const Type& type, PyObject* args, PyObject* kwargs) {
   throw std::runtime_error("new(): invalid arguments");
 }
 
-Tensor legacy_new_from_data(
+Tensor indexing_tensor_from_data(
     const Type& type,
     c10::optional<Device> device,
     PyObject* data) {
-  return internal_new_from_data(type, std::move(device), data, false, false, false);
+  // Specific to tensor indexing, converts an indexing list to an
+  // indexing tensor (type Byte or Long)
+  ScalarType scalar_type = infer_scalar_type(data);
+  if (scalar_type == ScalarType::Byte) {
+    auto& idx_type = type.toScalarType(scalar_type);
+    return internal_new_from_data(idx_type, std::move(device), data, false, false, false);
+  } else {
+    return internal_new_from_data(type, std::move(device), data, false, false, false);
+  }
 }
 
 Tensor sparse_coo_tensor_ctor(const Type& default_type, PyObject* args, PyObject* kwargs) {
index f445431..74b1809 100644 (file)
@@ -8,7 +8,7 @@ namespace torch { namespace utils {
 
 at::Tensor legacy_tensor_ctor(const at::Type& type, PyObject* args, PyObject* kwargs);
 at::Tensor legacy_tensor_new(const at::Type& type, PyObject* args, PyObject* kwargs);
-at::Tensor legacy_new_from_data(
+at::Tensor indexing_tensor_from_data(
     const at::Type& type,
     c10::optional<at::Device> device,
     PyObject* data);