self.assertEqual(storage.size(), 6)
self.assertEqual(storage.tolist(), [-1, 0, 1, 2, 3, 4])
self.assertEqual(storage.type(), 'torch.IntStorage')
+ self.assertIs(storage.dtype, torch.int32)
floatStorage = storage.float()
self.assertEqual(floatStorage.size(), 6)
self.assertEqual(floatStorage.tolist(), [-1, 0, 1, 2, 3, 4])
self.assertEqual(floatStorage.type(), 'torch.FloatStorage')
self.assertEqual(floatStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+ self.assertIs(floatStorage.dtype, torch.float32)
halfStorage = storage.half()
self.assertEqual(halfStorage.size(), 6)
self.assertEqual(halfStorage.tolist(), [-1, 0, 1, 2, 3, 4])
self.assertEqual(halfStorage.type(), 'torch.HalfStorage')
self.assertEqual(halfStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+ self.assertIs(halfStorage.dtype, torch.float16)
longStorage = storage.long()
self.assertEqual(longStorage.size(), 6)
self.assertEqual(longStorage.tolist(), [-1, 0, 1, 2, 3, 4])
self.assertEqual(longStorage.type(), 'torch.LongStorage')
self.assertEqual(longStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+ self.assertIs(longStorage.dtype, torch.int64)
shortStorage = storage.short()
self.assertEqual(shortStorage.size(), 6)
self.assertEqual(shortStorage.tolist(), [-1, 0, 1, 2, 3, 4])
self.assertEqual(shortStorage.type(), 'torch.ShortStorage')
self.assertEqual(shortStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+ self.assertIs(shortStorage.dtype, torch.int16)
doubleStorage = storage.double()
self.assertEqual(doubleStorage.size(), 6)
self.assertEqual(doubleStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
self.assertEqual(doubleStorage.type(), 'torch.DoubleStorage')
self.assertEqual(doubleStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+ self.assertIs(doubleStorage.dtype, torch.float64)
charStorage = storage.char()
self.assertEqual(charStorage.size(), 6)
self.assertEqual(charStorage.tolist(), [-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
self.assertEqual(charStorage.type(), 'torch.CharStorage')
self.assertEqual(charStorage.int().tolist(), [-1, 0, 1, 2, 3, 4])
+ self.assertIs(charStorage.dtype, torch.int8)
byteStorage = storage.byte()
self.assertEqual(byteStorage.size(), 6)
self.assertEqual(byteStorage.tolist(), [255, 0, 1, 2, 3, 4])
self.assertEqual(byteStorage.type(), 'torch.ByteStorage')
self.assertEqual(byteStorage.int().tolist(), [255, 0, 1, 2, 3, 4])
+ self.assertIs(byteStorage.dtype, torch.uint8)
boolStorage = storage.bool()
self.assertEqual(boolStorage.size(), 6)
self.assertEqual(boolStorage.tolist(), [True, False, True, True, True, True])
self.assertEqual(boolStorage.type(), 'torch.BoolStorage')
self.assertEqual(boolStorage.int().tolist(), [1, 0, 1, 1, 1, 1])
+ self.assertIs(boolStorage.dtype, torch.bool)
+
+ def test_storage_device(self):
+ devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
+ for device in devices:
+ x = torch.tensor([], device=device)
+ self.assertEqual(x.dtype, x.storage().dtype)
+
+ @unittest.skipIf(torch.cuda.device_count() < 2, 'less than 2 GPUs detected')
+ def test_storage_multigpu(self):
+ devices = ['cuda:0', 'cuda:1']
+ for device in devices:
+ x = torch.tensor([], device=device)
+ self.assertEqual(x.dtype, x.storage().dtype)
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows")
def test_from_file(self):