Error out on in-place (unary) ops on tensors that have internal overlap (#17927)
authorRichard Zou <rzou@fb.com>
Fri, 15 Mar 2019 14:41:08 +0000 (07:41 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 15 Mar 2019 14:50:19 +0000 (07:50 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17927
ghimport-source-id: 626d321e430b6b5c0ea3aa1eb9df8c1e2d058bf8

Stack:
* #17926 Implement at::has_internal_overlap helper function
* **#17927 Error out on in-place (unary) ops on tensors that have internal overlap**

On the way to #17935.

Works for CPU and CUDA on the following ops:
- abs_, acos_, asin_, atan_, ceil_, cos_, erf_, erfc_, exp_, expm1_
- floor_, log_, log10_, log1p_, log2_, round_, rsqrt_,
- sin_, sqrt_, tan_, tanh_, trunc_

This PR adds a check to see if the out/result tensor has internal
overlap. If it does, then we error out because the result **may** be
incorrect.

This is overly conservative; there are some cases where if the result is
the same as the input, the inplace operation is OK (such as floor_,
round_, and trunc_). However, the current code isn't organized in such a
way that this is easy to check, so enabling those will come in the future.

Reviewed By: ezyang

Differential Revision: D14438871

fbshipit-source-id: 15e12bf1fdb2ab7f74bb806e22bc74840bd6abd1

aten/src/ATen/MemoryOverlap.cpp
aten/src/ATen/MemoryOverlap.h
aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
aten/src/THC/generic/THCTensorMathPointwise.cu
test/test_cuda.py
test/test_torch.py

index 5080221..98aeac3 100644 (file)
@@ -4,17 +4,19 @@
 namespace at {
 
 MemOverlap has_internal_overlap(const Tensor& tensor) {
-  auto* t = tensor.unsafeGetTensorImpl();
+  return has_internal_overlap(tensor.unsafeGetTensorImpl());
+}
 
-  AT_ASSERT(tensor.layout() == kStrided);
+MemOverlap has_internal_overlap(TensorImpl* t) {
+  AT_ASSERT(t->layout() == kStrided);
 
   if (t->is_contiguous()) {
     return MemOverlap::NO;
   }
 
   auto strides = t->strides();
-  if (std::find_if(
-        strides.begin(), strides.end(), [](int s) { return s == 0; })) {
+  if (strides.end() != std::find_if(
+        strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {
     return MemOverlap::YES;
   }
 
@@ -22,6 +24,10 @@ MemOverlap has_internal_overlap(const Tensor& tensor) {
 }
 
 void assert_no_internal_overlap(const Tensor& t, std::string op) {
+  assert_no_internal_overlap(t.unsafeGetTensorImpl(), op);
+}
+
+void assert_no_internal_overlap(TensorImpl* t, std::string op) {
   if (has_internal_overlap(t) == MemOverlap::YES) {
     AT_ERROR(
         op, ": unsupported operation: more than one element of the written-to "
index bb02275..e2f6013 100644 (file)
@@ -13,8 +13,10 @@ namespace at {
 // NB: Please update the python test for these if you renumber them.
 enum class MemOverlap { NO, YES, TOO_HARD };
 
-MemOverlap has_internal_overlap(const Tensor& t);
+CAFFE2_API MemOverlap has_internal_overlap(const Tensor& t);
+CAFFE2_API MemOverlap has_internal_overlap(TensorImpl* t);
 
-void assert_no_internal_overlap(const Tensor& t, std::string op);
+CAFFE2_API void assert_no_internal_overlap(const Tensor& t, std::string op);
+CAFFE2_API void assert_no_internal_overlap(TensorImpl* t, std::string op);
 
 }
index aaa566c..f53733b 100644 (file)
@@ -7,6 +7,7 @@
 #include <ATen/CPUGenerator.h>
 #include <ATen/CheckGenerator.h>
 #include <ATen/Generator.h>
+#include <ATen/MemoryOverlap.h>
 #include <ATen/cpu/vml.h>
 #include <ATen/CPUApplyUtils.h>
 #include <ATen/native/DispatchStub.h>
@@ -183,6 +184,7 @@ void bernoulli_mkl_kernel(Tensor &self, const double p, Generator* gen) {
             result.data<scalar_t>(), self.data<scalar_t>(), self.numel()); \
                                                                            \
       } else {                                                             \
+        assert_no_internal_overlap(result, #op);                           \
         static constexpr int64_t WIDTH = 131072 / sizeof(scalar_t);        \
         CPU_tensor_parallel_kernel_apply2<scalar_t, scalar_t>(             \
             result,                                                        \
@@ -211,7 +213,6 @@ void bernoulli_mkl_kernel(Tensor &self, const double p, Generator* gen) {
     });                                                                    \
   }                                                                        \
   REGISTER_DISPATCH(op##Impl, &op##_kernel)
-
 } // anonymous namespace
 
 REGISTER_DISPATCH(sigmoidImpl, &sigmoid_kernel)
index 517633b..0c128a5 100644 (file)
@@ -2,6 +2,8 @@
 #define THC_GENERIC_FILE "THC/generic/THCTensorMathPointwise.cu"
 #else
 
+#include <ATen/MemoryOverlap.h>
+
 #define IMPLEMENT_CUDA_TENSOR_BASIC_FUNC_(NAME, CFUNC, REAL)             \
   struct Tensor_##NAME##_##REAL##_Op {                                  \
     __device__ __forceinline__ void operator()(scalar_t* out, scalar_t* in) const { \
@@ -15,6 +17,7 @@
                                                                         \
   void THCTensor_(NAME)(THCState* state, THCTensor* self_, THCTensor* src) { \
     THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));               \
+    at::assert_no_internal_overlap(self_, #NAME);                       \
     if (self_ == src) {                                                 \
       if (!THC_pointwiseApply1<scalar_t>(state, self_, Tensor_##NAME##_##REAL##_Op())) { \
         THArgCheck(false, 2, CUTORCH_DIM_WARNING);                      \
index d1462de..98e1fe2 100644 (file)
@@ -1049,6 +1049,9 @@ class TestCuda(TestCase):
     def test_isinf(self):
         _TestTorchMixin._test_isinf(self, lambda t: t.cuda())
 
+    def test_inplace_unary_mem_overlap(self):
+        _TestTorchMixin._test_inplace_unary_mem_overlap(self, device='cuda')
+
     @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
     def test_arithmetic_large_tensor(self):
         x = torch.empty(2**30, device='cuda')
index afc6309..63b4c07 100644 (file)
@@ -10401,6 +10401,51 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         b_expanded = b.expand(4, 3)
         self.assertEqual(torch._debug_has_internal_overlap(b_expanded), OVERLAP_YES)
 
+    @staticmethod
+    def unary_check_mem_overlap(self, inplace_op, value=-0.5, device='cpu'):
+        tensor = torch.tensor(value, device=device).expand(3, 3)
+        with self.assertRaisesRegex(RuntimeError, 'single memory location'):
+            inplace_op(tensor)
+
+    @staticmethod
+    def _test_inplace_unary_mem_overlap(self, device='cpu'):
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.acos_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.asin_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.atan_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.ceil_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.cos_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.erf_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.erfc_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.exp_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.expm1_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.floor_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.log_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.log10_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.log1p_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.log2_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.round_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.rsqrt_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.sin_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.sqrt_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.tan_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.tanh_(), device=device)
+        TestTorch.unary_check_mem_overlap(self, lambda t: t.trunc_(), device=device)
+
+    def test_inplace_unary_mem_overlap(self):
+        return self._test_inplace_unary_mem_overlap(self)
+
+    @unittest.expectedFailure
+    def test_abs_unary_mem_overlap(self):
+        self.unary_check_mem_overlap(lambda t: t.abs_())
+
+    @unittest.expectedFailure
+    def test_sinh_unary_mem_overlap(self):
+        self.unary_check_mem_overlap(lambda t: t.sinh_())
+
+    @unittest.expectedFailure
+    def test_cosh_unary_mem_overlap(self):
+        self.unary_check_mem_overlap(lambda t: t.cosh_())
+
     @unittest.skipIf(torch.cuda.device_count() < 2, 'only one GPU detected')
     def test_reverse_binary_ops_multiple_device(self):
         self.assertEqual(2 + torch.tensor(3), 2 + torch.tensor(3).to("cuda:1"))    # __radd__