Issue 14984: Remove divide by zero error in index_put_ (#14986)
authorJosef Lindman Hörnlund <jotsif@gmail.com>
Tue, 11 Dec 2018 21:36:00 +0000 (13:36 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 11 Dec 2018 21:38:12 +0000 (13:38 -0800)
Summary:
No check for zero index tensor was done in the accumulate=True (serial) case in the new TensorIterator code since https://github.com/pytorch/pytorch/pull/13420.

https://github.com/pytorch/pytorch/issues/14984
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14986

Differential Revision: D13417861

Pulled By: colesbury

fbshipit-source-id: e6ed1af8f708b53a35803fc157ed1f043169ec89

aten/src/ATen/native/TensorIterator.cpp
test/test_indexing.py

index ccf0bf2..4570833 100644 (file)
@@ -379,6 +379,9 @@ void TensorIterator::serial_for_each(const loop_t& loop, Range range) const {
 }
 
 void TensorIterator::serial_for_each(const loop2d_t& loop, Range range) const {
+  if (range.size() == 0) {
+    return;
+  }
   auto strides = get_strides();
   while (strides.size() < 2 * ntensors()) {
     strides.push_back(0);
@@ -682,8 +685,10 @@ DimCounter::DimCounter(IntList shape, Range range)
   int64_t ndim = values.size();
   for (int dim = 0; dim < ndim; dim++) {
     int64_t size = shape[dim];
-    values[dim] = linear_offset % size;
-    linear_offset /= size;
+    if (size > 0) {
+      values[dim] = linear_offset % size;
+      linear_offset /= size;
+    }
   }
   AT_ASSERT(linear_offset == 0);
 }
index b1b7596..1cf9e25 100644 (file)
@@ -45,6 +45,12 @@ class TestIndexing(TestCase):
         v = torch.tensor([1.])
         self.assertEqual(v[v == 0], torch.tensor([]))
 
+    def test_byte_mask_accumulate(self):
+        mask = torch.zeros(size=(10, ), dtype=torch.uint8)
+        y = torch.ones(size=(10, 10))
+        y.index_put_((mask, ), y[mask], accumulate=True)
+        self.assertEqual(y, torch.ones(size=(10, 10)))
+
     def test_multiple_byte_mask(self):
         v = torch.randn(5, 7, 3)
         # note: these broadcast together and are transposed to the first dim