Make tril_ and triu_ actually in-place (#17031)
authorWill Feng <willfeng@fb.com>
Tue, 19 Feb 2019 22:31:34 +0000 (14:31 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 19 Feb 2019 22:47:17 +0000 (14:47 -0800)
Summary:
Currently, when the input tensor `self` is not contiguous, `tril_` and `triu_` calls `self = self.contiguous()`, which allocates a new contiguous tensor and assign it to `self`. This effectively changes the input tensor `self`'s pointer and will break downstream code after Variable/Tensor merge.

This PR fixes it so that `tril_` and `triu_` always update the input tensor in-place and preserve the input tensor's TensorImpl.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17031

Differential Revision: D14069592

Pulled By: yf225

fbshipit-source-id: d188218f426446a44ccc1d33fc28ac3f828c6a05

aten/src/ATen/native/BatchLinearAlgebra.cpp
aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
test/test_torch.py
tools/autograd/gen_variable_type.py

index e067878..b373181 100644 (file)
@@ -391,9 +391,9 @@ Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) {
   return result;
 }
 
-template <typename scalar_t, bool inplace, bool upper>
+template <typename scalar_t, bool upper>
 static void apply_triu_tril_single(
-    scalar_t* result, scalar_t* self,
+    scalar_t* result, scalar_t* self, bool inplace,
     int64_t k, int64_t n, int64_t m,
     int64_t res_row_stride, int64_t res_col_stride,
     int64_t self_row_stride, int64_t self_col_stride) {
@@ -428,8 +428,8 @@ static void apply_triu_tril_single(
   }
 }
 
-template <typename scalar_t, bool inplace, bool upper>
-void apply_triu_tril(Tensor& result, const Tensor& self, int64_t k) {
+template <typename scalar_t, bool upper>
+void apply_triu_tril(Tensor& result, const Tensor& self, bool inplace, int64_t k) {
   auto n = self.size(-2);
   auto m = self.size(-1);
   auto self_data = self.data<scalar_t>();
@@ -455,8 +455,8 @@ void apply_triu_tril(Tensor& result, const Tensor& self, int64_t k) {
   for (b = 0; b < batchsize; b++) {
     scalar_t* self_batch = &self_data[b * self_stride];
     scalar_t* result_batch = &result_data[b * result_stride];
-    apply_triu_tril_single<scalar_t, inplace, upper>(
-        result_batch, self_batch, k, n, m,
+    apply_triu_tril_single<scalar_t, upper>(
+        result_batch, self_batch, inplace, k, n, m,
         result_row_stride, result_column_stride, self_row_stride, self_column_stride);
   }
 }
@@ -471,10 +471,13 @@ Tensor& tril_cpu_(Tensor &self, int64_t k) {
   if (self.numel() == 0) {
     return self;
   }
-  if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
+  bool inplace = checkTrilTriuBatchContiguous(self);
+  Tensor self_c = inplace ? self : self.contiguous();
+  Tensor result = inplace ? self : at::empty_like(self);
   AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{
-    apply_triu_tril<scalar_t, true, false>(self, self, k);
+    apply_triu_tril<scalar_t, false>(result, self_c, inplace, k);
   });
+  if (!inplace) self.copy_(result);
   return self;
 }
 
@@ -487,7 +490,7 @@ Tensor& tril_cpu_out(Tensor &result, const Tensor& self, int64_t k) {
   }
   Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
   AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{
-    apply_triu_tril<scalar_t, false, false>(result, self_c, k);
+    apply_triu_tril<scalar_t, false>(result, self_c, false, k);
   });
   return result;
 }
@@ -502,10 +505,13 @@ Tensor& triu_cpu_(Tensor &self, int64_t k) {
   if (self.numel() == 0) {
     return self;
   }
-  if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
+  bool inplace = checkTrilTriuBatchContiguous(self);
+  Tensor self_c = inplace ? self : self.contiguous();
+  Tensor result = inplace ? self : at::empty_like(self);
   AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{
-    apply_triu_tril<scalar_t, true, true>(self, self, k);
+    apply_triu_tril<scalar_t, true>(result, self_c, inplace, k);
   });
+  if (!inplace) self.copy_(result);
   return self;
 }
 
@@ -518,7 +524,7 @@ Tensor& triu_cpu_out(Tensor &result, const Tensor& self, int64_t k) {
   }
   Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
   AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{
-    apply_triu_tril<scalar_t, false, true>(result, self_c, k);
+    apply_triu_tril<scalar_t, true>(result, self_c, false, k);
   });
   return result;
 }
index b45b4d8..2ac7f5f 100644 (file)
@@ -505,8 +505,12 @@ Tensor& triu_tril_cuda_template(Tensor& result, const Tensor& self, int64_t k, c
 }
 
 Tensor& tril_cuda_(Tensor &self, int64_t k) {
-  if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
-  return tril_cuda_out(self, self, k);
+  bool inplace = checkTrilTriuBatchContiguous(self);
+  Tensor self_c = inplace ? self : self.contiguous();
+  Tensor result = inplace ? self : at::empty_like(self);
+  tril_cuda_out(result, self_c, k);
+  if (!inplace) self.copy_(result);
+  return self;
 }
 
 Tensor& tril_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
@@ -521,8 +525,12 @@ Tensor& tril_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
 }
 
 Tensor& triu_cuda_(Tensor &self, int64_t k) {
-  if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
-  return triu_cuda_out(self, self, k);
+  bool inplace = checkTrilTriuBatchContiguous(self);
+  Tensor self_c = inplace ? self : self.contiguous();
+  Tensor result = inplace ? self : at::empty_like(self);
+  triu_cuda_out(result, self_c, k);
+  if (!inplace) self.copy_(result);
+  return self;
 }
 
 Tensor& triu_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
index 28494ac..cc76a22 100644 (file)
@@ -4149,18 +4149,14 @@ class _TestTorchMixin(object):
                         assert not x_nc.is_contiguous(), "x is intentionally non-contiguous"
                         exp_nc = torch.where(exp_mask, torch.tensor(0).type_as(x), x_nc)
                         self.assertEqual(torch_tri_func(x_nc, diagonal), exp_nc, 0)
+                        x_nc_is_contiguous = x_nc.is_contiguous()
                         if upper:
                             self.assertEqual(x_nc.triu_(diagonal), exp_nc, 0)
                         else:
                             self.assertEqual(x_nc.tril_(diagonal), exp_nc, 0)
 
-                        # any 3-dimensional tensor should be fine
-                        if len(shape) <= 3 or s == -2:
-                            self.assertFalse(x_nc.is_contiguous(),
-                                             "x_nc should remain non-contiguous")
-                        elif s < -3:
-                            self.assertTrue(x_nc.is_contiguous(),
-                                            "x_nc should become contiguous")
+                        self.assertTrue(x_nc.is_contiguous() == x_nc_is_contiguous,
+                                        "contiguity of x_nc should not be changed")
 
                     # expanded tensors
                     expanded_size = (x.size(0),) + x.size()
index f5e5cc6..73d5a58 100644 (file)
@@ -136,8 +136,6 @@ for (size_t i=0; i<${tensorlist_name}.size(); i++) {
 DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = {
     # These functions are expected to change impl or storage of input tensors
     '_th_set_', '_cudnn_rnn_flatten_weight',
-    # TODO: Fix these functions to update input tensor in-place
-    'tril_', 'triu_',
 }
 # END CHECKS FOR [ Invariant: TensorImpl and Storage Pointer Equality ]