From d701357d921ef167d42c125e65b6f7da6be3ad0f Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Wed, 8 Sep 2021 13:25:42 -0700 Subject: [PATCH] Factor out TensorBase that doesn't depend on native operators (#63612) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63612 This makes Tensor inherit from a new class TensorBase, that provides a subset of Tensor that doesn't directly depend on native_functions.yaml. Code that only includes TensorBase.h with thus not need to be rebuilt every time someone changes an operator signature. Making `Tensor` inherit from this class means that `const TensorBase&` parameters will be callable with an ordinary `Tensor`. I've also made `Tensor` constructible and assignable from `TensorBase` to minimize friction in code mixing the two types. To help enforce that `Tensor.h` and `Functions.h` aren't accidentally included, I've added an error into `Operators.h` if `TORCH_ASSERT_NO_OPERATORS` is defined. We can either set this in the build system for certain folders, or just define it at the top of any file. I've also included an example of manually special-casing the commonly used `contiguous` operator. The inline function's slow path defers to `TensorBase::__dispatch_contiguous` which is defined in `Tensor.cpp`. I've made it so `OptionalTensorRef` is constructible from `TensorBase`, so I can materialize a `Tensor` for use in dispatch without actually increasing its refcount. Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D30728580 Pulled By: ezyang fbshipit-source-id: 2cbc8eee08043382ee6904ea8e743b1286921c03 --- aten/src/ATen/core/NamedTensor.cpp | 9 +- aten/src/ATen/core/NamedTensor.h | 8 +- aten/src/ATen/core/QuantizerBase.h | 4 + aten/src/ATen/core/Tensor.cpp | 47 +- aten/src/ATen/core/Tensor.h | 25 +- aten/src/ATen/core/TensorBase.h | 917 +++++++++++++++++++++ aten/src/ATen/core/VariableHooksInterface.h | 34 +- aten/src/ATen/quantized/Quantizer.cpp | 4 +- aten/src/ATen/quantized/Quantizer.h | 2 +- aten/src/ATen/templates/Operators.h | 8 + aten/src/ATen/templates/TensorBody.h | 570 +------------ aten/src/ATen/templates/TensorMethods.cpp | 4 +- c10/core/TensorImpl.cpp | 9 +- c10/core/TensorImpl.h | 13 +- caffe2/core/tensor.cc | 2 +- test/cpp/api/modules.cpp | 4 +- .../api/include/torch/nn/functional/activation.h | 4 +- torch/csrc/api/src/nn/modules/activation.cpp | 4 +- torch/csrc/autograd/autograd_meta.cpp | 16 +- torch/csrc/autograd/cpp_hook.cpp | 4 +- torch/csrc/autograd/cpp_hook.h | 2 +- torch/csrc/autograd/custom_function.cpp | 2 +- torch/csrc/autograd/custom_function.h | 4 +- torch/csrc/autograd/python_variable.cpp | 2 +- torch/csrc/autograd/python_variable.h | 2 +- torch/csrc/autograd/variable.cpp | 89 +- torch/csrc/autograd/variable.h | 16 +- 27 files changed, 1150 insertions(+), 655 deletions(-) create mode 100644 aten/src/ATen/core/TensorBase.h diff --git a/aten/src/ATen/core/NamedTensor.cpp b/aten/src/ATen/core/NamedTensor.cpp index b394a8d..40fd58a 100644 --- a/aten/src/ATen/core/NamedTensor.cpp +++ b/aten/src/ATen/core/NamedTensor.cpp @@ -1,6 +1,7 @@ +#define TORCH_ASSERT_NO_OPERATORS #include -#include +#include #include namespace at { @@ -16,12 +17,12 @@ void NamesMode::set_enabled(bool enabled) { c10::impl::tls_set_dispatch_key_excluded(DispatchKey::Named, !enabled); } -const Tensor& internal_set_names_inplace(const Tensor& tensor, optional names) { +const TensorBase& internal_set_names_inplace(const TensorBase& tensor, optional names) { impl::internal_set_names_inplace(tensor.unsafeGetTensorImpl(), names, /*validate_names=*/true); return tensor; } -const Tensor& internal_set_names_inplace(const Tensor& tensor, std::vector&& names, bool validate_names) { +const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::vector&& names, bool validate_names) { impl::internal_set_names_inplace(tensor.unsafeGetTensorImpl(), std::move(names), validate_names); return tensor; } @@ -48,7 +49,7 @@ static void check_unique_names(DimnameList names) { } } -void check_names_valid_for(const Tensor& tensor, DimnameList names) { +void check_names_valid_for(const TensorBase& tensor, DimnameList names) { return impl::check_names_valid_for(tensor.unsafeGetTensorImpl(), names); } diff --git a/aten/src/ATen/core/NamedTensor.h b/aten/src/ATen/core/NamedTensor.h index e343149..73a0d7d 100644 --- a/aten/src/ATen/core/NamedTensor.h +++ b/aten/src/ATen/core/NamedTensor.h @@ -6,6 +6,8 @@ namespace at { +class TensorBase; + // XXX: This file exists because TensorImpl is in c10, but Dimname is in ATen. // Due to the c10/ATen library split, TensorImpl cannot depend on Dimname, // so we have a couple of workarounds. @@ -95,12 +97,12 @@ struct TORCH_API NoNamesGuard { bool initialized; }; -void check_names_valid_for(const Tensor& tensor, DimnameList names); +void check_names_valid_for(const TensorBase& tensor, DimnameList names); void check_names_valid_for(size_t tensor_dim, DimnameList names); // Sets the names of `tensor` to be `names`. -TORCH_API const Tensor& internal_set_names_inplace(const Tensor& tensor, c10::optional names); -TORCH_API const Tensor& internal_set_names_inplace(const Tensor& tensor, std::vector&& names, bool validate_names); +TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, c10::optional names); +TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::vector&& names, bool validate_names); constexpr size_t kMaxNamedTensorDim = 64; diff --git a/aten/src/ATen/core/QuantizerBase.h b/aten/src/ATen/core/QuantizerBase.h index cdc5b6f..2648223 100644 --- a/aten/src/ATen/core/QuantizerBase.h +++ b/aten/src/ATen/core/QuantizerBase.h @@ -1,5 +1,9 @@ #pragma once +#include +#include +#include + namespace at { class Tensor; diff --git a/aten/src/ATen/core/Tensor.cpp b/aten/src/ATen/core/Tensor.cpp index 4c7643c..98ab4af 100644 --- a/aten/src/ATen/core/Tensor.cpp +++ b/aten/src/ATen/core/Tensor.cpp @@ -6,7 +6,16 @@ namespace at { -void Tensor::enforce_invariants() { +const TensorBase& get_tensor_base(const Tensor &t) { + return t; +} + +TensorBase TensorBase::__dispatch_contiguous(c10::MemoryFormat memory_format) const { + OptionalTensorRef self(*this); + return at::_ops::contiguous::call(*self, memory_format); +} + +void TensorBase::enforce_invariants() { if (impl_.get() == nullptr) { throw std::runtime_error("TensorImpl with nullptr is not supported"); } @@ -26,7 +35,7 @@ void Tensor::enforce_invariants() { } } -void Tensor::print() const { +void TensorBase::print() const { if (defined()) { std::cerr << "[" << toString() << " " << sizes() << "]" << std::endl; } else { @@ -34,7 +43,7 @@ void Tensor::print() const { } } -std::string Tensor::toString() const { +std::string TensorBase::toString() const { std::string base_str; if (scalar_type() == ScalarType::Undefined) { base_str = "UndefinedType"; @@ -44,39 +53,39 @@ std::string Tensor::toString() const { return base_str; } -Tensor Tensor::variable_data() const { +TensorBase TensorBase::variable_data() const { return impl::GetVariableHooks()->variable_data(*this); } -Tensor Tensor::tensor_data() const { +TensorBase TensorBase::tensor_data() const { return impl::GetVariableHooks()->tensor_data(*this); } -bool Tensor::is_leaf() const { +bool TensorBase::is_leaf() const { return impl::GetVariableHooks()->is_leaf(*this); } -int64_t Tensor::output_nr() const { +int64_t TensorBase::output_nr() const { return impl::GetVariableHooks()->output_nr(*this); } -void Tensor::set_data(const Tensor & new_data) const { +void TensorBase::set_data(const TensorBase & new_data) const { impl::GetVariableHooks()->set_data(*this, new_data); } -Tensor Tensor::data() const { +TensorBase TensorBase::data() const { return impl::GetVariableHooks()->data(*this); } -int64_t Tensor::_version() const { +int64_t TensorBase::_version() const { return impl::GetVariableHooks()->_version(*this); } -void Tensor::retain_grad() const { +void TensorBase::retain_grad() const { impl::GetVariableHooks()->retain_grad(*this); } -bool Tensor::retains_grad() const { +bool TensorBase::retains_grad() const { return impl::GetVariableHooks()->retains_grad(*this); } @@ -87,7 +96,7 @@ void Tensor::_backward(TensorList inputs, return impl::GetVariableHooks()->_backward(*this, inputs, gradient, keep_graph, create_graph); } -const Tensor& Tensor::requires_grad_(bool _requires_grad) const { +const TensorBase& TensorBase::requires_grad_(bool _requires_grad) const { impl::GetVariableHooks()->requires_grad_(*this, _requires_grad); return *this; } @@ -95,27 +104,27 @@ const Tensor& Tensor::requires_grad_(bool _requires_grad) const { // View Variables //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -bool Tensor::is_view() const { +bool TensorBase::is_view() const { return impl::GetVariableHooks()->is_view(*this); } -const Tensor& Tensor::_base() const { +const TensorBase& TensorBase::_base() const { return impl::GetVariableHooks()->base(*this); } -const std::string& Tensor::name() const { +const std::string& TensorBase::name() const { return impl::GetVariableHooks()->name(*this); } -const std::shared_ptr& Tensor::grad_fn() const { +const std::shared_ptr& TensorBase::grad_fn() const { return impl::GetVariableHooks()->grad_fn(*this); } -void Tensor::remove_hook(unsigned pos) const { +void TensorBase::remove_hook(unsigned pos) const { impl::GetVariableHooks()->remove_hook(*this, pos); } -unsigned Tensor::_register_hook(std::function hook) const { +unsigned TensorBase::_register_hook(std::function hook) const { return impl::GetVariableHooks()->_register_hook(*this, std::move(hook)); } diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index fa2479c..5293928 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -12,7 +12,7 @@ class TORCH_API OptionalTensorRef { ref_.unsafeReleaseTensorImpl(); } - OptionalTensorRef(const Tensor& src) + OptionalTensorRef(const TensorBase& src) : ref_(Tensor::unsafe_borrow_t{}, src) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined()); } @@ -48,4 +48,27 @@ class TORCH_API OptionalTensorRef { private: Tensor ref_; }; + +template +auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t { + // Return the grad argument in case of a hook with void return type to have an + // std::function with Tensor return type + static_assert(std::is_same::value, + "Expected hook to return void"); + return _register_hook([fn=std::forward(hook)](const TensorBase& grad_base) { + OptionalTensorRef grad(grad_base); + fn(*grad); + return Tensor(); + }); +} + +template +auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t { + return _register_hook([fn=std::forward(hook)](const TensorBase& grad_base) { + OptionalTensorRef grad(grad_base); + Tensor ret = fn(*grad); + return TensorBase(std::move(ret)); + }); +} + } // namespace at diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h new file mode 100644 index 0000000..e91e9e1 --- /dev/null +++ b/aten/src/ATen/core/TensorBase.h @@ -0,0 +1,917 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace c10 { +struct TensorOptions; +} + +namespace torch { namespace autograd { + +struct Node; + +}} // namespace torch::autograd + +namespace at { + +class Tensor; +class TensorBase; + +// Convert Tensor to TensorBase without any need to include Tensor.h +TORCH_API const TensorBase& get_tensor_base(const Tensor& t); + +namespace impl { +inline bool variable_excluded_from_dispatch() { +#ifdef C10_MOBILE + // Please read the comment in `VariableFallbackKernel.cpp` about the background of this change. + return true; +#else + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::impl::tls_local_dispatch_key_set().excluded_.has(DispatchKey::Autograd)); + return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset); +#endif +} +} + +// NOTE: [Tensor vs. TensorBase] +// +// Tensor, being the central data structure in PyTorch, gets used and +// it's header included almost everywhere. Unfortunately this means +// every time an operator signature is updated or changed in +// native_functions.yaml, you (and every other PyTorch developer) need +// to recompile all of ATen and it's dependencies. +// +// TensorBase aims to break up these header dependencies, and improve +// incremental build times for all PyTorch developers. TensorBase +// represents a reference counted handle to TensorImpl, exactly the +// same as Tensor. However, TensorBase doesn't have code generated +// methods in it's API and thus no dependence on native_functions.yaml. +// +// Usage tips +// ---------- +// - You can `#define TORCH_ASSERT_NO_OPERATORS` at the top of a .cpp +// or .cu file to ensure it has no header dependencies on +// native_functions.yaml (direct or indirect). +// - Tensor inherits from TensorBase, so functions taking +// `const TensorBase &` are callable with Tensor as well. +// - TensorBase can be converted to tensor with `Tensor(tensor_base)`, +// but this requires a reference-count bump. OptionalTensorRef on +// the other hand can materialize a `const Tensor &` without +// touching the reference-count. +class TORCH_API TensorBase { + public: + struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; }; + + protected: + // Create a Tensor with a +0 reference count. Special care must be + // taken to avoid decrementing this reference count at destruction + // time. Intended to support MaybeOwnedTraits. + explicit TensorBase(unsafe_borrow_t, const TensorBase& rhs) + : impl_(c10::intrusive_ptr::reclaim(rhs.impl_.get())) {} + friend MaybeOwnedTraits; + + public: + TensorBase() = default; + // This constructor should not be used by end users and is an implementation + // detail invoked by autogenerated code. + explicit TensorBase( + c10::intrusive_ptr tensor_impl) + : impl_(std::move(tensor_impl)) { + if (impl_.get() == nullptr) { + throw std::runtime_error("TensorImpl with nullptr is not supported"); + } + } + TensorBase(const TensorBase&) = default; + TensorBase(TensorBase&&) = default; + + public: + // Creates a new wrapper from TensorImpl. Intentionally a free method because + // it should be used with care. Checks necessary invariants + static TensorBase wrap_tensor_impl( + c10::intrusive_ptr tensor_impl) { + TensorBase r(std::move(tensor_impl)); + r.enforce_invariants(); + return r; + } + + int64_t dim() const { + return impl_->dim(); + } + int64_t storage_offset() const { + return impl_->storage_offset(); + } + + TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const { + if (is_contiguous(memory_format)) { + return *this; + } else { + return __dispatch_contiguous(memory_format); + } + } + + /// Should be used if *this can reasonably be expected to be contiguous and + /// performance is important. + /// Compared to contiguous, it saves a reference count + /// increment/decrement if *this is already contiguous, at the cost + /// in all cases of an extra pointer of stack usage, an extra branch + /// to access, and an extra branch at destruction time. + c10::MaybeOwned expect_contiguous( + MemoryFormat memory_format=MemoryFormat::Contiguous) const &; + + // Use .contiguous() instead. Trying to borrow from a prvalue + // will only lead to trouble and dangling references. + c10::MaybeOwned expect_contiguous( + MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete; + + bool is_complex() const { + return at::isComplexType(this->scalar_type()); + } + + bool is_floating_point() const { + return at::isFloatingType(this->scalar_type()); + } + + bool is_signed() const { + return at::isSignedType(this->scalar_type()); + } + + int64_t size(int64_t dim) const { + // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) + dim = c10::maybe_wrap_dim(dim, this->dim(), false); + return sizes()[dim]; + } + + int64_t stride(int64_t dim) const { + // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) + dim = c10::maybe_wrap_dim(dim, this->dim(), false); + return strides()[dim]; + } + + TensorImpl * unsafeGetTensorImpl() const { + return impl_.get(); + } + TensorImpl * unsafeReleaseTensorImpl() { + return impl_.release(); + } + const c10::intrusive_ptr& getIntrusivePtr() const { + return impl_; + } + + c10::intrusive_ptr unsafeReleaseIntrusivePtr() { + return std::move(impl_); + } + + bool defined() const { + return impl_; + } + + void reset() { + impl_.reset(); + } + + TensorBase& operator=(const TensorBase& x) & { + impl_ = x.impl_; + return *this; + }; + TensorBase& operator=(TensorBase&& x) & { + impl_ = std::move(x.impl_); + return *this; + } + + // Ban assignment to rvalues, since at::Tensor (weirdly) performs a deep copy here + TensorBase& operator=(const TensorBase&) && = delete; + TensorBase& operator=(TensorBase&&) && = delete; + + bool is_same(const TensorBase& other) const noexcept { + return impl_ == other.impl_; + } + size_t use_count() const noexcept { + return impl_.use_count(); + } + size_t weak_use_count() const noexcept { + return impl_.weak_use_count(); + } + + std::string toString() const; + + IntArrayRef sizes() const { + return impl_->sizes(); + } + IntArrayRef strides() const { + return impl_->strides(); + } + // See impl::get_opt_names in ATen/NamedTensor.h for docs. + c10::optional opt_names() const { + return impl::get_opt_names(unsafeGetTensorImpl()); + } + // See impl::get_names in ATen/NamedTensor.h for docs. + DimnameList names() const { + return impl::get_names(unsafeGetTensorImpl()); + } + int64_t ndimension() const { + return dim(); + } + + bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const { + return impl_->is_contiguous(memory_format); + } + + bool is_non_overlapping_and_dense() const { + return impl_->is_non_overlapping_and_dense(); + } + + at::MemoryFormat suggest_memory_format( + bool channels_last_strides_exact_match = false) const { + // Setting channels_last_strides_exact_match to true forces function to + // check 0,1 - sized dimension strides. + if (!is_mkldnn() && !is_sparse()) { + if (impl_->is_strides_like_channels_last()) { + if (!channels_last_strides_exact_match || + get_channels_last_strides_2d(sizes()) == strides()) { + return at::MemoryFormat::ChannelsLast; + } + } + else if (impl_->is_strides_like_channels_last_3d()) { + if (!channels_last_strides_exact_match || + get_channels_last_strides_3d(sizes()) == strides()) { + return at::MemoryFormat::ChannelsLast3d; + } + } + } + return at::MemoryFormat::Contiguous; + } + + // Total bytes consumed by the "view" of elements of the array. Does not + // include size of metadata. The number reported here does not necessarily + // correspond to the true physical memory consumed by a tensor; instead, + // it reports the memory the tensor would take *if* it were contiguous. + // Defined to be numel() * itemsize() + size_t nbytes() const { + TORCH_CHECK(layout () != at::kSparse, + "nbytes is not defined for sparse tensors. If you want the size of the constituent " \ + "tensors, add the nbytes of the indices and values. If you want the size of the " \ + "equivalent dense tensor, multiply numel() by element_size()"); + return impl_->numel() * impl_->itemsize(); + } + + int64_t numel() const { + return impl_->numel(); + } + + // Length of one array element in bytes. This is the traditional + // Numpy naming. + size_t itemsize() const { + return impl_->itemsize(); + } + + // Same as itemsize(). This is the PyTorch naming. + int64_t element_size() const { + return static_cast(impl_->itemsize()); + } + + DispatchKeySet key_set() const { + return impl_->key_set(); + } + ScalarType scalar_type() const { + return typeMetaToScalarType(impl_->dtype()); + } + bool has_storage() const { + return defined() && impl_->has_storage(); + } + const Storage& storage() const { + return impl_->storage(); + } + bool is_alias_of(const at::TensorBase& other) const{ + return impl_->storage().is_alias_of(other.storage()); + } + + inline bool is_conj() const { + return impl_->is_conj(); + } + + // sets the conjugate bit of a tensor. + // NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are sure + // that's what you want. Changing this might lead to incorrect behavior since conjugation is + // a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized. + inline void _set_conj(bool conjugate) const { + impl_->_set_conj(conjugate); + } + + inline bool is_neg() const { + return impl_->is_neg(); + } + + // sets the negative bit of a tensor. + // NOTE: Negative bit is supposed to be a read-only field. Only change this, if you are sure + // that's what you want. Changing this might lead to incorrect behavior since we rely on this + // bit to determine if a negation needs to be materialized. + inline void _set_neg(bool negative) const { + impl_->_set_neg(negative); + } + + /// Returns a `Tensor`'s layout. + Layout layout() const noexcept { + return impl_->layout(); + } + + /// Returns a `Tensor`'s dtype (`TypeMeta`). + caffe2::TypeMeta dtype() const noexcept { + return impl_->dtype(); + } + + /// Returns a `Tensor`'s device. + inline Device device() const { + return impl_->device(); + } + + /// Returns a `Tensor`'s device index. + int64_t get_device() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->get_device(); + } + + /// Returns if a `Tensor` has CPU backend. + bool is_cpu() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_cpu(); + } + + /// Returns if a `Tensor` has CUDA backend. + bool is_cuda() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_cuda(); + } + + /// Returns if a `Tensor` has XPU backend. + bool is_xpu() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_xpu(); + } + + /// Returns if a `Tensor` has XLA backend. + bool is_xla() const { + return impl_->is_xla(); + } + + /// Returns if a `Tensor` has Lazy backend. + bool is_lazy() const { + return impl_->is_lazy(); + } + + /// Returns if a `Tensor` has HIP backend. + bool is_hip() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_hip(); + } + + /// Returns if a `Tensor` has VE backend. + bool is_ve() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_ve(); + } + + /// Returns if a `Tensor` has sparse backend. + bool is_sparse() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_sparse(); + } + + /// Returns is a `Tensor` has a sparse CSR backend. + bool is_sparse_csr() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_sparse_csr(); + } + + /// Returns if a `Tensor` is mkldnn tensor. + bool is_mkldnn() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_mkldnn(); + } + + /// Returns if a `Tensor` is mlc tensor. + bool is_mlc() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_mlc(); + } + + /// Returns if a `Tensor` is ort tensor. + bool is_ort() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_ort(); + } + + /// Returns if a `Tensor` is vulkan tensor. + bool is_vulkan() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_vulkan(); + } + + /// Returns if a `Tensor` is metal tensor. + bool is_metal() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_metal(); + } + + /// Returns if a `Tensor` has quantized backend. + bool is_quantized() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_quantized(); + } + + /// Returns if a `Tensor` is a meta tensor. Meta tensors can + /// also have other designations. + bool is_meta() const { + return impl_->is_meta(); + } + + /// Returns if a `Tensor` is an inference tensor. + bool is_inference() const { + return impl_->is_inference(); + } + + /// If a tensor is a quantized tensor, returns its quantizer + /// TODO: it's not in native_functions.yaml yet as it's not exposed to python + QuantizerPtr quantizer() const; + + /// Returns if a `Tensor` has any dimension names + bool has_names() const { + // If a user is using unnamed tensors, then we can short-circuit right here. + // Otherwise, impl::has_names attempts to retrieve names. + if (!impl_->has_named_tensor_meta()) { + return false; + } + return impl::has_names(unsafeGetTensorImpl()); + } + + /// Returns a `Tensor`'s dimension names data structure + const NamedTensorMeta* get_named_tensor_meta() const { + return static_cast(impl_->named_tensor_meta()); + } + + NamedTensorMeta* get_named_tensor_meta() { + return static_cast(impl_->named_tensor_meta()); + } + + /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in + /// TensorOptions.h. + TensorOptions options() const { + return TensorOptions().dtype(dtype()) + .device(device()) + .layout(layout()); + } + + void* data_ptr() const { + return this->unsafeGetTensorImpl()->data(); + } + + template + T * data_ptr() const; + + // Purposely not defined here to avoid inlining + void print() const; + + // Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar type and + // dimension. + template + TensorAccessor accessor() const& { + static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()"); + TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim()); + return TensorAccessor(data_ptr(),sizes().data(),strides().data()); + } + template + TensorAccessor accessor() && = delete; + + // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and + // dimension. You can optionally specify RestrictPtrTraits as a template parameter to + // cast the data pointer to a __restrict__ pointer. + // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor + // as an argument. + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + GenericPackedTensorAccessor generic_packed_accessor() const& { + static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()"); + TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim()); + return GenericPackedTensorAccessor(static_cast::PtrType>(data_ptr()),sizes().data(),strides().data()); + } + template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> + GenericPackedTensorAccessor generic_packed_accessor() && = delete; + + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor32 packed_accessor32() const& { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor32 packed_accessor32() && = delete; + + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor64 packed_accessor64() const& { + return generic_packed_accessor(); + } + template class PtrTraits = DefaultPtrTraits> + PackedTensorAccessor64 packed_accessor64() && = delete; + + // ~~~~~ Autograd API ~~~~~ + + /// \fn bool is_leaf() const; + /// + /// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention. + /// + /// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were + /// created by the user. This means that they are not the result of an operation and so + /// `grad_fn()` is `nullptr`. + /// + /// Only leaf Tensors will have their `grad()` populated during a call to `backward()`. + /// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`. + /// + /// Example: + /// @code + /// auto a = torch::rand(10, torch::requires_grad()); + /// std::cout << a.is_leaf() << std::endl; // prints `true` + /// + /// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA); + /// std::cout << b.is_leaf() << std::endl; // prints `false` + /// // b was created by the operation that cast a cpu Tensor into a cuda Tensor + /// + /// auto c = torch::rand(10, torch::requires_grad()) + 2; + /// std::cout << c.is_leaf() << std::endl; // prints `false` + /// // c was created by the addition operation + /// + /// auto d = torch::rand(10).cuda(); + /// std::cout << d.is_leaf() << std::endl; // prints `true` + /// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine) + /// + /// auto e = torch::rand(10).cuda().requires_grad_(); + /// std::cout << e.is_leaf() << std::endl; // prints `true` + /// // e requires gradients and has no operations creating it + /// + /// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true)); + /// std::cout << f.is_leaf() << std::endl; // prints `true` + /// // f requires grad, has no operation creating it + /// @endcode + + /// \fn void backward(const Tensor & gradient={}, c10::optional retain_graph=c10::nullopt, bool create_graph=false, c10::optional inputs=c10::nullopt) const; + /// + /// Computes the gradient of current tensor with respect to graph leaves. + /// + /// The graph is differentiated using the chain rule. If the tensor is + /// non-scalar (i.e. its data has more than one element) and requires + /// gradient, the function additionally requires specifying ``gradient``. + /// It should be a tensor of matching type and location, that contains + /// the gradient of the differentiated function w.r.t. this Tensor. + /// + /// This function accumulates gradients in the leaves - you might need to + /// zero them before calling it. + /// + /// \param gradient Gradient w.r.t. the + /// tensor. If it is a tensor, it will be automatically converted + /// to a Tensor that does not require grad unless ``create_graph`` is True. + /// None values can be specified for scalar Tensors or ones that + /// don't require grad. If a None value would be acceptable then + /// this argument is optional. + /// \param retain_graph If ``false``, the graph used to compute + /// the grads will be freed. Note that in nearly all cases setting + /// this option to True is not needed and often can be worked around + /// in a much more efficient way. Defaults to the value of + /// ``create_graph``. + /// \param create_graph If ``true``, graph of the derivative will + /// be constructed, allowing to compute higher order derivative + /// products. Defaults to ``false``. + /// \param inputs Inputs w.r.t. which the gradient will be accumulated into + /// ``at::Tensor::grad``. All other Tensors will be ignored. If not + /// provided, the gradient is accumulated into all the leaf Tensors + /// that were used to compute the current tensor. + /// When inputs are provided and a given input is not a leaf, + /// the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients). + /// It is an implementation detail on which the user should not rely. + /// See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details. + + /// \fn Tensor detach() const; + /// + /// Returns a new Tensor, detached from the current graph. + /// The result will never require gradient. + + /// \fn Tensor & detach_() const; + /// + /// Detaches the Tensor from the graph that created it, making it a leaf. + /// Views cannot be detached in-place. + + /// \fn void retain_grad() const; + /// + /// Enables this Tensor to have their :attr:`grad` populated during + /// :func:`backward`. This is a no-op for leaf tensors. + + /// \fn bool retains_grad() const; + /// + /// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be + /// populated during :func:`backward`, ``false`` otherwise. + + const TensorBase& set_requires_grad(bool requires_grad) const { + impl_->set_requires_grad(requires_grad); + return *this; + } + bool requires_grad() const { + return impl_->requires_grad(); + } + + /// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended + /// to be used from functions that need to access the `Variable`'s equivalent `Tensor` + /// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`). + /// + /// One notable difference with the legacy `.data()` function is that changes to the + /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset) + /// will not update the original `Variable`, due to the fact that this function + /// shallow-copies the `Variable`'s underlying TensorImpl. + at::TensorBase tensor_data() const; + + /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data` + /// in Python, which create a new `Variable` that shares the same storage and + /// tensor metadata with the original `Variable`, but with a completely new + /// autograd history. + /// + /// NOTE: If we change the tensor metadata (e.g. sizes / strides / + /// storage / storage_offset) of a variable created from `var.variable_data()`, those + /// changes will not update the original variable `var`. In `.variable_data()`, we set + /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal, + /// in order to prevent users from changing metadata of `var.variable_data()` + /// and expecting the original variable `var` to also be updated. + at::TensorBase variable_data() const; + + // Gradient Node and Edges + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + /// Gets the gradient function of the `Variable`. If this is a leaf variable, + /// the pointer returned will be null. + /// + /// For View Variables: + /// Gets the up-to-date grad_fn. If the shared data or base was modified, we + /// re-create the grad_fn to express the up-to-date view relationship between + /// this and the base Variable. + const std::shared_ptr& grad_fn() const; + + // Hooks + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + template + using hook_return_void_t = std::enable_if_t::type>::value, unsigned>; + template + using hook_return_var_t = std::enable_if_t::type, TensorBase>::value, unsigned>; + + /// Registers a backward hook. + /// + /// The hook will be called every time a gradient with respect to the Tensor is computed. + /// The hook should have one of the following signature: + /// ``` + /// hook(TensorBase grad) -> TensorBase + /// ``` + /// ``` + /// hook(TensorBase grad) -> void + /// ``` + /// The hook should not modify its argument, but it can optionally return a new gradient + /// which will be used in place of `grad`. + /// + /// This function returns the index of the hook in the list which can be used to remove hook. + /// + /// Example: + /// @code + /// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad()); + /// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient + /// v.backward(torch::tensor({1., 2., 3.})); + /// // This prints: + /// // ``` + /// // 2 + /// // 4 + /// // 6 + /// // [ CPUFloatType{3} ] + /// // ``` + /// std::cout << v.grad() << std::endl; + /// v.remove_hook(h); // removes the hook + /// @endcode + template + hook_return_void_t register_hook(T&& hook) const; + template + hook_return_var_t register_hook(T&& hook) const; + +protected: + unsigned _register_hook(std::function hook) const; + +public: + + /// Remove hook at given position + void remove_hook(unsigned pos) const; + + // Variable methods + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + bool is_leaf() const; + + int64_t output_nr() const; + + void set_data(const TensorBase & new_data) const; + + TensorBase data() const; + + int64_t _version() const; + + void retain_grad() const; + + bool retains_grad() const; + + const TensorBase& requires_grad_(bool _requires_grad=true) const; + + // View Variables + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + /// Returns true if this `Variable` is a view of another `Variable`. + bool is_view() const; + + /// Returns the `Variable` that this `Variable` is a view of. If this + /// `Variable` is not a view, throw a `std::runtime_error`. + const TensorBase& _base() const; + + // Miscellaneous + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + const std::string& name() const; + +protected: + void enforce_invariants(); + c10::intrusive_ptr impl_; + +private: + TensorBase __dispatch_contiguous(c10::MemoryFormat) const; +}; + +// For "multiple ... operators specified" warnings, closing brace of class +// declaration must be included between pragma push & pop +#ifdef _MSC_VER +#pragma warning( pop ) +#endif + +inline int64_t get_device(const TensorBase& self) { + return self.get_device(); +} + +template +auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_void_t { + // Return the grad argument in case of a hook with void return type to have an + // std::function with Tensor return type + static_assert(std::is_same::value, + "Expected hook to return void"); + return _register_hook([fn=std::forward(hook)](const TensorBase& grad) { + fn(grad); + return TensorBase(); + }); +} + +template +auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_var_t { + return _register_hook(std::move(hook)); +} + +namespace detail { +// Helper creator for Tensor class which doesn't requires the users to pass +// in an intrusive_ptr instead it just converts the argument passed to +// requested intrusive_ptr type. +template +TensorBase make_tensor_base(Args&&... args) { + return TensorBase(c10::make_intrusive(std::forward(args)...)); +} + +} // namespace detail + +static inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) { + return legacyExtractDispatchKey(t.key_set()); +} + +} // namespace at + +namespace c10 { +template <> +struct MaybeOwnedTraits { + using owned_type = at::TensorBase; + using borrow_type = at::TensorBase; + + static borrow_type createBorrow(const owned_type& from) { + // NOTE: this can be implemented without the special + // unsafe_borrow_t Tensor constructor as + // + // return borrow_type(c10::intrusive_ptr::reclaim(from.unsafeGetTensorImpl())); + // + // but that hurts inlining due to the nullptr check in the + // Tensor(c10::intrusive_ptr<...>) constructor. We already know + // that from.impl_ isn't null because from is a valid Tensor, so + // we needn't do the check again. (using __builtin_assume can + // avoid this, but wouldn't be portable to MSVC.) + return borrow_type(borrow_type::unsafe_borrow_t{}, from); + } + + static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) { + lhs.unsafeReleaseTensorImpl(); + // See above note: this can be implemented with public API + // similarly to createBorrow(), but that would hurt inlining. + lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs); + } + + static void destroyBorrow(borrow_type& toDestroy) { + toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0. + } + + static const owned_type& referenceFromBorrow(const borrow_type& borrow) { + return borrow; + } + + static const owned_type* pointerFromBorrow(const borrow_type& borrow) { + return &borrow; + } + + static bool debugBorrowIsValid(const borrow_type& borrow) { + return true; + } +}; + +template <> +struct ExclusivelyOwnedTraits { + using repr_type = at::TensorBase; + using pointer_type = at::TensorBase*; + using const_pointer_type = const at::TensorBase*; + + static repr_type nullRepr() { + return at::TensorBase(); + } + + template + static repr_type createInPlace(Args&&... args) { + return at::TensorBase(std::forward(args)...); + } + + static repr_type moveToRepr(at::TensorBase&& x) { + return std::move(x); + } + + static void destroyOwned(at::TensorBase& x) { + TensorImpl*const toDestroy = x.unsafeReleaseTensorImpl(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(toDestroy != nullptr, "Tensor somehow got null TensorImpl?"); + // May be 0 because UndefinedTensorImpl doesn't get its refcount + // incremented. + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + toDestroy->refcount_ == 1 || (toDestroy->refcount_ == 0 && toDestroy == UndefinedTensorImpl::singleton()), + "ExclusivelyOwned destroyed with refcount ", toDestroy->refcount_, ", expected 1!"); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + toDestroy->weakcount_ == 1 || (toDestroy->weakcount_ == 0 && toDestroy == UndefinedTensorImpl::singleton()), + "ExclusivelyOwned destroyed with weakcount ", toDestroy->weakcount_, ", expected 1!"); + if (toDestroy != UndefinedTensorImpl::singleton()) { +#ifndef NDEBUG + // Needed to pass the debug assertions in ~intrusive_ptr_target. + toDestroy->refcount_ = 0; + toDestroy->weakcount_ = 0; +#endif + toDestroy->release_resources(); + delete toDestroy; + } + } + + static at::TensorBase take(at::TensorBase& x) { + return std::move(x); + } + + static pointer_type getImpl(repr_type& x) { + return &x; + } + + static const_pointer_type getImpl(const repr_type& x) { + return &x; + } +}; +} // namespace c10 + +namespace at { + +inline c10::MaybeOwned borrow_from_optional_tensor( + const c10::optional& opt) { + return opt.has_value() + ? c10::MaybeOwned::borrowed(*opt) + : c10::MaybeOwned::owned(c10::in_place); +} + +inline c10::MaybeOwned TensorBase::expect_contiguous(MemoryFormat memory_format) const & { + if (is_contiguous(memory_format)) { + return c10::MaybeOwned::borrowed(*this); + } else { + return c10::MaybeOwned::owned(__dispatch_contiguous(memory_format)); + } +} +} // namespace at diff --git a/aten/src/ATen/core/VariableHooksInterface.h b/aten/src/ATen/core/VariableHooksInterface.h index f365c8d..9b0067b 100644 --- a/aten/src/ATen/core/VariableHooksInterface.h +++ b/aten/src/ATen/core/VariableHooksInterface.h @@ -40,23 +40,25 @@ namespace impl { struct TORCH_API VariableHooksInterface { virtual ~VariableHooksInterface() = default; - virtual Tensor tensor_data(const Tensor&) const = 0; - virtual Tensor variable_data(const Tensor&) const = 0; - virtual const std::shared_ptr& grad_fn(const Tensor&) const = 0; - virtual unsigned _register_hook(const Tensor&, std::function hook) const = 0; - virtual void remove_hook(const Tensor&, unsigned pos) const = 0; - virtual bool is_view(const Tensor&) const = 0; - virtual const Tensor& base(const Tensor&) const = 0; - virtual const std::string& name(const Tensor&) const = 0; - virtual bool is_leaf(const Tensor&) const = 0; - virtual int64_t output_nr(const Tensor&) const = 0; - virtual void set_data(const Tensor&, const Tensor&) const = 0; - virtual Tensor data(const Tensor&) const = 0; - virtual int64_t _version(const Tensor&) const = 0; - virtual void retain_grad(const Tensor&) const = 0; - virtual bool retains_grad(const Tensor&) const = 0; + virtual TensorBase tensor_data(const TensorBase&) const = 0; + virtual TensorBase variable_data(const TensorBase&) const = 0; + virtual const std::shared_ptr& grad_fn(const TensorBase&) const = 0; + virtual unsigned _register_hook( + const TensorBase&, + std::function hook) const = 0; + virtual void remove_hook(const TensorBase&, unsigned pos) const = 0; + virtual bool is_view(const TensorBase&) const = 0; + virtual const TensorBase& base(const TensorBase&) const = 0; + virtual const std::string& name(const TensorBase&) const = 0; + virtual bool is_leaf(const TensorBase&) const = 0; + virtual int64_t output_nr(const TensorBase&) const = 0; + virtual void set_data(const TensorBase&, const TensorBase&) const = 0; + virtual TensorBase data(const TensorBase&) const = 0; + virtual int64_t _version(const TensorBase&) const = 0; + virtual void retain_grad(const TensorBase&) const = 0; + virtual bool retains_grad(const TensorBase&) const = 0; virtual void _backward(const Tensor&, TensorList, const c10::optional&, c10::optional, bool) const = 0; - virtual void requires_grad_(const Tensor&, bool) const = 0; + virtual void requires_grad_(const TensorBase&, bool) const = 0; }; TORCH_API void SetVariableHooks(VariableHooksInterface* hooks); diff --git a/aten/src/ATen/quantized/Quantizer.cpp b/aten/src/ATen/quantized/Quantizer.cpp index b4a4f3a..b574dcf 100644 --- a/aten/src/ATen/quantized/Quantizer.cpp +++ b/aten/src/ATen/quantized/Quantizer.cpp @@ -29,7 +29,7 @@ namespace { } // anonymous namespace // Note: this is not a native function as Quantizer is not exposed to python yet -QuantizerPtr Tensor::quantizer() const { +QuantizerPtr TensorBase::quantizer() const { // This is a terrible hack to emulate what VariableType is doing at::AutoDispatchBelowAutograd mode; return get_qtensorimpl(*this)->quantizer(); @@ -71,7 +71,7 @@ QuantizerPtr make_per_channel_affine_quantizer( } } -QTensorImpl* get_qtensorimpl(const Tensor& self) { +QTensorImpl* get_qtensorimpl(const TensorBase& self) { TORCH_CHECK( !self.requires_grad(), "quantized tensors do not support autograd"); diff --git a/aten/src/ATen/quantized/Quantizer.h b/aten/src/ATen/quantized/Quantizer.h index c0c119c..34ec274 100644 --- a/aten/src/ATen/quantized/Quantizer.h +++ b/aten/src/ATen/quantized/Quantizer.h @@ -205,7 +205,7 @@ struct TORCH_API PerChannelAffineFloatQParamsQuantizer : public PerChannelAffine // setters/getters for QTensorImpl fields; otherwise, you should use // the low level setters/getters that were implemented using this. // This may be called repeatedly, so make sure it's pretty cheap. -TORCH_API QTensorImpl* get_qtensorimpl(const Tensor& self); +TORCH_API QTensorImpl* get_qtensorimpl(const TensorBase& self); // double and int64_t are because of the native function API, we only have these // argument types right now in native functions diff --git a/aten/src/ATen/templates/Operators.h b/aten/src/ATen/templates/Operators.h index a92b750..6897c32 100644 --- a/aten/src/ATen/templates/Operators.h +++ b/aten/src/ATen/templates/Operators.h @@ -2,6 +2,14 @@ // ${generated_comment} +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + #include #include #include diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 95312ff..eb6402f 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -23,13 +23,10 @@ #include #include #include +#include #include -namespace caffe2 { -class Tensor; -} namespace c10{ -struct TensorOptions; template class List; } namespace at { @@ -58,18 +55,6 @@ using TensorList = ArrayRef; using Stream = c10::Stream; -namespace impl { -inline bool variable_excluded_from_dispatch() { -#ifdef C10_MOBILE - // Please read the comment in `VariableFallbackKernel.cpp` about the background of this change. - return true; -#else - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::impl::tls_local_dispatch_key_set().excluded_.has(DispatchKey::Autograd)); - return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset); -#endif -} -} - // Tensor is a "generic" object holding a pointer to the underlying TensorImpl object, which // has an embedded reference count. In this way, Tensor is similar to boost::intrusive_ptr. // @@ -87,56 +72,38 @@ inline bool variable_excluded_from_dispatch() { // // Note that Tensor can also be NULL, i.e. it is not associated with any underlying TensorImpl, and // special care must be taken to handle this. -class TORCH_API Tensor { - private: - struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; }; - +class TORCH_API Tensor: public TensorBase { + protected: // Create a Tensor with a +0 reference count. Special care must be // taken to avoid decrementing this reference count at destruction // time. Intended to support MaybeOwnedTraits. - explicit Tensor(unsafe_borrow_t, const Tensor& rhs) - : impl_(c10::intrusive_ptr::reclaim(rhs.impl_.get())) {} + explicit Tensor(unsafe_borrow_t, const TensorBase& rhs): TensorBase(unsafe_borrow_t{}, rhs) {} friend MaybeOwnedTraits; friend OptionalTensorRef; public: - Tensor(){}; + Tensor() = default; // This constructor should not be used by end users and is an implementation // detail invoked by autogenerated code. explicit Tensor( c10::intrusive_ptr tensor_impl) - : impl_(std::move(tensor_impl)) { - if (impl_.get() == nullptr) { - throw std::runtime_error("TensorImpl with nullptr is not supported"); - } - } - Tensor(const Tensor&) = default; - Tensor(Tensor&&) = default; + : TensorBase(std::move(tensor_impl)) {} + Tensor(const Tensor &tensor) = default; + Tensor(Tensor &&tensor) = default; + // Implicitly move-constructible from TensorBase, but must be explicit to increase refcount + explicit Tensor(const TensorBase &base): TensorBase(base) {} + /*implicit*/ Tensor(TensorBase &&base): TensorBase(std::move(base)) {} - public: // Creates a new wrapper from TensorImpl. Intentionally a free method because // it should be used with care. Checks necessary invariants static Tensor wrap_tensor_impl( c10::intrusive_ptr tensor_impl) { - Tensor r(std::move(tensor_impl)); - r.enforce_invariants(); - return r; - } - - int64_t dim() const { - return impl_->dim(); - } - int64_t storage_offset() const { - return impl_->storage_offset(); + return TensorBase::wrap_tensor_impl(std::move(tensor_impl)); } Tensor contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const { - if (is_contiguous(memory_format)) { - return *this; - } else { - return __dispatch_contiguous(memory_format); - } + return TensorBase::contiguous(memory_format); } Tensor conj() const { @@ -150,6 +117,10 @@ class TORCH_API Tensor { } } + // Aliased by Dimname overloads, so need explicit using + using TensorBase::size; + using TensorBase::stride; + /// Should be used if *this can reasonably be expected to be contiguous and /// performance is important. /// Compared to contiguous, it saves a reference count @@ -162,52 +133,6 @@ class TORCH_API Tensor { // will only lead to trouble and dangling references. c10::MaybeOwned expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete; - bool is_complex() const { - return at::isComplexType(this->scalar_type()); - } - - bool is_floating_point() const { - return at::isFloatingType(this->scalar_type()); - } - - bool is_signed() const { - return at::isSignedType(this->scalar_type()); - } - - int64_t size(int64_t dim) const { - // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) - dim = c10::maybe_wrap_dim(dim, this->dim(), false); - return sizes()[dim]; - } - - int64_t stride(int64_t dim) const { - // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping) - dim = c10::maybe_wrap_dim(dim, this->dim(), false); - return strides()[dim]; - } - - TensorImpl * unsafeGetTensorImpl() const { - return impl_.get(); - } - TensorImpl * unsafeReleaseTensorImpl() { - return impl_.release(); - } - const c10::intrusive_ptr& getIntrusivePtr() const { - return impl_; - } - - c10::intrusive_ptr unsafeReleaseIntrusivePtr() { - return std::move(impl_); - } - - bool defined() const { - return impl_; - } - - void reset() { - impl_.reset(); - } - // The following overloads are very intruiging. Consider the following // program: // @@ -248,105 +173,25 @@ class TORCH_API Tensor { #pragma warning( disable : 4522 ) #endif - Tensor& operator=(const Tensor& x) & { - impl_ = x.impl_; + Tensor& operator=(const TensorBase& x) & { + impl_ = x.getIntrusivePtr(); return *this; } - Tensor& operator=(Tensor&& x) & { - impl_ = std::move(x.impl_); + Tensor& operator=(TensorBase&& x) & { + impl_ = x.unsafeReleaseIntrusivePtr(); return *this; } - Tensor& operator=(Scalar v) &&; - Tensor& operator=(const Tensor&) &&; - Tensor& operator=(Tensor&&) &&; - - bool is_same(const Tensor& other) const noexcept { - return impl_ == other.impl_; - } - size_t use_count() const noexcept { - return impl_.use_count(); + Tensor& operator=(const Tensor &x) & { + return operator=(static_cast(x)); } - size_t weak_use_count() const noexcept { - return impl_.weak_use_count(); + Tensor& operator=(Tensor &&x) & { + return operator=(static_cast(x)); } - std::string toString() const; - - IntArrayRef sizes() const { - return impl_->sizes(); - } - IntArrayRef strides() const { - return impl_->strides(); - } - // See impl::get_opt_names in ATen/NamedTensor.h for docs. - c10::optional opt_names() const { - return impl::get_opt_names(unsafeGetTensorImpl()); - } - // See impl::get_names in ATen/NamedTensor.h for docs. - DimnameList names() const { - return impl::get_names(unsafeGetTensorImpl()); - } - int64_t ndimension() const { - return dim(); - } - - bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const { - return impl_->is_contiguous(memory_format); - } - - bool is_non_overlapping_and_dense() const { - return impl_->is_non_overlapping_and_dense(); - } - - at::MemoryFormat suggest_memory_format( - bool channels_last_strides_exact_match = false) const { - // Setting channels_last_strides_exact_match to true forces function to - // check 0,1 - sized dimension strides. - if (!is_mkldnn() && !is_sparse()) { - if (impl_->is_strides_like_channels_last()) { - if (!channels_last_strides_exact_match || - get_channels_last_strides_2d(sizes()) == strides()) { - return at::MemoryFormat::ChannelsLast; - } - } - else if (impl_->is_strides_like_channels_last_3d()) { - if (!channels_last_strides_exact_match || - get_channels_last_strides_3d(sizes()) == strides()) { - return at::MemoryFormat::ChannelsLast3d; - } - } - } - return at::MemoryFormat::Contiguous; - } - - // Total bytes consumed by the "view" of elements of the array. Does not - // include size of metadata. The number reported here does not necessarily - // correspond to the true physical memory consumed by a tensor; instead, - // it reports the memory the tensor would take *if* it were contiguous. - // Defined to be numel() * itemsize() - size_t nbytes() const { - TORCH_CHECK(layout () != at::kSparse, - "nbytes is not defined for sparse tensors. If you want the size of the constituent " \ - "tensors, add the nbytes of the indices and values. If you want the size of the " \ - "equivalent dense tensor, multiply numel() by element_size()"); - return impl_->numel() * impl_->itemsize(); - } - - int64_t numel() const { - return impl_->numel(); - } - - // Length of one array element in bytes. This is the traditional - // Numpy naming. - size_t itemsize() const { - return impl_->itemsize(); - } - - // Same as itemsize(). This is the PyTorch naming. - int64_t element_size() const { - return static_cast(impl_->itemsize()); - } + Tensor& operator=(Scalar v) &&; + Tensor& operator=(const Tensor&) &&; + Tensor& operator=(Tensor&&) &&; C10_DEPRECATED_MESSAGE("Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device().") DeprecatedTypeProperties & type() const { @@ -354,21 +199,6 @@ class TORCH_API Tensor { dispatchKeyToBackend(legacyExtractDispatchKey(key_set())), scalar_type()); } - DispatchKeySet key_set() const { - return impl_->key_set(); - } - ScalarType scalar_type() const { - return typeMetaToScalarType(impl_->dtype()); - } - bool has_storage() const { - return defined() && impl_->has_storage(); - } - const Storage& storage() const { - return impl_->storage(); - } - bool is_alias_of(const at::Tensor& other) const{ - return impl_->storage().is_alias_of(other.storage()); - } Tensor toType(ScalarType t) const { return to(options().dtype(t), /*non_blocking*/ false, /*copy*/ false); @@ -384,189 +214,6 @@ class TORCH_API Tensor { return !at::impl::variable_excluded_from_dispatch(); } - inline bool is_conj() const { - return impl_->is_conj(); - } - - // sets the conjugate bit of a tensor. - // NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are sure - // that's what you want. Changing this might lead to incorrect behavior since conjugation is - // a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized. - inline void _set_conj(bool conjugate) const { - impl_->_set_conj(conjugate); - } - - inline bool is_neg() const { - return impl_->is_neg(); - } - - // sets the negative bit of a tensor. - // NOTE: Negative bit is supposed to be a read-only field. Only change this, if you are sure - // that's what you want. Changing this might lead to incorrect behavior since we rely on this - // bit to determine if a negation needs to be materialized. - inline void _set_neg(bool negative) const { - impl_->_set_neg(negative); - } - - /// Returns a `Tensor`'s layout. - Layout layout() const noexcept { - return impl_->layout(); - } - - /// Returns a `Tensor`'s dtype (`TypeMeta`). - caffe2::TypeMeta dtype() const noexcept { - return impl_->dtype(); - } - - /// Returns a `Tensor`'s device. - inline Device device() const { - return impl_->device(); - } - - /// Returns a `Tensor`'s device index. - int64_t get_device() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->get_device(); - } - - - /// Returns if a `Tensor` has CPU backend. - bool is_cpu() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_cpu(); - } - - /// Returns if a `Tensor` has CUDA backend. - bool is_cuda() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_cuda(); - } - - /// Returns if a `Tensor` has XPU backend. - bool is_xpu() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_xpu(); - } - - /// Returns if a `Tensor` has XLA backend. - bool is_xla() const { - return impl_->is_xla(); - } - - /// Returns if a `Tensor` has Lazy backend. - bool is_lazy() const { - return impl_->is_lazy(); - } - - /// Returns if a `Tensor` has HIP backend. - bool is_hip() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_hip(); - } - - /// Returns if a `Tensor` has VE backend. - bool is_ve() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_ve(); - } - - /// Returns if a `Tensor` has sparse backend. - bool is_sparse() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_sparse(); - } - - /// Returns is a `Tensor` has a sparse CSR backend. - bool is_sparse_csr() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_sparse_csr(); - } - - /// Returns if a `Tensor` is mkldnn tensor. - bool is_mkldnn() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_mkldnn(); - } - - /// Returns if a `Tensor` is mlc tensor. - bool is_mlc() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_mlc(); - } - - /// Returns if a `Tensor` is ort tensor. - bool is_ort() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_ort(); - } - - /// Returns if a `Tensor` is vulkan tensor. - bool is_vulkan() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_vulkan(); - } - - /// Returns if a `Tensor` is metal tensor. - bool is_metal() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_metal(); - } - - /// Returns if a `Tensor` has quantized backend. - bool is_quantized() const { - // NB: this is not a native function to avoid dispatching overhead. - return impl_->is_quantized(); - } - - /// Returns if a `Tensor` is a meta tensor. Meta tensors can - /// also have other designations. - bool is_meta() const { - return impl_->is_meta(); - } - - /// Returns if a `Tensor` is an inference tensor. - bool is_inference() const { - return impl_->is_inference(); - } - - /// If a tensor is a quantized tensor, returns its quantizer - /// TODO: it's not in native_functions.yaml yet as it's not exposed to python - QuantizerPtr quantizer() const; - - /// Returns if a `Tensor` has any dimension names - bool has_names() const { - // If a user is using unnamed tensors, then we can short-circuit right here. - // Otherwise, impl::has_names attempts to retrieve names. - if (!impl_->has_named_tensor_meta()) { - return false; - } - return impl::has_names(unsafeGetTensorImpl()); - } - - /// Returns a `Tensor`'s dimension names data structure - const NamedTensorMeta* get_named_tensor_meta() const { - return static_cast(impl_->named_tensor_meta()); - } - - NamedTensorMeta* get_named_tensor_meta() { - return static_cast(impl_->named_tensor_meta()); - } - - /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in - /// TensorOptions.h. - TensorOptions options() const { - return TensorOptions().dtype(dtype()) - .device(device()) - .layout(layout()); - } - - void* data_ptr() const { - return this->unsafeGetTensorImpl()->data(); - } - - template - T * data_ptr() const; - template C10_DEPRECATED_MESSAGE("Tensor.data() is deprecated. Please use Tensor.data_ptr() instead.") T * data() const { @@ -579,45 +226,6 @@ class TORCH_API Tensor { // Purposely not defined here to avoid inlining void print() const; - // Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar type and - // dimension. - template - TensorAccessor accessor() const& { - static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()"); - TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim()); - return TensorAccessor(data_ptr(),sizes().data(),strides().data()); - } - template - TensorAccessor accessor() && = delete; - - // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and - // dimension. You can optionally specify RestrictPtrTraits as a template parameter to - // cast the data pointer to a __restrict__ pointer. - // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor - // as an argument. - template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> - GenericPackedTensorAccessor generic_packed_accessor() const& { - static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr()"); - TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim()); - return GenericPackedTensorAccessor(static_cast::PtrType>(data_ptr()),sizes().data(),strides().data()); - } - template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> - GenericPackedTensorAccessor generic_packed_accessor() && = delete; - - template class PtrTraits = DefaultPtrTraits> - PackedTensorAccessor32 packed_accessor32() const& { - return generic_packed_accessor(); - } - template class PtrTraits = DefaultPtrTraits> - PackedTensorAccessor32 packed_accessor32() && = delete; - - template class PtrTraits = DefaultPtrTraits> - PackedTensorAccessor64 packed_accessor64() const& { - return generic_packed_accessor(); - } - template class PtrTraits = DefaultPtrTraits> - PackedTensorAccessor64 packed_accessor64() && = delete; - template class PtrTraits = DefaultPtrTraits, typename index_t = int64_t> C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead") GenericPackedTensorAccessor packed_accessor() const & { @@ -785,12 +393,9 @@ class TORCH_API Tensor { /// populated during :func:`backward`, ``false`` otherwise. const Tensor& set_requires_grad(bool requires_grad) const { - impl_->set_requires_grad(requires_grad); + TensorBase::set_requires_grad(requires_grad); return *this; } - bool requires_grad() const { - return impl_->requires_grad(); - } /// Return a mutable reference to the gradient. This is conventionally /// used as `t.grad() = x` to set a gradient to a completely new tensor. @@ -829,7 +434,7 @@ class TORCH_API Tensor { /// Note that the given new_grad might not be used directly if it has different /// metadata (size/stride/storage offset) compared to this Tensor. In that case, /// new_grad content will be copied into a new Tensor - void _set_fw_grad(const Tensor& new_grad, uint64_t level, bool is_inplace_op) const { + void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const { impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op); } @@ -878,7 +483,9 @@ class TORCH_API Tensor { /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset) /// will not update the original `Variable`, due to the fact that this function /// shallow-copies the `Variable`'s underlying TensorImpl. - at::Tensor tensor_data() const; + at::Tensor tensor_data() const { + return TensorBase::tensor_data(); + } /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data` /// in Python, which create a new `Variable` that shares the same storage and @@ -891,19 +498,9 @@ class TORCH_API Tensor { /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal, /// in order to prevent users from changing metadata of `var.variable_data()` /// and expecting the original variable `var` to also be updated. - at::Tensor variable_data() const; - - // Gradient Node and Edges - //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - /// Gets the gradient function of the `Variable`. If this is a leaf variable, - /// the pointer returned will be null. - /// - /// For View Variables: - /// Gets the up-to-date grad_fn. If the shared data or base was modified, we - /// re-create the grad_fn to express the up-to-date view relationship between - /// this and the base Variable. - const std::shared_ptr& grad_fn() const; + at::Tensor variable_data() const { + return TensorBase::variable_data(); + } // Hooks //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -948,55 +545,19 @@ class TORCH_API Tensor { template hook_return_var_t register_hook(T&& hook) const; -private: - unsigned _register_hook(std::function hook) const; - -public: - - /// Remove hook at given position - void remove_hook(unsigned pos) const; - // Variable methods //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - bool is_leaf() const; - - int64_t output_nr() const; - - void set_data(const Tensor & new_data) const; - - Tensor data() const; - - int64_t _version() const; - - void retain_grad() const; - - bool retains_grad() const; + Tensor data() const { + return TensorBase::data(); + } void _backward(TensorList inputs, const c10::optional& gradient, c10::optional keep_graph, bool create_graph) const; - const Tensor& requires_grad_(bool _requires_grad=true) const; - - // View Variables - //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - /// Returns true if this `Variable` is a view of another `Variable`. - bool is_view() const; - - /// Returns the `Variable` that this `Variable` is a view of. If this - /// `Variable` is not a view, throw a `std::runtime_error`. - const Tensor& _base() const; - - // Miscellaneous - //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - const std::string& name() const; - -protected: - friend class ::caffe2::Tensor; - - void enforce_invariants(); - c10::intrusive_ptr impl_; + const Tensor& requires_grad_(bool _requires_grad=true) const { + TensorBase::requires_grad_(_requires_grad); + return *this; + } }; // For "multiple ... operators specified" warnings, closing brace of class @@ -1005,27 +566,6 @@ protected: #pragma warning( pop ) #endif -inline int64_t get_device(const Tensor& self) { - return self.get_device(); -} - -template -auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t { - // Return the grad argument in case of a hook with void return type to have an - // std::function with Tensor return type - static_assert(std::is_same::value, - "Expected hook to return void"); - return _register_hook([fn=std::forward(hook)](const Tensor& grad) { - fn(grad); - return Tensor(); - }); -} - -template -auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t { - return _register_hook(std::forward(hook)); -} - namespace detail { // Helper creator for Tensor class which doesn't requires the users to pass // in an intrusive_ptr instead it just converts the argument passed to @@ -1037,10 +577,6 @@ Tensor make_tensor(Args&&... args) { } // namespace detail -static inline DispatchKey legacyExtractDispatchKey(const Tensor& t) { - return legacyExtractDispatchKey(t.key_set()); -} - } // namespace at // See Note [Avoiding Include Cycles In Static Dispatch] @@ -1114,25 +650,7 @@ struct ExclusivelyOwnedTraits { } static void destroyOwned(at::Tensor& x) { - TensorImpl*const toDestroy = x.unsafeReleaseTensorImpl(); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(toDestroy != nullptr, "Tensor somehow got null TensorImpl?"); - // May be 0 because UndefinedTensorImpl doesn't get its refcount - // incremented. - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - toDestroy->refcount_ == 1 || (toDestroy->refcount_ == 0 && toDestroy == UndefinedTensorImpl::singleton()), - "ExclusivelyOwned destroyed with refcount ", toDestroy->refcount_, ", expected 1!"); - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - toDestroy->weakcount_ == 1 || (toDestroy->weakcount_ == 0 && toDestroy == UndefinedTensorImpl::singleton()), - "ExclusivelyOwned destroyed with weakcount ", toDestroy->weakcount_, ", expected 1!"); - if (toDestroy != UndefinedTensorImpl::singleton()) { -#ifndef NDEBUG - // Needed to pass the debug assertions in ~intrusive_ptr_target. - toDestroy->refcount_ = 0; - toDestroy->weakcount_ = 0; -#endif - toDestroy->release_resources(); - delete toDestroy; - } + return ExclusivelyOwnedTraits::destroyOwned(x); } static at::Tensor take(at::Tensor& x) { diff --git a/aten/src/ATen/templates/TensorMethods.cpp b/aten/src/ATen/templates/TensorMethods.cpp index 84b48bd..29a43a6 100644 --- a/aten/src/ATen/templates/TensorMethods.cpp +++ b/aten/src/ATen/templates/TensorMethods.cpp @@ -3,9 +3,9 @@ namespace at { -#define DEFINE_CAST(T, name) \ +#define DEFINE_CAST(T, name) \ template <> \ - TORCH_API T* Tensor::data_ptr() const { \ + TORCH_API T* TensorBase::data_ptr() const { \ TORCH_CHECK( \ scalar_type() == ScalarType::name, \ "expected scalar type " \ diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index de829c4..2fd5c64 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -83,8 +83,9 @@ const at::Tensor& TensorImpl::grad() const { return autograd_meta_->grad(); } -const at::Tensor& TensorImpl::_fw_grad(uint64_t level, const at::Tensor& self) - const { +const at::Tensor& TensorImpl::_fw_grad( + uint64_t level, + const at::TensorBase& self) const { // See TensorImpl::grad() above for explanation about the line below if (!autograd_meta_) return impl::GetAutogradMetaFactory()->undefined_tensor(); @@ -92,8 +93,8 @@ const at::Tensor& TensorImpl::_fw_grad(uint64_t level, const at::Tensor& self) } void TensorImpl::_set_fw_grad( - const at::Tensor& new_grad, - const at::Tensor& self, + const at::TensorBase& new_grad, + const at::TensorBase& self, uint64_t level, bool is_inplace_op) { if (!autograd_meta_) diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 7051e36..cee3600 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -38,6 +38,7 @@ C10_DECLARE_int64(caffe2_max_keep_on_shrink_memory); namespace at { class Tensor; +class TensorBase; } namespace c10 { @@ -151,11 +152,11 @@ struct C10_API AutogradMetaInterface { virtual bool requires_grad() const = 0; virtual at::Tensor& mutable_grad() = 0; virtual const at::Tensor& grad() const = 0; - virtual const at::Tensor& fw_grad(uint64_t level, const at::Tensor& self) + virtual const at::Tensor& fw_grad(uint64_t level, const at::TensorBase& self) const = 0; virtual void set_fw_grad( - const at::Tensor& new_grad, - const at::Tensor& self, + const at::TensorBase& new_grad, + const at::TensorBase& self, uint64_t level, bool is_inplace_op) = 0; virtual ~AutogradMetaInterface(); @@ -1055,7 +1056,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * - "self" should represent the Tensor whose forward grad is accessed. It * is required when dealing with view. */ - const at::Tensor& _fw_grad(uint64_t level, const at::Tensor& self) const; + const at::Tensor& _fw_grad(uint64_t level, const at::TensorBase& self) const; /** * Sets the forward gradient for this Tensor. @@ -1078,8 +1079,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { * better error checking. */ void _set_fw_grad( - const at::Tensor& new_grad, - const at::Tensor& self, + const at::TensorBase& new_grad, + const at::TensorBase& self, uint64_t level, bool is_inplace_op); diff --git a/caffe2/core/tensor.cc b/caffe2/core/tensor.cc index 646c5cc..f5fb7d1 100644 --- a/caffe2/core/tensor.cc +++ b/caffe2/core/tensor.cc @@ -299,7 +299,7 @@ void Tensor::CopyFrom(const Tensor& src, bool async) { #if defined(EXPOSE_C2_OPS) || \ !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) -Tensor::Tensor(at::Tensor tensor) : impl_(std::move(tensor.impl_)) { +Tensor::Tensor(at::Tensor tensor) : impl_(tensor.unsafeReleaseIntrusivePtr()) { enforce_invariants(); } diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp index 927d884..58206b7 100644 --- a/test/cpp/api/modules.cpp +++ b/test/cpp/api/modules.cpp @@ -3558,8 +3558,8 @@ namespace detail { bias_k = multihead_attn_module->bias_k.detach(); bias_v = multihead_attn_module->bias_v.detach(); } else { - bias_k = {}; - bias_v = {}; + bias_k.reset(); + bias_v.reset(); } torch::Tensor _Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1); diff --git a/torch/csrc/api/include/torch/nn/functional/activation.h b/torch/csrc/api/include/torch/nn/functional/activation.h index 42ade6d..cddad59 100644 --- a/torch/csrc/api/include/torch/nn/functional/activation.h +++ b/torch/csrc/api/include/torch/nn/functional/activation.h @@ -645,8 +645,8 @@ inline std::tuple multi_head_attention_forward( if (!key.defined()) { TORCH_INTERNAL_ASSERT(!value.defined()); - k = {}; - v = {}; + k.reset(); + v.reset(); } else { // This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias; diff --git a/torch/csrc/api/src/nn/modules/activation.cpp b/torch/csrc/api/src/nn/modules/activation.cpp index e724a75..a770c15 100644 --- a/torch/csrc/api/src/nn/modules/activation.cpp +++ b/torch/csrc/api/src/nn/modules/activation.cpp @@ -511,8 +511,8 @@ void MultiheadAttentionImpl::reset() { bias_k = register_parameter("bias_k", torch::empty({1, 1, options.embed_dim()})); bias_v = register_parameter("bias_v", torch::empty({1, 1, options.embed_dim()})); } else { - bias_k = {}; - bias_v = {}; + bias_k.reset(); + bias_v.reset(); } _reset_parameters(); } diff --git a/torch/csrc/autograd/autograd_meta.cpp b/torch/csrc/autograd/autograd_meta.cpp index f35c122..2ad2778 100644 --- a/torch/csrc/autograd/autograd_meta.cpp +++ b/torch/csrc/autograd/autograd_meta.cpp @@ -101,7 +101,7 @@ namespace { // This function is will ensure that the fw_grad_ is properly a view of the base for inplace ops on // Tensors that do not have forward grad originally. -void AutogradMeta::set_fw_grad(const Variable& new_grad_, const Variable& self, uint64_t level, bool is_inplace_op) { +void AutogradMeta::set_fw_grad(const at::TensorBase& new_grad_base, const at::TensorBase& self_base, uint64_t level, bool is_inplace_op) { // Lazy initialization { std::lock_guard lock(mutex_); @@ -112,21 +112,23 @@ void AutogradMeta::set_fw_grad(const Variable& new_grad_, const Variable& self, if (fw_grad_->contains(level)) { // Setting the forward grad again is only allowed if it is a no-op. // We do allow this case to simplify writing codegen for inplace ops. - TORCH_INTERNAL_ASSERT(new_grad_.defined(), "Cannot set a forward grad that is an undefined Tensor. Use " + TORCH_INTERNAL_ASSERT(new_grad_base.defined(), "Cannot set a forward grad that is an undefined Tensor. Use " "_fw_primal(level) to get a new Tensor with this forward grad unset."); TORCH_INTERNAL_ASSERT(is_inplace_op, "Only inplace operations can re-set the forward grad of a Tensor that " "already has one."); - TORCH_INTERNAL_ASSERT(fw_grad_->value(level).is_same(new_grad_), "Cannot set a value of a forward grad if it " + TORCH_INTERNAL_ASSERT(fw_grad_->value(level).is_same(new_grad_base), "Cannot set a value of a forward grad if it " "already exists. Inplace operations should modify it inplace."); } else { // TODO(alband) remove this spurious version counter bump - auto new_grad = new_grad_; + Tensor new_grad(new_grad_base); + at::OptionalTensorRef self_ref(self_base); + const Tensor &self = *self_ref; - TORCH_CHECK(self.is_same_size(new_grad_), "Trying to set a forward gradient that has a different size than that " + TORCH_CHECK(self.is_same_size(new_grad), "Trying to set a forward gradient that has a different size than that " "of the original Tensor, this is not supported. Tensor is of size ", self.sizes(), " while the given " - "forward gradient is of size ", new_grad_.sizes(), "."); + "forward gradient is of size ", new_grad.sizes(), "."); if (is_inplace_op && is_view_) { auto this_view_meta = static_cast(this); @@ -182,7 +184,7 @@ void AutogradMeta::set_fw_grad(const Variable& new_grad_, const Variable& self, } } -const Variable& AutogradMeta::fw_grad(uint64_t level, const Variable& self) const { +const Variable& AutogradMeta::fw_grad(uint64_t level, const at::TensorBase& self) const { // TLS that disables forward AD // This is only used for custom Function implementation if (!c10::AutogradState::get_tls_state().get_fw_grad_mode()) { diff --git a/torch/csrc/autograd/cpp_hook.cpp b/torch/csrc/autograd/cpp_hook.cpp index d7bfd04..e593354 100644 --- a/torch/csrc/autograd/cpp_hook.cpp +++ b/torch/csrc/autograd/cpp_hook.cpp @@ -4,7 +4,7 @@ namespace { using torch::autograd::Variable; -void check_single_result (Variable value, Variable result, std::string hook_name) { +void check_single_result (const at::TensorBase &value, const at::TensorBase &result, std::string hook_name) { if (!value.defined()) { throw std::runtime_error("can't replace a empty gradient with a non-empty value"); } @@ -34,7 +34,7 @@ variable_list CppFunctionPreHook::operator()(const variable_list& values) { continue; } check_single_result(value, res, c10::to_string(i)); - value = res; + value = std::move(res); } variable_list results(values); results[value_idx_] = value; diff --git a/torch/csrc/autograd/cpp_hook.h b/torch/csrc/autograd/cpp_hook.h index a5f18291..c3d9135 100644 --- a/torch/csrc/autograd/cpp_hook.h +++ b/torch/csrc/autograd/cpp_hook.h @@ -5,7 +5,7 @@ namespace torch { namespace autograd { -using hooks_list = std::vector>; +using hooks_list = std::vector>; struct CppFunctionPreHook : public FunctionPreHook { CppFunctionPreHook(const std::shared_ptr &hooks, int value_idx); diff --git a/torch/csrc/autograd/custom_function.cpp b/torch/csrc/autograd/custom_function.cpp index 1bb4cb8..be02659 100644 --- a/torch/csrc/autograd/custom_function.cpp +++ b/torch/csrc/autograd/custom_function.cpp @@ -364,7 +364,7 @@ optional_variable_list _wrap_outputs(const variable_list &input_vars, return outputs; } -void check_variable_result(const Variable& original, const Variable& result, std::string hook_name) { +void check_variable_result(const at::TensorBase& original, const at::TensorBase& result, std::string hook_name) { if (!original.options().type_equal(result.options())) { std::stringstream ss; ss << "hook '" << hook_name << "' has changed the type of value ("; diff --git a/torch/csrc/autograd/custom_function.h b/torch/csrc/autograd/custom_function.h index 94e62bf..7e67a10 100644 --- a/torch/csrc/autograd/custom_function.h +++ b/torch/csrc/autograd/custom_function.h @@ -20,8 +20,8 @@ TORCH_API std::vector> _wrap_outputs( const std::shared_ptr &cdata, _jvp_fn_t jvp_user_function); -TORCH_API void check_variable_result(const Variable& original, - const Variable& result, std::string hook_name); +TORCH_API void check_variable_result(const at::TensorBase& original, + const at::TensorBase& result, std::string hook_name); // Get the return type of the forward function of the custom Function class X template diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 50d6eb9..5f115ee 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -158,7 +158,7 @@ static PyObject* THPVariable_NewWithVar( } // TODO: Make this take Variable by const reference -PyObject * THPVariable_Wrap(Variable var) +PyObject * THPVariable_Wrap(at::TensorBase var) { if (!var.defined()) { Py_RETURN_NONE; diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h index 4078235..407d5dc 100644 --- a/torch/csrc/autograd/python_variable.h +++ b/torch/csrc/autograd/python_variable.h @@ -22,7 +22,7 @@ THP_API PyObject *THPVariableClass; THP_API PyObject *ParameterClass; bool THPVariable_initModule(PyObject *module); -THP_API PyObject * THPVariable_Wrap(torch::autograd::Variable var); +THP_API PyObject * THPVariable_Wrap(at::TensorBase var); static inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) { // Check that a python object is a `Tensor`, but not a `Tensor` subclass. diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index 7ae1ac0..b123b7c 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -130,7 +130,7 @@ static c10::impl::AutogradMetaFactoryRegisterer meta_factory_registerer(&meta_fa namespace impl { - AutogradMeta* materialize_autograd_meta(const Variable& self) { + AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) { TORCH_CHECK(self.defined(), "cannot call materialize_autograd_meta() on undefined tensor"); auto p = self.unsafeGetTensorImpl(); if (!p->autograd_meta()) { @@ -165,7 +165,7 @@ namespace impl { set_gradient_edge(self, std::move(gradient_edge)); } - void create_cpp_hook(const Variable& self) { + void create_cpp_hook(const at::TensorBase& self) { auto &list = materialize_autograd_meta(self)->cpp_hooks_list_; // NOLINTNEXTLINE(modernize-make-shared) list.reset(new hooks_list()); @@ -277,7 +277,7 @@ namespace impl { // Hooks //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - void add_hook(const Variable& self, std::shared_ptr hook) { + void add_hook(const at::TensorBase& self, std::shared_ptr hook) { materialize_autograd_meta(self)->hooks_.push_back(std::move(hook)); } @@ -296,7 +296,7 @@ namespace impl { } } - void clear_hooks(const Variable& self) { + void clear_hooks(const at::TensorBase& self) { // This is a little goofy, but usually this should be a no oop materialize_autograd_meta(self)->hooks_.clear(); } @@ -308,13 +308,13 @@ namespace impl { // Miscellaneous //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - AutogradMeta* get_autograd_meta(const Variable& self) { + AutogradMeta* get_autograd_meta(const at::TensorBase& self) { // NB: could return nullptr TORCH_CHECK(self.defined(), "cannot call get_autograd_meta() on undefined tensor"); return static_cast(self.unsafeGetTensorImpl()->autograd_meta()); } - DifferentiableViewMeta* get_view_autograd_meta(const Variable& self) { + DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase& self) { // NB: return nullptr if self is not a view AutogradMeta* meta = get_autograd_meta(self); if (meta && meta->is_view_) { @@ -329,31 +329,32 @@ namespace impl { using at::Tensor; struct VariableHooks final : at::impl::VariableHooksInterface { - Tensor tensor_data(const Tensor&) const override; - Tensor variable_data(const Tensor&) const override; - const std::shared_ptr& grad_fn(const Tensor&) const override; - unsigned _register_hook(const Tensor&, std::function hook) const override; - void remove_hook(const Tensor&, unsigned pos) const override; - bool is_view(const Tensor&) const override; - const Tensor& base(const Tensor&) const override; - const std::string& name(const Tensor&) const override; - bool is_leaf(const Tensor&) const override; - int64_t output_nr(const Tensor&) const override; - void set_data(const Tensor & self, const Tensor & new_data) const override; - Tensor data(const Tensor & self) const override; - int64_t _version(const Tensor & self) const override; - void retain_grad(const Tensor& self) const override; - bool retains_grad(const Tensor& self) const override; + at::TensorBase tensor_data(const at::TensorBase&) const override; + at::TensorBase variable_data(const at::TensorBase&) const override; + const std::shared_ptr& grad_fn(const at::TensorBase&) const override; + unsigned _register_hook( + const at::TensorBase&, std::function hook) const override; + void remove_hook(const at::TensorBase&, unsigned pos) const override; + bool is_view(const at::TensorBase&) const override; + const at::TensorBase& base(const at::TensorBase&) const override; + const std::string& name(const at::TensorBase&) const override; + bool is_leaf(const at::TensorBase&) const override; + int64_t output_nr(const at::TensorBase&) const override; + void set_data(const at::TensorBase & self, const at::TensorBase & new_data) const override; + at::TensorBase data(const at::TensorBase & self) const override; + int64_t _version(const at::TensorBase & self) const override; + void retain_grad(const at::TensorBase& self) const override; + bool retains_grad(const at::TensorBase& self) const override; void _backward(const Tensor& self, at::TensorList inputs, const c10::optional& gradient, c10::optional keep_graph, bool create_graph) const override; - void requires_grad_(const Tensor& self, bool _requires_grad) const override; + void requires_grad_(const at::TensorBase& self, bool _requires_grad) const override; }; VariableHooks variableHooks; at::impl::VariableHooksRegisterer registerVariableHooks(&variableHooks); -Tensor VariableHooks::variable_data(const Tensor& self) const { +at::TensorBase VariableHooks::variable_data(const at::TensorBase& self) const { TORCH_CHECK(self.defined(), "cannot call variable_data() on undefined tensor"); auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach( /*version_counter=*/0, @@ -362,7 +363,7 @@ Tensor VariableHooks::variable_data(const Tensor& self) const { return at::Tensor(self_impl_copy); } -Tensor VariableHooks::tensor_data(const Tensor& self) const { +at::TensorBase VariableHooks::tensor_data(const at::TensorBase& self) const { TORCH_CHECK(self.defined(), "cannot call tensor_data() on undefined tensor"); auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach( /*version_counter=*/self.unsafeGetTensorImpl()->version_counter(), @@ -370,7 +371,7 @@ Tensor VariableHooks::tensor_data(const Tensor& self) const { return at::Tensor(self_impl_copy); } -bool VariableHooks::is_leaf(const Tensor & self) const { +bool VariableHooks::is_leaf(const at::TensorBase & self) const { if (impl::get_autograd_meta(self)) { return impl::get_autograd_meta(self)->grad_fn_ == nullptr; } else { @@ -378,7 +379,7 @@ bool VariableHooks::is_leaf(const Tensor & self) const { } } -int64_t VariableHooks::output_nr(const Tensor & self) const { +int64_t VariableHooks::output_nr(const at::TensorBase & self) const { if (impl::get_autograd_meta(self)) { return impl::get_autograd_meta(self)->output_nr_; } else { @@ -386,7 +387,12 @@ int64_t VariableHooks::output_nr(const Tensor & self) const { } } -void VariableHooks::set_data(const Tensor & self, const Tensor & new_data) const { +void VariableHooks::set_data(const at::TensorBase & self_base, const at::TensorBase & new_data_base) const { + at::OptionalTensorRef self_ref(self_base); + const Tensor &self = *self_ref; + at::OptionalTensorRef new_data_ref(new_data_base); + const Tensor &new_data = *new_data_ref; + // `var.set_data(new_data)` shallow-copies all non-autograd TensorImpl fields // from `new_data` to `var`. It requires that `new_data` and `var` have compatible // tensor type. @@ -420,15 +426,15 @@ void VariableHooks::set_data(const Tensor & self, const Tensor & new_data) const self.unsafeGetTensorImpl()->shallow_copy_from(new_data.getIntrusivePtr()); } -Tensor VariableHooks::data(const Tensor & self) const { +at::TensorBase VariableHooks::data(const at::TensorBase & self) const { return self.variable_data(); } -int64_t VariableHooks::_version(const Tensor & self) const { +int64_t VariableHooks::_version(const at::TensorBase & self) const { return self.unsafeGetTensorImpl()->version_counter().current_version(); } -void VariableHooks::retain_grad(const Tensor& self) const { +void VariableHooks::retain_grad(const at::TensorBase& self) const { TORCH_CHECK(self.requires_grad(), "can't retain_grad on Tensor that has requires_grad=False"); if (self.is_leaf()) { // no-op for leaves return; @@ -438,7 +444,7 @@ void VariableHooks::retain_grad(const Tensor& self) const { } c10::weak_intrusive_ptr weak_self(self.getIntrusivePtr()); - std::function retain_grad_hook([weak_self](const Tensor& grad) { + auto retain_grad_hook = [weak_self](const at::Tensor& grad) { if (weak_self.expired()) { return; } else { @@ -453,13 +459,13 @@ void VariableHooks::retain_grad(const Tensor& self) const { var->mutable_grad() = var->grad() + grad; } } - }); + }; - self.register_hook(retain_grad_hook); + at::OptionalTensorRef(self)->register_hook(retain_grad_hook); impl::get_autograd_meta(self)->retains_grad_ = true; } -bool VariableHooks::retains_grad(const Tensor& self) const { +bool VariableHooks::retains_grad(const at::TensorBase& self) const { if (impl::get_autograd_meta(self)) { return impl::get_autograd_meta(self)->retains_grad_; } else { @@ -480,7 +486,7 @@ void VariableHooks::_backward( torch::autograd::backward({self}, {_gradient}, keep_graph, create_graph, input_vars); } -void VariableHooks::requires_grad_(const Tensor& self, bool _requires_grad) const { +void VariableHooks::requires_grad_(const at::TensorBase& self, bool _requires_grad) const { if (!self.is_leaf() && !_requires_grad) { throw std::runtime_error( autograd::utils::requires_grad_leaf_error(_requires_grad) @@ -492,7 +498,7 @@ void VariableHooks::requires_grad_(const Tensor& self, bool _requires_grad) cons // Backward View Variables //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -bool VariableHooks::is_view(const Tensor& self) const { +bool VariableHooks::is_view(const at::TensorBase& self) const { auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self); if (diff_view_meta) { return diff_view_meta->has_bw_view(); @@ -501,7 +507,7 @@ bool VariableHooks::is_view(const Tensor& self) const { } } -const Tensor& VariableHooks::base(const Tensor& self) const { +const at::TensorBase& VariableHooks::base(const at::TensorBase& self) const { auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self); if (diff_view_meta) { TORCH_CHECK(diff_view_meta->has_bw_view(), "Can't get base of non-backward view Tensor"); @@ -515,7 +521,7 @@ namespace { std::string singleton_string; } -const std::string& VariableHooks::name(const Tensor& self) const { +const std::string& VariableHooks::name(const at::TensorBase& self) const { TORCH_CHECK(self.defined(), "cannot call variable_data() on undefined tensor"); if (torch::autograd::impl::get_autograd_meta(self)) { return torch::autograd::impl::get_autograd_meta(self)->name_; @@ -528,7 +534,7 @@ namespace { std::shared_ptr singleton_shared_ptr; } -const std::shared_ptr& VariableHooks::grad_fn(const Tensor& self) const { +const std::shared_ptr& VariableHooks::grad_fn(const at::TensorBase& self) const { auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self); if (diff_view_meta && diff_view_meta->has_bw_view()) { // See NOTE [ View + Inplace detection ] @@ -596,14 +602,15 @@ const std::shared_ptr& VariableHooks::grad_fn(const Tenso } } -void VariableHooks::remove_hook(const Tensor& self, unsigned pos) const { +void VariableHooks::remove_hook(const at::TensorBase& self, unsigned pos) const { auto &list = torch::autograd::impl::materialize_autograd_meta(self)->cpp_hooks_list_; TORCH_CHECK(list && pos < list->size() , "Invalid index, no hook at position ", pos); // Hook will be ignored (*list)[pos] = nullptr; } -unsigned VariableHooks::_register_hook(const Tensor& self, std::function hook) const { +unsigned VariableHooks::_register_hook( + const at::TensorBase& self, std::function hook) const { TORCH_CHECK(self.requires_grad(), "cannot register a hook on a variable that " "doesn't require gradient"); // NB: materialize_autograd_meta unnecessary due to requires grad check diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index 5f0846a..3a81cfb 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -107,15 +107,15 @@ namespace impl { // WARNING: This may return a nullptr. If you require AutogradMeta to return // a materialized structure, use materialize_autograd_meta instead. - TORCH_API AutogradMeta* get_autograd_meta(const Variable&); + TORCH_API AutogradMeta* get_autograd_meta(const at::TensorBase&); // WARNING: This will return a nullptr if the Tensor is not a view. - TORCH_API DifferentiableViewMeta* get_view_autograd_meta(const Variable&); + TORCH_API DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase&); // Returns the current autograd meta, materializing it if it was previously // none. This counts as a *mutating* operation, so do not call it on // "read-only" operators; in particular, this is NOT thread safe - TORCH_API AutogradMeta* materialize_autograd_meta(const Variable&); + TORCH_API AutogradMeta* materialize_autograd_meta(const at::TensorBase&); /// Set the gradient accumulator of the `Variable`. This is only applicable to /// leaf variables. Interior variables should call `set_gradient_edge()`. @@ -171,11 +171,11 @@ namespace impl { TORCH_API void set_name(const Variable&, const std::string& name); - TORCH_API void add_hook(const Variable&, std::shared_ptr hook); + TORCH_API void add_hook(const at::TensorBase&, std::shared_ptr hook); TORCH_API const std::vector>& hooks(const Variable&); - TORCH_API void clear_hooks(const Variable&); + TORCH_API void clear_hooks(const at::TensorBase&); - TORCH_API void create_cpp_hook(const Variable&); + TORCH_API void create_cpp_hook(const at::TensorBase&); } //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -252,9 +252,9 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { return grad_; } - const Variable& fw_grad(uint64_t level, const Variable& self) const override; + const Variable& fw_grad(uint64_t level, const at::TensorBase& self) const override; - void set_fw_grad(const Variable& new_grad, const Variable& self, uint64_t level, bool is_inplace_op) override; + void set_fw_grad(const at::TensorBase& new_grad, const at::TensorBase& self, uint64_t level, bool is_inplace_op) override; // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) AutogradMeta(at::TensorImpl* self_impl = nullptr, bool requires_grad = false, Edge gradient_edge = Edge() ) { -- 2.7.4