From 17f05ad5e562830127dd06be6c11c10453a86d0e Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Thu, 18 Apr 2019 14:07:30 -0700 Subject: [PATCH] Moving at::Tensor into caffe2::Tensor without bumping refcount (#19388) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19388 The old implementation forced a refcount bump when converting at::Tensor to caffe2::Tensor. Now, it is possible to move it without a refcount bump. Reviewed By: dzhulgakov Differential Revision: D14986815 fbshipit-source-id: 92b4b0a6f323ed38376ffad75f960cad250ecd9b --- aten/src/ATen/core/Tensor.h | 5 +++++ aten/src/ATen/templates/Tensor.h | 5 +++++ caffe2/core/tensor.h | 4 ++-- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index e6552f2..5fbdb4e 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -14,6 +14,9 @@ #include #include +namespace caffe2 { +class Tensor; +} namespace c10{ struct TensorOptions; } @@ -761,6 +764,8 @@ class CAFFE2_API Tensor { friend struct WeakTensor; protected: + friend class ::caffe2::Tensor; + void enforce_invariants(); c10::intrusive_ptr impl_; }; diff --git a/aten/src/ATen/templates/Tensor.h b/aten/src/ATen/templates/Tensor.h index 3944676..f9d9da0 100644 --- a/aten/src/ATen/templates/Tensor.h +++ b/aten/src/ATen/templates/Tensor.h @@ -14,6 +14,9 @@ #include #include +namespace caffe2 { +class Tensor; +} namespace c10{ struct TensorOptions; } @@ -348,6 +351,8 @@ class CAFFE2_API Tensor { friend struct WeakTensor; protected: + friend class ::caffe2::Tensor; + void enforce_invariants(); c10::intrusive_ptr impl_; }; diff --git a/caffe2/core/tensor.h b/caffe2/core/tensor.h index 4846d9a..d6a5030 100644 --- a/caffe2/core/tensor.h +++ b/caffe2/core/tensor.h @@ -115,8 +115,8 @@ class CAFFE2_API Tensor final { * The tensor will share the same instance (data, strides, sizes, etc) but * a different subset of APIs would be available */ - explicit Tensor(const at::Tensor& tensor) - : impl_(std::move(tensor.getIntrusivePtr())) { + explicit Tensor(at::Tensor tensor) + : impl_(std::move(tensor.impl_)) { enforce_invariants(); } -- 2.7.4