Back out "[pt1][tensor] Remove caffe2::ShareData" (#15983)
authorJerry Zhang <jerryzh@fb.com>
Sat, 12 Jan 2019 15:04:49 +0000 (07:04 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 12 Jan 2019 15:07:22 +0000 (07:07 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15983

Original commit changeset: 6e4275d02f4c

Reviewed By: supertopher, Yangqing

Differential Revision: D13644123

fbshipit-source-id: 4b15a4c62995c0e68aad58465600409e302e6504

aten/src/ATen/test/tensor_interop_test.cpp
c10/core/TensorImpl.h
caffe2/core/blob_gpu_test.cc
caffe2/core/blob_test.cc
caffe2/core/operator.h
caffe2/core/tensor.h
caffe2/ideep/operators/operator_fallback_ideep.h
caffe2/operators/softmax_ops.cu
caffe2/operators/string_ops_test.cc
caffe2/operators/utility_ops.h
caffe2/predictor/predictor_test.cc

index f926312..ec3886b 100644 (file)
@@ -66,11 +66,11 @@ TEST(TestTensorInterop, PytorchToCaffe2Op) {
   auto* c2_tensor_a = BlobSetTensor(workspace.CreateBlob("a"), at_tensor_a.getIntrusivePtr());
   auto* c2_tensor_b = BlobSetTensor(workspace.CreateBlob("b"), at_tensor_b.getIntrusivePtr());
 
-  // Test Alias
+  // Test ShareData as well
   {
-    caffe2::Tensor c2_tensor_from_aten(at_tensor_c.getIntrusivePtr());
-    BlobSetTensor(workspace.CreateBlob("c"), c2_tensor_from_aten.Alias());
-
+    auto c2_tensor_c = XBlobGetMutableTensor(workspace.CreateBlob("c"), {0}, at::kCPU);
+    c2_tensor_c.ResizeLike(at_tensor_c.getIntrusivePtr());
+    c2_tensor_c.ShareData(at_tensor_c.getIntrusivePtr());
   }
 
   {
index 7e7ea61..5d8a3a0 100644 (file)
@@ -1068,6 +1068,47 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
     storage_offset_ = 0;
   }
 
+  /**
+   * @brief Shares the data with another tensor.
+   *
+   * To share data between two tensors, the sizes of the two tensors must be
+   * equal already. The reason we do not implicitly do a Resize to make the two
+   * tensors have the same shape is that we want to allow tensors of different
+   * shapes but the same number of items to still be able to share data. This
+   * allows one to e.g. have a n-dimensional Tensor and a flattened version
+   * sharing the same underlying storage.
+   *
+   * The source tensor should already have its data allocated.
+   */
+  void ShareData(const TensorImpl& src) {
+    // Right now, we are assuming the device_type are the same, since it is
+    // inherently the same in the non-templatized code. We should probably add
+    // an assert here which might affect perf a little bit.
+    AT_ASSERTM(
+        src.numel_ == numel_,
+        "Size mismatch - did you call reshape before sharing the data?");
+    // It is possible that the source tensor hasn't called mutable_data() yet,
+    // in which case ShareData() doesn't make much sense since we don't really
+    // know what to share yet.
+    // TODO: Add the assert after all uninitialized states are eliminated
+    // AT_ASSERTM(src.dtype_initialized(),
+    //            "Source tensor don't have a data type (did you call mutable_data<T> on the tensor?)");
+    if (!src.dtype_initialized()) {
+      C10_LOG_EVERY_MS(WARNING, 1000) <<
+                   "Source tensor don't have a data type (did you call mutable_data<T> on the tensor?)";
+    }
+    AT_ASSERTM(
+        src.storage_initialized(),
+        "Source tensor has no content and has size > 0");
+    // Finally, do sharing.
+    /* Since we create new Storage whenever we need to change data_type/capacity
+     * this still keeps the original semantics
+     */
+    storage_ = src.storage();
+    data_type_ = src.dtype();
+    storage_offset_ = src.storage_offset();
+  }
+
   void ShareExternalPointer(
       DataPtr&& data_ptr,
       const caffe2::TypeMeta& data_type,
index 329b667..07e5638 100644 (file)
@@ -61,21 +61,22 @@ TYPED_TEST(TensorGPUTest, TensorInitializedNonEmpty) {
   EXPECT_TRUE(tensor.data<TypeParam>() != nullptr);
 }
 
-TYPED_TEST(TensorGPUTest, TensorAlias) {
+TYPED_TEST(TensorGPUTest, TensorShareData) {
   if (!HasCudaGPU()) return;
   vector<int> dims(3);
   dims[0] = 2;
   dims[1] = 3;
   dims[2] = 5;
   Tensor tensor(dims, CUDA);
+  Tensor other_tensor(dims, CUDA);
   EXPECT_TRUE(tensor.mutable_data<TypeParam>() != nullptr);
-  Tensor other_tensor = tensor.Alias();
+  other_tensor.ShareData(tensor);
   EXPECT_TRUE(tensor.data<TypeParam>() != nullptr);
   EXPECT_TRUE(other_tensor.data<TypeParam>() != nullptr);
   EXPECT_EQ(tensor.data<TypeParam>(), other_tensor.data<TypeParam>());
 }
 
-TYPED_TEST(TensorGPUTest, TensorAliasCanUseDifferentShapes) {
+TYPED_TEST(TensorGPUTest, TensorShareDataCanUseDifferentShapes) {
   if (!HasCudaGPU()) return;
   vector<int> dims(3);
   dims[0] = 2;
@@ -84,9 +85,9 @@ TYPED_TEST(TensorGPUTest, TensorAliasCanUseDifferentShapes) {
   vector<int> alternate_dims(1);
   alternate_dims[0] = 2 * 3 * 5;
   Tensor tensor(dims, CUDA);
+  Tensor other_tensor(alternate_dims, CUDA);
   EXPECT_TRUE(tensor.mutable_data<TypeParam>() != nullptr);
-  Tensor other_tensor = tensor.Alias();
-  other_tensor.Resize(alternate_dims);
+  other_tensor.ShareData(tensor);
   EXPECT_EQ(other_tensor.dim(), 1);
   EXPECT_EQ(other_tensor.dim32(0), alternate_dims[0]);
   EXPECT_TRUE(tensor.data<TypeParam>() != nullptr);
@@ -94,15 +95,16 @@ TYPED_TEST(TensorGPUTest, TensorAliasCanUseDifferentShapes) {
   EXPECT_EQ(tensor.data<TypeParam>(), other_tensor.data<TypeParam>());
 }
 
-TYPED_TEST(TensorGPUTest, NoLongerAliasAfterNumelChanges) {
+TYPED_TEST(TensorGPUTest, NoLongerSharesAfterResize) {
   if (!HasCudaGPU()) return;
   vector<int> dims(3);
   dims[0] = 2;
   dims[1] = 3;
   dims[2] = 5;
   Tensor tensor(dims, CUDA);
+  Tensor other_tensor(dims, CUDA);
   EXPECT_TRUE(tensor.mutable_data<TypeParam>() != nullptr);
-  Tensor other_tensor = tensor.Alias();
+  other_tensor.ShareData(tensor);
   EXPECT_EQ(tensor.data<TypeParam>(), other_tensor.data<TypeParam>());
   auto* old_pointer = other_tensor.data<TypeParam>();
 
index 3dd8292..290e310 100644 (file)
@@ -212,7 +212,8 @@ TEST(TensorNonTypedTest, TensorChangeType) {
 
   // share the data with other tensor so that the pointer won't be reused
   // when we reallocate
-  Tensor other_tensor = tensor.Alias();
+  Tensor other_tensor(dims, CPU);
+  other_tensor.ShareData(tensor);
   // but double is bigger, so it should allocate a new one
   auto* doubleptr = tensor.mutable_data<double>();
   EXPECT_TRUE(doubleptr != (double*)ptr);
@@ -336,14 +337,15 @@ TYPED_TEST(TensorCPUTest, TensorInitializedScalar) {
   EXPECT_TRUE(tensor.data<TypeParam>() != nullptr);
 }
 
-TYPED_TEST(TensorCPUTest, TensorAlias) {
+TYPED_TEST(TensorCPUTest, TensorShareData) {
   vector<int> dims(3);
   dims[0] = 2;
   dims[1] = 3;
   dims[2] = 5;
   Tensor tensor(dims, CPU);
+  Tensor other_tensor(dims, CPU);
   EXPECT_TRUE(tensor.mutable_data<TypeParam>() != nullptr);
-  Tensor other_tensor = tensor.Alias();
+  other_tensor.ShareData(tensor);
   EXPECT_TRUE(tensor.data<TypeParam>() != nullptr);
   EXPECT_TRUE(other_tensor.data<TypeParam>() != nullptr);
   EXPECT_EQ(tensor.data<TypeParam>(), other_tensor.data<TypeParam>());
@@ -389,7 +391,7 @@ TYPED_TEST(TensorCPUTest, TensorShareDataRawPointerWithMeta) {
   }
 }
 
-TYPED_TEST(TensorCPUTest, TensorAliasCanUseDifferentShapes) {
+TYPED_TEST(TensorCPUTest, TensorShareDataCanUseDifferentShapes) {
   vector<int> dims(3);
   dims[0] = 2;
   dims[1] = 3;
@@ -397,9 +399,9 @@ TYPED_TEST(TensorCPUTest, TensorAliasCanUseDifferentShapes) {
   vector<int> alternate_dims(1);
   alternate_dims[0] = 2 * 3 * 5;
   Tensor tensor(dims, CPU);
+  Tensor other_tensor(alternate_dims, CPU);
   EXPECT_TRUE(tensor.mutable_data<TypeParam>() != nullptr);
-  Tensor other_tensor = tensor.Alias();
-  other_tensor.Resize(alternate_dims);
+  other_tensor.ShareData(tensor);
   EXPECT_EQ(other_tensor.dim(), 1);
   EXPECT_EQ(other_tensor.dim32(0), alternate_dims[0]);
   EXPECT_TRUE(tensor.data<TypeParam>() != nullptr);
@@ -413,14 +415,15 @@ TYPED_TEST(TensorCPUTest, TensorAliasCanUseDifferentShapes) {
 }
 
 
-TYPED_TEST(TensorCPUTest, NoLongerAliassAfterNumelChanges) {
+TYPED_TEST(TensorCPUTest, NoLongerSharesAfterResize) {
   vector<int> dims(3);
   dims[0] = 2;
   dims[1] = 3;
   dims[2] = 5;
   Tensor tensor(dims, CPU);
+  Tensor other_tensor(dims, CPU);
   EXPECT_TRUE(tensor.mutable_data<TypeParam>() != nullptr);
-  Tensor other_tensor = tensor.Alias();
+  other_tensor.ShareData(tensor);
   EXPECT_EQ(tensor.data<TypeParam>(), other_tensor.data<TypeParam>());
   auto* old_pointer = other_tensor.data<TypeParam>();
 
@@ -430,14 +433,15 @@ TYPED_TEST(TensorCPUTest, NoLongerAliassAfterNumelChanges) {
   EXPECT_NE(old_pointer, tensor.mutable_data<TypeParam>());
 }
 
-TYPED_TEST(TensorCPUTest, NoLongerAliasAfterFreeMemory) {
+TYPED_TEST(TensorCPUTest, NoLongerSharesAfterFreeMemory) {
   vector<int> dims(3);
   dims[0] = 2;
   dims[1] = 3;
   dims[2] = 5;
   Tensor tensor(dims, CPU);
+  Tensor other_tensor(dims, CPU);
   EXPECT_TRUE(tensor.mutable_data<TypeParam>() != nullptr);
-  Tensor other_tensor = tensor.Alias();
+  other_tensor.ShareData(tensor);
   EXPECT_EQ(tensor.data<TypeParam>(), other_tensor.data<TypeParam>());
   auto* old_pointer = other_tensor.data<TypeParam>();
 
index da3df5a..5e1a9ab 100644 (file)
@@ -223,12 +223,6 @@ class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
     return t;
   }
 
-  Tensor* OutputTensorAlias(int idx, const Tensor& src) {
-    return BlobSetTensor(OutputBlob(idx),
-                  src.Alias());
-  }
-
-
   template <typename T>
   inline T* Output(int idx, T* allocated) {
     outputs_.at(idx)->Reset(allocated);
@@ -794,8 +788,7 @@ class Operator : public OperatorBase {
   /* using override */ using OperatorBase::Output;                  \
   /* using override */ using OperatorBase::Input;                   \
   /* using override */ using OperatorBase::OutputSize;              \
-  /* using override */ using OperatorBase::IsInputOutputAlias;      \
-  /* using override */ using OperatorBase::OutputTensorAlias
+  /* using override */ using OperatorBase::IsInputOutputAlias
 
 #define USE_OPERATOR_FUNCTIONS(context)                     \
   USE_OPERATOR_BASE_FUNCTIONS;                              \
index 627293f..7a02659 100644 (file)
@@ -118,35 +118,6 @@ class CAFFE2_API Tensor final {
     return x;
   }
 
-  /**
-   * Clone self as a Tensor that share the same Storage,
-   * that is, both Tensors are views on the same Storage.
-   * If we change the sizes or strides of one Tensor, it
-   * does not affect the other Tensor that it shares Storage
-   * with.
-   * A similar yet different usage is `Tensor x = y;`, this
-   * will make x and y pointing to the same Tensor and resizing
-   * one of them will resize the other as well.
-   *
-   * TODO: Deduplicate this with THTensor_(newWithTensor)
-   * (exposed in ATen as at::alias but not otherwise available)
-   */
-  Tensor Alias() const {
-    Tensor x(sizes(), GetDevice());
-    if (!dtype_initialized()) {
-      C10_LOG_EVERY_MS(WARNING, 1000) <<
-                   "Cloning a tensor that don't have a data type (did you call mutable_data<T> on the tensor?)";
-    }
-    AT_ASSERTM(
-        storage_initialized(),
-        "Cloning a tensor that has no content and has size > 0");
-    // set_storage already sets data_type_ of TensorImpl
-    x.impl_->set_storage(storage());
-    x.impl_->set_storage_offset(impl_->storage_offset());
-    x.impl_->set_sizes_and_strides(sizes(), strides());
-    return x;
-  }
-
   DeviceType GetDeviceType() const {
     return impl_->device_type();
   }
@@ -324,6 +295,10 @@ class CAFFE2_API Tensor final {
     std::swap(*impl_.get(), *other.impl_.get());
   }
 
+  void ShareData(const Tensor& src) const {
+    impl_.get()->ShareData(*src.impl_.get());
+  }
+
   /**
    * @brief Shares the data with an externally managed pointer.
    *
@@ -499,7 +474,7 @@ class CAFFE2_API Tensor final {
     return impl_.get()->stride(dim);
   }
 
-  inline at::IntList strides() const {
+  inline at::IntList strides() {
     return impl_.get()->strides();
   }
 
index 4372807..e834ed9 100644 (file)
@@ -157,7 +157,9 @@ class C10_EXPORT IDEEPFallbackOp final : public IDEEPOperator {
           dtensor->CopyFrom(src);
         } else {
           dst->Reset(new Tensor(CPU));
-          BlobSetTensor(dst, src.Alias());
+          auto dtensor = BlobGetMutableTensor(dst, CPU);
+          dtensor->Resize(src_dims);
+          dtensor->ShareData(src);
         }
       }
     }
index a58ebc8..0876ad8 100644 (file)
@@ -487,21 +487,20 @@ bool SoftmaxWithLossGradientOp<float, CUDAContext>::RunOnDevice() {
   auto& d_avg_loss = Input(InputSize() - 1); // Gradient w.r.t. avg loss
   const float* weights = (InputSize() > 4 ? Input(2).data<float>() : NULL);
 
-  Tensor* dX;
-  if (only_loss_) {
-    // Memory saving trick to share the buffer with the softmax output.
-    // Softmax output is thus overwritten.
-    dX = OutputTensorAlias(0, P);
-    dX->ResizeLike(X);
-  } else {
-    dX = Output(0, X.sizes(), at::dtype<float>());
-  }
+  auto* dX = Output(0);
+  dX->ResizeLike(X);
 
   const auto canonical_axis = X.canonical_axis_index(axis_);
   int N, D;
   N = X.size_to_dim(canonical_axis); // batch size
   D = X.size_from_dim(canonical_axis);
 
+  if (only_loss_) {
+    // Memory saving trick to share the buffer with the softmax output.
+    // Softmax output is thus overwritten.
+    dX->ShareData(P);
+  }
+
   ReinitializeTensor(&total_weight_ptr_, {1}, at::dtype<float>().device(CUDA));
 
   if (label_prob_mode_) {
@@ -603,21 +602,20 @@ bool SpatialSoftmaxWithLossGradientOp<float, CUDAContext>::RunOnDevice() {
   auto& d_avg_loss = Input(InputSize() - 1); // Gradient w.r.t. avg loss
   const float* weights = (InputSize() > 4 ? Input(2).data<float>() : NULL);
 
-  Tensor* dX;
-  if (only_loss_) {
-    // Memory saving trick to share the buffer with the softmax output.
-    // Softmax output is thus overwritten.
-    dX = OutputTensorAlias(0, P);
-    dX->ResizeLike(X);
-  } else {
-    dX = Output(0, X.sizes(), at::dtype<float>());
-  }
+  auto* dX = Output(0);
+  dX->ResizeLike(X);
 
   const auto canonical_axis = X.canonical_axis_index(1);
   int N, D;
   N = X.dim32(0);
   D = X.dim32(1);
 
+  if (only_loss_) {
+    // Memory saving trick to share the buffer with the softmax output.
+    // Softmax output is thus overwritten.
+    dX->ShareData(P);
+  }
+
   ReinitializeTensor(&total_weight_ptr_, {1}, at::dtype<float>().device(CUDA));
   // Spatial mode, compute softmax for each x, y location
   CAFFE_ENFORCE_EQ(X.ndim(), 4);
index 0856229..f325e72 100644 (file)
@@ -7,9 +7,11 @@ namespace caffe2 {
 
 class StringJoinOpTest : public testing::Test {
  public:
-  bool runOp(const Tensor& input) {
+  bool runOp(const TensorCPU& input) {
     auto* blob = ws_.CreateBlob("X");
-    BlobSetTensor(blob, input.Alias());
+    auto* tensor = BlobGetMutableTensor(blob, CPU);
+    tensor->ResizeLike(input);
+    tensor->ShareData(input);
 
     OperatorDef def;
     def.set_name("test");
index dd34102..8426f49 100644 (file)
@@ -169,7 +169,8 @@ class AliasOp final : public Operator<Context> {
   bool RunOnDevice() override {
     auto& input = Input(0);
     CAFFE_ENFORCE_GE(input.numel(), 0, "Tensor is not initialized");
-    OutputTensorAlias(0, input);
+    Output(0)->ResizeLike(input);
+    Output(0)->ShareData(input);
     return true;
   }
 };
index da97d59..0d3bc33 100644 (file)
@@ -179,8 +179,10 @@ class PredictorTest : public testing::Test {
 TEST_F(PredictorTest, SimpleBatchSized) {
   auto inputData = randomTensor({1, 4}, ctx_.get());
   Predictor::TensorList input;
+  input.emplace_back(CPU);
   auto tensor = BlobGetMutableTensor(inputData.get(), CPU);
-  input.emplace_back(tensor->Alias());
+  input.back().ResizeLike(*tensor);
+  input.back().ShareData(*tensor);
   Predictor::TensorList output;
   (*p_)(input, &output);
   EXPECT_EQ(output.size(), 1);
@@ -193,8 +195,10 @@ TEST_F(PredictorTest, SimpleBatchSized) {
 TEST_F(PredictorTest, SimpleBatchSizedMapInput) {
   auto inputData = randomTensor({1, 4}, ctx_.get());
   Predictor::TensorMap input;
+  auto iter = input.emplace("data", Tensor(CPU));
   auto tensor = BlobGetMutableTensor(inputData.get(), CPU);
-  input.emplace("data", tensor->Alias());
+  iter.first->second.ResizeLike(*tensor);
+  iter.first->second.ShareData(*tensor);
 
   Predictor::TensorList output;
   (*p_)(input, &output);