sharded_tensor = _sharded_tensor.empty(spec, 10, 20, init_rrefs=True)
sharded_tensor_metadata = sharded_tensor.metadata()
self.assertEqual(torch.Size([10, 20]), sharded_tensor_metadata.size)
- self.assertEqual(torch.float, sharded_tensor_metadata.dtype)
- self.assertEqual(torch.strided, sharded_tensor_metadata.layout)
- self.assertEqual(False, sharded_tensor_metadata.requires_grad)
- self.assertEqual(torch.contiguous_format, sharded_tensor_metadata.memory_format)
- self.assertEqual(False, sharded_tensor_metadata.pin_memory)
+ self.assertEqual(torch.float, sharded_tensor.dtype)
+ self.assertEqual(torch.strided, sharded_tensor.layout)
+ self.assertEqual(False, sharded_tensor.requires_grad)
+ self.assertTrue(sharded_tensor.is_contiguous())
+ self.assertFalse(sharded_tensor.is_pinned())
sharded_tensor = _sharded_tensor.empty(spec, 10, 20, requires_grad=True, init_rrefs=True)
- sharded_tensor_metadata = sharded_tensor.metadata()
- self.assertEqual(True, sharded_tensor_metadata.requires_grad)
+ self.assertEqual(True, sharded_tensor.requires_grad)
sharded_tensor = _sharded_tensor.empty(spec, 10, 20, dtype=torch.double, init_rrefs=True)
- sharded_tensor_metadata = sharded_tensor.metadata()
- self.assertEqual(torch.double, sharded_tensor_metadata.dtype)
+ self.assertEqual(torch.double, sharded_tensor.dtype)
# Need CPU for pin_memory
spec = ChunkShardingSpec(
)
sharded_tensor = _sharded_tensor.empty(spec, 10, 20, pin_memory=True, init_rrefs=True)
- sharded_tensor_metadata = sharded_tensor.metadata()
- self.assertEqual(True, sharded_tensor_metadata.pin_memory)
+ self.assertEqual(True, sharded_tensor.is_pinned())
+
+ # test read only properties, they're read only as we can't simply change
+ # the global metadata without changing the underlying shard's properties
+ with self.assertRaisesRegex(AttributeError, "can't set attribute"):
+ sharded_tensor.requires_grad = True
@with_comms
@skip_if_lt_x_gpu(4)
sharded_tensor = _sharded_tensor.empty(spec, 10, 10, init_rrefs=True)
sharded_tensor_metadata = sharded_tensor.metadata()
self.assertEqual(torch.Size([10, 10]), sharded_tensor_metadata.size)
- self.assertEqual(torch.float, sharded_tensor_metadata.dtype)
- self.assertEqual(torch.strided, sharded_tensor_metadata.layout)
- self.assertEqual(False, sharded_tensor_metadata.requires_grad)
- self.assertEqual(torch.contiguous_format, sharded_tensor_metadata.memory_format)
- self.assertEqual(False, sharded_tensor_metadata.pin_memory)
+ self.assertEqual(torch.float, sharded_tensor.dtype)
+ self.assertEqual(torch.strided, sharded_tensor.layout)
+ self.assertEqual(False, sharded_tensor.requires_grad)
+ self.assertTrue(sharded_tensor.is_contiguous())
+ self.assertFalse(sharded_tensor.is_pinned())
sharded_tensor = _sharded_tensor.empty(spec, 10, 10, requires_grad=True, init_rrefs=True)
- sharded_tensor_metadata = sharded_tensor.metadata()
- self.assertEqual(True, sharded_tensor_metadata.requires_grad)
+ self.assertEqual(True, sharded_tensor.requires_grad)
sharded_tensor = _sharded_tensor.empty(spec, 10, 10, dtype=torch.double, init_rrefs=True)
- sharded_tensor_metadata = sharded_tensor.metadata()
- self.assertEqual(torch.double, sharded_tensor_metadata.dtype)
+ self.assertEqual(torch.double, sharded_tensor.dtype)
# Need CPU for pin_memory
spec = EnumerableShardingSpec([
])
sharded_tensor = _sharded_tensor.empty(spec, 10, 10, pin_memory=True, init_rrefs=True)
- sharded_tensor_metadata = sharded_tensor.metadata()
- self.assertEqual(True, sharded_tensor_metadata.pin_memory)
+ self.assertTrue(sharded_tensor.is_pinned())
@with_comms
@skip_if_lt_x_gpu(4)