Use IndexError instead of RuntimeError in ATen CPU kernels
authorStefan Krah <skrah@bytereef.org>
Wed, 13 Feb 2019 17:24:04 +0000 (09:24 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 13 Feb 2019 18:19:28 +0000 (10:19 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17049

Reviewed By: ezyang

Differential Revision: D14064700

Pulled By: fmassa

fbshipit-source-id: 3575db103bba5a7d82f574cbb082beca419151ec

aten/src/ATen/native/Indexing.cpp
aten/src/ATen/native/TensorShape.cpp
aten/src/ATen/native/cpu/IndexKernel.cpp
c10/core/WrapDimMinimal.h
c10/util/Exception.h
test/test_indexing.py
test/test_torch.py
torch/csrc/Exceptions.h

index ef5c124..12cdd48 100644 (file)
@@ -72,7 +72,7 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask,
   ss << "The shape of the mask " << mask.sizes() << " at index " << maskIdx;
   ss << " does not match the shape of the indexed tensor " << self.sizes();
   ss << " at index " << idx;
-  AT_ERROR(ss.str());
+  AT_INDEX_ERROR(ss.str());
 }
 
 static void checkIndexTensorTypes(TensorList indices) {
@@ -80,8 +80,9 @@ static void checkIndexTensorTypes(TensorList indices) {
     if (tensor.defined()) {
       auto& type = tensor.type();
       auto scalarType = type.scalarType();
-      AT_CHECK(scalarType == kLong || scalarType == kByte,
-               "tensors used as indices must be long or byte tensors");
+      if (scalarType != kLong && scalarType != kByte) {
+          AT_INDEX_ERROR("tensors used as indices must be long or byte tensors");
+      }
     }
   }
 }
@@ -350,7 +351,7 @@ AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
   // restride_src with an unhelpful error message.
   if (std::find(indexed_sizes.begin(), indexed_sizes.end(), 0) != indexed_sizes.end() &&
       std::find(replacement_shape.begin(), replacement_shape.end(), 0) == replacement_shape.end()) {
-    AT_ERROR("index is out of bounds for dim with size 0");
+    AT_INDEX_ERROR("index is out of bounds for dimension with size 0");
   }
 
   this->dims_before = dims_before;
@@ -382,8 +383,8 @@ static AdvancedIndex make_info(Tensor self, TensorList orig) {
   try {
     indices = expand_outplace(indices);
   } catch (std::exception& e) {
-    AT_ERROR("shape mismatch: indexing tensors could not be broadcast together"
-             " with shapes ", shapes_as_str(indices));
+    AT_INDEX_ERROR("shape mismatch: indexing tensors could not be broadcast together"
+                   " with shapes ", shapes_as_str(indices));
   }
   // add missing null Tensors so that it matches self.dim()
   while (indices.size() < (size_t)self.dim()) {
index 5e31897..2e7b5e6 100644 (file)
@@ -431,12 +431,15 @@ Tensor reshape_as(const Tensor& self, const Tensor& other) {
 
 Tensor select(const Tensor& self, int64_t dim, int64_t index) {
   int64_t ndim = self.dim();
-  AT_CHECK(ndim > 0, "select() cannot be applied to a 0-dim tensor.");
+  if (ndim == 0) {
+    AT_INDEX_ERROR("select() cannot be applied to a 0-dim tensor.");
+  }
   dim = maybe_wrap_dim(dim, ndim);
   auto size = self.size(dim);
-  AT_CHECK(index >= -size && index < size,
-           "select(): index ", index, " out of range for tensor of size ",
-           self.sizes(), " at dimension ", dim);
+  if (index < -size || index >= size) {
+    AT_INDEX_ERROR("select(): index ", index, " out of range for tensor of size ",
+                   self.sizes(), " at dimension ", dim);
+  }
   if (index < 0) {
     index += size;
   }
@@ -450,7 +453,9 @@ Tensor select(const Tensor& self, int64_t dim, int64_t index) {
 
 Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
   int64_t ndim = self.dim();
-  AT_CHECK(ndim > 0, "slice() cannot be applied to a 0-dim tensor.");
+  if (ndim == 0) {
+    AT_INDEX_ERROR("slice() cannot be applied to a 0-dim tensor.");
+  }
   dim = maybe_wrap_dim(dim, ndim);
   auto sizes = self.sizes().vec();
   auto strides = self.strides().vec();
index 1bab204..aa41497 100644 (file)
@@ -36,7 +36,7 @@ struct Indexer {
       int64_t value = *(int64_t*)&indexers[j][idx * indexer_strides[j]];
       int64_t size = original_sizes[j];
       if (value < -size || value >= size) {
-        AT_ERROR("index ", value, " is out of bounds for dim with size ", size);
+        AT_INDEX_ERROR("index ", value, " is out of bounds for dimension ", j, " with size ", size);
       }
       if (value < 0) {
         value += size;
index bebd24a..63d7152 100644 (file)
@@ -6,7 +6,9 @@ namespace c10 {
 
 static inline int64_t maybe_wrap_dim(int64_t dim, int64_t dim_post_expr, bool wrap_scalar=true) {
   if (dim_post_expr <= 0) {
-    AT_CHECK(wrap_scalar, "dimension specified as ", dim, " but tensor has no dimensions");
+    if (!wrap_scalar) {
+      AT_INDEX_ERROR("dimension specified as ", dim, " but tensor has no dimensions");
+    }
     dim_post_expr = 1; // this will make range [-1, 0]
   }
 
index 94cb9aa..8c4fbed 100644 (file)
@@ -101,6 +101,13 @@ class C10_API Warning {
   static handler_t warning_handler_;
 };
 
+// Used in ATen for out-of-bound indices that can reasonably only be detected
+// lazily inside a kernel (See: advanced indexing).
+class C10_API IndexError : public Error {
+  using Error::Error;
+};
+
+
 // A utility function to return an exception std::string by prepending its
 // exception type before its what() content
 C10_API std::string GetExceptionString(const std::exception& e);
@@ -126,6 +133,9 @@ C10_API std::string GetExceptionString(const std::exception& e);
 #define AT_ERROR(...) \
   throw ::c10::Error({__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, ::c10::str(__VA_ARGS__))
 
+#define AT_INDEX_ERROR(...) \
+  throw ::c10::IndexError({__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, ::c10::str(__VA_ARGS__))
+
 #define AT_WARN(...) \
   ::c10::Warning::warn({__func__, __FILE__, static_cast<uint32_t>(__LINE__)}, ::c10::str(__VA_ARGS__))
 
index b2c1804..adfa778 100644 (file)
@@ -113,7 +113,7 @@ class TestIndexing(TestCase):
         x = torch.empty(10, 0)
         self.assertEqual(x[[1, 2]].shape, (2, 0))
         self.assertEqual(x[[], []].shape, (0,))
-        with self.assertRaisesRegex(RuntimeError, 'for dim with size 0'):
+        with self.assertRaisesRegex(IndexError, 'for dimension with size 0'):
             x[:, [0, 1]]
 
     def test_empty_ndim_index_bool(self):
@@ -171,7 +171,7 @@ class TestIndexing(TestCase):
             a[...] = neg_ones_expanded * 4
             self.assertEqual(a, neg_ones * 4)
             if a.dim() == 0:
-                with self.assertRaises(RuntimeError):
+                with self.assertRaises(IndexError):
                     a[:] = neg_ones_expanded * 5
 
     def test_setitem_expansion_error(self):
@@ -179,6 +179,7 @@ class TestIndexing(TestCase):
         a = torch.randn(2, 3)
         # check prefix with  non-1s doesn't work
         a_expanded = a.expand(torch.Size([5, 1]) + a.size())
+        # NumPy: ValueError
         with self.assertRaises(RuntimeError):
             a[True] = a_expanded
         with self.assertRaises(RuntimeError):
@@ -202,7 +203,7 @@ class TestIndexing(TestCase):
 
         # scalar indexed with scalar
         r = torch.randn(())
-        with self.assertRaises(RuntimeError):
+        with self.assertRaises(IndexError):
             r[:]
         with self.assertRaises(IndexError):
             r[zero]
@@ -225,7 +226,7 @@ class TestIndexing(TestCase):
 
         # scalar indexed with scalars
         r = torch.randn(())
-        with self.assertRaises(RuntimeError):
+        with self.assertRaises(IndexError):
             r[:] = 8.8
         with self.assertRaises(IndexError):
             r[zero] = 8.8
@@ -388,7 +389,7 @@ class NumpyTests(TestCase):
         self.assertEqual(a[[]], torch.tensor([], dtype=torch.long))
 
         b = tensor([]).float()
-        self.assertRaises(RuntimeError, lambda: a[b])
+        self.assertRaises(IndexError, lambda: a[b])
 
     def test_ellipsis_index(self):
         a = tensor([[1, 2, 3],
@@ -440,17 +441,16 @@ class NumpyTests(TestCase):
     def test_boolean_shape_mismatch(self):
         arr = torch.ones((5, 4, 3))
 
-        # TODO: prefer IndexError
         index = tensor([True])
-        self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index])
+        self.assertRaisesRegex(IndexError, 'mask', lambda: arr[index])
 
         index = tensor([False] * 6)
-        self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index])
+        self.assertRaisesRegex(IndexError, 'mask', lambda: arr[index])
 
         index = torch.ByteTensor(4, 4).zero_()
-        self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index])
+        self.assertRaisesRegex(IndexError, 'mask', lambda: arr[index])
 
-        self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[(slice(None), index)])
+        self.assertRaisesRegex(IndexError, 'mask', lambda: arr[(slice(None), index)])
 
     def test_boolean_indexing_onedim(self):
         # Indexing a 2-dimensional array with
@@ -498,7 +498,7 @@ class NumpyTests(TestCase):
         a = torch.ones((2, 3, 4))
         self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape)
         self.assertEqual(torch.ones(1, 2), a[True, [0, 1], True, True, [1], [[2]]])
-        self.assertRaises(RuntimeError, lambda: a[False, [0, 1], ...])
+        self.assertRaises(IndexError, lambda: a[False, [0, 1], ...])
 
     def test_boolean_indexing_weirdness_tensors(self):
         # Weird boolean indexing things
@@ -507,7 +507,7 @@ class NumpyTests(TestCase):
         a = torch.ones((2, 3, 4))
         self.assertEqual((0, 2, 3, 4), a[False, True, ...].shape)
         self.assertEqual(torch.ones(1, 2), a[true, [0, 1], true, true, [1], [[2]]])
-        self.assertRaises(RuntimeError, lambda: a[false, [0, 1], ...])
+        self.assertRaises(IndexError, lambda: a[false, [0, 1], ...])
 
     def test_boolean_indexing_alldims(self):
         true = torch.tensor(True)
@@ -538,8 +538,8 @@ class NumpyTests(TestCase):
 
     def test_broaderrors_indexing(self):
         a = torch.zeros(5, 5)
-        self.assertRaisesRegex(RuntimeError, 'shape mismatch', a.__getitem__, ([0, 1], [0, 1, 2]))
-        self.assertRaisesRegex(RuntimeError, 'shape mismatch', a.__setitem__, ([0, 1], [0, 1, 2]), 0)
+        self.assertRaisesRegex(IndexError, 'shape mismatch', a.__getitem__, ([0, 1], [0, 1, 2]))
+        self.assertRaisesRegex(IndexError, 'shape mismatch', a.__setitem__, ([0, 1], [0, 1, 2]), 0)
 
     def test_trivial_fancy_out_of_bounds(self):
         a = torch.zeros(5)
@@ -547,12 +547,12 @@ class NumpyTests(TestCase):
         if a.is_cuda:
             raise unittest.SkipTest('CUDA asserts instead of raising an exception')
         ind[-1] = 10
-        self.assertRaises(RuntimeError, a.__getitem__, ind)
-        self.assertRaises(RuntimeError, a.__setitem__, ind, 0)
+        self.assertRaises(IndexError, a.__getitem__, ind)
+        self.assertRaises(IndexError, a.__setitem__, ind, 0)
         ind = torch.ones(20, dtype=torch.int64)
         ind[0] = 11
-        self.assertRaises(RuntimeError, a.__getitem__, ind)
-        self.assertRaises(RuntimeError, a.__setitem__, ind, 0)
+        self.assertRaises(IndexError, a.__getitem__, ind)
+        self.assertRaises(IndexError, a.__setitem__, ind, 0)
 
     def test_index_is_larger(self):
         # Simple case of fancy index broadcasting of the index.
index cd2e741..d6269f2 100644 (file)
@@ -6440,9 +6440,9 @@ class _TestTorchMixin(object):
             for err_idx in (10, -11):
                 with self.assertRaisesRegex(IndexError, r'out of'):
                     reference[err_idx]
-                with self.assertRaisesRegex(RuntimeError, r'out of'):
+                with self.assertRaisesRegex(IndexError, r'out of'):
                     reference[conv_fn(torch.LongTensor([err_idx]))]
-                with self.assertRaisesRegex(RuntimeError, r'out of'):
+                with self.assertRaisesRegex(IndexError, r'out of'):
                     reference[[err_idx]]
 
         if TEST_NUMPY:
index 5471999..0731cd4 100644 (file)
   catch (python_error & e) {                                       \
     return retval;                                                 \
   }                                                                \
+  catch (const c10::IndexError& e) {                               \
+    auto msg = torch::processErrorMsg(e.what_without_backtrace()); \
+    PyErr_SetString(PyExc_IndexError, msg.c_str());                \
+    return retval;                                                 \
+  }                                                                \
   catch (const c10::Error& e) {                                    \
     auto msg = torch::processErrorMsg(e.what_without_backtrace()); \
     PyErr_SetString(PyExc_RuntimeError, msg.c_str());              \