Disallow changing the device of a tensor via set_. (#18832)
authorGregory Chanan <gchanan@fb.com>
Thu, 4 Apr 2019 18:12:13 +0000 (11:12 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 4 Apr 2019 18:15:37 +0000 (11:15 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18832
ghimport-source-id: fde4ad90541ba52dfa02bdd83466f17e6541e535

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18833 [STACK] Cache device on TensorImpl; clean up TensorImpl constructors.
* **#18832 [STACK] Disallow changing the device of a tensor via set_.**
* #18831 [STACK] Stop swapping in Storages of the wrong device for Tensors.

This is necessary to cache the device on a TensorImpl.

Differential Revision: D14766231

fbshipit-source-id: bba61634b2d6252ac0697b96033c9eea680956e8

aten/src/TH/THTensor.cpp
test/test_torch.py
torch/csrc/jit/import.cpp

index bef67f9..2be5ce7 100644 (file)
@@ -56,9 +56,10 @@ void THTensor_setStorageNd(THTensor *self, THStorage *storage, ptrdiff_t storage
   }
 
   /* storageOffset */
-  if(storageOffset < 0)
+  if(storageOffset < 0) {
     THError("Tensor: invalid storage offset");
-    self->set_storage_offset(storageOffset);
+  }
+  self->set_storage_offset(storageOffset);
 
   /* size and stride */
   THTensor_resizeNd(self, nDimension, size, stride);
@@ -160,5 +161,15 @@ void THTensor_stealAndSetStoragePtr(THTensor* tensor, THStorage* storage) {
   // Caffe2 might have tensors whose storages are null, but we
   // don't allow it in PyTorch.
   AT_ASSERT(storage);
+  // Caffe2 also has uninitialized dtype states, which we disallow here
+  AT_ASSERT(tensor->storage().dtype() == storage->dtype());
+
+  // We used to allow this, but this breaks device caching,
+  // see Note [We regret making Variable hold a Tensor]
+  // Let's put an actual error message for this one.
+  AT_CHECK(tensor->storage().device() == storage->device(),
+            "Attempted to set the storage of a tensor on device ", tensor->storage().device(),
+             " to a storage on different device ", storage->device(),
+            ".  This is no longer allowed; the devices must match.");
   tensor->set_storage(at::Storage(c10::intrusive_ptr<THStorage>::reclaim(storage)));
 }
index 29bb6b7..443b8db 100644 (file)
@@ -8352,6 +8352,42 @@ class _TestTorchMixin(object):
         t1.set_(t2)
         self.assertEqual(t1.storage()._cdata, t2.storage()._cdata)
 
+    def test_tensor_set_errors(self):
+        f_cpu = torch.randn((2, 3), dtype=torch.float32)
+        d_cpu = torch.randn((2, 3), dtype=torch.float64)
+
+        # change dtype
+        self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu.storage()))
+        self.assertRaises(RuntimeError,
+                          lambda: f_cpu.set_(d_cpu.storage(), 0, d_cpu.size(), d_cpu.stride()))
+        self.assertRaises(RuntimeError, lambda: f_cpu.set_(d_cpu))
+
+        # change device
+        if torch.cuda.is_available():
+            f_cuda = torch.randn((2, 3), dtype=torch.float32, device='cuda')
+
+            # cpu -> cuda
+            self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda.storage()))
+            self.assertRaises(RuntimeError,
+                              lambda: f_cpu.set_(f_cuda.storage(), 0, f_cuda.size(), f_cuda.stride()))
+            self.assertRaises(RuntimeError, lambda: f_cpu.set_(f_cuda))
+
+            # cuda -> cpu
+            self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu.storage()))
+            self.assertRaises(RuntimeError,
+                              lambda: f_cuda.set_(f_cpu.storage(), 0, f_cpu.size(), f_cpu.stride()))
+            self.assertRaises(RuntimeError, lambda: f_cuda.set_(f_cpu))
+
+    @unittest.skipIf(torch.cuda.device_count() < 2, 'less than 2 GPUs detected')
+    def test_tensor_set_errors_multigpu(self):
+        f_cuda0 = torch.randn((2, 3), dtype=torch.float32, device='cuda:0')
+        f_cuda1 = torch.randn((2, 3), dtype=torch.float32, device='cuda:1')
+
+        self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1.storage()))
+        self.assertRaises(RuntimeError,
+                          lambda: f_cuda0.set_(f_cuda1.storage(), 0, f_cuda1.size(), f_cuda1.stride()))
+        self.assertRaises(RuntimeError, lambda: f_cuda0.set_(f_cuda1))
+
     def test_equal(self):
         # Contiguous, 1D
         t1 = torch.Tensor((3, 4, 9, 10))
index fd63c33..d803886 100644 (file)
@@ -229,7 +229,7 @@ at::Tensor ScriptModuleDeserializer::loadTensor(
             .set_(storage_it->second, tensor_proto.offset(), dims, strides);
   } else if (device.type() == at::DeviceType::CUDA) {
     result =
-        at::empty({0}, at::CUDA(type).options())
+        at::empty({0}, c10::TensorOptions(type).device(storage_it->second.device()))
             .set_(storage_it->second, tensor_proto.offset(), dims, strides);
   }
   AT_ASSERT(result.defined());