Move TensorImpl::CopyFrom to caffe2::Tensor (1/2) (#14656)
authorSebastian Messmer <messmer@fb.com>
Fri, 14 Dec 2018 02:38:54 +0000 (18:38 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 14 Dec 2018 02:41:23 +0000 (18:41 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14656

This diff doesn't move it yet, but prepares it to be moved, i.e. removes all access to class internals.

dzhulgakov: Please comment on if you think it still makes sense to land this even though it's not blocking anymore since we're going to move at::CopyBytes anyhow.

ezyang: There's some changes in the implementation, especially handling undefined dest tensors. Please review carefully.

Reviewed By: ezyang

Differential Revision: D13287688

fbshipit-source-id: 17800ca8a79ab1633f23be58d96f99a160d8ed24

c10/core/TensorImpl.h
caffe2/core/tensor.h

index 0d617bc..310900c 100644 (file)
@@ -830,26 +830,23 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
         src.storage_initialized(),
         "Cannot copy from an uninitialized Tensor");
 
-    if ((void*)&src == (void*)this) {
+    if (&src == this) {
       return;
     }
 
     // Test if we need to allocate a new storage
     // Uninitialized storages are guaranteed to be uniquely owned,
     // so we don't need to swap in this case.
-    if (storage_initialized()) {
-      // If the dtype changed, we need to reallocate storage.
-      if (data_type_ != src.dtype()) {
-        // NB: copy preserves device_type
-        // This storage will get initialized by the mutable_data call below.
-        storage_ = Storage(device_type(), src.dtype());
-      }
+    // If the dtype changed, we need to reallocate storage.
+    if (dtype() != src.dtype()) {
+      // NB: copy preserves device_type
+      // This storage will get initialized by the mutable_data call below.
+      set_storage(at::Storage(device_type(), src.dtype()));
     }
-    data_type_ = src.dtype();
     Resize(src.sizes());
 
     if (numel() > 0) {
-      if (data_type_.copy()) {
+      if (dtype().copy()) {
         AT_ASSERTM(
             device_type() == DeviceType::CPU,
             "In CopyFrom source and dest tensors must both be CPU for "
@@ -860,7 +857,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
             "In CopyFrom source and dest tensors must both be CPU for "
             "non-POD copy, but src tensor was ",
             src.device_type());
-        data_type_.copy()(src.data(), raw_mutable_data(data_type_), numel());
+        dtype().copy()(src.data(), raw_mutable_data(data_type_), numel());
       } else {
         // The following copy uses the current (thread local) stream for copying
         // and also takes the GPU id from the device() field passed in.
@@ -873,7 +870,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
         // properly.
         //
         // note: raw_mutable_data initializes device here
-        void* new_data = raw_mutable_data(data_type_);
+        void* new_data = raw_mutable_data(dtype());
         CopyBytes(
             numel() * itemsize(),
             src.data(),
@@ -1233,6 +1230,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
     return data_type_ != caffe2::TypeMeta();
   }
 
+  void set_storage(at::Storage storage) {
+    storage_ = std::move(storage);
+    data_type_ = storage_.dtype();
+  }
+
 private:
 
   // The Caffe2 Resize() method supports being called both as Resize({2,2}) as
index fc24e5d..c12eb81 100644 (file)
@@ -105,7 +105,7 @@ class CAFFE2_API Tensor final {
   }
 
   void CopyFrom(const Tensor& src, bool async = false) const {
-    impl_.get()->CopyFrom(*src.impl_.get(), async);
+    impl_.get()->CopyFrom(*src.impl_, async);
   }
 
   /**