From 2113ea6fbf20cb820ea6a504b1dfe30297ce4d7c Mon Sep 17 00:00:00 2001 From: Gregory Chanan Date: Wed, 3 Apr 2019 07:52:54 -0700 Subject: [PATCH] Add device and dtype to storage. (#18749) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18749 ghimport-source-id: 9026a037f5e11cdb9ccd386f4b6b5768b9c3259b Stack from [ghstack](https://github.com/ezyang/ghstack): * #18751 Disallow changing the device of a tensor via set_. * #18750 Use non-legacy constructors for tensor deserialization. * **#18749 Add device and dtype to storage.** The goal here is to fix our serialization, which currently depends on the legacy constructors. Having dtype and device on Storage allows us to use the non-legacy constructors. This fits somewhat along our goal of removing Storage, my having Storage act like a Tensor. Differential Revision: D14729516 fbshipit-source-id: bf4a3e8669ad4859931f4a3fa56df605cbc08dcb --- test/test_torch.py | 22 ++++++++++++++++++++++ torch/csrc/Storage.cpp | 2 ++ torch/csrc/cuda/Storage.cpp | 2 ++ torch/csrc/generic/Storage.cpp | 23 +++++++++++++++++++++++ 4 files changed, 49 insertions(+) diff --git a/test/test_torch.py b/test/test_torch.py index 97169ad..ffc8943 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -9489,54 +9489,76 @@ class _TestTorchMixin(object): self.assertEqual(storage.size(), 6) self.assertEqual(storage.tolist(), [-1, 0, 1, 2, 3, 4]) self.assertEqual(storage.type(), 'torch.IntStorage') + self.assertIs(storage.dtype, torch.int32) floatStorage = storage.float() self.assertEqual(floatStorage.size(), 6) self.assertEqual(floatStorage.tolist(), [-1, 0, 1, 2, 3, 4]) self.assertEqual(floatStorage.type(), 'torch.FloatStorage') self.assertEqual(floatStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertIs(floatStorage.dtype, torch.float32) halfStorage = storage.half() self.assertEqual(halfStorage.size(), 6) self.assertEqual(halfStorage.tolist(), [-1, 0, 1, 2, 3, 4]) self.assertEqual(halfStorage.type(), 'torch.HalfStorage') self.assertEqual(halfStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertIs(halfStorage.dtype, torch.float16) longStorage = storage.long() self.assertEqual(longStorage.size(), 6) self.assertEqual(longStorage.tolist(), [-1, 0, 1, 2, 3, 4]) self.assertEqual(longStorage.type(), 'torch.LongStorage') self.assertEqual(longStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertIs(longStorage.dtype, torch.int64) shortStorage = storage.short() self.assertEqual(shortStorage.size(), 6) self.assertEqual(shortStorage.tolist(), [-1, 0, 1, 2, 3, 4]) self.assertEqual(shortStorage.type(), 'torch.ShortStorage') self.assertEqual(shortStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertIs(shortStorage.dtype, torch.int16) doubleStorage = storage.double() self.assertEqual(doubleStorage.size(), 6) self.assertEqual(doubleStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0]) self.assertEqual(doubleStorage.type(), 'torch.DoubleStorage') self.assertEqual(doubleStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertIs(doubleStorage.dtype, torch.float64) charStorage = storage.char() self.assertEqual(charStorage.size(), 6) self.assertEqual(charStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0]) self.assertEqual(charStorage.type(), 'torch.CharStorage') self.assertEqual(charStorage.int().tolist(), [-1, 0, 1, 2, 3, 4]) + self.assertIs(charStorage.dtype, torch.int8) byteStorage = storage.byte() self.assertEqual(byteStorage.size(), 6) self.assertEqual(byteStorage.tolist(), [255, 0, 1, 2, 3, 4]) self.assertEqual(byteStorage.type(), 'torch.ByteStorage') self.assertEqual(byteStorage.int().tolist(), [255, 0, 1, 2, 3, 4]) + self.assertIs(byteStorage.dtype, torch.uint8) boolStorage = storage.bool() self.assertEqual(boolStorage.size(), 6) self.assertEqual(boolStorage.tolist(), [True, False, True, True, True, True]) self.assertEqual(boolStorage.type(), 'torch.BoolStorage') self.assertEqual(boolStorage.int().tolist(), [1, 0, 1, 1, 1, 1]) + self.assertIs(boolStorage.dtype, torch.bool) + + def test_storage_device(self): + devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda'] + for device in devices: + x = torch.tensor([], device=device) + self.assertEqual(x.dtype, x.storage().dtype) + + @unittest.skipIf(torch.cuda.device_count() < 2, 'less than 2 GPUs detected') + def test_storage_multigpu(self): + devices = ['cuda:0', 'cuda:1'] + for device in devices: + x = torch.tensor([], device=device) + self.assertEqual(x.dtype, x.storage().dtype) @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") def test_from_file(self): diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index b150dad..fdc389d 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -17,6 +17,8 @@ #include #include #include +#include +#include #include #include diff --git a/torch/csrc/cuda/Storage.cpp b/torch/csrc/cuda/Storage.cpp index 05c8645..e5d281e 100644 --- a/torch/csrc/cuda/Storage.cpp +++ b/torch/csrc/cuda/Storage.cpp @@ -13,6 +13,8 @@ #include #include #include +#include +#include #define THC_GENERIC_FILE "torch/csrc/generic/Storage.cpp" #include diff --git a/torch/csrc/generic/Storage.cpp b/torch/csrc/generic/Storage.cpp index bdd872e..4e9382f 100644 --- a/torch/csrc/generic/Storage.cpp +++ b/torch/csrc/generic/Storage.cpp @@ -283,6 +283,28 @@ static struct PyMemberDef THPStorage_(members)[] = { {nullptr} }; +static PyObject * THPStorage_(device)(THPStorage* self) { + HANDLE_TH_ERRORS + return THPDevice_New(self->cdata->device()); + END_HANDLE_TH_ERRORS +} + +static PyObject * THPStorage_(dtype)(THPStorage *self) +{ + HANDLE_TH_ERRORS + return torch::autograd::utils::wrap( + torch::getDtype(at::typeMetaToScalarType(self->cdata->dtype()))); + END_HANDLE_TH_ERRORS +} + +typedef PyObject *(*getter)(PyObject *, void *); + +static struct PyGetSetDef THPStorage_(properties)[] = { + {"device", (getter)THPStorage_(device), nullptr, nullptr, nullptr}, + {"dtype", (getter)THPStorage_(dtype), nullptr, nullptr, nullptr}, + {nullptr} +}; + extern THPCopyList THWStorage_(copy_functions); THPCopyList THWStorage_(copy_functions); @@ -346,6 +368,7 @@ bool THPStorage_(init)(PyObject *module) THPStorageType.tp_methods = methods.data(); THPStorageType.tp_members = THPStorage_(members); + THPStorageType.tp_getset = THPStorage_(properties); if (PyType_Ready(&THPStorageType) < 0) return false; Py_INCREF(&THPStorageType); -- 2.7.4