From ede1f4ad05bbf6c7d483ea105d2d8bc451ead437 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 8 Jan 2019 10:55:26 -0800 Subject: [PATCH] Remove caffe2::ShareData (#15418) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15418 Previously we are using Resize + ShareData. Instead, we'll create a function on Tensor that clones itself with same storage. Suppose we want `t` to `ShareData` with `t0`, Previous: ``` Tensor t(dims, CPU); t.Resize(t0.sizes()); t.ShareData(t0); ``` Now: ``` Tensor t = t0.Alias(); ``` Reviewed By: dzhulgakov Differential Revision: D13507609 fbshipit-source-id: 6e4275d02f4c3356cbce91127f1b01111dc86b9f --- aten/src/ATen/test/tensor_interop_test.cpp | 8 ++--- c10/core/TensorImpl.h | 41 ------------------------ caffe2/core/blob_gpu_test.cc | 16 ++++----- caffe2/core/blob_test.cc | 24 ++++++-------- caffe2/core/operator.h | 9 +++++- caffe2/core/tensor.h | 35 +++++++++++++++++--- caffe2/ideep/operators/operator_fallback_ideep.h | 4 +-- caffe2/operators/softmax_ops.cu | 34 +++++++++++--------- caffe2/operators/string_ops_test.cc | 6 ++-- caffe2/operators/utility_ops.h | 3 +- caffe2/predictor/predictor_test.cc | 8 ++--- 11 files changed, 83 insertions(+), 105 deletions(-) diff --git a/aten/src/ATen/test/tensor_interop_test.cpp b/aten/src/ATen/test/tensor_interop_test.cpp index ec3886b..f926312 100644 --- a/aten/src/ATen/test/tensor_interop_test.cpp +++ b/aten/src/ATen/test/tensor_interop_test.cpp @@ -66,11 +66,11 @@ TEST(TestTensorInterop, PytorchToCaffe2Op) { auto* c2_tensor_a = BlobSetTensor(workspace.CreateBlob("a"), at_tensor_a.getIntrusivePtr()); auto* c2_tensor_b = BlobSetTensor(workspace.CreateBlob("b"), at_tensor_b.getIntrusivePtr()); - // Test ShareData as well + // Test Alias { - auto c2_tensor_c = XBlobGetMutableTensor(workspace.CreateBlob("c"), {0}, at::kCPU); - c2_tensor_c.ResizeLike(at_tensor_c.getIntrusivePtr()); - c2_tensor_c.ShareData(at_tensor_c.getIntrusivePtr()); + caffe2::Tensor c2_tensor_from_aten(at_tensor_c.getIntrusivePtr()); + BlobSetTensor(workspace.CreateBlob("c"), c2_tensor_from_aten.Alias()); + } { diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 5d8a3a0..7e7ea61 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -1068,47 +1068,6 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { storage_offset_ = 0; } - /** - * @brief Shares the data with another tensor. - * - * To share data between two tensors, the sizes of the two tensors must be - * equal already. The reason we do not implicitly do a Resize to make the two - * tensors have the same shape is that we want to allow tensors of different - * shapes but the same number of items to still be able to share data. This - * allows one to e.g. have a n-dimensional Tensor and a flattened version - * sharing the same underlying storage. - * - * The source tensor should already have its data allocated. - */ - void ShareData(const TensorImpl& src) { - // Right now, we are assuming the device_type are the same, since it is - // inherently the same in the non-templatized code. We should probably add - // an assert here which might affect perf a little bit. - AT_ASSERTM( - src.numel_ == numel_, - "Size mismatch - did you call reshape before sharing the data?"); - // It is possible that the source tensor hasn't called mutable_data() yet, - // in which case ShareData() doesn't make much sense since we don't really - // know what to share yet. - // TODO: Add the assert after all uninitialized states are eliminated - // AT_ASSERTM(src.dtype_initialized(), - // "Source tensor don't have a data type (did you call mutable_data on the tensor?)"); - if (!src.dtype_initialized()) { - C10_LOG_EVERY_MS(WARNING, 1000) << - "Source tensor don't have a data type (did you call mutable_data on the tensor?)"; - } - AT_ASSERTM( - src.storage_initialized(), - "Source tensor has no content and has size > 0"); - // Finally, do sharing. - /* Since we create new Storage whenever we need to change data_type/capacity - * this still keeps the original semantics - */ - storage_ = src.storage(); - data_type_ = src.dtype(); - storage_offset_ = src.storage_offset(); - } - void ShareExternalPointer( DataPtr&& data_ptr, const caffe2::TypeMeta& data_type, diff --git a/caffe2/core/blob_gpu_test.cc b/caffe2/core/blob_gpu_test.cc index 07e5638..329b667 100644 --- a/caffe2/core/blob_gpu_test.cc +++ b/caffe2/core/blob_gpu_test.cc @@ -61,22 +61,21 @@ TYPED_TEST(TensorGPUTest, TensorInitializedNonEmpty) { EXPECT_TRUE(tensor.data() != nullptr); } -TYPED_TEST(TensorGPUTest, TensorShareData) { +TYPED_TEST(TensorGPUTest, TensorAlias) { if (!HasCudaGPU()) return; vector dims(3); dims[0] = 2; dims[1] = 3; dims[2] = 5; Tensor tensor(dims, CUDA); - Tensor other_tensor(dims, CUDA); EXPECT_TRUE(tensor.mutable_data() != nullptr); - other_tensor.ShareData(tensor); + Tensor other_tensor = tensor.Alias(); EXPECT_TRUE(tensor.data() != nullptr); EXPECT_TRUE(other_tensor.data() != nullptr); EXPECT_EQ(tensor.data(), other_tensor.data()); } -TYPED_TEST(TensorGPUTest, TensorShareDataCanUseDifferentShapes) { +TYPED_TEST(TensorGPUTest, TensorAliasCanUseDifferentShapes) { if (!HasCudaGPU()) return; vector dims(3); dims[0] = 2; @@ -85,9 +84,9 @@ TYPED_TEST(TensorGPUTest, TensorShareDataCanUseDifferentShapes) { vector alternate_dims(1); alternate_dims[0] = 2 * 3 * 5; Tensor tensor(dims, CUDA); - Tensor other_tensor(alternate_dims, CUDA); EXPECT_TRUE(tensor.mutable_data() != nullptr); - other_tensor.ShareData(tensor); + Tensor other_tensor = tensor.Alias(); + other_tensor.Resize(alternate_dims); EXPECT_EQ(other_tensor.dim(), 1); EXPECT_EQ(other_tensor.dim32(0), alternate_dims[0]); EXPECT_TRUE(tensor.data() != nullptr); @@ -95,16 +94,15 @@ TYPED_TEST(TensorGPUTest, TensorShareDataCanUseDifferentShapes) { EXPECT_EQ(tensor.data(), other_tensor.data()); } -TYPED_TEST(TensorGPUTest, NoLongerSharesAfterResize) { +TYPED_TEST(TensorGPUTest, NoLongerAliasAfterNumelChanges) { if (!HasCudaGPU()) return; vector dims(3); dims[0] = 2; dims[1] = 3; dims[2] = 5; Tensor tensor(dims, CUDA); - Tensor other_tensor(dims, CUDA); EXPECT_TRUE(tensor.mutable_data() != nullptr); - other_tensor.ShareData(tensor); + Tensor other_tensor = tensor.Alias(); EXPECT_EQ(tensor.data(), other_tensor.data()); auto* old_pointer = other_tensor.data(); diff --git a/caffe2/core/blob_test.cc b/caffe2/core/blob_test.cc index 290e310..3dd8292 100644 --- a/caffe2/core/blob_test.cc +++ b/caffe2/core/blob_test.cc @@ -212,8 +212,7 @@ TEST(TensorNonTypedTest, TensorChangeType) { // share the data with other tensor so that the pointer won't be reused // when we reallocate - Tensor other_tensor(dims, CPU); - other_tensor.ShareData(tensor); + Tensor other_tensor = tensor.Alias(); // but double is bigger, so it should allocate a new one auto* doubleptr = tensor.mutable_data(); EXPECT_TRUE(doubleptr != (double*)ptr); @@ -337,15 +336,14 @@ TYPED_TEST(TensorCPUTest, TensorInitializedScalar) { EXPECT_TRUE(tensor.data() != nullptr); } -TYPED_TEST(TensorCPUTest, TensorShareData) { +TYPED_TEST(TensorCPUTest, TensorAlias) { vector dims(3); dims[0] = 2; dims[1] = 3; dims[2] = 5; Tensor tensor(dims, CPU); - Tensor other_tensor(dims, CPU); EXPECT_TRUE(tensor.mutable_data() != nullptr); - other_tensor.ShareData(tensor); + Tensor other_tensor = tensor.Alias(); EXPECT_TRUE(tensor.data() != nullptr); EXPECT_TRUE(other_tensor.data() != nullptr); EXPECT_EQ(tensor.data(), other_tensor.data()); @@ -391,7 +389,7 @@ TYPED_TEST(TensorCPUTest, TensorShareDataRawPointerWithMeta) { } } -TYPED_TEST(TensorCPUTest, TensorShareDataCanUseDifferentShapes) { +TYPED_TEST(TensorCPUTest, TensorAliasCanUseDifferentShapes) { vector dims(3); dims[0] = 2; dims[1] = 3; @@ -399,9 +397,9 @@ TYPED_TEST(TensorCPUTest, TensorShareDataCanUseDifferentShapes) { vector alternate_dims(1); alternate_dims[0] = 2 * 3 * 5; Tensor tensor(dims, CPU); - Tensor other_tensor(alternate_dims, CPU); EXPECT_TRUE(tensor.mutable_data() != nullptr); - other_tensor.ShareData(tensor); + Tensor other_tensor = tensor.Alias(); + other_tensor.Resize(alternate_dims); EXPECT_EQ(other_tensor.dim(), 1); EXPECT_EQ(other_tensor.dim32(0), alternate_dims[0]); EXPECT_TRUE(tensor.data() != nullptr); @@ -415,15 +413,14 @@ TYPED_TEST(TensorCPUTest, TensorShareDataCanUseDifferentShapes) { } -TYPED_TEST(TensorCPUTest, NoLongerSharesAfterResize) { +TYPED_TEST(TensorCPUTest, NoLongerAliassAfterNumelChanges) { vector dims(3); dims[0] = 2; dims[1] = 3; dims[2] = 5; Tensor tensor(dims, CPU); - Tensor other_tensor(dims, CPU); EXPECT_TRUE(tensor.mutable_data() != nullptr); - other_tensor.ShareData(tensor); + Tensor other_tensor = tensor.Alias(); EXPECT_EQ(tensor.data(), other_tensor.data()); auto* old_pointer = other_tensor.data(); @@ -433,15 +430,14 @@ TYPED_TEST(TensorCPUTest, NoLongerSharesAfterResize) { EXPECT_NE(old_pointer, tensor.mutable_data()); } -TYPED_TEST(TensorCPUTest, NoLongerSharesAfterFreeMemory) { +TYPED_TEST(TensorCPUTest, NoLongerAliasAfterFreeMemory) { vector dims(3); dims[0] = 2; dims[1] = 3; dims[2] = 5; Tensor tensor(dims, CPU); - Tensor other_tensor(dims, CPU); EXPECT_TRUE(tensor.mutable_data() != nullptr); - other_tensor.ShareData(tensor); + Tensor other_tensor = tensor.Alias(); EXPECT_EQ(tensor.data(), other_tensor.data()); auto* old_pointer = other_tensor.data(); diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index 5e1a9ab..da3df5a 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -223,6 +223,12 @@ class CAFFE2_API OperatorBase : public Observable { return t; } + Tensor* OutputTensorAlias(int idx, const Tensor& src) { + return BlobSetTensor(OutputBlob(idx), + src.Alias()); + } + + template inline T* Output(int idx, T* allocated) { outputs_.at(idx)->Reset(allocated); @@ -788,7 +794,8 @@ class Operator : public OperatorBase { /* using override */ using OperatorBase::Output; \ /* using override */ using OperatorBase::Input; \ /* using override */ using OperatorBase::OutputSize; \ - /* using override */ using OperatorBase::IsInputOutputAlias + /* using override */ using OperatorBase::IsInputOutputAlias; \ + /* using override */ using OperatorBase::OutputTensorAlias #define USE_OPERATOR_FUNCTIONS(context) \ USE_OPERATOR_BASE_FUNCTIONS; \ diff --git a/caffe2/core/tensor.h b/caffe2/core/tensor.h index feacc64..9f24e2d 100644 --- a/caffe2/core/tensor.h +++ b/caffe2/core/tensor.h @@ -102,6 +102,35 @@ class CAFFE2_API Tensor final { return x; } + /** + * Clone self as a Tensor that share the same Storage, + * that is, both Tensors are views on the same Storage. + * If we change the sizes or strides of one Tensor, it + * does not affect the other Tensor that it shares Storage + * with. + * A similar yet different usage is `Tensor x = y;`, this + * will make x and y pointing to the same Tensor and resizing + * one of them will resize the other as well. + * + * TODO: Deduplicate this with THTensor_(newWithTensor) + * (exposed in ATen as at::alias but not otherwise available) + */ + Tensor Alias() const { + Tensor x(sizes(), GetDevice()); + if (!dtype_initialized()) { + C10_LOG_EVERY_MS(WARNING, 1000) << + "Cloning a tensor that don't have a data type (did you call mutable_data on the tensor?)"; + } + AT_ASSERTM( + storage_initialized(), + "Cloning a tensor that has no content and has size > 0"); + // set_storage already sets data_type_ of TensorImpl + x.impl_->set_storage(storage()); + x.impl_->set_storage_offset(impl_->storage_offset()); + x.impl_->set_sizes_and_strides(sizes(), strides()); + return x; + } + DeviceType GetDeviceType() const { return impl_->device_type(); } @@ -279,10 +308,6 @@ class CAFFE2_API Tensor final { std::swap(*impl_.get(), *other.impl_.get()); } - void ShareData(const Tensor& src) const { - impl_.get()->ShareData(*src.impl_.get()); - } - /** * @brief Shares the data with an externally managed pointer. * @@ -458,7 +483,7 @@ class CAFFE2_API Tensor final { return impl_.get()->stride(dim); } - inline at::IntList strides() { + inline at::IntList strides() const { return impl_.get()->strides(); } diff --git a/caffe2/ideep/operators/operator_fallback_ideep.h b/caffe2/ideep/operators/operator_fallback_ideep.h index 2d31489..118f92d 100644 --- a/caffe2/ideep/operators/operator_fallback_ideep.h +++ b/caffe2/ideep/operators/operator_fallback_ideep.h @@ -151,9 +151,7 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator { VLOG(2) << "Output " << base_def_.output(i) << " as CPUTensor"; Blob* dst = OperatorBase::OutputBlob(i); dst->Reset(new Tensor(CPU)); - auto dtensor = BlobGetMutableTensor(dst, CPU); - dtensor->Resize(src_dims); - dtensor->ShareData(src); + BlobSetTensor(dst, src.Alias()); } } return true; diff --git a/caffe2/operators/softmax_ops.cu b/caffe2/operators/softmax_ops.cu index 81955f9..be8b455 100644 --- a/caffe2/operators/softmax_ops.cu +++ b/caffe2/operators/softmax_ops.cu @@ -483,20 +483,21 @@ bool SoftmaxWithLossGradientOp::RunOnDevice() { auto& d_avg_loss = Input(InputSize() - 1); // Gradient w.r.t. avg loss const float* weights = (InputSize() > 4 ? Input(2).data() : NULL); - auto* dX = Output(0); - dX->ResizeLike(X); + Tensor* dX; + if (only_loss_) { + // Memory saving trick to share the buffer with the softmax output. + // Softmax output is thus overwritten. + dX = OutputTensorAlias(0, P); + dX->ResizeLike(X); + } else { + dX = Output(0, X.sizes(), at::dtype()); + } const auto canonical_axis = X.canonical_axis_index(axis_); int N, D; N = X.size_to_dim(canonical_axis); // batch size D = X.size_from_dim(canonical_axis); - if (only_loss_) { - // Memory saving trick to share the buffer with the softmax output. - // Softmax output is thus overwritten. - dX->ShareData(P); - } - total_weight_ptr_.Resize(1); if (label_prob_mode_) { @@ -598,20 +599,21 @@ bool SpatialSoftmaxWithLossGradientOp::RunOnDevice() { auto& d_avg_loss = Input(InputSize() - 1); // Gradient w.r.t. avg loss const float* weights = (InputSize() > 4 ? Input(2).data() : NULL); - auto* dX = Output(0); - dX->ResizeLike(X); + Tensor* dX; + if (only_loss_) { + // Memory saving trick to share the buffer with the softmax output. + // Softmax output is thus overwritten. + dX = OutputTensorAlias(0, P); + dX->ResizeLike(X); + } else { + dX = Output(0, X.sizes(), at::dtype()); + } const auto canonical_axis = X.canonical_axis_index(1); int N, D; N = X.dim32(0); D = X.dim32(1); - if (only_loss_) { - // Memory saving trick to share the buffer with the softmax output. - // Softmax output is thus overwritten. - dX->ShareData(P); - } - total_weight_ptr_.Resize(1); // Spatial mode, compute softmax for each x, y location CAFFE_ENFORCE_EQ(X.ndim(), 4); diff --git a/caffe2/operators/string_ops_test.cc b/caffe2/operators/string_ops_test.cc index f325e72..0856229 100644 --- a/caffe2/operators/string_ops_test.cc +++ b/caffe2/operators/string_ops_test.cc @@ -7,11 +7,9 @@ namespace caffe2 { class StringJoinOpTest : public testing::Test { public: - bool runOp(const TensorCPU& input) { + bool runOp(const Tensor& input) { auto* blob = ws_.CreateBlob("X"); - auto* tensor = BlobGetMutableTensor(blob, CPU); - tensor->ResizeLike(input); - tensor->ShareData(input); + BlobSetTensor(blob, input.Alias()); OperatorDef def; def.set_name("test"); diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h index 0bc30e1..9097bbb 100644 --- a/caffe2/operators/utility_ops.h +++ b/caffe2/operators/utility_ops.h @@ -169,8 +169,7 @@ class AliasOp final : public Operator { bool RunOnDevice() override { auto& input = Input(0); CAFFE_ENFORCE_GE(input.numel(), 0, "Tensor is not initialized"); - Output(0)->ResizeLike(input); - Output(0)->ShareData(input); + OutputTensorAlias(0, input); return true; } }; diff --git a/caffe2/predictor/predictor_test.cc b/caffe2/predictor/predictor_test.cc index 0d3bc33..da97d59 100644 --- a/caffe2/predictor/predictor_test.cc +++ b/caffe2/predictor/predictor_test.cc @@ -179,10 +179,8 @@ class PredictorTest : public testing::Test { TEST_F(PredictorTest, SimpleBatchSized) { auto inputData = randomTensor({1, 4}, ctx_.get()); Predictor::TensorList input; - input.emplace_back(CPU); auto tensor = BlobGetMutableTensor(inputData.get(), CPU); - input.back().ResizeLike(*tensor); - input.back().ShareData(*tensor); + input.emplace_back(tensor->Alias()); Predictor::TensorList output; (*p_)(input, &output); EXPECT_EQ(output.size(), 1); @@ -195,10 +193,8 @@ TEST_F(PredictorTest, SimpleBatchSized) { TEST_F(PredictorTest, SimpleBatchSizedMapInput) { auto inputData = randomTensor({1, 4}, ctx_.get()); Predictor::TensorMap input; - auto iter = input.emplace("data", Tensor(CPU)); auto tensor = BlobGetMutableTensor(inputData.get(), CPU); - iter.first->second.ResizeLike(*tensor); - iter.first->second.ShareData(*tensor); + input.emplace("data", tensor->Alias()); Predictor::TensorList output; (*p_)(input, &output); -- 2.7.4