From 6371bc76a9576df7761d6c1058b6313b9239f3f8 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Sat, 12 Jan 2019 07:04:49 -0800 Subject: [PATCH] Back out "[pt1][tensor] Remove caffe2::ShareData" (#15983) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15983 Original commit changeset: 6e4275d02f4c Reviewed By: supertopher, Yangqing Differential Revision: D13644123 fbshipit-source-id: 4b15a4c62995c0e68aad58465600409e302e6504 --- aten/src/ATen/test/tensor_interop_test.cpp | 8 ++--- c10/core/TensorImpl.h | 41 ++++++++++++++++++++++++ caffe2/core/blob_gpu_test.cc | 16 +++++---- caffe2/core/blob_test.cc | 24 ++++++++------ caffe2/core/operator.h | 9 +----- caffe2/core/tensor.h | 35 +++----------------- caffe2/ideep/operators/operator_fallback_ideep.h | 4 ++- caffe2/operators/softmax_ops.cu | 34 +++++++++----------- caffe2/operators/string_ops_test.cc | 6 ++-- caffe2/operators/utility_ops.h | 3 +- caffe2/predictor/predictor_test.cc | 8 +++-- 11 files changed, 105 insertions(+), 83 deletions(-) diff --git a/aten/src/ATen/test/tensor_interop_test.cpp b/aten/src/ATen/test/tensor_interop_test.cpp index f926312..ec3886b 100644 --- a/aten/src/ATen/test/tensor_interop_test.cpp +++ b/aten/src/ATen/test/tensor_interop_test.cpp @@ -66,11 +66,11 @@ TEST(TestTensorInterop, PytorchToCaffe2Op) { auto* c2_tensor_a = BlobSetTensor(workspace.CreateBlob("a"), at_tensor_a.getIntrusivePtr()); auto* c2_tensor_b = BlobSetTensor(workspace.CreateBlob("b"), at_tensor_b.getIntrusivePtr()); - // Test Alias + // Test ShareData as well { - caffe2::Tensor c2_tensor_from_aten(at_tensor_c.getIntrusivePtr()); - BlobSetTensor(workspace.CreateBlob("c"), c2_tensor_from_aten.Alias()); - + auto c2_tensor_c = XBlobGetMutableTensor(workspace.CreateBlob("c"), {0}, at::kCPU); + c2_tensor_c.ResizeLike(at_tensor_c.getIntrusivePtr()); + c2_tensor_c.ShareData(at_tensor_c.getIntrusivePtr()); } { diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 7e7ea61..5d8a3a0 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -1068,6 +1068,47 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { storage_offset_ = 0; } + /** + * @brief Shares the data with another tensor. + * + * To share data between two tensors, the sizes of the two tensors must be + * equal already. The reason we do not implicitly do a Resize to make the two + * tensors have the same shape is that we want to allow tensors of different + * shapes but the same number of items to still be able to share data. This + * allows one to e.g. have a n-dimensional Tensor and a flattened version + * sharing the same underlying storage. + * + * The source tensor should already have its data allocated. + */ + void ShareData(const TensorImpl& src) { + // Right now, we are assuming the device_type are the same, since it is + // inherently the same in the non-templatized code. We should probably add + // an assert here which might affect perf a little bit. + AT_ASSERTM( + src.numel_ == numel_, + "Size mismatch - did you call reshape before sharing the data?"); + // It is possible that the source tensor hasn't called mutable_data() yet, + // in which case ShareData() doesn't make much sense since we don't really + // know what to share yet. + // TODO: Add the assert after all uninitialized states are eliminated + // AT_ASSERTM(src.dtype_initialized(), + // "Source tensor don't have a data type (did you call mutable_data on the tensor?)"); + if (!src.dtype_initialized()) { + C10_LOG_EVERY_MS(WARNING, 1000) << + "Source tensor don't have a data type (did you call mutable_data on the tensor?)"; + } + AT_ASSERTM( + src.storage_initialized(), + "Source tensor has no content and has size > 0"); + // Finally, do sharing. + /* Since we create new Storage whenever we need to change data_type/capacity + * this still keeps the original semantics + */ + storage_ = src.storage(); + data_type_ = src.dtype(); + storage_offset_ = src.storage_offset(); + } + void ShareExternalPointer( DataPtr&& data_ptr, const caffe2::TypeMeta& data_type, diff --git a/caffe2/core/blob_gpu_test.cc b/caffe2/core/blob_gpu_test.cc index 329b667..07e5638 100644 --- a/caffe2/core/blob_gpu_test.cc +++ b/caffe2/core/blob_gpu_test.cc @@ -61,21 +61,22 @@ TYPED_TEST(TensorGPUTest, TensorInitializedNonEmpty) { EXPECT_TRUE(tensor.data() != nullptr); } -TYPED_TEST(TensorGPUTest, TensorAlias) { +TYPED_TEST(TensorGPUTest, TensorShareData) { if (!HasCudaGPU()) return; vector dims(3); dims[0] = 2; dims[1] = 3; dims[2] = 5; Tensor tensor(dims, CUDA); + Tensor other_tensor(dims, CUDA); EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); + other_tensor.ShareData(tensor); EXPECT_TRUE(tensor.data() != nullptr); EXPECT_TRUE(other_tensor.data() != nullptr); EXPECT_EQ(tensor.data(), other_tensor.data()); } -TYPED_TEST(TensorGPUTest, TensorAliasCanUseDifferentShapes) { +TYPED_TEST(TensorGPUTest, TensorShareDataCanUseDifferentShapes) { if (!HasCudaGPU()) return; vector dims(3); dims[0] = 2; @@ -84,9 +85,9 @@ TYPED_TEST(TensorGPUTest, TensorAliasCanUseDifferentShapes) { vector alternate_dims(1); alternate_dims[0] = 2 * 3 * 5; Tensor tensor(dims, CUDA); + Tensor other_tensor(alternate_dims, CUDA); EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); - other_tensor.Resize(alternate_dims); + other_tensor.ShareData(tensor); EXPECT_EQ(other_tensor.dim(), 1); EXPECT_EQ(other_tensor.dim32(0), alternate_dims[0]); EXPECT_TRUE(tensor.data() != nullptr); @@ -94,15 +95,16 @@ TYPED_TEST(TensorGPUTest, TensorAliasCanUseDifferentShapes) { EXPECT_EQ(tensor.data(), other_tensor.data()); } -TYPED_TEST(TensorGPUTest, NoLongerAliasAfterNumelChanges) { +TYPED_TEST(TensorGPUTest, NoLongerSharesAfterResize) { if (!HasCudaGPU()) return; vector dims(3); dims[0] = 2; dims[1] = 3; dims[2] = 5; Tensor tensor(dims, CUDA); + Tensor other_tensor(dims, CUDA); EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); + other_tensor.ShareData(tensor); EXPECT_EQ(tensor.data(), other_tensor.data()); auto* old_pointer = other_tensor.data(); diff --git a/caffe2/core/blob_test.cc b/caffe2/core/blob_test.cc index 3dd8292..290e310 100644 --- a/caffe2/core/blob_test.cc +++ b/caffe2/core/blob_test.cc @@ -212,7 +212,8 @@ TEST(TensorNonTypedTest, TensorChangeType) { // share the data with other tensor so that the pointer won't be reused // when we reallocate - Tensor other_tensor = tensor.Alias(); + Tensor other_tensor(dims, CPU); + other_tensor.ShareData(tensor); // but double is bigger, so it should allocate a new one auto* doubleptr = tensor.mutable_data(); EXPECT_TRUE(doubleptr != (double*)ptr); @@ -336,14 +337,15 @@ TYPED_TEST(TensorCPUTest, TensorInitializedScalar) { EXPECT_TRUE(tensor.data() != nullptr); } -TYPED_TEST(TensorCPUTest, TensorAlias) { +TYPED_TEST(TensorCPUTest, TensorShareData) { vector dims(3); dims[0] = 2; dims[1] = 3; dims[2] = 5; Tensor tensor(dims, CPU); + Tensor other_tensor(dims, CPU); EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); + other_tensor.ShareData(tensor); EXPECT_TRUE(tensor.data() != nullptr); EXPECT_TRUE(other_tensor.data() != nullptr); EXPECT_EQ(tensor.data(), other_tensor.data()); @@ -389,7 +391,7 @@ TYPED_TEST(TensorCPUTest, TensorShareDataRawPointerWithMeta) { } } -TYPED_TEST(TensorCPUTest, TensorAliasCanUseDifferentShapes) { +TYPED_TEST(TensorCPUTest, TensorShareDataCanUseDifferentShapes) { vector dims(3); dims[0] = 2; dims[1] = 3; @@ -397,9 +399,9 @@ TYPED_TEST(TensorCPUTest, TensorAliasCanUseDifferentShapes) { vector alternate_dims(1); alternate_dims[0] = 2 * 3 * 5; Tensor tensor(dims, CPU); + Tensor other_tensor(alternate_dims, CPU); EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); - other_tensor.Resize(alternate_dims); + other_tensor.ShareData(tensor); EXPECT_EQ(other_tensor.dim(), 1); EXPECT_EQ(other_tensor.dim32(0), alternate_dims[0]); EXPECT_TRUE(tensor.data() != nullptr); @@ -413,14 +415,15 @@ TYPED_TEST(TensorCPUTest, TensorAliasCanUseDifferentShapes) { } -TYPED_TEST(TensorCPUTest, NoLongerAliassAfterNumelChanges) { +TYPED_TEST(TensorCPUTest, NoLongerSharesAfterResize) { vector dims(3); dims[0] = 2; dims[1] = 3; dims[2] = 5; Tensor tensor(dims, CPU); + Tensor other_tensor(dims, CPU); EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); + other_tensor.ShareData(tensor); EXPECT_EQ(tensor.data(), other_tensor.data()); auto* old_pointer = other_tensor.data(); @@ -430,14 +433,15 @@ TYPED_TEST(TensorCPUTest, NoLongerAliassAfterNumelChanges) { EXPECT_NE(old_pointer, tensor.mutable_data()); } -TYPED_TEST(TensorCPUTest, NoLongerAliasAfterFreeMemory) { +TYPED_TEST(TensorCPUTest, NoLongerSharesAfterFreeMemory) { vector dims(3); dims[0] = 2; dims[1] = 3; dims[2] = 5; Tensor tensor(dims, CPU); + Tensor other_tensor(dims, CPU); EXPECT_TRUE(tensor.mutable_data() != nullptr); - Tensor other_tensor = tensor.Alias(); + other_tensor.ShareData(tensor); EXPECT_EQ(tensor.data(), other_tensor.data()); auto* old_pointer = other_tensor.data(); diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h index da3df5a..5e1a9ab 100644 --- a/caffe2/core/operator.h +++ b/caffe2/core/operator.h @@ -223,12 +223,6 @@ class CAFFE2_API OperatorBase : public Observable { return t; } - Tensor* OutputTensorAlias(int idx, const Tensor& src) { - return BlobSetTensor(OutputBlob(idx), - src.Alias()); - } - - template inline T* Output(int idx, T* allocated) { outputs_.at(idx)->Reset(allocated); @@ -794,8 +788,7 @@ class Operator : public OperatorBase { /* using override */ using OperatorBase::Output; \ /* using override */ using OperatorBase::Input; \ /* using override */ using OperatorBase::OutputSize; \ - /* using override */ using OperatorBase::IsInputOutputAlias; \ - /* using override */ using OperatorBase::OutputTensorAlias + /* using override */ using OperatorBase::IsInputOutputAlias #define USE_OPERATOR_FUNCTIONS(context) \ USE_OPERATOR_BASE_FUNCTIONS; \ diff --git a/caffe2/core/tensor.h b/caffe2/core/tensor.h index 627293f..7a02659 100644 --- a/caffe2/core/tensor.h +++ b/caffe2/core/tensor.h @@ -118,35 +118,6 @@ class CAFFE2_API Tensor final { return x; } - /** - * Clone self as a Tensor that share the same Storage, - * that is, both Tensors are views on the same Storage. - * If we change the sizes or strides of one Tensor, it - * does not affect the other Tensor that it shares Storage - * with. - * A similar yet different usage is `Tensor x = y;`, this - * will make x and y pointing to the same Tensor and resizing - * one of them will resize the other as well. - * - * TODO: Deduplicate this with THTensor_(newWithTensor) - * (exposed in ATen as at::alias but not otherwise available) - */ - Tensor Alias() const { - Tensor x(sizes(), GetDevice()); - if (!dtype_initialized()) { - C10_LOG_EVERY_MS(WARNING, 1000) << - "Cloning a tensor that don't have a data type (did you call mutable_data on the tensor?)"; - } - AT_ASSERTM( - storage_initialized(), - "Cloning a tensor that has no content and has size > 0"); - // set_storage already sets data_type_ of TensorImpl - x.impl_->set_storage(storage()); - x.impl_->set_storage_offset(impl_->storage_offset()); - x.impl_->set_sizes_and_strides(sizes(), strides()); - return x; - } - DeviceType GetDeviceType() const { return impl_->device_type(); } @@ -324,6 +295,10 @@ class CAFFE2_API Tensor final { std::swap(*impl_.get(), *other.impl_.get()); } + void ShareData(const Tensor& src) const { + impl_.get()->ShareData(*src.impl_.get()); + } + /** * @brief Shares the data with an externally managed pointer. * @@ -499,7 +474,7 @@ class CAFFE2_API Tensor final { return impl_.get()->stride(dim); } - inline at::IntList strides() const { + inline at::IntList strides() { return impl_.get()->strides(); } diff --git a/caffe2/ideep/operators/operator_fallback_ideep.h b/caffe2/ideep/operators/operator_fallback_ideep.h index 4372807..e834ed9 100644 --- a/caffe2/ideep/operators/operator_fallback_ideep.h +++ b/caffe2/ideep/operators/operator_fallback_ideep.h @@ -157,7 +157,9 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator { dtensor->CopyFrom(src); } else { dst->Reset(new Tensor(CPU)); - BlobSetTensor(dst, src.Alias()); + auto dtensor = BlobGetMutableTensor(dst, CPU); + dtensor->Resize(src_dims); + dtensor->ShareData(src); } } } diff --git a/caffe2/operators/softmax_ops.cu b/caffe2/operators/softmax_ops.cu index a58ebc8..0876ad8 100644 --- a/caffe2/operators/softmax_ops.cu +++ b/caffe2/operators/softmax_ops.cu @@ -487,21 +487,20 @@ bool SoftmaxWithLossGradientOp::RunOnDevice() { auto& d_avg_loss = Input(InputSize() - 1); // Gradient w.r.t. avg loss const float* weights = (InputSize() > 4 ? Input(2).data() : NULL); - Tensor* dX; - if (only_loss_) { - // Memory saving trick to share the buffer with the softmax output. - // Softmax output is thus overwritten. - dX = OutputTensorAlias(0, P); - dX->ResizeLike(X); - } else { - dX = Output(0, X.sizes(), at::dtype()); - } + auto* dX = Output(0); + dX->ResizeLike(X); const auto canonical_axis = X.canonical_axis_index(axis_); int N, D; N = X.size_to_dim(canonical_axis); // batch size D = X.size_from_dim(canonical_axis); + if (only_loss_) { + // Memory saving trick to share the buffer with the softmax output. + // Softmax output is thus overwritten. + dX->ShareData(P); + } + ReinitializeTensor(&total_weight_ptr_, {1}, at::dtype().device(CUDA)); if (label_prob_mode_) { @@ -603,21 +602,20 @@ bool SpatialSoftmaxWithLossGradientOp::RunOnDevice() { auto& d_avg_loss = Input(InputSize() - 1); // Gradient w.r.t. avg loss const float* weights = (InputSize() > 4 ? Input(2).data() : NULL); - Tensor* dX; - if (only_loss_) { - // Memory saving trick to share the buffer with the softmax output. - // Softmax output is thus overwritten. - dX = OutputTensorAlias(0, P); - dX->ResizeLike(X); - } else { - dX = Output(0, X.sizes(), at::dtype()); - } + auto* dX = Output(0); + dX->ResizeLike(X); const auto canonical_axis = X.canonical_axis_index(1); int N, D; N = X.dim32(0); D = X.dim32(1); + if (only_loss_) { + // Memory saving trick to share the buffer with the softmax output. + // Softmax output is thus overwritten. + dX->ShareData(P); + } + ReinitializeTensor(&total_weight_ptr_, {1}, at::dtype().device(CUDA)); // Spatial mode, compute softmax for each x, y location CAFFE_ENFORCE_EQ(X.ndim(), 4); diff --git a/caffe2/operators/string_ops_test.cc b/caffe2/operators/string_ops_test.cc index 0856229..f325e72 100644 --- a/caffe2/operators/string_ops_test.cc +++ b/caffe2/operators/string_ops_test.cc @@ -7,9 +7,11 @@ namespace caffe2 { class StringJoinOpTest : public testing::Test { public: - bool runOp(const Tensor& input) { + bool runOp(const TensorCPU& input) { auto* blob = ws_.CreateBlob("X"); - BlobSetTensor(blob, input.Alias()); + auto* tensor = BlobGetMutableTensor(blob, CPU); + tensor->ResizeLike(input); + tensor->ShareData(input); OperatorDef def; def.set_name("test"); diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h index dd34102..8426f49 100644 --- a/caffe2/operators/utility_ops.h +++ b/caffe2/operators/utility_ops.h @@ -169,7 +169,8 @@ class AliasOp final : public Operator { bool RunOnDevice() override { auto& input = Input(0); CAFFE_ENFORCE_GE(input.numel(), 0, "Tensor is not initialized"); - OutputTensorAlias(0, input); + Output(0)->ResizeLike(input); + Output(0)->ShareData(input); return true; } }; diff --git a/caffe2/predictor/predictor_test.cc b/caffe2/predictor/predictor_test.cc index da97d59..0d3bc33 100644 --- a/caffe2/predictor/predictor_test.cc +++ b/caffe2/predictor/predictor_test.cc @@ -179,8 +179,10 @@ class PredictorTest : public testing::Test { TEST_F(PredictorTest, SimpleBatchSized) { auto inputData = randomTensor({1, 4}, ctx_.get()); Predictor::TensorList input; + input.emplace_back(CPU); auto tensor = BlobGetMutableTensor(inputData.get(), CPU); - input.emplace_back(tensor->Alias()); + input.back().ResizeLike(*tensor); + input.back().ShareData(*tensor); Predictor::TensorList output; (*p_)(input, &output); EXPECT_EQ(output.size(), 1); @@ -193,8 +195,10 @@ TEST_F(PredictorTest, SimpleBatchSized) { TEST_F(PredictorTest, SimpleBatchSizedMapInput) { auto inputData = randomTensor({1, 4}, ctx_.get()); Predictor::TensorMap input; + auto iter = input.emplace("data", Tensor(CPU)); auto tensor = BlobGetMutableTensor(inputData.get(), CPU); - input.emplace("data", tensor->Alias()); + iter.first->second.ResizeLike(*tensor); + iter.first->second.ShareData(*tensor); Predictor::TensorList output; (*p_)(input, &output); -- 2.7.4