Factor out TensorBase that doesn't depend on native operators (#63612)
authorPeter Bell <peterbell10@live.co.uk>
Wed, 8 Sep 2021 20:25:42 +0000 (13:25 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 20:28:54 +0000 (13:28 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63612

This makes Tensor inherit from a new class TensorBase, that provides a subset of Tensor that doesn't
directly depend on native_functions.yaml. Code that only includes TensorBase.h with thus not need to
be rebuilt every time someone changes an operator signature.

Making `Tensor` inherit from this class means that `const TensorBase&` parameters will be callable
with an ordinary `Tensor`. I've also made `Tensor` constructible and assignable from `TensorBase` to
minimize friction in code mixing the two types.

To help enforce that `Tensor.h` and `Functions.h` aren't accidentally included, I've added an error
into `Operators.h` if `TORCH_ASSERT_NO_OPERATORS` is defined. We can either set this in the build
system for certain folders, or just define it at the top of any file.

I've also included an example of manually special-casing the commonly used `contiguous` operator.
The inline function's slow path defers to `TensorBase::__dispatch_contiguous` which is defined in
`Tensor.cpp`. I've made it so `OptionalTensorRef` is constructible from `TensorBase`, so I can
materialize a `Tensor` for use in dispatch without actually increasing its refcount.

Test Plan: Imported from OSS

Reviewed By: gchanan

Differential Revision: D30728580

Pulled By: ezyang

fbshipit-source-id: 2cbc8eee08043382ee6904ea8e743b1286921c03

27 files changed:
aten/src/ATen/core/NamedTensor.cpp
aten/src/ATen/core/NamedTensor.h
aten/src/ATen/core/QuantizerBase.h
aten/src/ATen/core/Tensor.cpp
aten/src/ATen/core/Tensor.h
aten/src/ATen/core/TensorBase.h [new file with mode: 0644]
aten/src/ATen/core/VariableHooksInterface.h
aten/src/ATen/quantized/Quantizer.cpp
aten/src/ATen/quantized/Quantizer.h
aten/src/ATen/templates/Operators.h
aten/src/ATen/templates/TensorBody.h
aten/src/ATen/templates/TensorMethods.cpp
c10/core/TensorImpl.cpp
c10/core/TensorImpl.h
caffe2/core/tensor.cc
test/cpp/api/modules.cpp
torch/csrc/api/include/torch/nn/functional/activation.h
torch/csrc/api/src/nn/modules/activation.cpp
torch/csrc/autograd/autograd_meta.cpp
torch/csrc/autograd/cpp_hook.cpp
torch/csrc/autograd/cpp_hook.h
torch/csrc/autograd/custom_function.cpp
torch/csrc/autograd/custom_function.h
torch/csrc/autograd/python_variable.cpp
torch/csrc/autograd/python_variable.h
torch/csrc/autograd/variable.cpp
torch/csrc/autograd/variable.h

index b394a8d..40fd58a 100644 (file)
@@ -1,6 +1,7 @@
+#define TORCH_ASSERT_NO_OPERATORS
 #include <ATen/core/NamedTensor.h>
 
-#include <ATen/core/Tensor.h>
+#include <ATen/core/TensorBase.h>
 #include <c10/util/C++17.h>
 
 namespace at {
@@ -16,12 +17,12 @@ void NamesMode::set_enabled(bool enabled) {
   c10::impl::tls_set_dispatch_key_excluded(DispatchKey::Named, !enabled);
 }
 
-const Tensor& internal_set_names_inplace(const Tensor& tensor, optional<DimnameList> names) {
+const TensorBase& internal_set_names_inplace(const TensorBase& tensor, optional<DimnameList> names) {
   impl::internal_set_names_inplace(tensor.unsafeGetTensorImpl(), names, /*validate_names=*/true);
   return tensor;
 }
 
-const Tensor& internal_set_names_inplace(const Tensor& tensor, std::vector<Dimname>&& names, bool validate_names) {
+const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::vector<Dimname>&& names, bool validate_names) {
   impl::internal_set_names_inplace(tensor.unsafeGetTensorImpl(), std::move(names), validate_names);
   return tensor;
 }
@@ -48,7 +49,7 @@ static void check_unique_names(DimnameList names) {
   }
 }
 
-void check_names_valid_for(const Tensor& tensor, DimnameList names) {
+void check_names_valid_for(const TensorBase& tensor, DimnameList names) {
   return impl::check_names_valid_for(tensor.unsafeGetTensorImpl(), names);
 }
 
index e343149..73a0d7d 100644 (file)
@@ -6,6 +6,8 @@
 
 namespace at {
 
+class TensorBase;
+
 // XXX: This file exists because TensorImpl is in c10, but Dimname is in ATen.
 // Due to the c10/ATen library split, TensorImpl cannot depend on Dimname,
 // so we have a couple of workarounds.
@@ -95,12 +97,12 @@ struct TORCH_API NoNamesGuard {
   bool initialized;
 };
 
-void check_names_valid_for(const Tensor& tensor, DimnameList names);
+void check_names_valid_for(const TensorBase& tensor, DimnameList names);
 void check_names_valid_for(size_t tensor_dim, DimnameList names);
 
 // Sets the names of `tensor` to be `names`.
-TORCH_API const Tensor& internal_set_names_inplace(const Tensor& tensor, c10::optional<DimnameList> names);
-TORCH_API const Tensor& internal_set_names_inplace(const Tensor& tensor, std::vector<Dimname>&& names, bool validate_names);
+TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, c10::optional<DimnameList> names);
+TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::vector<Dimname>&& names, bool validate_names);
 
 constexpr size_t kMaxNamedTensorDim = 64;
 
index cdc5b6f..2648223 100644 (file)
@@ -1,5 +1,9 @@
 #pragma once
 
+#include <c10/core/ScalarType.h>
+#include <c10/core/QScheme.h>
+#include <c10/util/intrusive_ptr.h>
+
 namespace at {
 
 class Tensor;
index 4c7643c..98ab4af 100644 (file)
@@ -6,7 +6,16 @@
 
 namespace at {
 
-void Tensor::enforce_invariants() {
+const TensorBase& get_tensor_base(const Tensor &t) {
+  return t;
+}
+
+TensorBase TensorBase::__dispatch_contiguous(c10::MemoryFormat memory_format) const {
+  OptionalTensorRef self(*this);
+  return at::_ops::contiguous::call(*self, memory_format);
+}
+
+void TensorBase::enforce_invariants() {
   if (impl_.get() == nullptr) {
     throw std::runtime_error("TensorImpl with nullptr is not supported");
   }
@@ -26,7 +35,7 @@ void Tensor::enforce_invariants() {
   }
 }
 
-void Tensor::print() const {
+void TensorBase::print() const {
   if (defined()) {
     std::cerr << "[" << toString() << " " << sizes() << "]" << std::endl;
   } else {
@@ -34,7 +43,7 @@ void Tensor::print() const {
   }
 }
 
-std::string Tensor::toString() const {
+std::string TensorBase::toString() const {
   std::string base_str;
   if (scalar_type() == ScalarType::Undefined) {
     base_str = "UndefinedType";
@@ -44,39 +53,39 @@ std::string Tensor::toString() const {
   return base_str;
 }
 
-Tensor Tensor::variable_data() const {
+TensorBase TensorBase::variable_data() const {
   return impl::GetVariableHooks()->variable_data(*this);
 }
 
-Tensor Tensor::tensor_data() const {
+TensorBase TensorBase::tensor_data() const {
   return impl::GetVariableHooks()->tensor_data(*this);
 }
 
-bool Tensor::is_leaf() const {
+bool TensorBase::is_leaf() const {
   return impl::GetVariableHooks()->is_leaf(*this);
 }
 
-int64_t Tensor::output_nr() const {
+int64_t TensorBase::output_nr() const {
   return impl::GetVariableHooks()->output_nr(*this);
 }
 
-void Tensor::set_data(const Tensor & new_data) const {
+void TensorBase::set_data(const TensorBase & new_data) const {
   impl::GetVariableHooks()->set_data(*this, new_data);
 }
 
-Tensor Tensor::data() const {
+TensorBase TensorBase::data() const {
   return impl::GetVariableHooks()->data(*this);
 }
 
-int64_t Tensor::_version() const {
+int64_t TensorBase::_version() const {
   return impl::GetVariableHooks()->_version(*this);
 }
 
-void Tensor::retain_grad() const {
+void TensorBase::retain_grad() const {
   impl::GetVariableHooks()->retain_grad(*this);
 }
 
-bool Tensor::retains_grad() const {
+bool TensorBase::retains_grad() const {
   return impl::GetVariableHooks()->retains_grad(*this);
 }
 
@@ -87,7 +96,7 @@ void Tensor::_backward(TensorList inputs,
   return impl::GetVariableHooks()->_backward(*this, inputs, gradient, keep_graph, create_graph);
 }
 
-const Tensor& Tensor::requires_grad_(bool _requires_grad) const {
+const TensorBase& TensorBase::requires_grad_(bool _requires_grad) const {
   impl::GetVariableHooks()->requires_grad_(*this, _requires_grad);
   return *this;
 }
@@ -95,27 +104,27 @@ const Tensor& Tensor::requires_grad_(bool _requires_grad) const {
 // View Variables
 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-bool Tensor::is_view() const {
+bool TensorBase::is_view() const {
   return impl::GetVariableHooks()->is_view(*this);
 }
 
-const Tensor& Tensor::_base() const {
+const TensorBase& TensorBase::_base() const {
   return impl::GetVariableHooks()->base(*this);
 }
 
-const std::string& Tensor::name() const {
+const std::string& TensorBase::name() const {
   return impl::GetVariableHooks()->name(*this);
 }
 
-const std::shared_ptr<torch::autograd::Node>& Tensor::grad_fn() const {
+const std::shared_ptr<torch::autograd::Node>& TensorBase::grad_fn() const {
   return impl::GetVariableHooks()->grad_fn(*this);
 }
 
-void Tensor::remove_hook(unsigned pos) const {
+void TensorBase::remove_hook(unsigned pos) const {
   impl::GetVariableHooks()->remove_hook(*this, pos);
 }
 
-unsigned Tensor::_register_hook(std::function<Tensor(const Tensor&)> hook) const {
+unsigned TensorBase::_register_hook(std::function<TensorBase(const TensorBase&)> hook) const {
   return impl::GetVariableHooks()->_register_hook(*this, std::move(hook));
 }
 
index fa2479c..5293928 100644 (file)
@@ -12,7 +12,7 @@ class TORCH_API OptionalTensorRef {
     ref_.unsafeReleaseTensorImpl();
   }
 
-  OptionalTensorRef(const Tensor& src)
+  OptionalTensorRef(const TensorBase& src)
       : ref_(Tensor::unsafe_borrow_t{}, src) {
     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(src.defined());
   }
@@ -48,4 +48,27 @@ class TORCH_API OptionalTensorRef {
  private:
   Tensor ref_;
 };
+
+template <typename T>
+auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t<T> {
+  // Return the grad argument in case of a hook with void return type to have an
+  // std::function with Tensor return type
+  static_assert(std::is_same<decltype(hook(Tensor())), void>::value,
+                "Expected hook to return void");
+  return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
+    OptionalTensorRef grad(grad_base);
+    fn(*grad);
+    return Tensor();
+  });
+}
+
+template <typename T>
+auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t<T> {
+  return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad_base) {
+    OptionalTensorRef grad(grad_base);
+    Tensor ret = fn(*grad);
+    return TensorBase(std::move(ret));
+  });
+}
+
 } // namespace at
diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h
new file mode 100644 (file)
index 0000000..e91e9e1
--- /dev/null
@@ -0,0 +1,917 @@
+#pragma once
+
+#include <c10/core/Device.h>
+#include <c10/core/Layout.h>
+#include <c10/core/MemoryFormat.h>
+#include <c10/core/ScalarType.h>
+#include <c10/core/ScalarTypeToTypeMeta.h>
+#include <c10/core/Storage.h>
+#include <c10/core/TensorImpl.h>
+#include <c10/core/UndefinedTensorImpl.h>
+#include <c10/core/WrapDimMinimal.h>
+#include <c10/util/Exception.h>
+#include <c10/util/MaybeOwned.h>
+#include <c10/util/Optional.h>
+#include <c10/util/intrusive_ptr.h>
+
+#include <ATen/core/NamedTensor.h>
+#include <ATen/core/QuantizerBase.h>
+#include <ATen/core/TensorAccessor.h>
+
+namespace c10 {
+struct TensorOptions;
+}
+
+namespace torch { namespace autograd {
+
+struct Node;
+
+}} // namespace torch::autograd
+
+namespace at {
+
+class Tensor;
+class TensorBase;
+
+// Convert Tensor to TensorBase without any need to include Tensor.h
+TORCH_API const TensorBase& get_tensor_base(const Tensor& t);
+
+namespace impl {
+inline bool variable_excluded_from_dispatch() {
+#ifdef C10_MOBILE
+  // Please read the comment in `VariableFallbackKernel.cpp` about the background of this change.
+  return true;
+#else
+  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::impl::tls_local_dispatch_key_set().excluded_.has(DispatchKey::Autograd));
+  return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset);
+#endif
+}
+}
+
+// NOTE: [Tensor vs. TensorBase]
+//
+// Tensor, being the central data structure in PyTorch, gets used and
+// it's header included almost everywhere. Unfortunately this means
+// every time an operator signature is updated or changed in
+// native_functions.yaml, you (and every other PyTorch developer) need
+// to recompile all of ATen and it's dependencies.
+//
+// TensorBase aims to break up these header dependencies, and improve
+// incremental build times for all PyTorch developers. TensorBase
+// represents a reference counted handle to TensorImpl, exactly the
+// same as Tensor. However, TensorBase doesn't have code generated
+// methods in it's API and thus no dependence on native_functions.yaml.
+//
+// Usage tips
+// ----------
+// - You can `#define TORCH_ASSERT_NO_OPERATORS` at the top of a .cpp
+//   or .cu file to ensure it has no header dependencies on
+//   native_functions.yaml (direct or indirect).
+// - Tensor inherits from TensorBase, so functions taking
+//   `const TensorBase &` are callable with Tensor as well.
+// - TensorBase can be converted to tensor with `Tensor(tensor_base)`,
+//   but this requires a reference-count bump. OptionalTensorRef on
+//   the other hand can materialize a `const Tensor &` without
+//   touching the reference-count.
+class TORCH_API TensorBase {
+ public:
+  struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; };
+
+ protected:
+  // Create a Tensor with a +0 reference count. Special care must be
+  // taken to avoid decrementing this reference count at destruction
+  // time. Intended to support MaybeOwnedTraits<Tensor>.
+  explicit TensorBase(unsafe_borrow_t, const TensorBase& rhs)
+      : impl_(c10::intrusive_ptr<at::TensorImpl, UndefinedTensorImpl>::reclaim(rhs.impl_.get())) {}
+  friend MaybeOwnedTraits<TensorBase>;
+
+ public:
+  TensorBase() = default;
+  // This constructor should not be used by end users and is an implementation
+  // detail invoked by autogenerated code.
+  explicit TensorBase(
+      c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
+      : impl_(std::move(tensor_impl)) {
+    if (impl_.get() == nullptr) {
+      throw std::runtime_error("TensorImpl with nullptr is not supported");
+    }
+  }
+  TensorBase(const TensorBase&) = default;
+  TensorBase(TensorBase&&) = default;
+
+ public:
+  // Creates a new wrapper from TensorImpl. Intentionally a free method because
+  // it should be used with care. Checks necessary invariants
+  static TensorBase wrap_tensor_impl(
+      c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) {
+    TensorBase r(std::move(tensor_impl));
+    r.enforce_invariants();
+    return r;
+  }
+
+  int64_t dim() const {
+    return impl_->dim();
+  }
+  int64_t storage_offset() const {
+    return impl_->storage_offset();
+  }
+
+  TensorBase contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const {
+    if (is_contiguous(memory_format)) {
+      return *this;
+    } else {
+      return __dispatch_contiguous(memory_format);
+    }
+  }
+
+  /// Should be used if *this can reasonably be expected to be contiguous and
+  /// performance is important.
+  /// Compared to contiguous, it saves a reference count
+  /// increment/decrement if *this is already contiguous, at the cost
+  /// in all cases of an extra pointer of stack usage, an extra branch
+  /// to access, and an extra branch at destruction time.
+  c10::MaybeOwned<TensorBase> expect_contiguous(
+      MemoryFormat memory_format=MemoryFormat::Contiguous) const &;
+
+  // Use .contiguous() instead. Trying to borrow from a prvalue
+  // will only lead to trouble and dangling references.
+  c10::MaybeOwned<TensorBase> expect_contiguous(
+      MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete;
+
+  bool is_complex() const {
+    return at::isComplexType(this->scalar_type());
+  }
+
+  bool is_floating_point() const {
+    return at::isFloatingType(this->scalar_type());
+  }
+
+  bool is_signed() const {
+    return at::isSignedType(this->scalar_type());
+  }
+
+  int64_t size(int64_t dim) const {
+    // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
+    dim = c10::maybe_wrap_dim(dim, this->dim(), false);
+    return sizes()[dim];
+  }
+
+  int64_t stride(int64_t dim) const {
+    // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
+    dim = c10::maybe_wrap_dim(dim, this->dim(), false);
+    return strides()[dim];
+  }
+
+  TensorImpl * unsafeGetTensorImpl() const {
+    return impl_.get();
+  }
+  TensorImpl * unsafeReleaseTensorImpl() {
+    return impl_.release();
+  }
+  const c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>& getIntrusivePtr() const {
+    return impl_;
+  }
+
+  c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> unsafeReleaseIntrusivePtr() {
+    return std::move(impl_);
+  }
+
+  bool defined() const {
+    return impl_;
+  }
+
+  void reset() {
+    impl_.reset();
+  }
+
+  TensorBase& operator=(const TensorBase& x) & {
+    impl_ = x.impl_;
+    return *this;
+  };
+  TensorBase& operator=(TensorBase&& x) & {
+    impl_ = std::move(x.impl_);
+    return *this;
+  }
+
+  // Ban assignment to rvalues, since at::Tensor (weirdly) performs a deep copy here
+  TensorBase& operator=(const TensorBase&) && = delete;
+  TensorBase& operator=(TensorBase&&) && = delete;
+
+  bool is_same(const TensorBase& other) const noexcept {
+    return impl_ == other.impl_;
+  }
+  size_t use_count() const noexcept {
+    return impl_.use_count();
+  }
+  size_t weak_use_count() const noexcept {
+    return impl_.weak_use_count();
+  }
+
+  std::string toString() const;
+
+  IntArrayRef sizes() const {
+    return impl_->sizes();
+  }
+  IntArrayRef strides() const {
+    return impl_->strides();
+  }
+  // See impl::get_opt_names in ATen/NamedTensor.h for docs.
+  c10::optional<DimnameList> opt_names() const {
+    return impl::get_opt_names(unsafeGetTensorImpl());
+  }
+  // See impl::get_names in ATen/NamedTensor.h for docs.
+  DimnameList names() const {
+    return impl::get_names(unsafeGetTensorImpl());
+  }
+  int64_t ndimension() const {
+    return dim();
+  }
+
+  bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
+    return impl_->is_contiguous(memory_format);
+  }
+
+  bool is_non_overlapping_and_dense() const {
+    return impl_->is_non_overlapping_and_dense();
+  }
+
+  at::MemoryFormat suggest_memory_format(
+      bool channels_last_strides_exact_match = false) const {
+    // Setting channels_last_strides_exact_match to true forces function to
+    // check 0,1 - sized dimension strides.
+    if (!is_mkldnn() && !is_sparse()) {
+      if (impl_->is_strides_like_channels_last()) {
+        if (!channels_last_strides_exact_match ||
+            get_channels_last_strides_2d(sizes()) == strides()) {
+          return at::MemoryFormat::ChannelsLast;
+        }
+      }
+      else if (impl_->is_strides_like_channels_last_3d()) {
+        if (!channels_last_strides_exact_match ||
+            get_channels_last_strides_3d(sizes()) == strides()) {
+          return at::MemoryFormat::ChannelsLast3d;
+        }
+      }
+    }
+    return at::MemoryFormat::Contiguous;
+  }
+
+  // Total bytes consumed by the "view" of elements of the array.  Does not
+  // include size of metadata.  The number reported here does not necessarily
+  // correspond to the true physical memory consumed by a tensor; instead,
+  // it reports the memory the tensor would take *if* it were contiguous.
+  // Defined to be numel() * itemsize()
+  size_t nbytes() const {
+    TORCH_CHECK(layout () != at::kSparse,
+                "nbytes is not defined for sparse tensors.  If you want the size of the constituent " \
+                "tensors, add the nbytes of the indices and values.  If you want the size of the  " \
+                "equivalent dense tensor, multiply numel() by element_size()");
+    return impl_->numel() * impl_->itemsize();
+  }
+
+  int64_t numel() const {
+    return impl_->numel();
+  }
+
+  // Length of one array element in bytes.  This is the traditional
+  // Numpy naming.
+  size_t itemsize() const {
+    return impl_->itemsize();
+  }
+
+  // Same as itemsize().  This is the PyTorch naming.
+  int64_t element_size() const {
+    return static_cast<int64_t>(impl_->itemsize());
+  }
+
+  DispatchKeySet key_set() const {
+    return impl_->key_set();
+  }
+  ScalarType scalar_type() const {
+    return typeMetaToScalarType(impl_->dtype());
+  }
+  bool has_storage() const {
+    return defined() && impl_->has_storage();
+  }
+  const Storage& storage() const {
+    return impl_->storage();
+  }
+  bool is_alias_of(const at::TensorBase& other) const{
+    return impl_->storage().is_alias_of(other.storage());
+  }
+
+  inline bool is_conj() const {
+    return impl_->is_conj();
+  }
+
+  // sets the conjugate bit of a tensor.
+  // NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are sure
+  // that's what you want. Changing this might lead to incorrect behavior since conjugation is
+  // a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized.
+  inline void _set_conj(bool conjugate) const {
+    impl_->_set_conj(conjugate);
+  }
+
+  inline bool is_neg() const {
+    return impl_->is_neg();
+  }
+
+  // sets the negative bit of a tensor.
+  // NOTE: Negative bit is supposed to be a read-only field. Only change this, if you are sure
+  // that's what you want. Changing this might lead to incorrect behavior since we rely on this
+  // bit to determine if a negation needs to be materialized.
+  inline void _set_neg(bool negative) const {
+    impl_->_set_neg(negative);
+  }
+
+  /// Returns a `Tensor`'s layout.
+  Layout layout() const noexcept {
+    return impl_->layout();
+  }
+
+  /// Returns a `Tensor`'s dtype (`TypeMeta`).
+  caffe2::TypeMeta dtype() const noexcept {
+    return impl_->dtype();
+  }
+
+  /// Returns a `Tensor`'s device.
+  inline Device device() const {
+    return impl_->device();
+  }
+
+  /// Returns a `Tensor`'s device index.
+  int64_t get_device() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->get_device();
+  }
+
+  /// Returns if a `Tensor` has CPU backend.
+  bool is_cpu() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_cpu();
+  }
+
+  /// Returns if a `Tensor` has CUDA backend.
+  bool is_cuda() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_cuda();
+  }
+
+  /// Returns if a `Tensor` has XPU backend.
+  bool is_xpu() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_xpu();
+  }
+
+  /// Returns if a `Tensor` has XLA backend.
+  bool is_xla() const {
+    return impl_->is_xla();
+  }
+
+  /// Returns if a `Tensor` has Lazy backend.
+  bool is_lazy() const {
+    return impl_->is_lazy();
+  }
+
+  /// Returns if a `Tensor` has HIP backend.
+  bool is_hip() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_hip();
+  }
+
+  /// Returns if a `Tensor` has VE backend.
+  bool is_ve() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_ve();
+  }
+
+  /// Returns if a `Tensor` has sparse backend.
+  bool is_sparse() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_sparse();
+  }
+
+  /// Returns is a `Tensor` has a sparse CSR backend.
+  bool is_sparse_csr() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_sparse_csr();
+  }
+
+  /// Returns if a `Tensor` is mkldnn tensor.
+  bool is_mkldnn() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_mkldnn();
+  }
+
+  /// Returns if a `Tensor` is mlc tensor.
+  bool is_mlc() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_mlc();
+  }
+
+  /// Returns if a `Tensor` is ort tensor.
+  bool is_ort() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_ort();
+  }
+
+  /// Returns if a `Tensor` is vulkan tensor.
+  bool is_vulkan() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_vulkan();
+  }
+
+  /// Returns if a `Tensor` is metal tensor.
+  bool is_metal() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_metal();
+  }
+
+  /// Returns if a `Tensor` has quantized backend.
+  bool is_quantized() const {
+    // NB: this is not a native function to avoid dispatching overhead.
+    return impl_->is_quantized();
+  }
+
+  /// Returns if a `Tensor` is a meta tensor.  Meta tensors can
+  /// also have other designations.
+  bool is_meta() const {
+    return impl_->is_meta();
+  }
+
+  /// Returns if a `Tensor` is an inference tensor.
+  bool is_inference() const {
+    return impl_->is_inference();
+  }
+
+  /// If a tensor is a quantized tensor, returns its quantizer
+  /// TODO: it's not in native_functions.yaml yet as it's not exposed to python
+  QuantizerPtr quantizer() const;
+
+  /// Returns if a `Tensor` has any dimension names
+  bool has_names() const {
+    // If a user is using unnamed tensors, then we can short-circuit right here.
+    // Otherwise, impl::has_names attempts to retrieve names.
+    if (!impl_->has_named_tensor_meta()) {
+      return false;
+    }
+    return impl::has_names(unsafeGetTensorImpl());
+  }
+
+  /// Returns a `Tensor`'s dimension names data structure
+  const NamedTensorMeta* get_named_tensor_meta() const {
+    return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
+  }
+
+  NamedTensorMeta* get_named_tensor_meta() {
+    return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
+  }
+
+  /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in
+  /// TensorOptions.h.
+  TensorOptions options() const {
+    return TensorOptions().dtype(dtype())
+                          .device(device())
+                          .layout(layout());
+  }
+
+  void* data_ptr() const {
+    return this->unsafeGetTensorImpl()->data();
+  }
+
+  template <typename T>
+  T * data_ptr() const;
+
+  // Purposely not defined here to avoid inlining
+  void print() const;
+
+  // Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar type and
+  // dimension.
+  template<typename T, size_t N>
+  TensorAccessor<T,N> accessor() const& {
+    static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
+    TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
+    return TensorAccessor<T,N>(data_ptr<T>(),sizes().data(),strides().data());
+  }
+  template<typename T, size_t N>
+  TensorAccessor<T,N> accessor() && = delete;
+
+  // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and
+  // dimension. You can optionally specify RestrictPtrTraits as a template parameter to
+  // cast the data pointer to a __restrict__ pointer.
+  // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor
+  // as an argument.
+  template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
+  GenericPackedTensorAccessor<T,N,PtrTraits,index_t> generic_packed_accessor() const& {
+    static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
+    TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
+    return GenericPackedTensorAccessor<T,N,PtrTraits,index_t>(static_cast<typename PtrTraits<T>::PtrType>(data_ptr<T>()),sizes().data(),strides().data());
+  }
+  template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
+  GenericPackedTensorAccessor<T,N> generic_packed_accessor() && = delete;
+
+  template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
+  PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() const& {
+    return generic_packed_accessor<T,N,PtrTraits,int32_t>();
+  }
+  template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
+  PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() && = delete;
+
+  template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
+  PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() const& {
+    return generic_packed_accessor<T,N,PtrTraits,int64_t>();
+  }
+  template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
+  PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() && = delete;
+
+  // ~~~~~ Autograd API ~~~~~
+
+  /// \fn bool is_leaf() const;
+  ///
+  /// All Tensors that have `requires_grad()` which is ``false`` will be leaf Tensors by convention.
+  ///
+  /// For Tensors that have `requires_grad()` which is ``true``, they will be leaf Tensors if they were
+  /// created by the user. This means that they are not the result of an operation and so
+  /// `grad_fn()` is `nullptr`.
+  ///
+  /// Only leaf Tensors will have their `grad()` populated during a call to `backward()`.
+  /// To get `grad()` populated for non-leaf Tensors, you can use `retain_grad()`.
+  ///
+  /// Example:
+  /// @code
+  /// auto a = torch::rand(10, torch::requires_grad());
+  /// std::cout << a.is_leaf() << std::endl; // prints `true`
+  ///
+  /// auto b = torch::rand(10, torch::requires_grad()).to(torch::kCUDA);
+  /// std::cout << b.is_leaf() << std::endl; // prints `false`
+  /// // b was created by the operation that cast a cpu Tensor into a cuda Tensor
+  ///
+  /// auto c = torch::rand(10, torch::requires_grad()) + 2;
+  /// std::cout << c.is_leaf() << std::endl; // prints `false`
+  /// // c was created by the addition operation
+  ///
+  /// auto d = torch::rand(10).cuda();
+  /// std::cout << d.is_leaf() << std::endl; // prints `true`
+  /// // d does not require gradients and so has no operation creating it (that is tracked by the autograd engine)
+  ///
+  /// auto e = torch::rand(10).cuda().requires_grad_();
+  /// std::cout << e.is_leaf() << std::endl; // prints `true`
+  /// // e requires gradients and has no operations creating it
+  ///
+  /// auto f = torch::rand(10, torch::device(torch::kCUDA).requires_grad(true));
+  /// std::cout << f.is_leaf() << std::endl; // prints `true`
+  /// // f requires grad, has no operation creating it
+  /// @endcode
+
+  /// \fn void backward(const Tensor & gradient={}, c10::optional<bool> retain_graph=c10::nullopt, bool create_graph=false, c10::optional<TensorList> inputs=c10::nullopt) const;
+  ///
+  /// Computes the gradient of current tensor with respect to graph leaves.
+  ///
+  /// The graph is differentiated using the chain rule. If the tensor is
+  /// non-scalar (i.e. its data has more than one element) and requires
+  /// gradient, the function additionally requires specifying ``gradient``.
+  /// It should be a tensor of matching type and location, that contains
+  /// the gradient of the differentiated function w.r.t. this Tensor.
+  ///
+  /// This function accumulates gradients in the leaves - you might need to
+  /// zero them before calling it.
+  ///
+  /// \param gradient Gradient w.r.t. the
+  ///     tensor. If it is a tensor, it will be automatically converted
+  ///     to a Tensor that does not require grad unless ``create_graph`` is True.
+  ///     None values can be specified for scalar Tensors or ones that
+  ///     don't require grad. If a None value would be acceptable then
+  ///     this argument is optional.
+  /// \param retain_graph If ``false``, the graph used to compute
+  ///     the grads will be freed. Note that in nearly all cases setting
+  ///     this option to True is not needed and often can be worked around
+  ///     in a much more efficient way. Defaults to the value of
+  ///     ``create_graph``.
+  /// \param create_graph If ``true``, graph of the derivative will
+  ///     be constructed, allowing to compute higher order derivative
+  ///     products. Defaults to ``false``.
+  /// \param inputs Inputs w.r.t. which the gradient will be accumulated into
+  ///     ``at::Tensor::grad``. All other Tensors will be ignored. If not
+  ///     provided, the gradient is accumulated into all the leaf Tensors
+  ///     that were used to compute the current tensor.
+  ///     When inputs are provided and a given input is not a leaf,
+  ///     the current implementation will call its grad_fn (even though it is not strictly needed to get this gradients).
+  ///     It is an implementation detail on which the user should not rely.
+  ///     See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
+
+  /// \fn Tensor detach() const;
+  ///
+  /// Returns a new Tensor, detached from the current graph.
+  /// The result will never require gradient.
+
+  /// \fn Tensor & detach_() const;
+  ///
+  /// Detaches the Tensor from the graph that created it, making it a leaf.
+  /// Views cannot be detached in-place.
+
+  /// \fn void retain_grad() const;
+  ///
+  /// Enables this Tensor to have their :attr:`grad` populated during
+  /// :func:`backward`. This is a no-op for leaf tensors.
+
+  /// \fn bool retains_grad() const;
+  ///
+  /// Is ``true`` if this Tensor is non-leaf and its :attr:`grad` is enabled to be
+  /// populated during :func:`backward`, ``false`` otherwise.
+
+  const TensorBase& set_requires_grad(bool requires_grad) const {
+    impl_->set_requires_grad(requires_grad);
+    return *this;
+  }
+  bool requires_grad() const {
+    return impl_->requires_grad();
+  }
+
+  /// NOTE: This is similar to the legacy `.data()` function on `Variable`, and is intended
+  /// to be used from functions that need to access the `Variable`'s equivalent `Tensor`
+  /// (i.e. `Tensor` that shares the same storage and tensor metadata with the `Variable`).
+  ///
+  /// One notable difference with the legacy `.data()` function is that changes to the
+  /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset)
+  /// will not update the original `Variable`, due to the fact that this function
+  /// shallow-copies the `Variable`'s underlying TensorImpl.
+  at::TensorBase tensor_data() const;
+
+  /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data`
+  /// in Python, which create a new `Variable` that shares the same storage and
+  /// tensor metadata with the original `Variable`, but with a completely new
+  /// autograd history.
+  ///
+  /// NOTE: If we change the tensor metadata (e.g. sizes / strides /
+  /// storage / storage_offset) of a variable created from `var.variable_data()`, those
+  /// changes will not update the original variable `var`. In `.variable_data()`, we set
+  /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal,
+  /// in order to prevent users from changing metadata of `var.variable_data()`
+  /// and expecting the original variable `var` to also be updated.
+  at::TensorBase variable_data() const;
+
+  // Gradient Node and Edges
+  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+  /// Gets the gradient function of the `Variable`. If this is a leaf variable,
+  /// the pointer returned will be null.
+  ///
+  /// For View Variables:
+  /// Gets the up-to-date grad_fn. If the shared data or base was modified, we
+  /// re-create the grad_fn to express the up-to-date view relationship between
+  /// this and the base Variable.
+  const std::shared_ptr<torch::autograd::Node>& grad_fn() const;
+
+  // Hooks
+  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+  template <typename T>
+  using hook_return_void_t = std::enable_if_t<std::is_void<typename std::result_of<T&(TensorBase)>::type>::value, unsigned>;
+  template <typename T>
+  using hook_return_var_t = std::enable_if_t<std::is_same<typename std::result_of<T&(TensorBase)>::type, TensorBase>::value, unsigned>;
+
+  /// Registers a backward hook.
+  ///
+  /// The hook will be called every time a gradient with respect to the Tensor is computed.
+  /// The hook should have one of the following signature:
+  /// ```
+  /// hook(TensorBase grad) -> TensorBase
+  /// ```
+  /// ```
+  /// hook(TensorBase grad) -> void
+  /// ```
+  /// The hook should not modify its argument, but it can optionally return a new gradient
+  /// which will be used in place of `grad`.
+  ///
+  /// This function returns the index of the hook in the list which can be used to remove hook.
+  ///
+  /// Example:
+  /// @code
+  /// auto v = torch::tensor({0., 0., 0.}, torch::requires_grad());
+  /// auto h = v.register_hook([](torch::Tensor grad){ return grad * 2; }); // double the gradient
+  /// v.backward(torch::tensor({1., 2., 3.}));
+  /// // This prints:
+  /// // ```
+  /// //  2
+  /// //  4
+  /// //  6
+  /// // [ CPUFloatType{3} ]
+  /// // ```
+  /// std::cout << v.grad() << std::endl;
+  /// v.remove_hook(h);  // removes the hook
+  /// @endcode
+  template <typename T>
+  hook_return_void_t<T> register_hook(T&& hook) const;
+  template <typename T>
+  hook_return_var_t<T> register_hook(T&& hook) const;
+
+protected:
+  unsigned _register_hook(std::function<TensorBase(const TensorBase&)> hook) const;
+
+public:
+
+  /// Remove hook at given position
+  void remove_hook(unsigned pos) const;
+
+  // Variable methods
+  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+  bool is_leaf() const;
+
+  int64_t output_nr() const;
+
+  void set_data(const TensorBase & new_data) const;
+
+  TensorBase data() const;
+
+  int64_t _version() const;
+
+  void retain_grad() const;
+
+  bool retains_grad() const;
+
+  const TensorBase& requires_grad_(bool _requires_grad=true) const;
+
+  // View Variables
+  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+  /// Returns true if this `Variable` is a view of another `Variable`.
+  bool is_view() const;
+
+  /// Returns the `Variable` that this `Variable` is a view of. If this
+  /// `Variable` is not a view, throw a `std::runtime_error`.
+  const TensorBase& _base() const;
+
+  // Miscellaneous
+  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+  const std::string& name() const;
+
+protected:
+  void enforce_invariants();
+  c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;
+
+private:
+  TensorBase __dispatch_contiguous(c10::MemoryFormat) const;
+};
+
+// For "multiple ... operators specified" warnings, closing brace of class
+// declaration must be included between pragma push & pop
+#ifdef _MSC_VER
+#pragma warning( pop )
+#endif
+
+inline int64_t get_device(const TensorBase& self) {
+  return self.get_device();
+}
+
+template <typename T>
+auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_void_t<T> {
+  // Return the grad argument in case of a hook with void return type to have an
+  // std::function with Tensor return type
+  static_assert(std::is_same<decltype(hook(TensorBase())), void>::value,
+                "Expected hook to return void");
+  return _register_hook([fn=std::forward<T>(hook)](const TensorBase& grad) {
+    fn(grad);
+    return TensorBase();
+  });
+}
+
+template <typename T>
+auto TensorBase::register_hook(T&& hook) const -> TensorBase::hook_return_var_t<T> {
+  return _register_hook(std::move(hook));
+}
+
+namespace detail {
+// Helper creator for Tensor class which doesn't requires the users to pass
+// in an intrusive_ptr instead it just converts the argument passed to
+// requested intrusive_ptr type.
+template <typename T, typename... Args>
+TensorBase make_tensor_base(Args&&... args) {
+  return TensorBase(c10::make_intrusive<T>(std::forward<Args>(args)...));
+}
+
+} // namespace detail
+
+static inline DispatchKey legacyExtractDispatchKey(const TensorBase& t) {
+  return legacyExtractDispatchKey(t.key_set());
+}
+
+} // namespace at
+
+namespace c10 {
+template <>
+struct MaybeOwnedTraits<at::TensorBase> {
+  using owned_type = at::TensorBase;
+  using borrow_type = at::TensorBase;
+
+  static borrow_type createBorrow(const owned_type& from) {
+    // NOTE: this can be implemented without the special
+    // unsafe_borrow_t Tensor constructor as
+    //
+    // return borrow_type(c10::intrusive_ptr<at::TensorImpl, at::UndefinedTensorImpl>::reclaim(from.unsafeGetTensorImpl()));
+    //
+    // but that hurts inlining due to the nullptr check in the
+    // Tensor(c10::intrusive_ptr<...>) constructor. We already know
+    // that from.impl_ isn't null because from is a valid Tensor, so
+    // we needn't do the check again. (using __builtin_assume can
+    // avoid this, but wouldn't be portable to MSVC.)
+    return borrow_type(borrow_type::unsafe_borrow_t{}, from);
+  }
+
+  static void assignBorrow(borrow_type& lhs, const borrow_type& rhs) {
+    lhs.unsafeReleaseTensorImpl();
+    // See above note: this can be implemented with public API
+    // similarly to createBorrow(), but that would hurt inlining.
+    lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs);
+  }
+
+  static void destroyBorrow(borrow_type& toDestroy) {
+    toDestroy.unsafeReleaseTensorImpl(); // "leak" it, but it was already +0.
+  }
+
+  static const owned_type& referenceFromBorrow(const borrow_type& borrow) {
+    return borrow;
+  }
+
+  static const owned_type* pointerFromBorrow(const borrow_type& borrow) {
+    return &borrow;
+  }
+
+  static bool debugBorrowIsValid(const borrow_type& borrow) {
+    return true;
+  }
+};
+
+template <>
+struct ExclusivelyOwnedTraits<at::TensorBase> {
+  using repr_type = at::TensorBase;
+  using pointer_type = at::TensorBase*;
+  using const_pointer_type = const at::TensorBase*;
+
+  static repr_type nullRepr() {
+    return at::TensorBase();
+  }
+
+  template <class... Args>
+  static repr_type createInPlace(Args&&... args) {
+    return at::TensorBase(std::forward<Args>(args)...);
+  }
+
+  static repr_type moveToRepr(at::TensorBase&& x) {
+    return std::move(x);
+  }
+
+  static void destroyOwned(at::TensorBase& x) {
+    TensorImpl*const toDestroy = x.unsafeReleaseTensorImpl();
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(toDestroy != nullptr, "Tensor somehow got null TensorImpl?");
+    // May be 0 because UndefinedTensorImpl doesn't get its refcount
+    // incremented.
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+        toDestroy->refcount_ == 1 || (toDestroy->refcount_ == 0 && toDestroy == UndefinedTensorImpl::singleton()),
+        "ExclusivelyOwned<Tensor> destroyed with refcount ", toDestroy->refcount_, ", expected 1!");
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
+        toDestroy->weakcount_ == 1 || (toDestroy->weakcount_ == 0 && toDestroy == UndefinedTensorImpl::singleton()),
+        "ExclusivelyOwned<Tensor> destroyed with weakcount ", toDestroy->weakcount_, ", expected 1!");
+    if (toDestroy != UndefinedTensorImpl::singleton()) {
+#ifndef NDEBUG
+      // Needed to pass the debug assertions in ~intrusive_ptr_target.
+      toDestroy->refcount_ = 0;
+      toDestroy->weakcount_ = 0;
+#endif
+      toDestroy->release_resources();
+      delete toDestroy;
+    }
+  }
+
+  static at::TensorBase take(at::TensorBase& x) {
+    return std::move(x);
+  }
+
+  static pointer_type getImpl(repr_type& x) {
+    return &x;
+  }
+
+  static const_pointer_type getImpl(const repr_type& x) {
+    return &x;
+  }
+};
+} // namespace c10
+
+namespace at {
+
+inline c10::MaybeOwned<TensorBase> borrow_from_optional_tensor(
+    const c10::optional<TensorBase>& opt) {
+  return opt.has_value()
+    ? c10::MaybeOwned<TensorBase>::borrowed(*opt)
+    : c10::MaybeOwned<TensorBase>::owned(c10::in_place);
+}
+
+inline c10::MaybeOwned<TensorBase> TensorBase::expect_contiguous(MemoryFormat memory_format) const & {
+  if (is_contiguous(memory_format)) {
+    return c10::MaybeOwned<TensorBase>::borrowed(*this);
+  } else {
+    return c10::MaybeOwned<TensorBase>::owned(__dispatch_contiguous(memory_format));
+  }
+}
+} // namespace at
index f365c8d..9b0067b 100644 (file)
@@ -40,23 +40,25 @@ namespace impl {
 
 struct TORCH_API VariableHooksInterface {
   virtual ~VariableHooksInterface() = default;
-  virtual Tensor tensor_data(const Tensor&) const = 0;
-  virtual Tensor variable_data(const Tensor&) const = 0;
-  virtual const std::shared_ptr<torch::autograd::Node>& grad_fn(const Tensor&) const = 0;
-  virtual unsigned _register_hook(const Tensor&, std::function<Tensor(const Tensor&)> hook) const = 0;
-  virtual void remove_hook(const Tensor&, unsigned pos) const = 0;
-  virtual bool is_view(const Tensor&) const = 0;
-  virtual const Tensor& base(const Tensor&) const = 0;
-  virtual const std::string& name(const Tensor&) const = 0;
-  virtual bool is_leaf(const Tensor&) const = 0;
-  virtual int64_t output_nr(const Tensor&) const = 0;
-  virtual void set_data(const Tensor&, const Tensor&) const = 0;
-  virtual Tensor data(const Tensor&) const = 0;
-  virtual int64_t _version(const Tensor&) const = 0;
-  virtual void retain_grad(const Tensor&) const = 0;
-  virtual bool retains_grad(const Tensor&) const = 0;
+  virtual TensorBase tensor_data(const TensorBase&) const = 0;
+  virtual TensorBase variable_data(const TensorBase&) const = 0;
+  virtual const std::shared_ptr<torch::autograd::Node>& grad_fn(const TensorBase&) const = 0;
+  virtual unsigned _register_hook(
+      const TensorBase&,
+      std::function<TensorBase(const TensorBase&)> hook) const = 0;
+  virtual void remove_hook(const TensorBase&, unsigned pos) const = 0;
+  virtual bool is_view(const TensorBase&) const = 0;
+  virtual const TensorBase& base(const TensorBase&) const = 0;
+  virtual const std::string& name(const TensorBase&) const = 0;
+  virtual bool is_leaf(const TensorBase&) const = 0;
+  virtual int64_t output_nr(const TensorBase&) const = 0;
+  virtual void set_data(const TensorBase&, const TensorBase&) const = 0;
+  virtual TensorBase data(const TensorBase&) const = 0;
+  virtual int64_t _version(const TensorBase&) const = 0;
+  virtual void retain_grad(const TensorBase&) const = 0;
+  virtual bool retains_grad(const TensorBase&) const = 0;
   virtual void _backward(const Tensor&, TensorList, const c10::optional<Tensor>&, c10::optional<bool>, bool) const = 0;
-  virtual void requires_grad_(const Tensor&, bool) const = 0;
+  virtual void requires_grad_(const TensorBase&, bool) const = 0;
 };
 
 TORCH_API void SetVariableHooks(VariableHooksInterface* hooks);
index b4a4f3a..b574dcf 100644 (file)
@@ -29,7 +29,7 @@ namespace {
 } // anonymous namespace
 
 // Note: this is not a native function as Quantizer is not exposed to python yet
-QuantizerPtr Tensor::quantizer() const {
+QuantizerPtr TensorBase::quantizer() const {
   // This is a terrible hack to emulate what VariableType is doing
   at::AutoDispatchBelowAutograd mode;
   return get_qtensorimpl(*this)->quantizer();
@@ -71,7 +71,7 @@ QuantizerPtr make_per_channel_affine_quantizer(
   }
 }
 
-QTensorImpl* get_qtensorimpl(const Tensor& self) {
+QTensorImpl* get_qtensorimpl(const TensorBase& self) {
   TORCH_CHECK(
       !self.requires_grad(),
       "quantized tensors do not support autograd");
index c0c119c..34ec274 100644 (file)
@@ -205,7 +205,7 @@ struct TORCH_API PerChannelAffineFloatQParamsQuantizer : public PerChannelAffine
 // setters/getters for QTensorImpl fields; otherwise, you should use
 // the low level setters/getters that were implemented using this.
 // This may be called repeatedly, so make sure it's pretty cheap.
-TORCH_API QTensorImpl* get_qtensorimpl(const Tensor& self);
+TORCH_API QTensorImpl* get_qtensorimpl(const TensorBase& self);
 
 // double and int64_t are because of the native function API, we only have these
 // argument types right now in native functions
index a92b750..6897c32 100644 (file)
@@ -2,6 +2,14 @@
 
 // ${generated_comment}
 
+#ifdef TORCH_ASSERT_NO_OPERATORS
+#error This change adds a dependency on native_functions.yaml,             \
+  meaning the file will need to be re-compiled every time an operator      \
+  is changed or added. Consider if your change would be better placed in   \
+  another file, or if a more specific header might achieve the same goal.  \
+  See NOTE: [Tensor vs. TensorBase]
+#endif
+
 #include <c10/core/Scalar.h>
 #include <c10/core/TensorOptions.h>
 #include <c10/core/QScheme.h>
index 95312ff..eb6402f 100644 (file)
 #include <ATen/core/DeprecatedTypeProperties.h>
 #include <ATen/core/NamedTensor.h>
 #include <ATen/core/QuantizerBase.h>
+#include <ATen/core/TensorBase.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
 
-namespace caffe2 {
-class Tensor;
-}
 namespace c10{
-struct TensorOptions;
 template<class T> class List;
 }
 namespace at {
@@ -58,18 +55,6 @@ using TensorList = ArrayRef<Tensor>;
 
 using Stream = c10::Stream;
 
-namespace impl {
-inline bool variable_excluded_from_dispatch() {
-#ifdef C10_MOBILE
-  // Please read the comment in `VariableFallbackKernel.cpp` about the background of this change.
-  return true;
-#else
-  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::impl::tls_local_dispatch_key_set().excluded_.has(DispatchKey::Autograd));
-  return c10::impl::tls_local_dispatch_key_set().excluded_.isSupersetOf(c10::autograd_dispatch_keyset);
-#endif
-}
-}
-
 // Tensor is a "generic" object holding a pointer to the underlying TensorImpl object, which
 // has an embedded reference count. In this way, Tensor is similar to boost::intrusive_ptr.
 //
@@ -87,56 +72,38 @@ inline bool variable_excluded_from_dispatch() {
 //
 // Note that Tensor can also be NULL, i.e. it is not associated with any underlying TensorImpl, and
 // special care must be taken to handle this.
-class TORCH_API Tensor {
- private:
-  struct unsafe_borrow_t { explicit unsafe_borrow_t() = default; };
-
+class TORCH_API Tensor: public TensorBase {
+ protected:
   // Create a Tensor with a +0 reference count. Special care must be
   // taken to avoid decrementing this reference count at destruction
   // time. Intended to support MaybeOwnedTraits<Tensor>.
-  explicit Tensor(unsafe_borrow_t, const Tensor& rhs)
-      : impl_(c10::intrusive_ptr<at::TensorImpl, UndefinedTensorImpl>::reclaim(rhs.impl_.get())) {}
+  explicit Tensor(unsafe_borrow_t, const TensorBase& rhs): TensorBase(unsafe_borrow_t{}, rhs) {}
   friend MaybeOwnedTraits<Tensor>;
   friend OptionalTensorRef;
 
  public:
-  Tensor(){};
+  Tensor() = default;
   // This constructor should not be used by end users and is an implementation
   // detail invoked by autogenerated code.
   explicit Tensor(
       c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl)
-      : impl_(std::move(tensor_impl)) {
-    if (impl_.get() == nullptr) {
-      throw std::runtime_error("TensorImpl with nullptr is not supported");
-    }
-  }
-  Tensor(const Tensor&) = default;
-  Tensor(Tensor&&) = default;
+      : TensorBase(std::move(tensor_impl)) {}
+  Tensor(const Tensor &tensor) = default;
+  Tensor(Tensor &&tensor) = default;
 
+  // Implicitly move-constructible from TensorBase, but must be explicit to increase refcount
+  explicit Tensor(const TensorBase &base): TensorBase(base) {}
+  /*implicit*/ Tensor(TensorBase &&base): TensorBase(std::move(base)) {}
 
- public:
   // Creates a new wrapper from TensorImpl. Intentionally a free method because
   // it should be used with care. Checks necessary invariants
   static Tensor wrap_tensor_impl(
       c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> tensor_impl) {
-    Tensor r(std::move(tensor_impl));
-    r.enforce_invariants();
-    return r;
-  }
-
-  int64_t dim() const {
-    return impl_->dim();
-  }
-  int64_t storage_offset() const {
-    return impl_->storage_offset();
+    return TensorBase::wrap_tensor_impl(std::move(tensor_impl));
   }
 
   Tensor contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) const {
-    if (is_contiguous(memory_format)) {
-      return *this;
-    } else {
-      return __dispatch_contiguous(memory_format);
-    }
+    return TensorBase::contiguous(memory_format);
   }
 
   Tensor conj() const {
@@ -150,6 +117,10 @@ class TORCH_API Tensor {
     }
   }
 
+  // Aliased by Dimname overloads, so need explicit using
+  using TensorBase::size;
+  using TensorBase::stride;
+
   /// Should be used if *this can reasonably be expected to be contiguous and
   /// performance is important.
   /// Compared to contiguous, it saves a reference count
@@ -162,52 +133,6 @@ class TORCH_API Tensor {
   // will only lead to trouble and dangling references.
   c10::MaybeOwned<Tensor> expect_contiguous(MemoryFormat memory_format=MemoryFormat::Contiguous) && = delete;
 
-  bool is_complex() const {
-    return at::isComplexType(this->scalar_type());
-  }
-
-  bool is_floating_point() const {
-    return at::isFloatingType(this->scalar_type());
-  }
-
-  bool is_signed() const {
-    return at::isSignedType(this->scalar_type());
-  }
-
-  int64_t size(int64_t dim) const {
-    // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
-    dim = c10::maybe_wrap_dim(dim, this->dim(), false);
-    return sizes()[dim];
-  }
-
-  int64_t stride(int64_t dim) const {
-    // false is passed to maybe_wrap_dim so behavior is identical to array access (but with wrapping)
-    dim = c10::maybe_wrap_dim(dim, this->dim(), false);
-    return strides()[dim];
-  }
-
-  TensorImpl * unsafeGetTensorImpl() const {
-    return impl_.get();
-  }
-  TensorImpl * unsafeReleaseTensorImpl() {
-    return impl_.release();
-  }
-  const c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>& getIntrusivePtr() const {
-    return impl_;
-  }
-
-  c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> unsafeReleaseIntrusivePtr()  {
-    return std::move(impl_);
-  }
-
-  bool defined() const {
-    return impl_;
-  }
-
-  void reset() {
-    impl_.reset();
-  }
-
   // The following overloads are very intruiging.  Consider the following
   // program:
   //
@@ -248,105 +173,25 @@ class TORCH_API Tensor {
   #pragma warning( disable : 4522 )
   #endif
 
-  Tensor& operator=(const Tensor& x) & {
-    impl_ = x.impl_;
+  Tensor& operator=(const TensorBase& x) & {
+    impl_ = x.getIntrusivePtr();
     return *this;
   }
-  Tensor& operator=(Tensor&& x) & {
-    impl_ = std::move(x.impl_);
+  Tensor& operator=(TensorBase&& x) & {
+    impl_ = x.unsafeReleaseIntrusivePtr();
     return *this;
   }
 
-  Tensor& operator=(Scalar v) &&;
-  Tensor& operator=(const Tensor&) &&;
-  Tensor& operator=(Tensor&&) &&;
-
-  bool is_same(const Tensor& other) const noexcept {
-    return impl_ == other.impl_;
-  }
-  size_t use_count() const noexcept {
-    return impl_.use_count();
+  Tensor& operator=(const Tensor &x) & {
+    return operator=(static_cast<const TensorBase&>(x));
   }
-  size_t weak_use_count() const noexcept {
-    return impl_.weak_use_count();
+  Tensor& operator=(Tensor &&x) & {
+    return operator=(static_cast<TensorBase&&>(x));
   }
 
-  std::string toString() const;
-
-  IntArrayRef sizes() const {
-    return impl_->sizes();
-  }
-  IntArrayRef strides() const {
-    return impl_->strides();
-  }
-  // See impl::get_opt_names in ATen/NamedTensor.h for docs.
-  c10::optional<DimnameList> opt_names() const {
-    return impl::get_opt_names(unsafeGetTensorImpl());
-  }
-  // See impl::get_names in ATen/NamedTensor.h for docs.
-  DimnameList names() const {
-    return impl::get_names(unsafeGetTensorImpl());
-  }
-  int64_t ndimension() const {
-    return dim();
-  }
-
-  bool is_contiguous(at::MemoryFormat memory_format=at::MemoryFormat::Contiguous) const {
-    return impl_->is_contiguous(memory_format);
-  }
-
-  bool is_non_overlapping_and_dense() const {
-    return impl_->is_non_overlapping_and_dense();
-  }
-
-  at::MemoryFormat suggest_memory_format(
-      bool channels_last_strides_exact_match = false) const {
-    // Setting channels_last_strides_exact_match to true forces function to
-    // check 0,1 - sized dimension strides.
-    if (!is_mkldnn() && !is_sparse()) {
-      if (impl_->is_strides_like_channels_last()) {
-        if (!channels_last_strides_exact_match ||
-            get_channels_last_strides_2d(sizes()) == strides()) {
-          return at::MemoryFormat::ChannelsLast;
-        }
-      }
-      else if (impl_->is_strides_like_channels_last_3d()) {
-        if (!channels_last_strides_exact_match ||
-            get_channels_last_strides_3d(sizes()) == strides()) {
-          return at::MemoryFormat::ChannelsLast3d;
-        }
-      }
-    }
-    return at::MemoryFormat::Contiguous;
-  }
-
-  // Total bytes consumed by the "view" of elements of the array.  Does not
-  // include size of metadata.  The number reported here does not necessarily
-  // correspond to the true physical memory consumed by a tensor; instead,
-  // it reports the memory the tensor would take *if* it were contiguous.
-  // Defined to be numel() * itemsize()
-  size_t nbytes() const {
-    TORCH_CHECK(layout () != at::kSparse,
-                "nbytes is not defined for sparse tensors.  If you want the size of the constituent " \
-                "tensors, add the nbytes of the indices and values.  If you want the size of the  " \
-                "equivalent dense tensor, multiply numel() by element_size()");
-    return impl_->numel() * impl_->itemsize();
-  }
-
-  int64_t numel() const {
-    return impl_->numel();
-  }
-
-  // Length of one array element in bytes.  This is the traditional
-  // Numpy naming.
-  size_t itemsize() const {
-    return impl_->itemsize();
-  }
-
-  // Same as itemsize().  This is the PyTorch naming.
-  int64_t element_size() const {
-    return static_cast<int64_t>(impl_->itemsize());
-  }
+  Tensor& operator=(Scalar v) &&;
+  Tensor& operator=(const Tensor&) &&;
+  Tensor& operator=(Tensor&&) &&;
 
   C10_DEPRECATED_MESSAGE("Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device().")
   DeprecatedTypeProperties & type() const {
@@ -354,21 +199,6 @@ class TORCH_API Tensor {
         dispatchKeyToBackend(legacyExtractDispatchKey(key_set())),
         scalar_type());
   }
-  DispatchKeySet key_set() const {
-    return impl_->key_set();
-  }
-  ScalarType scalar_type() const {
-    return typeMetaToScalarType(impl_->dtype());
-  }
-  bool has_storage() const {
-    return defined() && impl_->has_storage();
-  }
-  const Storage& storage() const {
-    return impl_->storage();
-  }
-  bool is_alias_of(const at::Tensor& other) const{
-    return impl_->storage().is_alias_of(other.storage());
-  }
 
   Tensor toType(ScalarType t) const {
     return to(options().dtype(t), /*non_blocking*/ false, /*copy*/ false);
@@ -384,189 +214,6 @@ class TORCH_API Tensor {
     return !at::impl::variable_excluded_from_dispatch();
   }
 
-  inline bool is_conj() const {
-    return impl_->is_conj();
-  }
-
-  // sets the conjugate bit of a tensor.
-  // NOTE: Conjugate bit is supposed to be a read-only field. Only change this, if you are sure
-  // that's what you want. Changing this might lead to incorrect behavior since conjugation is
-  // a lazy operation and we rely on this bit to determine if a conjugation needs to be materialized.
-  inline void _set_conj(bool conjugate) const {
-    impl_->_set_conj(conjugate);
-  }
-
-  inline bool is_neg() const {
-    return impl_->is_neg();
-  }
-
-  // sets the negative bit of a tensor.
-  // NOTE: Negative bit is supposed to be a read-only field. Only change this, if you are sure
-  // that's what you want. Changing this might lead to incorrect behavior since we rely on this
-  // bit to determine if a negation needs to be materialized.
-  inline void _set_neg(bool negative) const {
-    impl_->_set_neg(negative);
-  }
-
-  /// Returns a `Tensor`'s layout.
-  Layout layout() const noexcept {
-    return impl_->layout();
-  }
-
-  /// Returns a `Tensor`'s dtype (`TypeMeta`).
-  caffe2::TypeMeta dtype() const noexcept {
-    return impl_->dtype();
-  }
-
-  /// Returns a `Tensor`'s device.
-  inline Device device() const {
-    return impl_->device();
-  }
-
-  /// Returns a `Tensor`'s device index.
-  int64_t get_device() const {
-    // NB: this is not a native function to avoid dispatching overhead.
-    return impl_->get_device();
-  }
-
-
-  /// Returns if a `Tensor` has CPU backend.
-  bool is_cpu() const {
-    // NB: this is not a native function to avoid dispatching overhead.
-    return impl_->is_cpu();
-  }
-
-  /// Returns if a `Tensor` has CUDA backend.
-  bool is_cuda() const {
-    // NB: this is not a native function to avoid dispatching overhead.
-    return impl_->is_cuda();
-  }
-
-  /// Returns if a `Tensor` has XPU backend.
-  bool is_xpu() const {
-    // NB: this is not a native function to avoid dispatching overhead.
-    return impl_->is_xpu();
-  }
-
-  /// Returns if a `Tensor` has XLA backend.
-  bool is_xla() const {
-    return impl_->is_xla();
-  }
-
-  /// Returns if a `Tensor` has Lazy backend.
-  bool is_lazy() const {
-    return impl_->is_lazy();
-  }
-
-  /// Returns if a `Tensor` has HIP backend.
-  bool is_hip() const {
-    // NB: this is not a native function to avoid dispatching overhead.
-    return impl_->is_hip();
-  }
-
-  /// Returns if a `Tensor` has VE backend.
-  bool is_ve() const {
-    // NB: this is not a native function to avoid dispatching overhead.
-    return impl_->is_ve();
-  }
-
-  /// Returns if a `Tensor` has sparse backend.
-  bool is_sparse() const {
-    // NB: this is not a native function to avoid dispatching overhead.
-    return impl_->is_sparse();
-  }
-
-  /// Returns is a `Tensor` has a sparse CSR backend.
-  bool is_sparse_csr() const {
-    // NB: this is not a native function to avoid dispatching overhead.
-    return impl_->is_sparse_csr();
-  }
-
-  /// Returns if a `Tensor` is mkldnn tensor.
-  bool is_mkldnn() const {
-    // NB: this is not a native function to avoid dispatching overhead.
-    return impl_->is_mkldnn();
-  }
-
-  /// Returns if a `Tensor` is mlc tensor.
-  bool is_mlc() const {
-    // NB: this is not a native function to avoid dispatching overhead.
-    return impl_->is_mlc();
-  }
-
-  /// Returns if a `Tensor` is ort tensor.
-  bool is_ort() const {
-    // NB: this is not a native function to avoid dispatching overhead.
-    return impl_->is_ort();
-  }
-
-  /// Returns if a `Tensor` is vulkan tensor.
-  bool is_vulkan() const {
-    // NB: this is not a native function to avoid dispatching overhead.
-    return impl_->is_vulkan();
-  }
-
-  /// Returns if a `Tensor` is metal tensor.
-  bool is_metal() const {
-    // NB: this is not a native function to avoid dispatching overhead.
-    return impl_->is_metal();
-  }
-
-  /// Returns if a `Tensor` has quantized backend.
-  bool is_quantized() const {
-    // NB: this is not a native function to avoid dispatching overhead.
-    return impl_->is_quantized();
-  }
-
-  /// Returns if a `Tensor` is a meta tensor.  Meta tensors can
-  /// also have other designations.
-  bool is_meta() const {
-    return impl_->is_meta();
-  }
-
-  /// Returns if a `Tensor` is an inference tensor.
-  bool is_inference() const {
-    return impl_->is_inference();
-  }
-
-  /// If a tensor is a quantized tensor, returns its quantizer
-  /// TODO: it's not in native_functions.yaml yet as it's not exposed to python
-  QuantizerPtr quantizer() const;
-
-  /// Returns if a `Tensor` has any dimension names
-  bool has_names() const {
-    // If a user is using unnamed tensors, then we can short-circuit right here.
-    // Otherwise, impl::has_names attempts to retrieve names.
-    if (!impl_->has_named_tensor_meta()) {
-      return false;
-    }
-    return impl::has_names(unsafeGetTensorImpl());
-  }
-
-  /// Returns a `Tensor`'s dimension names data structure
-  const NamedTensorMeta* get_named_tensor_meta() const {
-    return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
-  }
-
-  NamedTensorMeta* get_named_tensor_meta() {
-    return static_cast<NamedTensorMeta*>(impl_->named_tensor_meta());
-  }
-
-  /// Returns the `TensorOptions` corresponding to this `Tensor`. Defined in
-  /// TensorOptions.h.
-  TensorOptions options() const {
-    return TensorOptions().dtype(dtype())
-                          .device(device())
-                          .layout(layout());
-  }
-
-  void* data_ptr() const {
-    return this->unsafeGetTensorImpl()->data();
-  }
-
-  template <typename T>
-  T * data_ptr() const;
-
   template<typename T>
   C10_DEPRECATED_MESSAGE("Tensor.data<T>() is deprecated. Please use Tensor.data_ptr<T>() instead.")
   T * data() const {
@@ -579,45 +226,6 @@ class TORCH_API Tensor {
   // Purposely not defined here to avoid inlining
   void print() const;
 
-  // Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar type and
-  // dimension.
-  template<typename T, size_t N>
-  TensorAccessor<T,N> accessor() const& {
-    static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
-    TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
-    return TensorAccessor<T,N>(data_ptr<T>(),sizes().data(),strides().data());
-  }
-  template<typename T, size_t N>
-  TensorAccessor<T,N> accessor() && = delete;
-
-  // Return a `GenericPackedTensorAccessor` for CUDA `Tensor`s. You have to specify scalar type and
-  // dimension. You can optionally specify RestrictPtrTraits as a template parameter to
-  // cast the data pointer to a __restrict__ pointer.
-  // In order to use this, your CUDA kernel has to take a corresponding GenericPackedTensorAccessor
-  // as an argument.
-  template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
-  GenericPackedTensorAccessor<T,N,PtrTraits,index_t> generic_packed_accessor() const& {
-    static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
-    TORCH_CHECK(dim() == N, "TensorAccessor expected ", N, " dims but tensor has ", dim());
-    return GenericPackedTensorAccessor<T,N,PtrTraits,index_t>(static_cast<typename PtrTraits<T>::PtrType>(data_ptr<T>()),sizes().data(),strides().data());
-  }
-  template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
-  GenericPackedTensorAccessor<T,N> generic_packed_accessor() && = delete;
-
-  template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
-  PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() const& {
-    return generic_packed_accessor<T,N,PtrTraits,int32_t>();
-  }
-  template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
-  PackedTensorAccessor32<T,N,PtrTraits> packed_accessor32() && = delete;
-
-  template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
-  PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() const& {
-    return generic_packed_accessor<T,N,PtrTraits,int64_t>();
-  }
-  template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits>
-  PackedTensorAccessor64<T,N,PtrTraits> packed_accessor64() && = delete;
-
   template<typename T, size_t N, template <typename U> class PtrTraits = DefaultPtrTraits, typename index_t = int64_t>
   C10_DEPRECATED_MESSAGE("packed_accessor is deprecated, use packed_accessor32 or packed_accessor64 instead")
   GenericPackedTensorAccessor<T,N,PtrTraits,index_t> packed_accessor() const & {
@@ -785,12 +393,9 @@ class TORCH_API Tensor {
   /// populated during :func:`backward`, ``false`` otherwise.
 
   const Tensor& set_requires_grad(bool requires_grad) const {
-    impl_->set_requires_grad(requires_grad);
+    TensorBase::set_requires_grad(requires_grad);
     return *this;
   }
-  bool requires_grad() const {
-    return impl_->requires_grad();
-  }
 
   /// Return a mutable reference to the gradient. This is conventionally
   /// used as `t.grad() = x` to set a gradient to a completely new tensor.
@@ -829,7 +434,7 @@ class TORCH_API Tensor {
   /// Note that the given new_grad might not be used directly if it has different
   /// metadata (size/stride/storage offset) compared to this Tensor. In that case,
   /// new_grad content will be copied into a new Tensor
-  void _set_fw_grad(const Tensor& new_grad, uint64_t level, bool is_inplace_op) const {
+  void _set_fw_grad(const TensorBase& new_grad, uint64_t level, bool is_inplace_op) const {
     impl_->_set_fw_grad(new_grad, *this, level, is_inplace_op);
   }
 
@@ -878,7 +483,9 @@ class TORCH_API Tensor {
   /// returned `Tensor`'s tensor metadata (e.g. sizes / strides / storage / storage_offset)
   /// will not update the original `Variable`, due to the fact that this function
   /// shallow-copies the `Variable`'s underlying TensorImpl.
-  at::Tensor tensor_data() const;
+  at::Tensor tensor_data() const {
+    return TensorBase::tensor_data();
+  }
 
   /// NOTE: `var.variable_data()` in C++ has the same semantics as `tensor.data`
   /// in Python, which create a new `Variable` that shares the same storage and
@@ -891,19 +498,9 @@ class TORCH_API Tensor {
   /// `allow_tensor_metadata_change_` to false to make such changes explicitly illegal,
   /// in order to prevent users from changing metadata of `var.variable_data()`
   /// and expecting the original variable `var` to also be updated.
-  at::Tensor variable_data() const;
-
-  // Gradient Node and Edges
-  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-  /// Gets the gradient function of the `Variable`. If this is a leaf variable,
-  /// the pointer returned will be null.
-  ///
-  /// For View Variables:
-  /// Gets the up-to-date grad_fn. If the shared data or base was modified, we
-  /// re-create the grad_fn to express the up-to-date view relationship between
-  /// this and the base Variable.
-  const std::shared_ptr<torch::autograd::Node>& grad_fn() const;
+  at::Tensor variable_data() const {
+    return TensorBase::variable_data();
+  }
 
   // Hooks
   //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -948,55 +545,19 @@ class TORCH_API Tensor {
   template <typename T>
   hook_return_var_t<T> register_hook(T&& hook) const;
 
-private:
-  unsigned _register_hook(std::function<Tensor(const Tensor&)> hook) const;
-
-public:
-
-  /// Remove hook at given position
-  void remove_hook(unsigned pos) const;
-
   // Variable methods
   //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-  bool is_leaf() const;
-
-  int64_t output_nr() const;
-
-  void set_data(const Tensor & new_data) const;
-
-  Tensor data() const;
-
-  int64_t _version() const;
-
-  void retain_grad() const;
-
-  bool retains_grad() const;
+  Tensor data() const {
+    return TensorBase::data();
+  }
 
   void _backward(TensorList inputs, const c10::optional<Tensor>& gradient, c10::optional<bool> keep_graph, bool create_graph) const;
 
-  const Tensor& requires_grad_(bool _requires_grad=true) const;
-
-  // View Variables
-  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-  /// Returns true if this `Variable` is a view of another `Variable`.
-  bool is_view() const;
-
-  /// Returns the `Variable` that this `Variable` is a view of. If this
-  /// `Variable` is not a view, throw a `std::runtime_error`.
-  const Tensor& _base() const;
-
-  // Miscellaneous
-  //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-  const std::string& name() const;
-
-protected:
-  friend class ::caffe2::Tensor;
-
-  void enforce_invariants();
-  c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl> impl_;
+  const Tensor& requires_grad_(bool _requires_grad=true) const {
+    TensorBase::requires_grad_(_requires_grad);
+    return *this;
+  }
 };
 
 // For "multiple ... operators specified" warnings, closing brace of class
@@ -1005,27 +566,6 @@ protected:
 #pragma warning( pop )
 #endif
 
-inline int64_t get_device(const Tensor& self) {
-  return self.get_device();
-}
-
-template <typename T>
-auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_void_t<T> {
-  // Return the grad argument in case of a hook with void return type to have an
-  // std::function with Tensor return type
-  static_assert(std::is_same<decltype(hook(Tensor())), void>::value,
-                "Expected hook to return void");
-  return _register_hook([fn=std::forward<T>(hook)](const Tensor& grad) {
-    fn(grad);
-    return Tensor();
-  });
-}
-
-template <typename T>
-auto Tensor::register_hook(T&& hook) const -> Tensor::hook_return_var_t<T> {
-  return _register_hook(std::forward<T>(hook));
-}
-
 namespace detail {
 // Helper creator for Tensor class which doesn't requires the users to pass
 // in an intrusive_ptr instead it just converts the argument passed to
@@ -1037,10 +577,6 @@ Tensor make_tensor(Args&&... args) {
 
 } // namespace detail
 
-static inline DispatchKey legacyExtractDispatchKey(const Tensor& t) {
-  return legacyExtractDispatchKey(t.key_set());
-}
-
 } // namespace at
 
 // See Note [Avoiding Include Cycles In Static Dispatch]
@@ -1114,25 +650,7 @@ struct ExclusivelyOwnedTraits<at::Tensor> {
   }
 
   static void destroyOwned(at::Tensor& x) {
-    TensorImpl*const toDestroy = x.unsafeReleaseTensorImpl();
-    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(toDestroy != nullptr, "Tensor somehow got null TensorImpl?");
-    // May be 0 because UndefinedTensorImpl doesn't get its refcount
-    // incremented.
-    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
-        toDestroy->refcount_ == 1 || (toDestroy->refcount_ == 0 && toDestroy == UndefinedTensorImpl::singleton()),
-        "ExclusivelyOwned<Tensor> destroyed with refcount ", toDestroy->refcount_, ", expected 1!");
-    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
-        toDestroy->weakcount_ == 1 || (toDestroy->weakcount_ == 0 && toDestroy == UndefinedTensorImpl::singleton()),
-        "ExclusivelyOwned<Tensor> destroyed with weakcount ", toDestroy->weakcount_, ", expected 1!");
-    if (toDestroy != UndefinedTensorImpl::singleton()) {
-#ifndef NDEBUG
-      // Needed to pass the debug assertions in ~intrusive_ptr_target.
-      toDestroy->refcount_ = 0;
-      toDestroy->weakcount_ = 0;
-#endif
-      toDestroy->release_resources();
-      delete toDestroy;
-    }
+    return ExclusivelyOwnedTraits<at::TensorBase>::destroyOwned(x);
   }
 
   static at::Tensor take(at::Tensor& x) {
index 84b48bd..29a43a6 100644 (file)
@@ -3,9 +3,9 @@
 
 namespace at {
 
-#define DEFINE_CAST(T, name)                                        \
+#define DEFINE_CAST(T, name)                                         \
    template <>                                                       \
-   TORCH_API T* Tensor::data_ptr() const {                           \
+   TORCH_API T* TensorBase::data_ptr() const {                       \
      TORCH_CHECK(                                                    \
          scalar_type() == ScalarType::name,                          \
          "expected scalar type "                                     \
index de829c4..2fd5c64 100644 (file)
@@ -83,8 +83,9 @@ const at::Tensor& TensorImpl::grad() const {
   return autograd_meta_->grad();
 }
 
-const at::Tensor& TensorImpl::_fw_grad(uint64_t level, const at::Tensor& self)
-    const {
+const at::Tensor& TensorImpl::_fw_grad(
+    uint64_t level,
+    const at::TensorBase& self) const {
   // See TensorImpl::grad() above for explanation about the line below
   if (!autograd_meta_)
     return impl::GetAutogradMetaFactory()->undefined_tensor();
@@ -92,8 +93,8 @@ const at::Tensor& TensorImpl::_fw_grad(uint64_t level, const at::Tensor& self)
 }
 
 void TensorImpl::_set_fw_grad(
-    const at::Tensor& new_grad,
-    const at::Tensor& self,
+    const at::TensorBase& new_grad,
+    const at::TensorBase& self,
     uint64_t level,
     bool is_inplace_op) {
   if (!autograd_meta_)
index 7051e36..cee3600 100644 (file)
@@ -38,6 +38,7 @@ C10_DECLARE_int64(caffe2_max_keep_on_shrink_memory);
 
 namespace at {
 class Tensor;
+class TensorBase;
 }
 
 namespace c10 {
@@ -151,11 +152,11 @@ struct C10_API AutogradMetaInterface {
   virtual bool requires_grad() const = 0;
   virtual at::Tensor& mutable_grad() = 0;
   virtual const at::Tensor& grad() const = 0;
-  virtual const at::Tensor& fw_grad(uint64_t level, const at::Tensor& self)
+  virtual const at::Tensor& fw_grad(uint64_t level, const at::TensorBase& self)
       const = 0;
   virtual void set_fw_grad(
-      const at::Tensor& new_grad,
-      const at::Tensor& self,
+      const at::TensorBase& new_grad,
+      const at::TensorBase& self,
       uint64_t level,
       bool is_inplace_op) = 0;
   virtual ~AutogradMetaInterface();
@@ -1055,7 +1056,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
    *   - "self" should represent the Tensor whose forward grad is accessed. It
    * is required when dealing with view.
    */
-  const at::Tensor& _fw_grad(uint64_t level, const at::Tensor& self) const;
+  const at::Tensor& _fw_grad(uint64_t level, const at::TensorBase& self) const;
 
   /**
    * Sets the forward gradient for this Tensor.
@@ -1078,8 +1079,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
    * better error checking.
    */
   void _set_fw_grad(
-      const at::Tensor& new_grad,
-      const at::Tensor& self,
+      const at::TensorBase& new_grad,
+      const at::TensorBase& self,
       uint64_t level,
       bool is_inplace_op);
 
index 646c5cc..f5fb7d1 100644 (file)
@@ -299,7 +299,7 @@ void Tensor::CopyFrom(const Tensor& src, bool async) {
 
 #if defined(EXPOSE_C2_OPS) || \
     !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE)
-Tensor::Tensor(at::Tensor tensor) : impl_(std::move(tensor.impl_)) {
+Tensor::Tensor(at::Tensor tensor) : impl_(tensor.unsafeReleaseIntrusivePtr()) {
   enforce_invariants();
 }
 
index 927d884..58206b7 100644 (file)
@@ -3558,8 +3558,8 @@ namespace detail {
         bias_k = multihead_attn_module->bias_k.detach();
         bias_v = multihead_attn_module->bias_v.detach();
       } else {
-        bias_k = {};
-        bias_v = {};
+        bias_k.reset();
+        bias_v.reset();
       }
 
       torch::Tensor _Q = decoder_state_tensor.unsqueeze(1).transpose(0, 1);
index 42ade6d..cddad59 100644 (file)
@@ -645,8 +645,8 @@ inline std::tuple<Tensor, Tensor> multi_head_attention_forward(
 
       if (!key.defined()) {
         TORCH_INTERNAL_ASSERT(!value.defined());
-        k = {};
-        v = {};
+        k.reset();
+        v.reset();
       } else {
         // This is inline in_proj function with in_proj_weight and in_proj_bias
         _b = in_proj_bias;
index e724a75..a770c15 100644 (file)
@@ -511,8 +511,8 @@ void MultiheadAttentionImpl::reset() {
     bias_k = register_parameter("bias_k", torch::empty({1, 1, options.embed_dim()}));
     bias_v = register_parameter("bias_v", torch::empty({1, 1, options.embed_dim()}));
   } else {
-    bias_k = {};
-    bias_v = {};
+    bias_k.reset();
+    bias_v.reset();
   }
   _reset_parameters();
 }
index f35c122..2ad2778 100644 (file)
@@ -101,7 +101,7 @@ namespace {
 
 // This function is will ensure that the fw_grad_ is properly a view of the base for inplace ops on
 // Tensors that do not have forward grad originally.
-void AutogradMeta::set_fw_grad(const Variable& new_grad_, const Variable& self, uint64_t level, bool is_inplace_op) {
+void AutogradMeta::set_fw_grad(const at::TensorBase& new_grad_base, const at::TensorBase& self_base, uint64_t level, bool is_inplace_op) {
   // Lazy initialization
   {
     std::lock_guard<std::mutex> lock(mutex_);
@@ -112,21 +112,23 @@ void AutogradMeta::set_fw_grad(const Variable& new_grad_, const Variable& self,
   if (fw_grad_->contains(level)) {
     // Setting the forward grad again is only allowed if it is a no-op.
     // We do allow this case to simplify writing codegen for inplace ops.
-    TORCH_INTERNAL_ASSERT(new_grad_.defined(), "Cannot set a forward grad that is an undefined Tensor. Use "
+    TORCH_INTERNAL_ASSERT(new_grad_base.defined(), "Cannot set a forward grad that is an undefined Tensor. Use "
                           "_fw_primal(level) to get a new Tensor with this forward grad unset.");
 
     TORCH_INTERNAL_ASSERT(is_inplace_op, "Only inplace operations can re-set the forward grad of a Tensor that "
                           "already has one.");
 
-    TORCH_INTERNAL_ASSERT(fw_grad_->value(level).is_same(new_grad_), "Cannot set a value of a forward grad if it "
+    TORCH_INTERNAL_ASSERT(fw_grad_->value(level).is_same(new_grad_base), "Cannot set a value of a forward grad if it "
                           "already exists. Inplace operations should modify it inplace.");
   } else {
     // TODO(alband) remove this spurious version counter bump
-    auto new_grad = new_grad_;
+    Tensor new_grad(new_grad_base);
+    at::OptionalTensorRef self_ref(self_base);
+    const Tensor &self = *self_ref;
 
-    TORCH_CHECK(self.is_same_size(new_grad_), "Trying to set a forward gradient that has a different size than that "
+    TORCH_CHECK(self.is_same_size(new_grad), "Trying to set a forward gradient that has a different size than that "
                 "of the original Tensor, this is not supported. Tensor is of size ", self.sizes(), " while the given "
-                "forward gradient is of size ", new_grad_.sizes(), ".");
+                "forward gradient is of size ", new_grad.sizes(), ".");
 
     if (is_inplace_op && is_view_) {
       auto this_view_meta = static_cast<DifferentiableViewMeta*>(this);
@@ -182,7 +184,7 @@ void AutogradMeta::set_fw_grad(const Variable& new_grad_, const Variable& self,
   }
 }
 
-const Variable& AutogradMeta::fw_grad(uint64_t level, const Variable& self) const {
+const Variable& AutogradMeta::fw_grad(uint64_t level, const at::TensorBase& self) const {
   // TLS that disables forward AD
   // This is only used for custom Function implementation
   if (!c10::AutogradState::get_tls_state().get_fw_grad_mode()) {
index d7bfd04..e593354 100644 (file)
@@ -4,7 +4,7 @@
 
 namespace {
 using torch::autograd::Variable;
-void check_single_result (Variable value, Variable result, std::string hook_name) {
+void check_single_result (const at::TensorBase &value, const at::TensorBase &result, std::string hook_name) {
   if (!value.defined()) {
     throw std::runtime_error("can't replace a empty gradient with a non-empty value");
   }
@@ -34,7 +34,7 @@ variable_list CppFunctionPreHook::operator()(const variable_list& values) {
       continue;
     }
     check_single_result(value, res, c10::to_string(i));
-    value = res;
+    value = std::move(res);
   }
   variable_list results(values);
   results[value_idx_] = value;
index a5f1829..c3d9135 100644 (file)
@@ -5,7 +5,7 @@
 
 namespace torch { namespace autograd {
 
-using hooks_list = std::vector<std::function<Variable(const Variable&)>>;
+using hooks_list = std::vector<std::function<at::TensorBase(const at::TensorBase&)>>;
 
 struct CppFunctionPreHook : public FunctionPreHook {
   CppFunctionPreHook(const std::shared_ptr<hooks_list> &hooks, int value_idx);
index 1bb4cb8..be02659 100644 (file)
@@ -364,7 +364,7 @@ optional_variable_list _wrap_outputs(const variable_list &input_vars,
   return outputs;
 }
 
-void check_variable_result(const Variable& original, const Variable& result, std::string hook_name) {
+void check_variable_result(const at::TensorBase& original, const at::TensorBase& result, std::string hook_name) {
   if (!original.options().type_equal(result.options())) {
     std::stringstream ss;
     ss << "hook '" << hook_name << "' has changed the type of value (";
index 94e62bf..7e67a10 100644 (file)
@@ -20,8 +20,8 @@ TORCH_API std::vector<c10::optional<Variable>> _wrap_outputs(
   const std::shared_ptr<Node> &cdata,
   _jvp_fn_t jvp_user_function);
 
-TORCH_API void check_variable_result(const Variable& original,
-  const Variable& result, std::string hook_name);
+TORCH_API void check_variable_result(const at::TensorBase& original,
+    const at::TensorBase& result, std::string hook_name);
 
 // Get the return type of the forward function of the custom Function class X
 template<typename X, typename... Args>
index 50d6eb9..5f115ee 100644 (file)
@@ -158,7 +158,7 @@ static PyObject* THPVariable_NewWithVar(
 }
 
 // TODO: Make this take Variable by const reference
-PyObject * THPVariable_Wrap(Variable var)
+PyObject * THPVariable_Wrap(at::TensorBase var)
 {
   if (!var.defined()) {
     Py_RETURN_NONE;
index 4078235..407d5dc 100644 (file)
@@ -22,7 +22,7 @@ THP_API PyObject *THPVariableClass;
 THP_API PyObject *ParameterClass;
 
 bool THPVariable_initModule(PyObject *module);
-THP_API PyObject * THPVariable_Wrap(torch::autograd::Variable var);
+THP_API PyObject * THPVariable_Wrap(at::TensorBase var);
 
 static inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) {
   // Check that a python object is a `Tensor`, but not a `Tensor` subclass.
index 7ae1ac0..b123b7c 100644 (file)
@@ -130,7 +130,7 @@ static c10::impl::AutogradMetaFactoryRegisterer meta_factory_registerer(&meta_fa
 
 namespace impl {
 
-  AutogradMeta* materialize_autograd_meta(const Variable& self) {
+  AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) {
     TORCH_CHECK(self.defined(), "cannot call materialize_autograd_meta() on undefined tensor");
     auto p = self.unsafeGetTensorImpl();
     if (!p->autograd_meta()) {
@@ -165,7 +165,7 @@ namespace impl {
     set_gradient_edge(self, std::move(gradient_edge));
   }
 
-  void create_cpp_hook(const Variable& self) {
+  void create_cpp_hook(const at::TensorBase& self) {
     auto &list = materialize_autograd_meta(self)->cpp_hooks_list_;
     // NOLINTNEXTLINE(modernize-make-shared)
     list.reset(new hooks_list());
@@ -277,7 +277,7 @@ namespace impl {
   // Hooks
   //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-  void add_hook(const Variable& self, std::shared_ptr<FunctionPreHook> hook) {
+  void add_hook(const at::TensorBase& self, std::shared_ptr<FunctionPreHook> hook) {
     materialize_autograd_meta(self)->hooks_.push_back(std::move(hook));
   }
 
@@ -296,7 +296,7 @@ namespace impl {
     }
   }
 
-  void clear_hooks(const Variable& self) {
+  void clear_hooks(const at::TensorBase& self) {
     // This is a little goofy, but usually this should be a no oop
     materialize_autograd_meta(self)->hooks_.clear();
   }
@@ -308,13 +308,13 @@ namespace impl {
   // Miscellaneous
   //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-  AutogradMeta* get_autograd_meta(const Variable& self) {
+  AutogradMeta* get_autograd_meta(const at::TensorBase& self) {
     // NB: could return nullptr
     TORCH_CHECK(self.defined(), "cannot call get_autograd_meta() on undefined tensor");
     return static_cast<AutogradMeta*>(self.unsafeGetTensorImpl()->autograd_meta());
   }
 
-  DifferentiableViewMeta* get_view_autograd_meta(const Variable& self) {
+  DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase& self) {
     // NB: return nullptr if self is not a view
     AutogradMeta* meta = get_autograd_meta(self);
     if (meta && meta->is_view_) {
@@ -329,31 +329,32 @@ namespace impl {
 using at::Tensor;
 
 struct VariableHooks final : at::impl::VariableHooksInterface {
-  Tensor tensor_data(const Tensor&) const override;
-  Tensor variable_data(const Tensor&) const override;
-  const std::shared_ptr<torch::autograd::Node>& grad_fn(const Tensor&) const override;
-  unsigned _register_hook(const Tensor&, std::function<Tensor(const Tensor&)> hook) const override;
-  void remove_hook(const Tensor&, unsigned pos) const override;
-  bool is_view(const Tensor&) const override;
-  const Tensor& base(const Tensor&) const override;
-  const std::string& name(const Tensor&) const override;
-  bool is_leaf(const Tensor&) const override;
-  int64_t output_nr(const Tensor&) const override;
-  void set_data(const Tensor & self, const Tensor & new_data) const override;
-  Tensor data(const Tensor & self) const override;
-  int64_t _version(const Tensor & self) const override;
-  void retain_grad(const Tensor& self) const override;
-  bool retains_grad(const Tensor& self) const override;
+  at::TensorBase tensor_data(const at::TensorBase&) const override;
+  at::TensorBase variable_data(const at::TensorBase&) const override;
+  const std::shared_ptr<torch::autograd::Node>& grad_fn(const at::TensorBase&) const override;
+  unsigned _register_hook(
+      const at::TensorBase&, std::function<at::TensorBase(const at::TensorBase&)> hook) const override;
+  void remove_hook(const at::TensorBase&, unsigned pos) const override;
+  bool is_view(const at::TensorBase&) const override;
+  const at::TensorBase& base(const at::TensorBase&) const override;
+  const std::string& name(const at::TensorBase&) const override;
+  bool is_leaf(const at::TensorBase&) const override;
+  int64_t output_nr(const at::TensorBase&) const override;
+  void set_data(const at::TensorBase & self, const at::TensorBase & new_data) const override;
+  at::TensorBase data(const at::TensorBase & self) const override;
+  int64_t _version(const at::TensorBase & self) const override;
+  void retain_grad(const at::TensorBase& self) const override;
+  bool retains_grad(const at::TensorBase& self) const override;
   void _backward(const Tensor& self, at::TensorList inputs,
     const c10::optional<Tensor>& gradient, c10::optional<bool> keep_graph,
     bool create_graph) const override;
-  void requires_grad_(const Tensor& self, bool _requires_grad) const override;
+  void requires_grad_(const at::TensorBase& self, bool _requires_grad) const override;
 };
 
 VariableHooks variableHooks;
 at::impl::VariableHooksRegisterer registerVariableHooks(&variableHooks);
 
-Tensor VariableHooks::variable_data(const Tensor& self) const {
+at::TensorBase VariableHooks::variable_data(const at::TensorBase& self) const {
   TORCH_CHECK(self.defined(), "cannot call variable_data() on undefined tensor");
   auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach(
     /*version_counter=*/0,
@@ -362,7 +363,7 @@ Tensor VariableHooks::variable_data(const Tensor& self) const {
   return at::Tensor(self_impl_copy);
 }
 
-Tensor VariableHooks::tensor_data(const Tensor& self) const {
+at::TensorBase VariableHooks::tensor_data(const at::TensorBase& self) const {
   TORCH_CHECK(self.defined(), "cannot call tensor_data() on undefined tensor");
   auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach(
     /*version_counter=*/self.unsafeGetTensorImpl()->version_counter(),
@@ -370,7 +371,7 @@ Tensor VariableHooks::tensor_data(const Tensor& self) const {
   return at::Tensor(self_impl_copy);
 }
 
-bool VariableHooks::is_leaf(const Tensor & self) const {
+bool VariableHooks::is_leaf(const at::TensorBase & self) const {
   if (impl::get_autograd_meta(self)) {
     return impl::get_autograd_meta(self)->grad_fn_ == nullptr;
   } else {
@@ -378,7 +379,7 @@ bool VariableHooks::is_leaf(const Tensor & self) const {
   }
 }
 
-int64_t VariableHooks::output_nr(const Tensor & self) const {
+int64_t VariableHooks::output_nr(const at::TensorBase & self) const {
   if (impl::get_autograd_meta(self)) {
     return impl::get_autograd_meta(self)->output_nr_;
   } else {
@@ -386,7 +387,12 @@ int64_t VariableHooks::output_nr(const Tensor & self) const {
   }
 }
 
-void VariableHooks::set_data(const Tensor & self, const Tensor & new_data) const {
+void VariableHooks::set_data(const at::TensorBase & self_base, const at::TensorBase & new_data_base) const {
+  at::OptionalTensorRef self_ref(self_base);
+  const Tensor &self = *self_ref;
+  at::OptionalTensorRef new_data_ref(new_data_base);
+  const Tensor &new_data = *new_data_ref;
+
   // `var.set_data(new_data)` shallow-copies all non-autograd TensorImpl fields
   // from `new_data` to `var`. It requires that `new_data` and `var` have compatible
   // tensor type.
@@ -420,15 +426,15 @@ void VariableHooks::set_data(const Tensor & self, const Tensor & new_data) const
   self.unsafeGetTensorImpl()->shallow_copy_from(new_data.getIntrusivePtr());
 }
 
-Tensor VariableHooks::data(const Tensor & self) const {
+at::TensorBase VariableHooks::data(const at::TensorBase & self) const {
   return self.variable_data();
 }
 
-int64_t VariableHooks::_version(const Tensor & self) const {
+int64_t VariableHooks::_version(const at::TensorBase & self) const {
   return self.unsafeGetTensorImpl()->version_counter().current_version();
 }
 
-void VariableHooks::retain_grad(const Tensor& self) const {
+void VariableHooks::retain_grad(const at::TensorBase& self) const {
   TORCH_CHECK(self.requires_grad(), "can't retain_grad on Tensor that has requires_grad=False");
   if (self.is_leaf()) {  // no-op for leaves
     return;
@@ -438,7 +444,7 @@ void VariableHooks::retain_grad(const Tensor& self) const {
   }
   c10::weak_intrusive_ptr<c10::TensorImpl> weak_self(self.getIntrusivePtr());
 
-  std::function<void(Tensor)> retain_grad_hook([weak_self](const Tensor& grad) {
+  auto retain_grad_hook = [weak_self](const at::Tensor& grad) {
     if (weak_self.expired()) {
       return;
     } else {
@@ -453,13 +459,13 @@ void VariableHooks::retain_grad(const Tensor& self) const {
         var->mutable_grad() = var->grad() + grad;
       }
     }
-  });
+  };
 
-  self.register_hook(retain_grad_hook);
+  at::OptionalTensorRef(self)->register_hook(retain_grad_hook);
   impl::get_autograd_meta(self)->retains_grad_ = true;
 }
 
-bool VariableHooks::retains_grad(const Tensor& self) const {
+bool VariableHooks::retains_grad(const at::TensorBase& self) const {
   if (impl::get_autograd_meta(self)) {
     return impl::get_autograd_meta(self)->retains_grad_;
   } else {
@@ -480,7 +486,7 @@ void VariableHooks::_backward(
   torch::autograd::backward({self}, {_gradient}, keep_graph, create_graph, input_vars);
 }
 
-void VariableHooks::requires_grad_(const Tensor& self, bool _requires_grad) const {
+void VariableHooks::requires_grad_(const at::TensorBase& self, bool _requires_grad) const {
   if (!self.is_leaf() && !_requires_grad) {
     throw std::runtime_error(
       autograd::utils::requires_grad_leaf_error(_requires_grad)
@@ -492,7 +498,7 @@ void VariableHooks::requires_grad_(const Tensor& self, bool _requires_grad) cons
 // Backward View Variables
 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-bool VariableHooks::is_view(const Tensor& self) const {
+bool VariableHooks::is_view(const at::TensorBase& self) const {
   auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
   if (diff_view_meta) {
     return diff_view_meta->has_bw_view();
@@ -501,7 +507,7 @@ bool VariableHooks::is_view(const Tensor& self) const {
   }
 }
 
-const Tensor& VariableHooks::base(const Tensor& self) const {
+const at::TensorBase& VariableHooks::base(const at::TensorBase& self) const {
   auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
   if (diff_view_meta) {
     TORCH_CHECK(diff_view_meta->has_bw_view(), "Can't get base of non-backward view Tensor");
@@ -515,7 +521,7 @@ namespace {
   std::string singleton_string;
 }
 
-const std::string& VariableHooks::name(const Tensor& self) const {
+const std::string& VariableHooks::name(const at::TensorBase& self) const {
   TORCH_CHECK(self.defined(), "cannot call variable_data() on undefined tensor");
   if (torch::autograd::impl::get_autograd_meta(self)) {
     return torch::autograd::impl::get_autograd_meta(self)->name_;
@@ -528,7 +534,7 @@ namespace {
   std::shared_ptr<torch::autograd::Node> singleton_shared_ptr;
 }
 
-const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(const Tensor& self) const {
+const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(const at::TensorBase& self) const {
   auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
   if (diff_view_meta && diff_view_meta->has_bw_view()) {
     // See NOTE [ View + Inplace detection ]
@@ -596,14 +602,15 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(const Tenso
   }
 }
 
-void VariableHooks::remove_hook(const Tensor& self, unsigned pos) const {
+void VariableHooks::remove_hook(const at::TensorBase& self, unsigned pos) const {
   auto &list = torch::autograd::impl::materialize_autograd_meta(self)->cpp_hooks_list_;
   TORCH_CHECK(list && pos < list->size() , "Invalid index, no hook at position ", pos);
   // Hook will be ignored
   (*list)[pos] = nullptr;
 }
 
-unsigned VariableHooks::_register_hook(const Tensor& self, std::function<Tensor(const Tensor&)> hook) const {
+unsigned VariableHooks::_register_hook(
+    const at::TensorBase& self, std::function<at::TensorBase(const at::TensorBase&)> hook) const {
   TORCH_CHECK(self.requires_grad(), "cannot register a hook on a variable that "
                            "doesn't require gradient");
   // NB: materialize_autograd_meta unnecessary due to requires grad check
index 5f0846a..3a81cfb 100644 (file)
@@ -107,15 +107,15 @@ namespace impl {
 
   // WARNING: This may return a nullptr.  If you require AutogradMeta to return
   // a materialized structure, use materialize_autograd_meta instead.
-  TORCH_API AutogradMeta* get_autograd_meta(const Variable&);
+  TORCH_API AutogradMeta* get_autograd_meta(const at::TensorBase&);
 
   // WARNING: This will return a nullptr if the Tensor is not a view.
-  TORCH_API DifferentiableViewMeta* get_view_autograd_meta(const Variable&);
+  TORCH_API DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase&);
 
   // Returns the current autograd meta, materializing it if it was previously
   // none.  This counts as a *mutating* operation, so do not call it on
   // "read-only" operators; in particular, this is NOT thread safe
-  TORCH_API AutogradMeta* materialize_autograd_meta(const Variable&);
+  TORCH_API AutogradMeta* materialize_autograd_meta(const at::TensorBase&);
 
   /// Set the gradient accumulator of the `Variable`. This is only applicable to
   /// leaf variables. Interior variables should call `set_gradient_edge()`.
@@ -171,11 +171,11 @@ namespace impl {
 
   TORCH_API void set_name(const Variable&, const std::string& name);
 
-  TORCH_API void add_hook(const Variable&, std::shared_ptr<FunctionPreHook> hook);
+  TORCH_API void add_hook(const at::TensorBase&, std::shared_ptr<FunctionPreHook> hook);
   TORCH_API const std::vector<std::shared_ptr<FunctionPreHook>>& hooks(const Variable&);
-  TORCH_API void clear_hooks(const Variable&);
+  TORCH_API void clear_hooks(const at::TensorBase&);
 
-  TORCH_API void create_cpp_hook(const Variable&);
+  TORCH_API void create_cpp_hook(const at::TensorBase&);
 }
 
 //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -252,9 +252,9 @@ struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
     return grad_;
   }
 
-  const Variable& fw_grad(uint64_t level, const Variable& self) const override;
+  const Variable& fw_grad(uint64_t level, const at::TensorBase& self) const override;
 
-  void set_fw_grad(const Variable& new_grad, const Variable& self, uint64_t level, bool is_inplace_op) override;
+  void set_fw_grad(const at::TensorBase& new_grad, const at::TensorBase& self, uint64_t level, bool is_inplace_op) override;
 
   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
   AutogradMeta(at::TensorImpl* self_impl = nullptr, bool requires_grad = false, Edge gradient_edge = Edge() ) {