Stop swapping in Storages of the wrong device for Tensors. (#18831)
authorGregory Chanan <gchanan@fb.com>
Thu, 4 Apr 2019 13:19:54 +0000 (06:19 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 4 Apr 2019 13:25:33 +0000 (06:25 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18831
ghimport-source-id: 2741e0d70ebe2c2217572c3af54ddd9d2047e342

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18833 [STACK] Cache device on TensorImpl; clean up TensorImpl constructors.
* #18832 [STACK] Disallow changing the device of a tensor via set_.
* **#18831 [STACK] Stop swapping in Storages of the wrong device for Tensors.**

This is necessary to support device caching, see https://github.com/pytorch/pytorch/pull/18751 and https://github.com/pytorch/pytorch/pull/18578.

In library code, we potentially swap in Storages with the wrong device when device_guard is False.  This happens as follows with "view-like" operations.
1) We allocate a tensor on the 'wrong' device (because device_guard is false).
2) We swap out the 'wrong' storage with the 'right' storage using e.g. THCTensor_setStorage.

Instead, we can just construct the Tensor with the correct Storage from the beginning.  This is what we do with 'view'.

Note there are two other "view-like" cases where this happens:
1) unfold
2) set_()

Because these aren't performance critical, I just added the device_guard instead of applying the above correction.

For completeness, this also includes a test that all `device_guard: false` functions behave properly under these conditions.

Reviewed By: dzhulgakov

Differential Revision: D14766232

fbshipit-source-id: 0865c3ddae3f415df5da7a9869b1ea9f210e81bc

aten/src/ATen/native/native_functions.yaml
aten/src/TH/THStorageFunctions.hpp
aten/src/THC/generic/THCTensor.cpp
test/test_torch.py

index 439745d..e0c497a 100644 (file)
 - func: set_(Tensor(a!) self) -> Tensor(a!)
   matches_jit_signature: True
   variants: method
-  device_guard: False
 
 - func: is_set_to(Tensor self, Tensor tensor) -> bool
   matches_jit_signature: True
 - func: unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)
   matches_jit_signature: True
   variants: method
-  device_guard: False
 
 - func: equal(Tensor self, Tensor other) -> bool
   matches_jit_signature: True
index aacccbd..8f381ff 100644 (file)
@@ -3,6 +3,7 @@
 // STOP!!! Thinking of including this header directly?  Please
 // read Note [TH abstraction violation]
 
+#include <c10/core/Storage.h>
 #include <c10/core/StorageImpl.h>
 #include <TH/THStorageFunctions.h>
 
@@ -35,3 +36,10 @@ TH_API ptrdiff_t THStorage_size(const THStorage *self);
 
 TH_API void THStorage_retain(THStorage *storage);
 TH_API void THStorage_resize(THStorage *storage, ptrdiff_t size);
+
+// Returns a Storage given a StorageImpl. The StorageImpl remains valid after the
+// the Storage is destroyed.
+inline c10::Storage THStorage_wrap(THStorage* storage) {
+  c10::raw::intrusive_ptr::incref(storage);
+  return c10::Storage(c10::intrusive_ptr<c10::StorageImpl>::reclaim(storage));
+}
index b4a31d1..f704470 100644 (file)
@@ -203,7 +203,6 @@ THCTensor *THCTensor_(newTranspose)(THCState *state, THCTensor *tensor, int dime
 THCTensor *THCTensor_(newView)(THCState *state, THCTensor *tensor, at::IntArrayRef size)
 {
   ptrdiff_t numel = THCTensor_(nElement)(state, tensor);
-  THCTensor *self = THCTensor_(new)(state);
   auto inferred_size = at::infer_size(size, numel);
   auto stride = THTensor_compute_stride(tensor->sizes(),
                                         tensor->strides(),
@@ -212,7 +211,21 @@ THCTensor *THCTensor_(newView)(THCState *state, THCTensor *tensor, at::IntArrayR
     "not compatible with input tensor's size and stride (at least one dimension spans "
     "across two contiguous subspaces). Call .contiguous() before .view().");
   auto stride_value = *stride;
+
+  // NOTE: This path of constructing the Tensor directly with the viewed Storage is necessary
+  // to allow `view` not to have a device_guard.  Taking the common TH path of allocating a storage
+  // on the current device [via THCTensor_(new)] and then swapping out the storage later can change
+  // the device out from under the tensor.  Having the device be consistent through a Tensor's lifetime
+  // is an invariant we wish to keep to support caching, simplicity, etc.
+  auto storage = THStorage_wrap(tensor->storage().unsafeGetStorageImpl());
+  THCTensor *self = c10::make_intrusive<at::TensorImpl, at::UndefinedTensorImpl>(
+    std::move(storage),
+    at::CUDATensorId(),
+    false
+  ).release();
+
   THCTensor_setStorage(state, self, THTensor_getStoragePtr(tensor), tensor->storage_offset(), inferred_size, stride_value);
+
   return self;
 }
 
index 026ce40..29bb6b7 100644 (file)
@@ -2665,6 +2665,95 @@ class _TestTorchMixin(object):
         for i in range(num_elements):
             self.assertEqual(r[i], rqr[i])
 
+    @unittest.skipIf(torch.cuda.device_count() < 2, 'fewer than 2 GPUs detected')
+    def test_device_guard(self):
+        # verify that all operators with `device_guard: False` behave properly with multiple devices.
+        # TODO: if we had operator introspection we could figure out this set of operators automatically...
+        current_device = torch.cuda.current_device()
+        device = torch.device('cuda:1') if current_device == 0 else torch.device('cuda:0')
+        x = torch.randn((1, 2, 3), device=device)
+        y = torch.zeros((1, 3, 2), device=device)
+        scalar = torch.tensor(5, device=device)
+
+        # property ops
+        torch.cudnn_is_acceptable(x)
+        x.is_distributed()
+        x.is_floating_point()
+        x.is_complex()
+        x.is_same_size(y)
+        x.is_signed()
+        x.size(0)
+        x.stride(0)
+        x.numel()
+        x.is_set_to(y)
+        x.data_ptr()
+        scalar.is_nonzero()
+
+        # sparse property ops
+        y[0][1] = 5
+        y_sparse = y.to_sparse()
+        y_sparse.sparse_dim()
+        y_sparse._dimI()
+        y_sparse.dense_dim()
+        y_sparse._dimV()
+        y_sparse._nnz()
+        y_sparse.is_coalesced()
+        y_sparse._indices()
+        y_sparse._values()
+        y_sparse.indices()
+        y_sparse.values()
+
+        # in-place ops
+        def inplace():
+            return torch.randn((1, 2, 3), device=device)
+        inplace().as_strided_(y.size(), y.stride())
+        inplace().resize_(y.size())
+        inplace().squeeze_()
+        inplace().squeeze_(0)
+        inplace().unsqueeze_(2)
+        inplace().transpose_(1, 2)
+        inplace().squeeze_().t_()
+        inplace().set_(x.storage())
+        inplace().set_(x.storage(), x.storage_offset(), x.size(), x.stride())
+        inplace().set_(x)
+        inplace().set_()
+        y_sparse._coalesced_(True)
+
+        # shape modification
+        x.as_strided(y.size(), y.stride())
+        x.expand((5, 2, 3))
+        x.expand_as(x)
+        x.sum_to_size((1,))
+        torch.broadcast_tensors(x , x)
+        x.reshape((1, 3, 2))
+        x.reshape_as(y)
+        x.squeeze()
+        x.squeeze(0)
+        x.squeeze().t()
+        x.transpose(1, 2)
+        x.unsqueeze(2)
+        x.view((1, 3, 2))
+        x.view_as(y)
+
+        # chunk, split, etc.
+        x.chunk(2, dim=1)
+        x.split(1, dim=2)
+        x.split_with_sizes([1, 2], dim=2)
+        x.unfold(dimension=2, size=1, step=1)
+
+        x.narrow(1, 1, 1)
+        x.select(1, 1)
+        torch.isnan(x)
+
+        torch.empty((1, 3, 2), out=y)
+        torch.empty_like(x)
+        torch.empty_like(x, dtype=torch.int64)
+
+        # to
+        x.to(x)
+        x.to(y)
+        x.to(x, copy=True)
+
     def test_to(self):
         def test_copy_behavior(t, non_blocking=False):
             self.assertIs(t, t.to(t, non_blocking=non_blocking))