#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
+#include <ATen/MemoryOverlap.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/TensorIterator.h>
} else if (self.is_sparse()) {
AT_ERROR("add(sparse, dense) is not supported. Use add(dense, sparse) instead.");
}
+ at::assert_no_internal_overlap(result, "add");
auto iter = TensorIterator::binary_op(result, self, other);
add_stub(iter->device_type(), *iter, alpha);
return result;
}
return at::_sparse_div_zerodim_out(result, self, other);
}
+ at::assert_no_internal_overlap(result, "div");
auto iter = TensorIterator::binary_op(result, self, other);
div_stub(iter->device_type(), *iter);
return result;
if (self.is_sparse() || other.is_sparse()) {
return at::_sparse_mul_out(result, self, other);
}
+ at::assert_no_internal_overlap(result, "mul");
auto iter = TensorIterator::binary_op(result, self, other);
mul_stub(iter->device_type(), *iter);
return result;
} else if (self.is_sparse()) {
AT_ERROR("sub(sparse, dense) is not supported. Use sub(dense, sparse) instead.");
}
+ at::assert_no_internal_overlap(result, "sub");
auto iter = TensorIterator::binary_op(result, self, other);
sub_stub(iter->device_type(), *iter, alpha);
return result;
def test_inplace_unary_mem_overlap(self):
_TestTorchMixin._test_inplace_unary_mem_overlap(self, device='cuda')
+ def test_inplace_binary_mem_overlap(self):
+ _TestTorchMixin._test_inplace_binary_mem_overlap(self, device='cuda')
+
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
def test_arithmetic_large_tensor(self):
x = torch.empty(2**30, device='cuda')
return t0_fn(1.0, t1, t2)
else:
return t0_fn(t1)
- r1 = tensorfn_inplace(large_expanded, small_expanded, small2_expanded)
- r2 = tensorfn_inplace(large_expanded_clone, small, small2)
# in-place pointwise operations don't actually work if the in-place
# tensor is 0-strided (numpy has the same issue)
if (0 not in large_expanded.stride() and 0 not in large_expanded_clone.stride()):
+ r1 = tensorfn_inplace(large_expanded, small_expanded, small2_expanded)
+ r2 = tensorfn_inplace(large_expanded_clone, small, small2)
self.assertEqual(r1, r2)
def broadcastable(t0, t1, t2=None):
inplace_op(tensor)
@staticmethod
+ def binary_check_mem_overlap(self, inplace_op, value=-0.5, device='cpu'):
+ if isinstance(inplace_op, str):
+ inplace_op = getattr(torch.Tensor, inplace_op)
+ tensor = torch.tensor(value, device=device).expand(3, 3)
+ other = torch.rand_like(tensor)
+ with self.assertRaisesRegex(RuntimeError, 'single memory location'):
+ inplace_op(tensor, other)
+
+ @staticmethod
def _test_inplace_unary_mem_overlap(self, device='cpu'):
TestTorch.unary_check_mem_overlap(self, lambda t: t.acos_(), device=device)
TestTorch.unary_check_mem_overlap(self, lambda t: t.asin_(), device=device)
TestTorch.unary_check_mem_overlap(self, lambda t: t.tanh_(), device=device)
TestTorch.unary_check_mem_overlap(self, lambda t: t.trunc_(), device=device)
+ @staticmethod
+ def _test_inplace_binary_mem_overlap(self, device='cpu'):
+ binary_ops = ['add_', 'mul_', 'div_', 'sub_']
+ for op in binary_ops:
+ TestTorch.binary_check_mem_overlap(self, op, device=device)
+
def test_inplace_unary_mem_overlap(self):
return self._test_inplace_unary_mem_overlap(self)
+ def test_inplace_binary_mem_overlap(self):
+ return self._test_inplace_binary_mem_overlap(self)
+
@unittest.expectedFailure
def test_abs_unary_mem_overlap(self):
self.unary_check_mem_overlap(lambda t: t.abs_())