Add device and dtype to storage. (#18749)
authorGregory Chanan <gchanan@fb.com>
Wed, 3 Apr 2019 14:52:54 +0000 (07:52 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 14:59:02 +0000 (07:59 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18749
ghimport-source-id: 9026a037f5e11cdb9ccd386f4b6b5768b9c3259b

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18751 Disallow changing the device of a tensor via set_.
* #18750 Use non-legacy constructors for tensor deserialization.
* **#18749 Add device and dtype to storage.**

The goal here is to fix our serialization, which currently depends on the legacy constructors.  Having dtype and device on Storage allows us to use the non-legacy constructors.

This fits somewhat along our goal of removing Storage, my having Storage act like a Tensor.

Differential Revision: D14729516

fbshipit-source-id: bf4a3e8669ad4859931f4a3fa56df605cbc08dcb

test/test_torch.py
torch/csrc/Storage.cpp
torch/csrc/cuda/Storage.cpp
torch/csrc/generic/Storage.cpp

index 97169ad..ffc8943 100644 (file)
@@ -9489,54 +9489,76 @@ class _TestTorchMixin(object):
         self.assertEqual(storage.size(), 6)
         self.assertEqual(storage.tolist(), [-1, 0, 1, 2, 3, 4])
         self.assertEqual(storage.type(), 'torch.IntStorage')
+        self.assertIs(storage.dtype, torch.int32)
 
         floatStorage = storage.float()
         self.assertEqual(floatStorage.size(), 6)
         self.assertEqual(floatStorage.tolist(), [-1, 0, 1, 2, 3, 4])
         self.assertEqual(floatStorage.type(), 'torch.FloatStorage')
         self.assertEqual(floatStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+        self.assertIs(floatStorage.dtype, torch.float32)
 
         halfStorage = storage.half()
         self.assertEqual(halfStorage.size(), 6)
         self.assertEqual(halfStorage.tolist(), [-1, 0, 1, 2, 3, 4])
         self.assertEqual(halfStorage.type(), 'torch.HalfStorage')
         self.assertEqual(halfStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+        self.assertIs(halfStorage.dtype, torch.float16)
 
         longStorage = storage.long()
         self.assertEqual(longStorage.size(), 6)
         self.assertEqual(longStorage.tolist(), [-1, 0, 1, 2, 3, 4])
         self.assertEqual(longStorage.type(), 'torch.LongStorage')
         self.assertEqual(longStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+        self.assertIs(longStorage.dtype, torch.int64)
 
         shortStorage = storage.short()
         self.assertEqual(shortStorage.size(), 6)
         self.assertEqual(shortStorage.tolist(), [-1, 0, 1, 2, 3, 4])
         self.assertEqual(shortStorage.type(), 'torch.ShortStorage')
         self.assertEqual(shortStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+        self.assertIs(shortStorage.dtype, torch.int16)
 
         doubleStorage = storage.double()
         self.assertEqual(doubleStorage.size(), 6)
         self.assertEqual(doubleStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
         self.assertEqual(doubleStorage.type(), 'torch.DoubleStorage')
         self.assertEqual(doubleStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+        self.assertIs(doubleStorage.dtype, torch.float64)
 
         charStorage = storage.char()
         self.assertEqual(charStorage.size(), 6)
         self.assertEqual(charStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
         self.assertEqual(charStorage.type(), 'torch.CharStorage')
         self.assertEqual(charStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+        self.assertIs(charStorage.dtype, torch.int8)
 
         byteStorage = storage.byte()
         self.assertEqual(byteStorage.size(), 6)
         self.assertEqual(byteStorage.tolist(), [255, 0, 1, 2, 3, 4])
         self.assertEqual(byteStorage.type(), 'torch.ByteStorage')
         self.assertEqual(byteStorage.int().tolist(), [255, 0, 1, 2, 3, 4])
+        self.assertIs(byteStorage.dtype, torch.uint8)
 
         boolStorage = storage.bool()
         self.assertEqual(boolStorage.size(), 6)
         self.assertEqual(boolStorage.tolist(), [True, False, True, True, True, True])
         self.assertEqual(boolStorage.type(), 'torch.BoolStorage')
         self.assertEqual(boolStorage.int().tolist(), [1, 0, 1, 1, 1, 1])
+        self.assertIs(boolStorage.dtype, torch.bool)
+
+    def test_storage_device(self):
+        devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
+        for device in devices:
+            x = torch.tensor([], device=device)
+            self.assertEqual(x.dtype, x.storage().dtype)
+
+    @unittest.skipIf(torch.cuda.device_count() < 2, 'less than 2 GPUs detected')
+    def test_storage_multigpu(self):
+        devices = ['cuda:0', 'cuda:1']
+        for device in devices:
+            x = torch.tensor([], device=device)
+            self.assertEqual(x.dtype, x.storage().dtype)
 
     @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows")
     def test_from_file(self):
index b150dad..fdc389d 100644 (file)
@@ -17,6 +17,8 @@
 #include <torch/csrc/copy_utils.h>
 #include <torch/csrc/DynamicTypes.h>
 #include <torch/csrc/CudaIPCTypes.h>
+#include <torch/csrc/Device.h>
+#include <torch/csrc/autograd/utils/wrap_outputs.h>
 
 #include <torch/csrc/generic/Storage.cpp>
 #include <TH/THGenerateAllTypes.h>
index 05c8645..e5d281e 100644 (file)
@@ -13,6 +13,8 @@
 #include <torch/csrc/copy_utils.h>
 #include <torch/csrc/DynamicTypes.h>
 #include <torch/csrc/CudaIPCTypes.h>
+#include <torch/csrc/Device.h>
+#include <torch/csrc/autograd/utils/wrap_outputs.h>
 
 #define THC_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
 #include <THC/THCGenerateAllTypes.h>
index bdd872e..4e9382f 100644 (file)
@@ -283,6 +283,28 @@ static struct PyMemberDef THPStorage_(members)[] = {
   {nullptr}
 };
 
+static PyObject * THPStorage_(device)(THPStorage* self) {
+  HANDLE_TH_ERRORS
+  return THPDevice_New(self->cdata->device());
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * THPStorage_(dtype)(THPStorage *self)
+{
+  HANDLE_TH_ERRORS
+  return torch::autograd::utils::wrap(
+      torch::getDtype(at::typeMetaToScalarType(self->cdata->dtype())));
+  END_HANDLE_TH_ERRORS
+}
+
+typedef PyObject *(*getter)(PyObject *, void *);
+
+static struct PyGetSetDef THPStorage_(properties)[] = {
+  {"device", (getter)THPStorage_(device), nullptr, nullptr, nullptr},
+  {"dtype",  (getter)THPStorage_(dtype), nullptr, nullptr, nullptr},
+  {nullptr}
+};
+
 extern THPCopyList THWStorage_(copy_functions);
 THPCopyList THWStorage_(copy_functions);
 
@@ -346,6 +368,7 @@ bool THPStorage_(init)(PyObject *module)
 
   THPStorageType.tp_methods = methods.data();
   THPStorageType.tp_members = THPStorage_(members);
+  THPStorageType.tp_getset = THPStorage_(properties);
   if (PyType_Ready(&THPStorageType) < 0)
     return false;
   Py_INCREF(&THPStorageType);