Error out on in-place binops on tensors with internal overlap (#19317)
authorRichard Zou <zou3519@gmail.com>
Wed, 17 Apr 2019 19:58:04 +0000 (12:58 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 17 Apr 2019 20:02:07 +0000 (13:02 -0700)
Summary:
This adds checks for `mul_`, `add_`, `sub_`, `div_`, the most common
binops. See #17935 for more details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19317

Differential Revision: D14972399

Pulled By: zou3519

fbshipit-source-id: b9de331dbdb2544ee859ded725a5b5659bfd11d2

aten/src/ATen/native/BinaryOps.cpp
test/test_cuda.py
test/test_torch.py

index bdd7299..68808e1 100644 (file)
@@ -2,6 +2,7 @@
 
 #include <ATen/ATen.h>
 #include <ATen/Dispatch.h>
+#include <ATen/MemoryOverlap.h>
 #include <ATen/NativeFunctions.h>
 #include <ATen/native/TensorIterator.h>
 
@@ -24,6 +25,7 @@ Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar
   } else if (self.is_sparse()) {
     AT_ERROR("add(sparse, dense) is not supported. Use add(dense, sparse) instead.");
   }
+  at::assert_no_internal_overlap(result, "add");
   auto iter = TensorIterator::binary_op(result, self, other);
   add_stub(iter->device_type(), *iter, alpha);
   return result;
@@ -52,6 +54,7 @@ Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) {
     }
     return at::_sparse_div_zerodim_out(result, self, other);
   }
+  at::assert_no_internal_overlap(result, "div");
   auto iter = TensorIterator::binary_op(result, self, other);
   div_stub(iter->device_type(), *iter);
   return result;
@@ -76,6 +79,7 @@ Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) {
   if (self.is_sparse() || other.is_sparse()) {
     return at::_sparse_mul_out(result, self, other);
   }
+  at::assert_no_internal_overlap(result, "mul");
   auto iter = TensorIterator::binary_op(result, self, other);
   mul_stub(iter->device_type(), *iter);
   return result;
@@ -110,6 +114,7 @@ Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar
   } else if (self.is_sparse()) {
     AT_ERROR("sub(sparse, dense) is not supported. Use sub(dense, sparse) instead.");
   }
+  at::assert_no_internal_overlap(result, "sub");
   auto iter = TensorIterator::binary_op(result, self, other);
   sub_stub(iter->device_type(), *iter, alpha);
   return result;
index 6774b72..f91c3db 100644 (file)
@@ -1049,6 +1049,9 @@ class TestCuda(TestCase):
     def test_inplace_unary_mem_overlap(self):
         _TestTorchMixin._test_inplace_unary_mem_overlap(self, device='cuda')
 
+    def test_inplace_binary_mem_overlap(self):
+        _TestTorchMixin._test_inplace_binary_mem_overlap(self, device='cuda')
+
     @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
     def test_arithmetic_large_tensor(self):
         x = torch.empty(2**30, device='cuda')
index bbad85a..cb6db50 100644 (file)
@@ -3907,11 +3907,11 @@ class _TestTorchMixin(object):
                     return t0_fn(1.0, t1, t2)
                 else:
                     return t0_fn(t1)
-            r1 = tensorfn_inplace(large_expanded, small_expanded, small2_expanded)
-            r2 = tensorfn_inplace(large_expanded_clone, small, small2)
             # in-place pointwise operations don't actually work if the in-place
             # tensor is 0-strided (numpy has the same issue)
             if (0 not in large_expanded.stride() and 0 not in large_expanded_clone.stride()):
+                r1 = tensorfn_inplace(large_expanded, small_expanded, small2_expanded)
+                r2 = tensorfn_inplace(large_expanded_clone, small, small2)
                 self.assertEqual(r1, r2)
 
             def broadcastable(t0, t1, t2=None):
@@ -11126,6 +11126,15 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
             inplace_op(tensor)
 
     @staticmethod
+    def binary_check_mem_overlap(self, inplace_op, value=-0.5, device='cpu'):
+        if isinstance(inplace_op, str):
+            inplace_op = getattr(torch.Tensor, inplace_op)
+        tensor = torch.tensor(value, device=device).expand(3, 3)
+        other = torch.rand_like(tensor)
+        with self.assertRaisesRegex(RuntimeError, 'single memory location'):
+            inplace_op(tensor, other)
+
+    @staticmethod
     def _test_inplace_unary_mem_overlap(self, device='cpu'):
         TestTorch.unary_check_mem_overlap(self, lambda t: t.acos_(), device=device)
         TestTorch.unary_check_mem_overlap(self, lambda t: t.asin_(), device=device)
@@ -11149,9 +11158,18 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         TestTorch.unary_check_mem_overlap(self, lambda t: t.tanh_(), device=device)
         TestTorch.unary_check_mem_overlap(self, lambda t: t.trunc_(), device=device)
 
+    @staticmethod
+    def _test_inplace_binary_mem_overlap(self, device='cpu'):
+        binary_ops = ['add_', 'mul_', 'div_', 'sub_']
+        for op in binary_ops:
+            TestTorch.binary_check_mem_overlap(self, op, device=device)
+
     def test_inplace_unary_mem_overlap(self):
         return self._test_inplace_unary_mem_overlap(self)
 
+    def test_inplace_binary_mem_overlap(self):
+        return self._test_inplace_binary_mem_overlap(self)
+
     @unittest.expectedFailure
     def test_abs_unary_mem_overlap(self):
         self.unary_check_mem_overlap(lambda t: t.abs_())