Move autograd metadata from VariableImpl to TensorImpl (#13827)
authorWill Feng <willfeng@fb.com>
Thu, 27 Dec 2018 00:31:47 +0000 (16:31 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 27 Dec 2018 00:34:24 +0000 (16:34 -0800)
Summary:
Changes originally in this PR:
1. Move Variable::Impl data members into TensorImpl as `AutogradMeta` struct
2. Change Variable::Impl functions to use data members in `AutogradMeta` struct
3. Add `shallow_copy_and_detach()` function to each subclass of TensorImpl
4. Do shallow copy when the user calls `make_variable(tensor)` / `make_variable_view(tensor)` / `variable.set_data(tensor)` / `variable.detach()`

Changes moved from https://github.com/pytorch/pytorch/pull/13645:
1. Add a flag to Variable to disallow size/stride/storage_ptr changes from in-place operations such as `resize_` / `resize_as_` / `set_` / `transpose_`, and set this flag to true when people call `tensor.data` in Python.
2. Write text in the docs to actively discourage changing the shape or storage of `tensor_detached` and expecting `tensor` to also be updated.

This is the 1st+2nd PR mentioned in https://github.com/pytorch/pytorch/issues/13638.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13827

Differential Revision: D13507173

Pulled By: yf225

fbshipit-source-id: b177b08438d534a8197e34e1ad4a837e2db0ed6a

18 files changed:
aten/src/ATen/SparseTensorImpl.cpp
aten/src/ATen/SparseTensorImpl.h
aten/src/TH/THTensor.cpp
aten/src/TH/THTensor.hpp
c10/core/TensorImpl.cpp
c10/core/TensorImpl.h
test/common_methods_invocations.py
test/common_utils.py
test/test_jit.py
test/test_nn.py
test/test_sparse.py
test/test_torch.py
torch/csrc/autograd/functions/accumulate_grad.cpp
torch/csrc/autograd/python_variable.cpp
torch/csrc/autograd/variable.cpp
torch/csrc/autograd/variable.h
torch/csrc/jit/python_ir.cpp
torch/tensor.py

index 0cc5456..779ac5d 100644 (file)
@@ -79,6 +79,7 @@ int64_t SparseTensorImpl::storage_offset() const {
   AT_ERROR("sparse tensors do not have storage");
 }
 void SparseTensorImpl::set_indices_and_values_unsafe(const Tensor& indices, const Tensor& values) {
+  AT_CHECK(allow_tensor_metadata_change(), "set_indices_and_values_unsafe is not allowed on Tensor created from .data or .detach()");
   AT_ASSERT(!indices.is_variable() && !values.is_variable());  // They should be plain tensors!
 
   AT_CHECK(!indices.is_sparse(), "expected indices to be a dense tensor, but got indices of layout ", indices.layout());
index 3ec436a..84ca155 100644 (file)
@@ -57,6 +57,7 @@ public:
   // WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim with
   // respect to indices and values
   void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntList size) {
+    AT_CHECK(allow_tensor_metadata_change(), "raw_resize_ is not allowed on Tensor created from .data or .detach()");
     sizes_ = size.vec();
     sparse_dim_ = sparse_dim;
     dense_dim_ = dense_dim;
@@ -86,6 +87,7 @@ public:
   // 4. When we attempt to shrink the size of any of the sparse dimensions on a non-empty sparse tensor
   // (this could make some of the stored indices out-of-bound and thus unsafe).
   void resize_(int64_t sparse_dim, int64_t dense_dim, IntList size) {
+    AT_CHECK(allow_tensor_metadata_change(), "resize_ is not allowed on Tensor created from .data or .detach()");
     AT_CHECK(sparse_dim + dense_dim == size.size(), "number of dimensions must be sparse_dim (", sparse_dim, ") + dense_dim (", dense_dim, "), but got ", size.size());
     if (nnz() > 0) {
       auto alt_options_msg = "You could try the following options:\n\
@@ -143,6 +145,7 @@ public:
 
   // NOTE: this function will resize the sparse tensor and also set `indices` and `values` to empty.
   void resize_and_clear_(int64_t sparse_dim, int64_t dense_dim, IntList size) {
+    AT_CHECK(allow_tensor_metadata_change(), "resize_and_clear_ is not allowed on Tensor created from .data or .detach()");
     AT_CHECK(sparse_dim + dense_dim == size.size(), "number of dimensions must be sparse_dim (", sparse_dim, ") + dense_dim (", dense_dim, "), but got ", size.size());
 
     sizes_ = size.vec();
@@ -158,10 +161,14 @@ public:
     refresh_numel();
   }
 
-  void set_coalesced(bool coalesced) { coalesced_ = coalesced; }
+  void set_coalesced(bool coalesced) {
+    AT_CHECK(allow_tensor_metadata_change(), "set_coalesced is not allowed on Tensor created from .data or .detach()");
+    coalesced_ = coalesced;
+  }
 
   // NOTE: this function is only used internally and not exposed to Python frontend
   void set_nnz_and_narrow(int64_t new_nnz) {
+    AT_CHECK(allow_tensor_metadata_change(), "set_nnz_and_narrow is not allowed on Tensor created from .data or .detach()");
     AT_ASSERT(new_nnz <= nnz());
     indices_ = indices_.narrow(1, 0, new_nnz);
     values_ = values_.narrow(0, 0, new_nnz);
@@ -176,6 +183,32 @@ public:
   // make it happen
   void set_indices_and_values_unsafe(const Tensor& indices, const Tensor& values);
 
+  // NOTE: `shallow_copy_and_detach()` does not copy the AutogradMeta pointer
+  // because it is unique for each Variable.
+  // NOTE: We don't set `allow_tensor_metadata_change_` to false here, because there are call sites
+  // to this function that need to change the shallow copy's size or storage afterwards, and setting
+  // `allow_tensor_metadata_change_` to false would prevent those changes from happening and is
+  // undesirable.
+  c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach() const override {
+    auto impl = c10::make_intrusive<SparseTensorImpl>(type_id(), dtype());
+    // TensorImpl general fields
+    // Note that these fields are not used in sparse tensor code, and we copy them here only for completeness.
+    impl->sizes_ = sizes_;
+    impl->strides_ = strides_;
+    impl->storage_offset_ = storage_offset_;
+    impl->is_contiguous_ = is_contiguous_;
+    impl->is_wrapped_number_ = is_wrapped_number_;
+    impl->reserved_ = reserved_;
+
+    // Sparse-specific fields
+    impl->sparse_dim_ = sparse_dim();
+    impl->dense_dim_ = dense_dim();
+    impl->indices_ = indices();
+    impl->values_ = values();
+    impl->coalesced_ = coalesced();
+    impl->refresh_numel();
+    return impl;
+  }
  private:
   int64_t get_device_slow() const override {
     return values_.get_device();
index 5cf3e8c..2b51a37 100644 (file)
@@ -157,5 +157,5 @@ void THTensor_stealAndSetStoragePtr(THTensor* tensor, THStorage* storage) {
   // Caffe2 might have tensors whose storages are null, but we
   // don't allow it in PyTorch.
   AT_ASSERT(storage);
-  tensor->storage_ = at::Storage(c10::intrusive_ptr<THStorage>::reclaim(storage));
+  tensor->set_storage(at::Storage(c10::intrusive_ptr<THStorage>::reclaim(storage)));
 }
index 4001bc4..dc22739 100644 (file)
@@ -34,10 +34,10 @@ inline THStorage* THTensor_getStoragePtr(const THTensor* tensor) {
   // for the first time (providing the necessary type).  It is an ERROR to
   // invoke any PyTorch operations on such a half-constructed storage,
   // and this check tests for that case.
-  AT_CHECK(tensor->storage_, "Cannot use PyTorch operations on a half-constructed "
+  AT_CHECK(tensor->storage(), "Cannot use PyTorch operations on a half-constructed "
            "tensor.  If this tensor came from Caffe2, please call GetMutableData on "
            "it first; otherwise, this is a bug, please report it.");
-  return tensor->storage_.unsafeGetStorageImpl();
+  return tensor->storage().unsafeGetStorageImpl();
 }
 
 inline void THTensor_maybe_zero_dim(THTensor *tensor, bool condition_when_zero_dim) {
index e417bef..f6cec73 100644 (file)
@@ -122,4 +122,6 @@ at::DataPtr PlacementDeleteContext::makeDataPtr(
           device};
 }
 
+AutogradMetaInterface::~AutogradMetaInterface() {}
+
 } // namespace c10
index 8c4be48..d5f248e 100644 (file)
@@ -127,6 +127,10 @@ struct C10_API PlacementDeleteContext {
   }
 };
 
+struct C10_API AutogradMetaInterface {
+  virtual ~AutogradMetaInterface();
+};
+
 /**
  * The low-level representation of a tensor, which contains a pointer
  * to a storage (which contains the actual data) and metadata (e.g., sizes and
@@ -646,6 +650,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
    * which is harder to misuse.
    */
   virtual void resize_dim(int64_t ndim) {
+    AT_CHECK(allow_tensor_metadata_change(), "resize_dim is not allowed on Tensor created from .data or .detach()");
     sizes_.resize(ndim, 0);
     strides_.resize(ndim, 0);
     refresh_numel();
@@ -661,6 +666,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
    * which is harder to misuse.
    */
   virtual void set_size(int64_t dim, int64_t new_size) {
+    AT_CHECK(allow_tensor_metadata_change(), "set_size is not allowed on Tensor created from .data or .detach()");
     sizes_.at(dim) = new_size;
     refresh_numel();
     refresh_contiguous();
@@ -673,6 +679,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
    * which is harder to misuse.
    */
   virtual void set_stride(int64_t dim, int64_t new_stride) {
+    AT_CHECK(allow_tensor_metadata_change(), "set_stride is not allowed on Tensor created from .data or .detach()");
     strides_[dim] = new_stride;
     refresh_numel();
     refresh_contiguous();
@@ -686,6 +693,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
    * (and resizing if necessary.)
    */
   virtual void set_storage_offset(int64_t storage_offset) {
+    AT_CHECK(allow_tensor_metadata_change(), "set_storage_offset is not allowed on Tensor created from .data or .detach()");
     storage_offset_ = storage_offset;
   }
 
@@ -700,6 +708,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
    * See Note [We regret making Variable hold a Tensor]
    */
   void set_sizes_contiguous(IntList new_size) {
+    AT_CHECK(allow_tensor_metadata_change(), "set_sizes_contiguous is not allowed on Tensor created from .data or .detach()");
     AT_ASSERT(!is_variable());
     auto old_dim = sizes_.size();
     auto new_dim = new_size.size();
@@ -724,6 +733,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
    * See Note [We regret making Variable hold a Tensor]
    */
   void set_sizes_and_strides(IntList new_size, IntList new_stride) {
+    AT_CHECK(allow_tensor_metadata_change(), "set_sizes_and_strides is not allowed on Tensor created from .data or .detach()");
     AT_ASSERT(!is_variable());
     AT_CHECK(
         new_size.size() == new_stride.size(),
@@ -778,6 +788,58 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
    */
   bool is_variable() const { return is_variable_; };
 
+  /**
+   * Set whether a tensor allows changes to its metadata (e.g. sizes / strides / storage / storage_offset).
+   */
+  virtual void set_allow_tensor_metadata_change(bool value) {
+    allow_tensor_metadata_change_ = value;
+  }
+
+  /**
+   * True if a tensor allows changes to its metadata (e.g. sizes / strides / storage / storage_offset).
+   */
+  virtual bool allow_tensor_metadata_change() const {
+    return allow_tensor_metadata_change_;
+  }
+
+  /**
+   * Set the pointer to autograd metadata.
+   */
+  void set_autograd_meta(std::unique_ptr<c10::AutogradMetaInterface> autograd_meta) {
+    autograd_meta_ = std::move(autograd_meta);
+  }
+
+  /**
+   * Return the pointer to autograd metadata.
+   */
+  c10::AutogradMetaInterface* autograd_meta() const {
+    return autograd_meta_.get();
+  }
+
+  /**
+   * Detach the autograd metadata unique_ptr from this tensor, and return it.
+   */
+  std::unique_ptr<c10::AutogradMetaInterface> detach_autograd_meta() {
+    return std::move(autograd_meta_);
+  }
+
+  // NOTE: `shallow_copy_and_detach()` does not copy the AutogradMeta pointer
+  // because it is unique for each Variable.
+  // NOTE: We don't set `allow_tensor_metadata_change_` to false here, because there are call sites
+  // to this function that need to change the shallow copy's size or storage afterwards, and setting
+  // `allow_tensor_metadata_change_` to false would prevent those changes from happening and is
+  // undesirable.
+  virtual c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach() const {
+    auto impl = c10::make_intrusive<TensorImpl>(Storage(storage()), type_id(), is_variable());
+    impl->set_sizes_and_strides(sizes(), strides());
+    impl->storage_offset_ = storage_offset_;
+    impl->is_wrapped_number_ = is_wrapped_number_;
+    impl->reserved_ = reserved_;
+    impl->refresh_numel();
+    impl->refresh_contiguous();
+    return impl;
+  }
+
  private:
   // As an optimization, get_device handles the typical CUDA Tensor case and
   // calls get_device_slow if the tensor stores its device somewhere else
@@ -1157,6 +1219,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
   }
 
   void set_storage(at::Storage storage) {
+    AT_CHECK(allow_tensor_metadata_change(), "set_storage is not allowed on Tensor created from .data or .detach()");
     storage_ = std::move(storage);
     data_type_ = storage_.dtype();
   }
@@ -1268,10 +1331,14 @@ protected:
     is_contiguous_ = compute_contiguous();
   }
 
-public:
-  Storage storage_; // TODO: Fix visibility on me
-
 protected:
+  Storage storage_;
+  // This pointer points to an AutogradMeta struct that stores autograd-specific fields
+  // (such as grad_ / grad_fn_ / grad_accumulator_).
+  // This pointer always has unique ownership (meaning only one TensorImpl can own it
+  // at a time).
+  std::unique_ptr<c10::AutogradMetaInterface> autograd_meta_ = nullptr;
+
   // We could save a word or two by combining the SmallVector structs,
   // since their size is redundant, and if we need to overflow the buffer space
   // we could keep the two pointers together. However, that would require
@@ -1296,6 +1363,18 @@ protected:
   bool is_contiguous_ = true;
   bool is_variable_ = false;
   bool is_wrapped_number_ = false;
+
+  // Previously, if we change the tensor metadata (e.g. sizes / strides / storage / storage_offset)
+  // of a derived tensor (i.e. tensors created from Python `tensor.data` or Python/C++ `tensor.detach()`),
+  // those metadata in the original tensor will also be updated. However, the new behavior is that
+  // those metadata changes to a derived tensor will not update the original tensor anymore, and we
+  // need this flag to make such changes explicitly illegal, to prevent users from changing metadata of
+  // the derived tensor and expecting the original tensor to also be updated.
+  //
+  // NOTE: For a full list of tensor metadata fields, please see `shallow_copy_and_detach()` in TensorImpl
+  // and its subclasses to find which fields are copied by value.
+  bool allow_tensor_metadata_change_ = true;
+
   // we decide to keep reserved_ and it will
   // live in Tensor after the split
   // The logic is that if Extend() or ReserveSpace() were ever called,
@@ -1352,10 +1431,11 @@ protected:
 //    storage offset
 //    numel
 //    data type pointer
+//    autograd metadata pointer
 //    miscellaneous bitfield
 //
 static_assert(sizeof(void*) != sizeof(int64_t) || // if 64-bit...
-              sizeof(TensorImpl) == sizeof(int64_t) * 24,
+              sizeof(TensorImpl) == sizeof(int64_t) * 25,
               "You changed the size of TensorImpl on 64-bit arch."
               "See Note [TensorImpl size constraints] on how to proceed.");
 
index e442df8..8d2b790 100644 (file)
@@ -821,7 +821,8 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg
         elif isinstance(arg, torch.Tensor):
             if arg.dtype == torch.float:
                 arg = arg.double()
-            v = maybe_non_contig(arg).detach()
+            # NOTE: We do clone() after detach() here because we need to be able to change size/storage of v afterwards
+            v = maybe_non_contig(arg).detach().clone()
             v.requires_grad = requires_grad and v.is_floating_point()
             return v
         elif callable(arg):
index 64c0248..940e55d 100644 (file)
@@ -327,7 +327,8 @@ class TestCase(expecttest.TestCase):
             #        needed for inplace operations done on `x`, e.g., copy_().
             #        Remove after implementing something equivalent to CopySlice
             #        for sparse views.
-            x = x.detach()
+            # NOTE: We do clone() after detach() here because we need to be able to change size/storage of x afterwards
+            x = x.detach().clone()
         return x, x._indices().clone(), x._values().clone()
 
     def safeToDense(self, t):
index 4104249..9ea5d14 100644 (file)
@@ -4894,7 +4894,7 @@ a")
         def __init__(self):
             super(TestScript.DerivedStateModule, self).__init__()
             self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float))
-            self.register_buffer('derived', torch.neg(self.param).detach())
+            self.register_buffer('derived', torch.neg(self.param).detach().clone())
 
             # This is a flag so we can test that the pack method was called
             self.register_buffer('pack_called', torch.zeros(1, dtype=torch.long))
index b08930c..c9fc07f 100644 (file)
@@ -537,7 +537,8 @@ class TestNN(NNTestCase):
     def _zero_grad_parameters(self, module):
         for p in module.parameters():
             if p.grad is not None:
-                p.grad.data.zero_()
+                with torch.no_grad():
+                    p.grad.zero_()
                 p.grad.detach_()
 
     def _get_parameters(self, module):
@@ -4519,7 +4520,8 @@ class TestNN(NNTestCase):
             # Weights will no longer view onto the same chunk of memory
             weight = all_vars[4]
             weight_data = weight.data.clone()
-            weight.data.set_(weight_data)
+            with torch.no_grad():
+                weight.set_(weight_data)
 
             for i in range(2):
                 with warnings.catch_warnings(record=True) as w:
index ee014d7..44d8790 100644 (file)
@@ -1829,6 +1829,41 @@ class TestSparse(TestCase):
         with self.assertRaisesRegex(RuntimeError, "bool value of Tensor with no values is ambiguous"):
             torch.sparse_coo_tensor(([0, 1],), self.ValueTensor(2, 0), (4, 0)).is_nonzero()
 
+    def test_allow_tensor_metadata_change(self):
+        def do_test(t):
+            with self.assertRaisesRegex(
+                    RuntimeError,
+                    "raw_resize_ is not allowed on Tensor created from .data or .detach()"):
+                t.transpose_(0, 1)
+            with self.assertRaisesRegex(
+                    RuntimeError,
+                    "resize_ is not allowed on Tensor created from .data or .detach()"):
+                t.resize_as_(self.SparseTensor(3, 3))
+            with self.assertRaisesRegex(
+                    RuntimeError,
+                    "resize_and_clear_ is not allowed on Tensor created from .data or .detach()"):
+                t.mul_(t)
+            with self.assertRaisesRegex(
+                    RuntimeError,
+                    "set_coalesced is not allowed on Tensor created from .data or .detach()"):
+                t._coalesced_(True)
+            with self.assertRaisesRegex(
+                    RuntimeError,
+                    "set_indices_and_values_unsafe is not allowed on Tensor created from .data or .detach()"):
+                a = self.SparseTensor(torch.tensor([[0, 1, 1], [2, 0, 2]]), torch.tensor([3., 4., 5.])).data
+                a.add_(a)
+            with self.assertRaisesRegex(
+                    RuntimeError,
+                    "resize_and_clear_ is not allowed on Tensor created from .data or .detach()"):
+                a.zero_()
+            with self.assertRaisesRegex(
+                    RuntimeError,
+                    "resize_ is not allowed on Tensor created from .data or .detach()"):
+                a.copy_(self.SparseTensor(3, 3))
+
+        do_test(self.SparseTensor(3, 0).data)
+        do_test(self.SparseTensor(3, 0).detach())
+
 
 class TestUncoalescedSparse(TestSparse):
     def setUp(self):
index 5a9c104..8762047 100644 (file)
@@ -9524,6 +9524,24 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         with self.assertRaisesRegex(RuntimeError, "expected both inputs to be on same device"):
             torch.tensor(2).to("cuda:1") // torch.tensor(3).to("cuda:0")
 
+    def test_allow_tensor_metadata_change(self):
+        def do_test(t):
+            with self.assertRaisesRegex(
+                    RuntimeError,
+                    "set_sizes_contiguous is not allowed on Tensor created from .data or .detach()"):
+                t.resize_((2, 1))
+            with self.assertRaisesRegex(
+                    RuntimeError,
+                    "set_storage is not allowed on Tensor created from .data or .detach()"):
+                t.set_()
+            with self.assertRaisesRegex(
+                    RuntimeError,
+                    "set_storage_offset is not allowed on Tensor created from .data or .detach()"):
+                t.set_(t.storage(), 0, t.size(), list(t.stride()))
+
+        do_test(torch.tensor([[1, 2]]).data)
+        do_test(torch.tensor([[1, 2]]).detach())
+
 # Functions to test negative dimension wrapping
 METHOD = 1
 INPLACE_METHOD = 2
index b12076f..85a013f 100644 (file)
@@ -64,7 +64,7 @@ auto AccumulateGrad::apply(variable_list&& grads) -> variable_list {
     // a thing never promised and documented, but used in some hacks seen
     // on the internet.
     if (grad_variable.is_sparse() && !new_grad.is_sparse()) {
-      grad_variable.data() = new_grad.data() + grad_variable.data();
+      grad_variable.set_data(new_grad.data() + grad_variable.data());
     } else {
       grad_variable.data() += new_grad.data();
     }
index 5d5ea8a..f747c6f 100644 (file)
@@ -200,7 +200,16 @@ static PyObject *THPVariable_is_leaf(THPVariable *self)
 static PyObject * THPVariable_get_data(THPVariable *self)
 {
   HANDLE_TH_ERRORS
-  return THPVariable_Wrap(make_variable(self->cdata.data(), false));
+  /// NOTE: Previously, if we change the tensor metadata (e.g. sizes / strides /
+  /// storage / storage_offset) of a tensor created from `.data`, those metadata
+  /// in the original tensor will also be updated. However, the new behavior is that
+  /// those metadata changes to the `.data` tensor will not update the original tensor
+  /// anymore, and here we need to set `allow_tensor_metadata_change_` to false to
+  /// make such changes explicitly illegal, in order to prevent users from changing
+  /// metadata of the `.data` tensor and expecting the original tensor to also
+  /// be updated.
+  auto var = make_variable(self->cdata.data(), /*requires_grad=*/false, /*allow_tensor_metadata_change=*/false);
+  return THPVariable_Wrap(var);
   END_HANDLE_TH_ERRORS
 }
 
index 7b133b9..baff363 100644 (file)
 
 namespace torch {
 namespace autograd {
-Variable::Impl::Impl(at::Tensor data, bool requires_grad, Edge gradient_edge)
+Variable::Impl::Impl(at::Tensor data, std::unique_ptr<Variable::AutogradMeta> autograd_meta, bool requires_grad, Edge gradient_edge)
     : TensorImpl(data.type_id(), data.dtype(), /*allocator=*/nullptr, /* is variable */ true),
-      data_(std::move(data)),
-      grad_fn_(std::move(gradient_edge.function)),
-      requires_grad_(false),
-      is_view_(false),
-      output_nr_(gradient_edge.input_nr),
-      pyobj_(nullptr) {
+      data_(std::move(data)) {
+  autograd_meta->grad_fn_ = std::move(gradient_edge.function);
+  autograd_meta->requires_grad_ = false;
+  autograd_meta->is_view_ = false;
+  autograd_meta->output_nr_ = gradient_edge.input_nr;
+  autograd_meta->pyobj_ = nullptr;
+  set_autograd_meta(std::move(autograd_meta));
+
   // set_requires_grad also checks error conditions.
   set_requires_grad(requires_grad);
   AT_CHECK(
-      !grad_fn_ || !requires_grad_,
+      !get_autograd_meta()->grad_fn_ || !get_autograd_meta()->requires_grad_,
       "requires_grad should be false if grad_fn is set");
   if (!data_.defined()) {
     throw std::runtime_error("data is undefined");
@@ -102,42 +104,45 @@ int64_t Variable::Impl::get_device_slow() const {
 }
 
 std::shared_ptr<Function> Variable::Impl::get_grad_accumulator() {
-  if (grad_fn_) {
+  auto autograd_meta = get_autograd_meta();
+  if (autograd_meta->grad_fn_) {
     throw std::logic_error(
         "get_grad_accumulator() should be only called on leaf Variables");
   }
-  if (!requires_grad_) {
+  if (!autograd_meta->requires_grad_) {
     return nullptr;
   }
 
-  std::lock_guard<std::mutex> lock(mutex_);
+  std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
 
-  auto result = grad_accumulator_.lock();
+  auto result = autograd_meta->grad_accumulator_.lock();
   if (result)
     return result;
 
   c10::raw::intrusive_ptr::incref(this);
   auto intrusive_from_this = c10::intrusive_ptr<Variable::Impl>::reclaim(this);
   result = std::make_shared<AccumulateGrad>(Variable(std::move(intrusive_from_this)));
-  grad_accumulator_ = result;
+  autograd_meta->grad_accumulator_ = result;
   return result;
 }
 
 void Variable::Impl::detach_() {
-  if (is_view_) {
+  auto autograd_meta = get_autograd_meta();
+  if (autograd_meta->is_view_) {
     AT_ERROR("Can't detach views in-place. Use detach() instead");
   }
   set_requires_grad(false);
-  grad_fn_.reset();
-  output_nr_ = 0;
+  autograd_meta->grad_fn_.reset();
+  autograd_meta->output_nr_ = 0;
 }
 
 void Variable::Impl::backward(
     c10::optional<Tensor> gradient,
     bool keep_graph,
     bool create_graph) {
+  auto autograd_meta = get_autograd_meta();
   std::vector<Edge> edges;
-  edges.emplace_back(grad_fn_, output_nr_);
+  edges.emplace_back(autograd_meta->grad_fn_, autograd_meta->output_nr_);
 
   std::vector<Variable> inputs;
   if (!gradient.has_value()) {
@@ -149,14 +154,15 @@ void Variable::Impl::backward(
 
 void Variable::Impl::set_data(Tensor new_data) {
   // Resets gradient accumulator if metadata is out of date
-  std::lock_guard<std::mutex> lock(mutex_);
-  auto prior_accumulator = grad_accumulator_.lock();
+  auto autograd_meta = get_autograd_meta();
+  std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
+  auto prior_accumulator = autograd_meta->grad_accumulator_.lock();
   if (prior_accumulator) {
     const auto prior_device = prior_accumulator->input_metadata(0).device();
     const auto new_device = new_data.is_cuda() ? new_data.get_device() : -1;
 
     if (new_data.type() != data_.type() || prior_device != new_device) {
-      grad_accumulator_.reset();
+      autograd_meta->grad_accumulator_.reset();
     }
   }
 
@@ -164,68 +170,72 @@ void Variable::Impl::set_data(Tensor new_data) {
   data_type_ = new_data.type().typeMeta();
   type_id_ = new_data.type().type_id();
   is_variable_ = true;
-  data_ = std::move(new_data);
+
+  auto new_data_copy = at::Tensor(new_data.getIntrusivePtr()->shallow_copy_and_detach());
+  data_ = std::move(new_data_copy);
 }
 
 void Variable::Impl::release_resources() {
+  autograd_meta_.reset();
   data_.reset();
-  grad_.reset();
-  grad_fn_.reset();
-  hooks_.clear();
 }
 
-Variable::DifferentiableViewImpl::DifferentiableViewImpl(Variable base, at::Tensor data, Edge gradient_edge)
-    : Variable::Impl(std::move(data), false, std::move(gradient_edge)),
-      base_(std::move(base)) {
-  AT_CHECK(base_.defined(), "base is undefined");
-  if (base_.is_view()) {
-    base_ = base_.base();
+Variable::DifferentiableViewImpl::DifferentiableViewImpl(Variable base, at::Tensor data, Edge gradient_edge, std::unique_ptr<Variable::DifferentiableViewMeta> autograd_meta)
+    : Variable::Impl(std::move(data), std::move(autograd_meta), false, std::move(gradient_edge)) {
+  auto diff_view_meta = static_cast<Variable::DifferentiableViewMeta*>(get_autograd_meta());
+  diff_view_meta->base_ = std::move(base);
+  AT_CHECK(diff_view_meta->base_.defined(), "base is undefined");
+  if (diff_view_meta->base_.is_view()) {
+    diff_view_meta->base_ = diff_view_meta->base_.base();
   }
-  is_view_ = true;
-  version_counter_ = base_.version_counter();
-  attr_version = version_counter_.current_version();
+  diff_view_meta->is_view_ = true;
+  diff_view_meta->version_counter_ = diff_view_meta->base_.version_counter();
+  diff_view_meta->attr_version = diff_view_meta->version_counter_.current_version();
 }
 
 std::shared_ptr<Function>& Variable::DifferentiableViewImpl::get_grad_fn() {
-  std::lock_guard<std::mutex> lock(mutex_);
-  if (!grad_fn_ && !base_.requires_grad()) {
-    return grad_fn_;
+  auto diff_view_meta = static_cast<Variable::DifferentiableViewMeta*>(get_autograd_meta());
+  std::lock_guard<std::mutex> lock(diff_view_meta->mutex_);
+  if (!diff_view_meta->grad_fn_ && !diff_view_meta->base_.requires_grad()) {
+    return diff_view_meta->grad_fn_;
   }
-  auto current_version = version_counter_.current_version();
-  if (attr_version != current_version) {
-    AT_ASSERT(output_nr_ == 0);
+  auto current_version = diff_view_meta->version_counter_.current_version();
+  if (diff_view_meta->attr_version != current_version) {
+    AT_ASSERT(diff_view_meta->output_nr_ == 0);
     auto fn = std::make_shared<generated::AsStridedBackward>();
-    fn->self_geometry = at::TensorGeometry(base_);
+    fn->self_geometry = at::TensorGeometry(diff_view_meta->base_);
     fn->size = sizes().vec();
     fn->stride = strides().vec();
     fn->storage_offset = data_.storage_offset();
-    fn->set_next_edges(collect_next_edges(base_));
+    fn->set_next_edges(collect_next_edges(diff_view_meta->base_));
     fn->add_input_metadata(
-      base_.type()
+      diff_view_meta->base_.type()
     , sizes() // Note: sizes(), not base_.sizes(), is intentional
-    , base_.is_cuda() ? base_.get_device() : -1);
-    grad_fn_ = std::move(fn);
-    attr_version = current_version;
+    , diff_view_meta->base_.is_cuda() ? diff_view_meta->base_.get_device() : -1);
+    diff_view_meta->grad_fn_ = std::move(fn);
+    diff_view_meta->attr_version = current_version;
   }
-  return grad_fn_;
+  return diff_view_meta->grad_fn_;
 }
 
 void Variable::DifferentiableViewImpl::rebase_history(Edge gradient_edge) {
+  auto diff_view_meta = static_cast<Variable::DifferentiableViewMeta*>(get_autograd_meta());
   AT_ASSERT(gradient_edge.input_nr == 0);
   AT_ASSERT(gradient_edge.function);
   AT_CHECK(
       gradient_edge.function->num_inputs() == 1,
       "Functions which modify views in-place must return a single Variable");
-  this->output_nr_ = gradient_edge.input_nr;
+  diff_view_meta->output_nr_ = gradient_edge.input_nr;
   auto copy_slices = std::make_shared<CopySlices>(
-      base_, at::TensorGeometry(data_), std::move(gradient_edge.function));
-  base_.set_gradient_edge({std::move(copy_slices), 0});
+      diff_view_meta->base_, at::TensorGeometry(data_), std::move(gradient_edge.function));
+  diff_view_meta->base_.set_gradient_edge({std::move(copy_slices), 0});
   get_grad_fn(); // trigger an update to the view's grad_fn
 }
 
 void Variable::DifferentiableViewImpl::release_resources() {
+  auto diff_view_meta = static_cast<Variable::DifferentiableViewMeta*>(get_autograd_meta());
+  diff_view_meta->base_.reset();
   Variable::Impl::release_resources();
-  base_.reset();
 }
 
 void Variable::rebase_history(Edge gradient_edge) {
index b67a810..7b14803 100644 (file)
@@ -101,19 +101,26 @@ struct TORCH_API Variable : public at::Tensor {
       Variable base,
       at::Tensor data,
       bool is_differentiable,
+      bool allow_tensor_metadata_change,
       Edge gradient_edge);
 
   /// Creates a `Variable` from the given `Tensor`. `requires_grad` should be
   /// set only for leaves, and determines whether the `Variable` will accumulate
   /// gradients. NOTE: `data` must *not* be a `Variable` already. Its dynamic
   /// type *must* be `Tensor`.
-  friend Variable make_variable(at::Tensor data, bool requires_grad);
+  friend Variable make_variable(
+      at::Tensor data,
+      bool requires_grad,
+      bool allow_tensor_metadata_change);
 
   /// Creates a `Variable` from the given `Tensor` and specify a
   /// `gradient_edge`, i.e. a (function, input_nr) pair specifying the function
   /// in the autograd graph, and what particular input of that function, this
   /// variable is connected to.
-  friend Variable make_variable(at::Tensor data, Edge gradient_edge);
+  friend Variable make_variable(
+      at::Tensor data,
+      Edge gradient_edge,
+      bool allow_tensor_metadata_change);
 
   // Tensor Conversions
   //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -186,6 +193,14 @@ struct TORCH_API Variable : public at::Tensor {
   /// Returns a copy of this `Variable` that is detached from its autograd graph
   /// and has a blank version. This method is OK to call if the `Variable` is a
   /// view.
+  /// NOTE: Previously, if we change the tensor metadata (e.g. sizes / strides /
+  /// storage / storage_offset) of a tensor created from `detach()`, those metadata
+  /// in the original tensor will also be updated. However, the new behavior is that
+  /// those metadata changes to the detached tensor will not update the original tensor
+  /// anymore, and in the `detach()` function we need to set `allow_tensor_metadata_change_`
+  /// to false to make such changes explicitly illegal, in order to prevent users from
+  /// changing metadata of the detached tensor and expecting the original tensor to also
+  /// be updated.
   Variable detach() const;
 
   /// Like `detach()`, but removes this `Variable` in-place. This method may
@@ -264,12 +279,16 @@ struct TORCH_API Variable : public at::Tensor {
   PyObject* pyobj() const noexcept;
   void set_pyobj(PyObject* pyobj) noexcept;
 
+  struct AutogradMeta;
+  Variable::AutogradMeta* get_autograd_meta() const noexcept;
+
  private:
   /// Private implementation struct of the `Variable`. This struct declaration
   /// and the `get()` method which exposes it shall forever remain private and
   /// never be exposed to the public interface of this class.
   struct Impl;
   struct DifferentiableViewImpl;
+  struct DifferentiableViewMeta;
 
   // Private Methods
   //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -279,12 +298,62 @@ struct TORCH_API Variable : public at::Tensor {
 };
 
 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+//                            Variable::AutogradMeta
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+/// Each `Variable` has one unique `AutogradMeta` struct, which stores autograd
+/// metadata fields that are necessary for tracking the Variable's autograd history.
+
+struct TORCH_API Variable::AutogradMeta : public c10::AutogradMetaInterface {
+  std::string name;
+
+  Variable grad_;
+  std::shared_ptr<Function> grad_fn_;
+  std::weak_ptr<Function> grad_accumulator_;
+
+  VariableVersion version_counter_;
+  std::vector<std::shared_ptr<FunctionPreHook>> hooks_;
+
+  // Only meaningful on leaf variables (must be false otherwise)
+  bool requires_grad_;
+
+  bool is_view_;
+
+  // The "output number" of this variable; e.g., if this variable
+  // was the second output of a function, then output_nr == 1.
+  // We use this to make sure we can setup the backwards trace
+  // correctly when this variable is passed to another function.
+  uint32_t output_nr_;
+  PyObject* pyobj_ = nullptr; // weak reference
+
+  // Mutex to ensure that concurrent read operations that modify internal
+  // state are still thread-safe. Used by get_grad_fn and
+  // get_grad_accumulator.
+  std::mutex mutex_;
+};
+
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+//                        Variable::DifferentiableViewMeta
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+struct TORCH_API Variable::DifferentiableViewMeta : public Variable::AutogradMeta {
+  /// The base `Variable` (never a view).
+  Variable base_;
+
+  /// The value of the version_counter at the time grad_fn was created. The
+  /// grad_fn field is stale if attr_version !=
+  /// version_counter.current_version().
+  uint32_t attr_version;
+};
+
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 //                            Variable::Impl
 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 struct TORCH_API Variable::Impl : public at::TensorImpl {
   explicit Impl(
       at::Tensor data,
+      std::unique_ptr<Variable::AutogradMeta> autograd_meta,
       bool requires_grad = false,
       Edge gradient_edge = Edge());
 
@@ -307,7 +376,7 @@ struct TORCH_API Variable::Impl : public at::TensorImpl {
 
   std::shared_ptr<Function> get_grad_accumulator();
   virtual std::shared_ptr<Function>& get_grad_fn() {
-    return grad_fn_;
+    return get_autograd_meta()->grad_fn_;
   }
 
   virtual const Variable& base() const {
@@ -321,19 +390,19 @@ struct TORCH_API Variable::Impl : public at::TensorImpl {
     AT_CHECK(
         !requires_grad || at::isFloatingType(at::typeMetaToScalarType(dtype())),
         "Only Tensors of floating point dtype can require gradients");
-    requires_grad_ = requires_grad;
+    get_autograd_meta()->requires_grad_ = requires_grad;
   }
 
   bool requires_grad() const override {
-    return requires_grad_ || grad_fn_ || (is_view_ && base().requires_grad());
+    return get_autograd_meta()->requires_grad_ || get_autograd_meta()->grad_fn_ || (get_autograd_meta()->is_view_ && base().requires_grad());
   }
 
   /// Accesses the gradient `Variable` of this `Variable`.
   Variable& grad() override {
-    return grad_;
+    return get_autograd_meta()->grad_;
   }
   const Variable& grad() const override {
-    return grad_;
+    return get_autograd_meta()->grad_;
   }
 
   void detach_();
@@ -348,34 +417,16 @@ struct TORCH_API Variable::Impl : public at::TensorImpl {
   /// Reset all expensive fields to free up resources
   void release_resources() override;
 
-  std::string name;
-  at::Tensor data_;
-
-  Variable grad_;
-  std::shared_ptr<Function> grad_fn_;
-  std::weak_ptr<Function> grad_accumulator_;
-
-  VariableVersion version_counter_;
-  std::vector<std::shared_ptr<FunctionPreHook>> hooks_;
-
-  // Only meaningful on leaf variables (must be false otherwise)
-  bool requires_grad_;
-
-  bool is_view_;
+  Variable::AutogradMeta* get_autograd_meta() const {
+    return static_cast<Variable::AutogradMeta*>(autograd_meta());
+  }
 
-  // The "output number" of this variable; e.g., if this variable
-  // was the second output of a function, then output_nr == 1.
-  // We use this to make sure we can setup the backwards trace
-  // correctly when this variable is passed to another function.
-  uint32_t output_nr_;
-  PyObject* pyobj_; // weak reference
+  int64_t storage_offset() const override;
 
-  // Mutex to ensure that concurrent read operations that modify internal
-  // state are still thread-safe. Used by get_grad_fn and
-  // get_grad_accumulator.
-  std::mutex mutex_;
+  /// The underlying data tensor for this Variable.
+  /// This field will be removed once VariableImpl and TensorImpl are merged.
+  at::Tensor data_;
 
-  int64_t storage_offset() const override;
  private:
   int64_t get_device_slow() const override;
 };
@@ -454,7 +505,11 @@ struct TORCH_API Variable::Impl : public at::TensorImpl {
 /// Relevant logic for non-differentiable views is implemented in
 /// make_variable_view below, and wrap_output of gen_variable_type.py.
 struct TORCH_API Variable::DifferentiableViewImpl : public Variable::Impl {
-  DifferentiableViewImpl(Variable base, at::Tensor data, Edge gradient_edge);
+  DifferentiableViewImpl(
+    Variable base,
+    at::Tensor data,
+    Edge gradient_edge,
+    std::unique_ptr<Variable::DifferentiableViewMeta> autograd_meta);
 
   /// Gets the up-to-date grad_fn. If the shared data or base was modified, we
   /// re-create the grad_fn to express the up-to-date view relationship between
@@ -462,7 +517,7 @@ struct TORCH_API Variable::DifferentiableViewImpl : public Variable::Impl {
   std::shared_ptr<Function>& get_grad_fn() override;
 
   const Variable& base() const override {
-    return base_;
+    return static_cast<Variable::DifferentiableViewMeta*>(get_autograd_meta())->base_;
   }
 
   /// Reset all expensive fields to free up resources
@@ -471,14 +526,6 @@ struct TORCH_API Variable::DifferentiableViewImpl : public Variable::Impl {
   /// Called after in-place modifications. Modifies the grad_fn of the base
   /// Variable.
   void rebase_history(Edge gradient_edge);
-
-  /// The base `Variable` (never a view).
-  Variable base_;
-
-  /// The value of the version_counter at the time grad_fn was created. The
-  /// grad_fn field is stale if attr_version !=
-  /// version_counter.current_version().
-  uint32_t attr_version;
 };
 
 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -488,21 +535,37 @@ struct TORCH_API Variable::DifferentiableViewImpl : public Variable::Impl {
 // Factory Functions
 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
+/// NOTE: `allow_tensor_metadata_change` is set to true by default, because there
+/// are a lot of call sites to these factory functions that need to change the
+/// variable's size or storage afterwards, and they don't expect the original
+/// tensor (where the variable is created from) to be updated. Setting
+/// `allow_tensor_metadata_change_` to false by default would unnecessarily
+/// prevent those changes from happening and is undesirable.
+
 // See NOTE [ Autograd View Variables ] for details.
 inline Variable make_variable_view(
     Variable base,
     at::Tensor data,
     bool is_differentiable = true,
+    bool allow_tensor_metadata_change = true,
     Edge gradient_edge = Edge()) {
   if (data.defined()) {
     if (is_differentiable) {
       /// Differentiable view. Track history with DifferentiableViewImpl.
+      auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach();
+      data_impl_copy->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
+      auto data_copy = at::Tensor(data_impl_copy);
+      auto diff_view_meta = c10::guts::make_unique<Variable::DifferentiableViewMeta>();
       return Variable(c10::make_intrusive<Variable::DifferentiableViewImpl>(
-              std::move(base), std::move(data), std::move(gradient_edge)));
+              std::move(base), std::move(data_copy), std::move(gradient_edge), std::move(diff_view_meta)));
     } else {
       /// Non-differentiable view. Just share version counter.
+      auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach();
+      data_impl_copy->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
+      auto data_copy = at::Tensor(data_impl_copy);
+      auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
       auto var = Variable(c10::make_intrusive<Variable::Impl>(
-              std::move(data), false, std::move(gradient_edge)));
+              std::move(data_copy), std::move(autograd_meta), false, std::move(gradient_edge)));
       var.set_version_counter(base.version_counter());
       return var;
     }
@@ -510,22 +573,36 @@ inline Variable make_variable_view(
   return Variable();
 }
 
-inline Variable make_variable(at::Tensor data, bool requires_grad = false) {
+inline Variable make_variable(
+    at::Tensor data,
+    bool requires_grad = false,
+    bool allow_tensor_metadata_change = true) {
   AT_CHECK(
       !data.is_variable(),
       "Must not create a new variable from a variable, use its .data()");
   if (data.defined()) {
-    return Variable(c10::make_intrusive<Variable::Impl>(data, requires_grad));
+    auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach();
+    data_impl_copy->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
+    auto data_copy = at::Tensor(data_impl_copy);
+    auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
+    return Variable(c10::make_intrusive<Variable::Impl>(data_copy, std::move(autograd_meta), requires_grad));
   }
   return Variable();
 }
 
-inline Variable make_variable(at::Tensor data, Edge gradient_edge) {
+inline Variable make_variable(
+    at::Tensor data,
+    Edge gradient_edge,
+    bool allow_tensor_metadata_change = true) {
   AT_CHECK(
       !data.is_variable(),
       "Must not create a new variable from a variable, use its .data()");
   if (data.defined()) {
-    return Variable(c10::make_intrusive<Variable::Impl>(data, false, std::move(gradient_edge)));
+    auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach();
+    data_impl_copy->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
+    auto data_copy = at::Tensor(data_impl_copy);
+    auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
+    return Variable(c10::make_intrusive<Variable::Impl>(data_copy, std::move(autograd_meta), false, std::move(gradient_edge)));
   }
   return Variable();
 }
@@ -568,16 +645,16 @@ inline const std::shared_ptr<Function>& Variable::grad_fn() const {
 }
 
 inline Function* Variable::grad_fn_unsafe() const {
-  return get()->grad_fn_.get();
+  return get_autograd_meta()->grad_fn_.get();
 }
 
 inline void Variable::set_grad_accumulator(
     std::weak_ptr<Function> grad_accumulator) {
-  get()->grad_accumulator_ = std::move(grad_accumulator);
+  get_autograd_meta()->grad_accumulator_ = std::move(grad_accumulator);
 }
 
 inline std::shared_ptr<Function> Variable::try_get_grad_accumulator() const {
-  return get()->grad_accumulator_.lock();
+  return get_autograd_meta()->grad_accumulator_.lock();
 }
 
 inline std::shared_ptr<Function> Variable::grad_accumulator() const {
@@ -585,7 +662,8 @@ inline std::shared_ptr<Function> Variable::grad_accumulator() const {
 }
 
 inline Variable Variable::detach() const {
-  return make_variable_view(*this, get()->data_, /*is_differentiable=*/false);
+  auto var = make_variable_view(*this, get()->data_, /*is_differentiable=*/false, /*allow_tensor_metadata_change=*/false, Edge());
+  return var;
 }
 
 inline void Variable::detach_() {
@@ -604,16 +682,16 @@ inline void Variable::set_data(Tensor new_data) const {
 }
 
 inline void Variable::set_gradient_edge(Edge edge) noexcept {
-  get()->grad_fn_ = std::move(edge.function);
-  get()->output_nr_ = edge.input_nr;
+  get_autograd_meta()->grad_fn_ = std::move(edge.function);
+  get_autograd_meta()->output_nr_ = edge.input_nr;
 }
 
 inline uint32_t Variable::output_nr() const noexcept {
-  return get()->output_nr_;
+  return get_autograd_meta()->output_nr_;
 }
 
 inline bool Variable::is_leaf() const noexcept {
-  return get()->grad_fn_ == nullptr;
+  return get_autograd_meta()->grad_fn_ == nullptr;
 }
 
 // Versions
@@ -621,42 +699,42 @@ inline bool Variable::is_leaf() const noexcept {
 
 inline void Variable::set_version_counter(
     const VariableVersion& version_counter) noexcept {
-  get()->version_counter_ = version_counter;
+  get_autograd_meta()->version_counter_ = version_counter;
 }
 
 inline void Variable::bump_version() noexcept {
-  get()->version_counter_.bump();
+  get_autograd_meta()->version_counter_.bump();
 }
 
 inline uint32_t Variable::current_version() const noexcept {
-  return get()->version_counter_.current_version();
+  return get_autograd_meta()->version_counter_.current_version();
 }
 
 inline const VariableVersion& Variable::version_counter() const noexcept {
-  return get()->version_counter_;
+  return get_autograd_meta()->version_counter_;
 }
 
 // Hooks
 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 inline void Variable::add_hook(std::shared_ptr<FunctionPreHook> hook) {
-  get()->hooks_.push_back(std::move(hook));
+  get_autograd_meta()->hooks_.push_back(std::move(hook));
 }
 
 inline const std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks()
     const noexcept {
-  return get()->hooks_;
+  return get_autograd_meta()->hooks_;
 }
 
 inline void Variable::clear_hooks() {
-  get()->hooks_.clear();
+  get_autograd_meta()->hooks_.clear();
 }
 
 // View Variables
 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 inline bool Variable::is_view() const noexcept {
-  return get()->is_view_;
+  return get_autograd_meta()->is_view_;
 }
 
 inline const Variable& Variable::base() const {
@@ -667,19 +745,23 @@ inline const Variable& Variable::base() const {
 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 inline void Variable::set_name(const std::string& name) {
-  get()->name = name;
+  get_autograd_meta()->name = name;
 }
 
 inline const std::string& Variable::name() const noexcept {
-  return get()->name;
+  return get_autograd_meta()->name;
 }
 
 inline void Variable::set_pyobj(PyObject* pyobj) noexcept {
-  get()->pyobj_ = pyobj;
+  get_autograd_meta()->pyobj_ = pyobj;
 }
 
 inline PyObject* Variable::pyobj() const noexcept {
-  return get()->pyobj_;
+  return get_autograd_meta()->pyobj_;
+}
+
+inline Variable::AutogradMeta* Variable::get_autograd_meta() const noexcept {
+  return get()->get_autograd_meta();
 }
 
 // Private Methods
index db832f0..f829452 100644 (file)
@@ -409,7 +409,7 @@ void initPythonIRBindings(PyObject* module_) {
       .def(
           "t_",
           [](Node& n, const char* name, torch::autograd::Variable v) {
-            return n.t_(Symbol::attr(name), std::move(v.data()));
+            return n.t_(Symbol::attr(name), v.data());
           })
       .def(
           "t",
@@ -426,7 +426,7 @@ void initPythonIRBindings(PyObject* module_) {
             std::vector<at::Tensor> tensors;
             tensors.reserve(vs.size());
             for (auto& variable : vs) {
-              tensors.push_back(std::move(variable.data()));
+              tensors.push_back(variable.data());
             }
             return n.ts_(Symbol::attr(name), std::move(tensors));
           })
index 031f1d1..eedd641 100644 (file)
@@ -175,9 +175,17 @@ class Tensor(torch._C._TensorBase):
 
     .. note::
 
-      Returned Tensor uses the same data tensor as the original one.
+      Returned Tensor shares the same storage with the original one.
       In-place modifications on either of them will be seen, and may trigger
       errors in correctness checks.
+      IMPORTANT NOTE: Previously, in-place size / stride / storage changes
+      (such as `resize_` / `resize_as_` / `set_` / `transpose_`) to the returned tensor
+      also update the original tensor. Now, these in-place changes will not update the
+      original tensor anymore, and will instead trigger an error.
+      For sparse tensors:
+      In-place indices / values changes (such as `zero_` / `copy_` / `add_`) to the
+      returned tensor will not update the original tensor anymore, and will instead
+      trigger an error.
     """)
 
     detach_ = _add_docstr(_C._TensorBase.detach_, r"""