AT_ERROR("sparse tensors do not have storage");
}
void SparseTensorImpl::set_indices_and_values_unsafe(const Tensor& indices, const Tensor& values) {
+ AT_CHECK(allow_tensor_metadata_change(), "set_indices_and_values_unsafe is not allowed on Tensor created from .data or .detach()");
AT_ASSERT(!indices.is_variable() && !values.is_variable()); // They should be plain tensors!
AT_CHECK(!indices.is_sparse(), "expected indices to be a dense tensor, but got indices of layout ", indices.layout());
// WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim with
// respect to indices and values
void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntList size) {
+ AT_CHECK(allow_tensor_metadata_change(), "raw_resize_ is not allowed on Tensor created from .data or .detach()");
sizes_ = size.vec();
sparse_dim_ = sparse_dim;
dense_dim_ = dense_dim;
// 4. When we attempt to shrink the size of any of the sparse dimensions on a non-empty sparse tensor
// (this could make some of the stored indices out-of-bound and thus unsafe).
void resize_(int64_t sparse_dim, int64_t dense_dim, IntList size) {
+ AT_CHECK(allow_tensor_metadata_change(), "resize_ is not allowed on Tensor created from .data or .detach()");
AT_CHECK(sparse_dim + dense_dim == size.size(), "number of dimensions must be sparse_dim (", sparse_dim, ") + dense_dim (", dense_dim, "), but got ", size.size());
if (nnz() > 0) {
auto alt_options_msg = "You could try the following options:\n\
// NOTE: this function will resize the sparse tensor and also set `indices` and `values` to empty.
void resize_and_clear_(int64_t sparse_dim, int64_t dense_dim, IntList size) {
+ AT_CHECK(allow_tensor_metadata_change(), "resize_and_clear_ is not allowed on Tensor created from .data or .detach()");
AT_CHECK(sparse_dim + dense_dim == size.size(), "number of dimensions must be sparse_dim (", sparse_dim, ") + dense_dim (", dense_dim, "), but got ", size.size());
sizes_ = size.vec();
refresh_numel();
}
- void set_coalesced(bool coalesced) { coalesced_ = coalesced; }
+ void set_coalesced(bool coalesced) {
+ AT_CHECK(allow_tensor_metadata_change(), "set_coalesced is not allowed on Tensor created from .data or .detach()");
+ coalesced_ = coalesced;
+ }
// NOTE: this function is only used internally and not exposed to Python frontend
void set_nnz_and_narrow(int64_t new_nnz) {
+ AT_CHECK(allow_tensor_metadata_change(), "set_nnz_and_narrow is not allowed on Tensor created from .data or .detach()");
AT_ASSERT(new_nnz <= nnz());
indices_ = indices_.narrow(1, 0, new_nnz);
values_ = values_.narrow(0, 0, new_nnz);
// make it happen
void set_indices_and_values_unsafe(const Tensor& indices, const Tensor& values);
+ // NOTE: `shallow_copy_and_detach()` does not copy the AutogradMeta pointer
+ // because it is unique for each Variable.
+ // NOTE: We don't set `allow_tensor_metadata_change_` to false here, because there are call sites
+ // to this function that need to change the shallow copy's size or storage afterwards, and setting
+ // `allow_tensor_metadata_change_` to false would prevent those changes from happening and is
+ // undesirable.
+ c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach() const override {
+ auto impl = c10::make_intrusive<SparseTensorImpl>(type_id(), dtype());
+ // TensorImpl general fields
+ // Note that these fields are not used in sparse tensor code, and we copy them here only for completeness.
+ impl->sizes_ = sizes_;
+ impl->strides_ = strides_;
+ impl->storage_offset_ = storage_offset_;
+ impl->is_contiguous_ = is_contiguous_;
+ impl->is_wrapped_number_ = is_wrapped_number_;
+ impl->reserved_ = reserved_;
+
+ // Sparse-specific fields
+ impl->sparse_dim_ = sparse_dim();
+ impl->dense_dim_ = dense_dim();
+ impl->indices_ = indices();
+ impl->values_ = values();
+ impl->coalesced_ = coalesced();
+ impl->refresh_numel();
+ return impl;
+ }
private:
int64_t get_device_slow() const override {
return values_.get_device();
// Caffe2 might have tensors whose storages are null, but we
// don't allow it in PyTorch.
AT_ASSERT(storage);
- tensor->storage_ = at::Storage(c10::intrusive_ptr<THStorage>::reclaim(storage));
+ tensor->set_storage(at::Storage(c10::intrusive_ptr<THStorage>::reclaim(storage)));
}
// for the first time (providing the necessary type). It is an ERROR to
// invoke any PyTorch operations on such a half-constructed storage,
// and this check tests for that case.
- AT_CHECK(tensor->storage_, "Cannot use PyTorch operations on a half-constructed "
+ AT_CHECK(tensor->storage(), "Cannot use PyTorch operations on a half-constructed "
"tensor. If this tensor came from Caffe2, please call GetMutableData on "
"it first; otherwise, this is a bug, please report it.");
- return tensor->storage_.unsafeGetStorageImpl();
+ return tensor->storage().unsafeGetStorageImpl();
}
inline void THTensor_maybe_zero_dim(THTensor *tensor, bool condition_when_zero_dim) {
device};
}
+AutogradMetaInterface::~AutogradMetaInterface() {}
+
} // namespace c10
}
};
+struct C10_API AutogradMetaInterface {
+ virtual ~AutogradMetaInterface();
+};
+
/**
* The low-level representation of a tensor, which contains a pointer
* to a storage (which contains the actual data) and metadata (e.g., sizes and
* which is harder to misuse.
*/
virtual void resize_dim(int64_t ndim) {
+ AT_CHECK(allow_tensor_metadata_change(), "resize_dim is not allowed on Tensor created from .data or .detach()");
sizes_.resize(ndim, 0);
strides_.resize(ndim, 0);
refresh_numel();
* which is harder to misuse.
*/
virtual void set_size(int64_t dim, int64_t new_size) {
+ AT_CHECK(allow_tensor_metadata_change(), "set_size is not allowed on Tensor created from .data or .detach()");
sizes_.at(dim) = new_size;
refresh_numel();
refresh_contiguous();
* which is harder to misuse.
*/
virtual void set_stride(int64_t dim, int64_t new_stride) {
+ AT_CHECK(allow_tensor_metadata_change(), "set_stride is not allowed on Tensor created from .data or .detach()");
strides_[dim] = new_stride;
refresh_numel();
refresh_contiguous();
* (and resizing if necessary.)
*/
virtual void set_storage_offset(int64_t storage_offset) {
+ AT_CHECK(allow_tensor_metadata_change(), "set_storage_offset is not allowed on Tensor created from .data or .detach()");
storage_offset_ = storage_offset;
}
* See Note [We regret making Variable hold a Tensor]
*/
void set_sizes_contiguous(IntList new_size) {
+ AT_CHECK(allow_tensor_metadata_change(), "set_sizes_contiguous is not allowed on Tensor created from .data or .detach()");
AT_ASSERT(!is_variable());
auto old_dim = sizes_.size();
auto new_dim = new_size.size();
* See Note [We regret making Variable hold a Tensor]
*/
void set_sizes_and_strides(IntList new_size, IntList new_stride) {
+ AT_CHECK(allow_tensor_metadata_change(), "set_sizes_and_strides is not allowed on Tensor created from .data or .detach()");
AT_ASSERT(!is_variable());
AT_CHECK(
new_size.size() == new_stride.size(),
*/
bool is_variable() const { return is_variable_; };
+ /**
+ * Set whether a tensor allows changes to its metadata (e.g. sizes / strides / storage / storage_offset).
+ */
+ virtual void set_allow_tensor_metadata_change(bool value) {
+ allow_tensor_metadata_change_ = value;
+ }
+
+ /**
+ * True if a tensor allows changes to its metadata (e.g. sizes / strides / storage / storage_offset).
+ */
+ virtual bool allow_tensor_metadata_change() const {
+ return allow_tensor_metadata_change_;
+ }
+
+ /**
+ * Set the pointer to autograd metadata.
+ */
+ void set_autograd_meta(std::unique_ptr<c10::AutogradMetaInterface> autograd_meta) {
+ autograd_meta_ = std::move(autograd_meta);
+ }
+
+ /**
+ * Return the pointer to autograd metadata.
+ */
+ c10::AutogradMetaInterface* autograd_meta() const {
+ return autograd_meta_.get();
+ }
+
+ /**
+ * Detach the autograd metadata unique_ptr from this tensor, and return it.
+ */
+ std::unique_ptr<c10::AutogradMetaInterface> detach_autograd_meta() {
+ return std::move(autograd_meta_);
+ }
+
+ // NOTE: `shallow_copy_and_detach()` does not copy the AutogradMeta pointer
+ // because it is unique for each Variable.
+ // NOTE: We don't set `allow_tensor_metadata_change_` to false here, because there are call sites
+ // to this function that need to change the shallow copy's size or storage afterwards, and setting
+ // `allow_tensor_metadata_change_` to false would prevent those changes from happening and is
+ // undesirable.
+ virtual c10::intrusive_ptr<TensorImpl> shallow_copy_and_detach() const {
+ auto impl = c10::make_intrusive<TensorImpl>(Storage(storage()), type_id(), is_variable());
+ impl->set_sizes_and_strides(sizes(), strides());
+ impl->storage_offset_ = storage_offset_;
+ impl->is_wrapped_number_ = is_wrapped_number_;
+ impl->reserved_ = reserved_;
+ impl->refresh_numel();
+ impl->refresh_contiguous();
+ return impl;
+ }
+
private:
// As an optimization, get_device handles the typical CUDA Tensor case and
// calls get_device_slow if the tensor stores its device somewhere else
}
void set_storage(at::Storage storage) {
+ AT_CHECK(allow_tensor_metadata_change(), "set_storage is not allowed on Tensor created from .data or .detach()");
storage_ = std::move(storage);
data_type_ = storage_.dtype();
}
is_contiguous_ = compute_contiguous();
}
-public:
- Storage storage_; // TODO: Fix visibility on me
-
protected:
+ Storage storage_;
+ // This pointer points to an AutogradMeta struct that stores autograd-specific fields
+ // (such as grad_ / grad_fn_ / grad_accumulator_).
+ // This pointer always has unique ownership (meaning only one TensorImpl can own it
+ // at a time).
+ std::unique_ptr<c10::AutogradMetaInterface> autograd_meta_ = nullptr;
+
// We could save a word or two by combining the SmallVector structs,
// since their size is redundant, and if we need to overflow the buffer space
// we could keep the two pointers together. However, that would require
bool is_contiguous_ = true;
bool is_variable_ = false;
bool is_wrapped_number_ = false;
+
+ // Previously, if we change the tensor metadata (e.g. sizes / strides / storage / storage_offset)
+ // of a derived tensor (i.e. tensors created from Python `tensor.data` or Python/C++ `tensor.detach()`),
+ // those metadata in the original tensor will also be updated. However, the new behavior is that
+ // those metadata changes to a derived tensor will not update the original tensor anymore, and we
+ // need this flag to make such changes explicitly illegal, to prevent users from changing metadata of
+ // the derived tensor and expecting the original tensor to also be updated.
+ //
+ // NOTE: For a full list of tensor metadata fields, please see `shallow_copy_and_detach()` in TensorImpl
+ // and its subclasses to find which fields are copied by value.
+ bool allow_tensor_metadata_change_ = true;
+
// we decide to keep reserved_ and it will
// live in Tensor after the split
// The logic is that if Extend() or ReserveSpace() were ever called,
// storage offset
// numel
// data type pointer
+// autograd metadata pointer
// miscellaneous bitfield
//
static_assert(sizeof(void*) != sizeof(int64_t) || // if 64-bit...
- sizeof(TensorImpl) == sizeof(int64_t) * 24,
+ sizeof(TensorImpl) == sizeof(int64_t) * 25,
"You changed the size of TensorImpl on 64-bit arch."
"See Note [TensorImpl size constraints] on how to proceed.");
elif isinstance(arg, torch.Tensor):
if arg.dtype == torch.float:
arg = arg.double()
- v = maybe_non_contig(arg).detach()
+ # NOTE: We do clone() after detach() here because we need to be able to change size/storage of v afterwards
+ v = maybe_non_contig(arg).detach().clone()
v.requires_grad = requires_grad and v.is_floating_point()
return v
elif callable(arg):
# needed for inplace operations done on `x`, e.g., copy_().
# Remove after implementing something equivalent to CopySlice
# for sparse views.
- x = x.detach()
+ # NOTE: We do clone() after detach() here because we need to be able to change size/storage of x afterwards
+ x = x.detach().clone()
return x, x._indices().clone(), x._values().clone()
def safeToDense(self, t):
def __init__(self):
super(TestScript.DerivedStateModule, self).__init__()
self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float))
- self.register_buffer('derived', torch.neg(self.param).detach())
+ self.register_buffer('derived', torch.neg(self.param).detach().clone())
# This is a flag so we can test that the pack method was called
self.register_buffer('pack_called', torch.zeros(1, dtype=torch.long))
def _zero_grad_parameters(self, module):
for p in module.parameters():
if p.grad is not None:
- p.grad.data.zero_()
+ with torch.no_grad():
+ p.grad.zero_()
p.grad.detach_()
def _get_parameters(self, module):
# Weights will no longer view onto the same chunk of memory
weight = all_vars[4]
weight_data = weight.data.clone()
- weight.data.set_(weight_data)
+ with torch.no_grad():
+ weight.set_(weight_data)
for i in range(2):
with warnings.catch_warnings(record=True) as w:
with self.assertRaisesRegex(RuntimeError, "bool value of Tensor with no values is ambiguous"):
torch.sparse_coo_tensor(([0, 1],), self.ValueTensor(2, 0), (4, 0)).is_nonzero()
+ def test_allow_tensor_metadata_change(self):
+ def do_test(t):
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "raw_resize_ is not allowed on Tensor created from .data or .detach()"):
+ t.transpose_(0, 1)
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "resize_ is not allowed on Tensor created from .data or .detach()"):
+ t.resize_as_(self.SparseTensor(3, 3))
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "resize_and_clear_ is not allowed on Tensor created from .data or .detach()"):
+ t.mul_(t)
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "set_coalesced is not allowed on Tensor created from .data or .detach()"):
+ t._coalesced_(True)
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "set_indices_and_values_unsafe is not allowed on Tensor created from .data or .detach()"):
+ a = self.SparseTensor(torch.tensor([[0, 1, 1], [2, 0, 2]]), torch.tensor([3., 4., 5.])).data
+ a.add_(a)
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "resize_and_clear_ is not allowed on Tensor created from .data or .detach()"):
+ a.zero_()
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "resize_ is not allowed on Tensor created from .data or .detach()"):
+ a.copy_(self.SparseTensor(3, 3))
+
+ do_test(self.SparseTensor(3, 0).data)
+ do_test(self.SparseTensor(3, 0).detach())
+
class TestUncoalescedSparse(TestSparse):
def setUp(self):
with self.assertRaisesRegex(RuntimeError, "expected both inputs to be on same device"):
torch.tensor(2).to("cuda:1") // torch.tensor(3).to("cuda:0")
+ def test_allow_tensor_metadata_change(self):
+ def do_test(t):
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "set_sizes_contiguous is not allowed on Tensor created from .data or .detach()"):
+ t.resize_((2, 1))
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "set_storage is not allowed on Tensor created from .data or .detach()"):
+ t.set_()
+ with self.assertRaisesRegex(
+ RuntimeError,
+ "set_storage_offset is not allowed on Tensor created from .data or .detach()"):
+ t.set_(t.storage(), 0, t.size(), list(t.stride()))
+
+ do_test(torch.tensor([[1, 2]]).data)
+ do_test(torch.tensor([[1, 2]]).detach())
+
# Functions to test negative dimension wrapping
METHOD = 1
INPLACE_METHOD = 2
// a thing never promised and documented, but used in some hacks seen
// on the internet.
if (grad_variable.is_sparse() && !new_grad.is_sparse()) {
- grad_variable.data() = new_grad.data() + grad_variable.data();
+ grad_variable.set_data(new_grad.data() + grad_variable.data());
} else {
grad_variable.data() += new_grad.data();
}
static PyObject * THPVariable_get_data(THPVariable *self)
{
HANDLE_TH_ERRORS
- return THPVariable_Wrap(make_variable(self->cdata.data(), false));
+ /// NOTE: Previously, if we change the tensor metadata (e.g. sizes / strides /
+ /// storage / storage_offset) of a tensor created from `.data`, those metadata
+ /// in the original tensor will also be updated. However, the new behavior is that
+ /// those metadata changes to the `.data` tensor will not update the original tensor
+ /// anymore, and here we need to set `allow_tensor_metadata_change_` to false to
+ /// make such changes explicitly illegal, in order to prevent users from changing
+ /// metadata of the `.data` tensor and expecting the original tensor to also
+ /// be updated.
+ auto var = make_variable(self->cdata.data(), /*requires_grad=*/false, /*allow_tensor_metadata_change=*/false);
+ return THPVariable_Wrap(var);
END_HANDLE_TH_ERRORS
}
namespace torch {
namespace autograd {
-Variable::Impl::Impl(at::Tensor data, bool requires_grad, Edge gradient_edge)
+Variable::Impl::Impl(at::Tensor data, std::unique_ptr<Variable::AutogradMeta> autograd_meta, bool requires_grad, Edge gradient_edge)
: TensorImpl(data.type_id(), data.dtype(), /*allocator=*/nullptr, /* is variable */ true),
- data_(std::move(data)),
- grad_fn_(std::move(gradient_edge.function)),
- requires_grad_(false),
- is_view_(false),
- output_nr_(gradient_edge.input_nr),
- pyobj_(nullptr) {
+ data_(std::move(data)) {
+ autograd_meta->grad_fn_ = std::move(gradient_edge.function);
+ autograd_meta->requires_grad_ = false;
+ autograd_meta->is_view_ = false;
+ autograd_meta->output_nr_ = gradient_edge.input_nr;
+ autograd_meta->pyobj_ = nullptr;
+ set_autograd_meta(std::move(autograd_meta));
+
// set_requires_grad also checks error conditions.
set_requires_grad(requires_grad);
AT_CHECK(
- !grad_fn_ || !requires_grad_,
+ !get_autograd_meta()->grad_fn_ || !get_autograd_meta()->requires_grad_,
"requires_grad should be false if grad_fn is set");
if (!data_.defined()) {
throw std::runtime_error("data is undefined");
}
std::shared_ptr<Function> Variable::Impl::get_grad_accumulator() {
- if (grad_fn_) {
+ auto autograd_meta = get_autograd_meta();
+ if (autograd_meta->grad_fn_) {
throw std::logic_error(
"get_grad_accumulator() should be only called on leaf Variables");
}
- if (!requires_grad_) {
+ if (!autograd_meta->requires_grad_) {
return nullptr;
}
- std::lock_guard<std::mutex> lock(mutex_);
+ std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
- auto result = grad_accumulator_.lock();
+ auto result = autograd_meta->grad_accumulator_.lock();
if (result)
return result;
c10::raw::intrusive_ptr::incref(this);
auto intrusive_from_this = c10::intrusive_ptr<Variable::Impl>::reclaim(this);
result = std::make_shared<AccumulateGrad>(Variable(std::move(intrusive_from_this)));
- grad_accumulator_ = result;
+ autograd_meta->grad_accumulator_ = result;
return result;
}
void Variable::Impl::detach_() {
- if (is_view_) {
+ auto autograd_meta = get_autograd_meta();
+ if (autograd_meta->is_view_) {
AT_ERROR("Can't detach views in-place. Use detach() instead");
}
set_requires_grad(false);
- grad_fn_.reset();
- output_nr_ = 0;
+ autograd_meta->grad_fn_.reset();
+ autograd_meta->output_nr_ = 0;
}
void Variable::Impl::backward(
c10::optional<Tensor> gradient,
bool keep_graph,
bool create_graph) {
+ auto autograd_meta = get_autograd_meta();
std::vector<Edge> edges;
- edges.emplace_back(grad_fn_, output_nr_);
+ edges.emplace_back(autograd_meta->grad_fn_, autograd_meta->output_nr_);
std::vector<Variable> inputs;
if (!gradient.has_value()) {
void Variable::Impl::set_data(Tensor new_data) {
// Resets gradient accumulator if metadata is out of date
- std::lock_guard<std::mutex> lock(mutex_);
- auto prior_accumulator = grad_accumulator_.lock();
+ auto autograd_meta = get_autograd_meta();
+ std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
+ auto prior_accumulator = autograd_meta->grad_accumulator_.lock();
if (prior_accumulator) {
const auto prior_device = prior_accumulator->input_metadata(0).device();
const auto new_device = new_data.is_cuda() ? new_data.get_device() : -1;
if (new_data.type() != data_.type() || prior_device != new_device) {
- grad_accumulator_.reset();
+ autograd_meta->grad_accumulator_.reset();
}
}
data_type_ = new_data.type().typeMeta();
type_id_ = new_data.type().type_id();
is_variable_ = true;
- data_ = std::move(new_data);
+
+ auto new_data_copy = at::Tensor(new_data.getIntrusivePtr()->shallow_copy_and_detach());
+ data_ = std::move(new_data_copy);
}
void Variable::Impl::release_resources() {
+ autograd_meta_.reset();
data_.reset();
- grad_.reset();
- grad_fn_.reset();
- hooks_.clear();
}
-Variable::DifferentiableViewImpl::DifferentiableViewImpl(Variable base, at::Tensor data, Edge gradient_edge)
- : Variable::Impl(std::move(data), false, std::move(gradient_edge)),
- base_(std::move(base)) {
- AT_CHECK(base_.defined(), "base is undefined");
- if (base_.is_view()) {
- base_ = base_.base();
+Variable::DifferentiableViewImpl::DifferentiableViewImpl(Variable base, at::Tensor data, Edge gradient_edge, std::unique_ptr<Variable::DifferentiableViewMeta> autograd_meta)
+ : Variable::Impl(std::move(data), std::move(autograd_meta), false, std::move(gradient_edge)) {
+ auto diff_view_meta = static_cast<Variable::DifferentiableViewMeta*>(get_autograd_meta());
+ diff_view_meta->base_ = std::move(base);
+ AT_CHECK(diff_view_meta->base_.defined(), "base is undefined");
+ if (diff_view_meta->base_.is_view()) {
+ diff_view_meta->base_ = diff_view_meta->base_.base();
}
- is_view_ = true;
- version_counter_ = base_.version_counter();
- attr_version = version_counter_.current_version();
+ diff_view_meta->is_view_ = true;
+ diff_view_meta->version_counter_ = diff_view_meta->base_.version_counter();
+ diff_view_meta->attr_version = diff_view_meta->version_counter_.current_version();
}
std::shared_ptr<Function>& Variable::DifferentiableViewImpl::get_grad_fn() {
- std::lock_guard<std::mutex> lock(mutex_);
- if (!grad_fn_ && !base_.requires_grad()) {
- return grad_fn_;
+ auto diff_view_meta = static_cast<Variable::DifferentiableViewMeta*>(get_autograd_meta());
+ std::lock_guard<std::mutex> lock(diff_view_meta->mutex_);
+ if (!diff_view_meta->grad_fn_ && !diff_view_meta->base_.requires_grad()) {
+ return diff_view_meta->grad_fn_;
}
- auto current_version = version_counter_.current_version();
- if (attr_version != current_version) {
- AT_ASSERT(output_nr_ == 0);
+ auto current_version = diff_view_meta->version_counter_.current_version();
+ if (diff_view_meta->attr_version != current_version) {
+ AT_ASSERT(diff_view_meta->output_nr_ == 0);
auto fn = std::make_shared<generated::AsStridedBackward>();
- fn->self_geometry = at::TensorGeometry(base_);
+ fn->self_geometry = at::TensorGeometry(diff_view_meta->base_);
fn->size = sizes().vec();
fn->stride = strides().vec();
fn->storage_offset = data_.storage_offset();
- fn->set_next_edges(collect_next_edges(base_));
+ fn->set_next_edges(collect_next_edges(diff_view_meta->base_));
fn->add_input_metadata(
- base_.type()
+ diff_view_meta->base_.type()
, sizes() // Note: sizes(), not base_.sizes(), is intentional
- , base_.is_cuda() ? base_.get_device() : -1);
- grad_fn_ = std::move(fn);
- attr_version = current_version;
+ , diff_view_meta->base_.is_cuda() ? diff_view_meta->base_.get_device() : -1);
+ diff_view_meta->grad_fn_ = std::move(fn);
+ diff_view_meta->attr_version = current_version;
}
- return grad_fn_;
+ return diff_view_meta->grad_fn_;
}
void Variable::DifferentiableViewImpl::rebase_history(Edge gradient_edge) {
+ auto diff_view_meta = static_cast<Variable::DifferentiableViewMeta*>(get_autograd_meta());
AT_ASSERT(gradient_edge.input_nr == 0);
AT_ASSERT(gradient_edge.function);
AT_CHECK(
gradient_edge.function->num_inputs() == 1,
"Functions which modify views in-place must return a single Variable");
- this->output_nr_ = gradient_edge.input_nr;
+ diff_view_meta->output_nr_ = gradient_edge.input_nr;
auto copy_slices = std::make_shared<CopySlices>(
- base_, at::TensorGeometry(data_), std::move(gradient_edge.function));
- base_.set_gradient_edge({std::move(copy_slices), 0});
+ diff_view_meta->base_, at::TensorGeometry(data_), std::move(gradient_edge.function));
+ diff_view_meta->base_.set_gradient_edge({std::move(copy_slices), 0});
get_grad_fn(); // trigger an update to the view's grad_fn
}
void Variable::DifferentiableViewImpl::release_resources() {
+ auto diff_view_meta = static_cast<Variable::DifferentiableViewMeta*>(get_autograd_meta());
+ diff_view_meta->base_.reset();
Variable::Impl::release_resources();
- base_.reset();
}
void Variable::rebase_history(Edge gradient_edge) {
Variable base,
at::Tensor data,
bool is_differentiable,
+ bool allow_tensor_metadata_change,
Edge gradient_edge);
/// Creates a `Variable` from the given `Tensor`. `requires_grad` should be
/// set only for leaves, and determines whether the `Variable` will accumulate
/// gradients. NOTE: `data` must *not* be a `Variable` already. Its dynamic
/// type *must* be `Tensor`.
- friend Variable make_variable(at::Tensor data, bool requires_grad);
+ friend Variable make_variable(
+ at::Tensor data,
+ bool requires_grad,
+ bool allow_tensor_metadata_change);
/// Creates a `Variable` from the given `Tensor` and specify a
/// `gradient_edge`, i.e. a (function, input_nr) pair specifying the function
/// in the autograd graph, and what particular input of that function, this
/// variable is connected to.
- friend Variable make_variable(at::Tensor data, Edge gradient_edge);
+ friend Variable make_variable(
+ at::Tensor data,
+ Edge gradient_edge,
+ bool allow_tensor_metadata_change);
// Tensor Conversions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// Returns a copy of this `Variable` that is detached from its autograd graph
/// and has a blank version. This method is OK to call if the `Variable` is a
/// view.
+ /// NOTE: Previously, if we change the tensor metadata (e.g. sizes / strides /
+ /// storage / storage_offset) of a tensor created from `detach()`, those metadata
+ /// in the original tensor will also be updated. However, the new behavior is that
+ /// those metadata changes to the detached tensor will not update the original tensor
+ /// anymore, and in the `detach()` function we need to set `allow_tensor_metadata_change_`
+ /// to false to make such changes explicitly illegal, in order to prevent users from
+ /// changing metadata of the detached tensor and expecting the original tensor to also
+ /// be updated.
Variable detach() const;
/// Like `detach()`, but removes this `Variable` in-place. This method may
PyObject* pyobj() const noexcept;
void set_pyobj(PyObject* pyobj) noexcept;
+ struct AutogradMeta;
+ Variable::AutogradMeta* get_autograd_meta() const noexcept;
+
private:
/// Private implementation struct of the `Variable`. This struct declaration
/// and the `get()` method which exposes it shall forever remain private and
/// never be exposed to the public interface of this class.
struct Impl;
struct DifferentiableViewImpl;
+ struct DifferentiableViewMeta;
// Private Methods
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
};
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// Variable::AutogradMeta
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+/// Each `Variable` has one unique `AutogradMeta` struct, which stores autograd
+/// metadata fields that are necessary for tracking the Variable's autograd history.
+
+struct TORCH_API Variable::AutogradMeta : public c10::AutogradMetaInterface {
+ std::string name;
+
+ Variable grad_;
+ std::shared_ptr<Function> grad_fn_;
+ std::weak_ptr<Function> grad_accumulator_;
+
+ VariableVersion version_counter_;
+ std::vector<std::shared_ptr<FunctionPreHook>> hooks_;
+
+ // Only meaningful on leaf variables (must be false otherwise)
+ bool requires_grad_;
+
+ bool is_view_;
+
+ // The "output number" of this variable; e.g., if this variable
+ // was the second output of a function, then output_nr == 1.
+ // We use this to make sure we can setup the backwards trace
+ // correctly when this variable is passed to another function.
+ uint32_t output_nr_;
+ PyObject* pyobj_ = nullptr; // weak reference
+
+ // Mutex to ensure that concurrent read operations that modify internal
+ // state are still thread-safe. Used by get_grad_fn and
+ // get_grad_accumulator.
+ std::mutex mutex_;
+};
+
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// Variable::DifferentiableViewMeta
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+struct TORCH_API Variable::DifferentiableViewMeta : public Variable::AutogradMeta {
+ /// The base `Variable` (never a view).
+ Variable base_;
+
+ /// The value of the version_counter at the time grad_fn was created. The
+ /// grad_fn field is stale if attr_version !=
+ /// version_counter.current_version().
+ uint32_t attr_version;
+};
+
+//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Variable::Impl
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
struct TORCH_API Variable::Impl : public at::TensorImpl {
explicit Impl(
at::Tensor data,
+ std::unique_ptr<Variable::AutogradMeta> autograd_meta,
bool requires_grad = false,
Edge gradient_edge = Edge());
std::shared_ptr<Function> get_grad_accumulator();
virtual std::shared_ptr<Function>& get_grad_fn() {
- return grad_fn_;
+ return get_autograd_meta()->grad_fn_;
}
virtual const Variable& base() const {
AT_CHECK(
!requires_grad || at::isFloatingType(at::typeMetaToScalarType(dtype())),
"Only Tensors of floating point dtype can require gradients");
- requires_grad_ = requires_grad;
+ get_autograd_meta()->requires_grad_ = requires_grad;
}
bool requires_grad() const override {
- return requires_grad_ || grad_fn_ || (is_view_ && base().requires_grad());
+ return get_autograd_meta()->requires_grad_ || get_autograd_meta()->grad_fn_ || (get_autograd_meta()->is_view_ && base().requires_grad());
}
/// Accesses the gradient `Variable` of this `Variable`.
Variable& grad() override {
- return grad_;
+ return get_autograd_meta()->grad_;
}
const Variable& grad() const override {
- return grad_;
+ return get_autograd_meta()->grad_;
}
void detach_();
/// Reset all expensive fields to free up resources
void release_resources() override;
- std::string name;
- at::Tensor data_;
-
- Variable grad_;
- std::shared_ptr<Function> grad_fn_;
- std::weak_ptr<Function> grad_accumulator_;
-
- VariableVersion version_counter_;
- std::vector<std::shared_ptr<FunctionPreHook>> hooks_;
-
- // Only meaningful on leaf variables (must be false otherwise)
- bool requires_grad_;
-
- bool is_view_;
+ Variable::AutogradMeta* get_autograd_meta() const {
+ return static_cast<Variable::AutogradMeta*>(autograd_meta());
+ }
- // The "output number" of this variable; e.g., if this variable
- // was the second output of a function, then output_nr == 1.
- // We use this to make sure we can setup the backwards trace
- // correctly when this variable is passed to another function.
- uint32_t output_nr_;
- PyObject* pyobj_; // weak reference
+ int64_t storage_offset() const override;
- // Mutex to ensure that concurrent read operations that modify internal
- // state are still thread-safe. Used by get_grad_fn and
- // get_grad_accumulator.
- std::mutex mutex_;
+ /// The underlying data tensor for this Variable.
+ /// This field will be removed once VariableImpl and TensorImpl are merged.
+ at::Tensor data_;
- int64_t storage_offset() const override;
private:
int64_t get_device_slow() const override;
};
/// Relevant logic for non-differentiable views is implemented in
/// make_variable_view below, and wrap_output of gen_variable_type.py.
struct TORCH_API Variable::DifferentiableViewImpl : public Variable::Impl {
- DifferentiableViewImpl(Variable base, at::Tensor data, Edge gradient_edge);
+ DifferentiableViewImpl(
+ Variable base,
+ at::Tensor data,
+ Edge gradient_edge,
+ std::unique_ptr<Variable::DifferentiableViewMeta> autograd_meta);
/// Gets the up-to-date grad_fn. If the shared data or base was modified, we
/// re-create the grad_fn to express the up-to-date view relationship between
std::shared_ptr<Function>& get_grad_fn() override;
const Variable& base() const override {
- return base_;
+ return static_cast<Variable::DifferentiableViewMeta*>(get_autograd_meta())->base_;
}
/// Reset all expensive fields to free up resources
/// Called after in-place modifications. Modifies the grad_fn of the base
/// Variable.
void rebase_history(Edge gradient_edge);
-
- /// The base `Variable` (never a view).
- Variable base_;
-
- /// The value of the version_counter at the time grad_fn was created. The
- /// grad_fn field is stale if attr_version !=
- /// version_counter.current_version().
- uint32_t attr_version;
};
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Factory Functions
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+/// NOTE: `allow_tensor_metadata_change` is set to true by default, because there
+/// are a lot of call sites to these factory functions that need to change the
+/// variable's size or storage afterwards, and they don't expect the original
+/// tensor (where the variable is created from) to be updated. Setting
+/// `allow_tensor_metadata_change_` to false by default would unnecessarily
+/// prevent those changes from happening and is undesirable.
+
// See NOTE [ Autograd View Variables ] for details.
inline Variable make_variable_view(
Variable base,
at::Tensor data,
bool is_differentiable = true,
+ bool allow_tensor_metadata_change = true,
Edge gradient_edge = Edge()) {
if (data.defined()) {
if (is_differentiable) {
/// Differentiable view. Track history with DifferentiableViewImpl.
+ auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach();
+ data_impl_copy->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
+ auto data_copy = at::Tensor(data_impl_copy);
+ auto diff_view_meta = c10::guts::make_unique<Variable::DifferentiableViewMeta>();
return Variable(c10::make_intrusive<Variable::DifferentiableViewImpl>(
- std::move(base), std::move(data), std::move(gradient_edge)));
+ std::move(base), std::move(data_copy), std::move(gradient_edge), std::move(diff_view_meta)));
} else {
/// Non-differentiable view. Just share version counter.
+ auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach();
+ data_impl_copy->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
+ auto data_copy = at::Tensor(data_impl_copy);
+ auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
auto var = Variable(c10::make_intrusive<Variable::Impl>(
- std::move(data), false, std::move(gradient_edge)));
+ std::move(data_copy), std::move(autograd_meta), false, std::move(gradient_edge)));
var.set_version_counter(base.version_counter());
return var;
}
return Variable();
}
-inline Variable make_variable(at::Tensor data, bool requires_grad = false) {
+inline Variable make_variable(
+ at::Tensor data,
+ bool requires_grad = false,
+ bool allow_tensor_metadata_change = true) {
AT_CHECK(
!data.is_variable(),
"Must not create a new variable from a variable, use its .data()");
if (data.defined()) {
- return Variable(c10::make_intrusive<Variable::Impl>(data, requires_grad));
+ auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach();
+ data_impl_copy->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
+ auto data_copy = at::Tensor(data_impl_copy);
+ auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
+ return Variable(c10::make_intrusive<Variable::Impl>(data_copy, std::move(autograd_meta), requires_grad));
}
return Variable();
}
-inline Variable make_variable(at::Tensor data, Edge gradient_edge) {
+inline Variable make_variable(
+ at::Tensor data,
+ Edge gradient_edge,
+ bool allow_tensor_metadata_change = true) {
AT_CHECK(
!data.is_variable(),
"Must not create a new variable from a variable, use its .data()");
if (data.defined()) {
- return Variable(c10::make_intrusive<Variable::Impl>(data, false, std::move(gradient_edge)));
+ auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach();
+ data_impl_copy->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
+ auto data_copy = at::Tensor(data_impl_copy);
+ auto autograd_meta = c10::guts::make_unique<Variable::AutogradMeta>();
+ return Variable(c10::make_intrusive<Variable::Impl>(data_copy, std::move(autograd_meta), false, std::move(gradient_edge)));
}
return Variable();
}
}
inline Function* Variable::grad_fn_unsafe() const {
- return get()->grad_fn_.get();
+ return get_autograd_meta()->grad_fn_.get();
}
inline void Variable::set_grad_accumulator(
std::weak_ptr<Function> grad_accumulator) {
- get()->grad_accumulator_ = std::move(grad_accumulator);
+ get_autograd_meta()->grad_accumulator_ = std::move(grad_accumulator);
}
inline std::shared_ptr<Function> Variable::try_get_grad_accumulator() const {
- return get()->grad_accumulator_.lock();
+ return get_autograd_meta()->grad_accumulator_.lock();
}
inline std::shared_ptr<Function> Variable::grad_accumulator() const {
}
inline Variable Variable::detach() const {
- return make_variable_view(*this, get()->data_, /*is_differentiable=*/false);
+ auto var = make_variable_view(*this, get()->data_, /*is_differentiable=*/false, /*allow_tensor_metadata_change=*/false, Edge());
+ return var;
}
inline void Variable::detach_() {
}
inline void Variable::set_gradient_edge(Edge edge) noexcept {
- get()->grad_fn_ = std::move(edge.function);
- get()->output_nr_ = edge.input_nr;
+ get_autograd_meta()->grad_fn_ = std::move(edge.function);
+ get_autograd_meta()->output_nr_ = edge.input_nr;
}
inline uint32_t Variable::output_nr() const noexcept {
- return get()->output_nr_;
+ return get_autograd_meta()->output_nr_;
}
inline bool Variable::is_leaf() const noexcept {
- return get()->grad_fn_ == nullptr;
+ return get_autograd_meta()->grad_fn_ == nullptr;
}
// Versions
inline void Variable::set_version_counter(
const VariableVersion& version_counter) noexcept {
- get()->version_counter_ = version_counter;
+ get_autograd_meta()->version_counter_ = version_counter;
}
inline void Variable::bump_version() noexcept {
- get()->version_counter_.bump();
+ get_autograd_meta()->version_counter_.bump();
}
inline uint32_t Variable::current_version() const noexcept {
- return get()->version_counter_.current_version();
+ return get_autograd_meta()->version_counter_.current_version();
}
inline const VariableVersion& Variable::version_counter() const noexcept {
- return get()->version_counter_;
+ return get_autograd_meta()->version_counter_;
}
// Hooks
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline void Variable::add_hook(std::shared_ptr<FunctionPreHook> hook) {
- get()->hooks_.push_back(std::move(hook));
+ get_autograd_meta()->hooks_.push_back(std::move(hook));
}
inline const std::vector<std::shared_ptr<FunctionPreHook>>& Variable::hooks()
const noexcept {
- return get()->hooks_;
+ return get_autograd_meta()->hooks_;
}
inline void Variable::clear_hooks() {
- get()->hooks_.clear();
+ get_autograd_meta()->hooks_.clear();
}
// View Variables
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline bool Variable::is_view() const noexcept {
- return get()->is_view_;
+ return get_autograd_meta()->is_view_;
}
inline const Variable& Variable::base() const {
//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
inline void Variable::set_name(const std::string& name) {
- get()->name = name;
+ get_autograd_meta()->name = name;
}
inline const std::string& Variable::name() const noexcept {
- return get()->name;
+ return get_autograd_meta()->name;
}
inline void Variable::set_pyobj(PyObject* pyobj) noexcept {
- get()->pyobj_ = pyobj;
+ get_autograd_meta()->pyobj_ = pyobj;
}
inline PyObject* Variable::pyobj() const noexcept {
- return get()->pyobj_;
+ return get_autograd_meta()->pyobj_;
+}
+
+inline Variable::AutogradMeta* Variable::get_autograd_meta() const noexcept {
+ return get()->get_autograd_meta();
}
// Private Methods
.def(
"t_",
[](Node& n, const char* name, torch::autograd::Variable v) {
- return n.t_(Symbol::attr(name), std::move(v.data()));
+ return n.t_(Symbol::attr(name), v.data());
})
.def(
"t",
std::vector<at::Tensor> tensors;
tensors.reserve(vs.size());
for (auto& variable : vs) {
- tensors.push_back(std::move(variable.data()));
+ tensors.push_back(variable.data());
}
return n.ts_(Symbol::attr(name), std::move(tensors));
})
.. note::
- Returned Tensor uses the same data tensor as the original one.
+ Returned Tensor shares the same storage with the original one.
In-place modifications on either of them will be seen, and may trigger
errors in correctness checks.
+ IMPORTANT NOTE: Previously, in-place size / stride / storage changes
+ (such as `resize_` / `resize_as_` / `set_` / `transpose_`) to the returned tensor
+ also update the original tensor. Now, these in-place changes will not update the
+ original tensor anymore, and will instead trigger an error.
+ For sparse tensors:
+ In-place indices / values changes (such as `zero_` / `copy_` / `add_`) to the
+ returned tensor will not update the original tensor anymore, and will instead
+ trigger an error.
""")
detach_ = _add_docstr(_C._TensorBase.detach_, r"""