From 04b65dfd1ff78cdda327aeb6ad33ce1bb444fb9d Mon Sep 17 00:00:00 2001 From: =?utf8?q?Josef=20Lindman=20H=C3=B6rnlund?= Date: Tue, 11 Dec 2018 13:36:00 -0800 Subject: [PATCH] Issue 14984: Remove divide by zero error in index_put_ (#14986) Summary: No check for zero index tensor was done in the accumulate=True (serial) case in the new TensorIterator code since https://github.com/pytorch/pytorch/pull/13420. https://github.com/pytorch/pytorch/issues/14984 Pull Request resolved: https://github.com/pytorch/pytorch/pull/14986 Differential Revision: D13417861 Pulled By: colesbury fbshipit-source-id: e6ed1af8f708b53a35803fc157ed1f043169ec89 --- aten/src/ATen/native/TensorIterator.cpp | 9 +++++++-- test/test_indexing.py | 6 ++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/TensorIterator.cpp b/aten/src/ATen/native/TensorIterator.cpp index ccf0bf2..4570833 100644 --- a/aten/src/ATen/native/TensorIterator.cpp +++ b/aten/src/ATen/native/TensorIterator.cpp @@ -379,6 +379,9 @@ void TensorIterator::serial_for_each(const loop_t& loop, Range range) const { } void TensorIterator::serial_for_each(const loop2d_t& loop, Range range) const { + if (range.size() == 0) { + return; + } auto strides = get_strides(); while (strides.size() < 2 * ntensors()) { strides.push_back(0); @@ -682,8 +685,10 @@ DimCounter::DimCounter(IntList shape, Range range) int64_t ndim = values.size(); for (int dim = 0; dim < ndim; dim++) { int64_t size = shape[dim]; - values[dim] = linear_offset % size; - linear_offset /= size; + if (size > 0) { + values[dim] = linear_offset % size; + linear_offset /= size; + } } AT_ASSERT(linear_offset == 0); } diff --git a/test/test_indexing.py b/test/test_indexing.py index b1b7596..1cf9e25 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -45,6 +45,12 @@ class TestIndexing(TestCase): v = torch.tensor([1.]) self.assertEqual(v[v == 0], torch.tensor([])) + def test_byte_mask_accumulate(self): + mask = torch.zeros(size=(10, ), dtype=torch.uint8) + y = torch.ones(size=(10, 10)) + y.index_put_((mask, ), y[mask], accumulate=True) + self.assertEqual(y, torch.ones(size=(10, 10))) + def test_multiple_byte_mask(self): v = torch.randn(5, 7, 3) # note: these broadcast together and are transposed to the first dim -- 2.7.4