From 8ac55a6812884d76d6116aa72aa7beb4a6bda832 Mon Sep 17 00:00:00 2001 From: Sebastian Messmer Date: Tue, 8 Jan 2019 20:22:41 -0800 Subject: [PATCH] Convert caffe2/aten Tensors to/from c10 Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14820 Reviewed By: dzhulgakov Differential Revision: D13348044 fbshipit-source-id: 95008e6ead3cfc478696b1c203769241d4cf6ca8 --- aten/src/ATen/core/Tensor.h | 12 ++++++++++++ aten/src/ATen/templates/Tensor.h | 12 ++++++++++++ caffe2/core/tensor.h | 12 ++++++++++++ 3 files changed, 36 insertions(+) diff --git a/aten/src/ATen/core/Tensor.h b/aten/src/ATen/core/Tensor.h index 936a881..f23c54c 100644 --- a/aten/src/ATen/core/Tensor.h +++ b/aten/src/ATen/core/Tensor.h @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace c10{ @@ -57,6 +58,17 @@ public: Tensor(const Tensor&) = default; Tensor(Tensor&&) = default; + explicit Tensor(C10Tensor tensor) + : impl_(std::move(tensor).impl()) {} + + explicit operator C10Tensor() const & { + return C10Tensor(impl_); + } + + explicit operator C10Tensor() && { + return C10Tensor(std::move(impl_)); + } + int64_t dim() const { return impl_->dim(); } diff --git a/aten/src/ATen/templates/Tensor.h b/aten/src/ATen/templates/Tensor.h index a9bf3c8..23618b5 100644 --- a/aten/src/ATen/templates/Tensor.h +++ b/aten/src/ATen/templates/Tensor.h @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace c10{ @@ -57,6 +58,17 @@ public: Tensor(const Tensor&) = default; Tensor(Tensor&&) = default; + explicit Tensor(C10Tensor tensor) + : impl_(std::move(tensor).impl()) {} + + explicit operator C10Tensor() const & { + return C10Tensor(impl_); + } + + explicit operator C10Tensor() && { + return C10Tensor(std::move(impl_)); + } + int64_t dim() const { return impl_->dim(); } diff --git a/caffe2/core/tensor.h b/caffe2/core/tensor.h index dbd3851..acd14b2 100644 --- a/caffe2/core/tensor.h +++ b/caffe2/core/tensor.h @@ -8,6 +8,7 @@ #include #include "ATen/core/Tensor.h" #include +#include namespace caffe2 { @@ -96,6 +97,17 @@ class CAFFE2_API Tensor final { CopyFrom(src); } + explicit Tensor(C10Tensor tensor) + : impl_(std::move(tensor).impl()) {} + + explicit operator C10Tensor() const & { + return C10Tensor(impl_); + } + + explicit operator C10Tensor() && { + return C10Tensor(std::move(impl_)); + } + Tensor Clone() const { Tensor x(GetDevice()); x.CopyFrom(*this); -- 2.7.4