[sharded_tensor] add readonly tensor properties (#63679)
authorWanchao Liang <wanchaol@fb.com>
Sat, 21 Aug 2021 05:15:55 +0000 (22:15 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sat, 21 Aug 2021 05:17:11 +0000 (22:17 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63679

This PR add read only tensor properties to sharded tensor, to match the torch.Tensor behaviors.

Test Plan: test_sharded_tensor_metadata

Reviewed By: pritamdamania87

Differential Revision: D30459343

fbshipit-source-id: 9aec8ecfe76479eed25f3b843495e5719ed2956d

test/distributed/_sharded_tensor/test_sharded_tensor.py
torch/distributed/_sharded_tensor/api.py

index 5067f30..26a176b 100644 (file)
@@ -174,19 +174,17 @@ class TestShardedTensorChunked(ShardedTensorTestBase, MultiProcessTestCase):
         sharded_tensor = _sharded_tensor.empty(spec, 10, 20, init_rrefs=True)
         sharded_tensor_metadata = sharded_tensor.metadata()
         self.assertEqual(torch.Size([10, 20]), sharded_tensor_metadata.size)
-        self.assertEqual(torch.float, sharded_tensor_metadata.dtype)
-        self.assertEqual(torch.strided, sharded_tensor_metadata.layout)
-        self.assertEqual(False, sharded_tensor_metadata.requires_grad)
-        self.assertEqual(torch.contiguous_format, sharded_tensor_metadata.memory_format)
-        self.assertEqual(False, sharded_tensor_metadata.pin_memory)
+        self.assertEqual(torch.float, sharded_tensor.dtype)
+        self.assertEqual(torch.strided, sharded_tensor.layout)
+        self.assertEqual(False, sharded_tensor.requires_grad)
+        self.assertTrue(sharded_tensor.is_contiguous())
+        self.assertFalse(sharded_tensor.is_pinned())
 
         sharded_tensor = _sharded_tensor.empty(spec, 10, 20, requires_grad=True, init_rrefs=True)
-        sharded_tensor_metadata = sharded_tensor.metadata()
-        self.assertEqual(True, sharded_tensor_metadata.requires_grad)
+        self.assertEqual(True, sharded_tensor.requires_grad)
 
         sharded_tensor = _sharded_tensor.empty(spec, 10, 20, dtype=torch.double, init_rrefs=True)
-        sharded_tensor_metadata = sharded_tensor.metadata()
-        self.assertEqual(torch.double, sharded_tensor_metadata.dtype)
+        self.assertEqual(torch.double, sharded_tensor.dtype)
 
         # Need CPU for pin_memory
         spec = ChunkShardingSpec(
@@ -200,8 +198,12 @@ class TestShardedTensorChunked(ShardedTensorTestBase, MultiProcessTestCase):
         )
 
         sharded_tensor = _sharded_tensor.empty(spec, 10, 20, pin_memory=True, init_rrefs=True)
-        sharded_tensor_metadata = sharded_tensor.metadata()
-        self.assertEqual(True, sharded_tensor_metadata.pin_memory)
+        self.assertEqual(True, sharded_tensor.is_pinned())
+
+        # test read only properties, they're read only as we can't simply change
+        # the global metadata without changing the underlying shard's properties
+        with self.assertRaisesRegex(AttributeError, "can't set attribute"):
+            sharded_tensor.requires_grad = True
 
     @with_comms
     @skip_if_lt_x_gpu(4)
@@ -782,19 +784,17 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase, MultiProcessTestCase):
         sharded_tensor = _sharded_tensor.empty(spec, 10, 10, init_rrefs=True)
         sharded_tensor_metadata = sharded_tensor.metadata()
         self.assertEqual(torch.Size([10, 10]), sharded_tensor_metadata.size)
-        self.assertEqual(torch.float, sharded_tensor_metadata.dtype)
-        self.assertEqual(torch.strided, sharded_tensor_metadata.layout)
-        self.assertEqual(False, sharded_tensor_metadata.requires_grad)
-        self.assertEqual(torch.contiguous_format, sharded_tensor_metadata.memory_format)
-        self.assertEqual(False, sharded_tensor_metadata.pin_memory)
+        self.assertEqual(torch.float, sharded_tensor.dtype)
+        self.assertEqual(torch.strided, sharded_tensor.layout)
+        self.assertEqual(False, sharded_tensor.requires_grad)
+        self.assertTrue(sharded_tensor.is_contiguous())
+        self.assertFalse(sharded_tensor.is_pinned())
 
         sharded_tensor = _sharded_tensor.empty(spec, 10, 10, requires_grad=True, init_rrefs=True)
-        sharded_tensor_metadata = sharded_tensor.metadata()
-        self.assertEqual(True, sharded_tensor_metadata.requires_grad)
+        self.assertEqual(True, sharded_tensor.requires_grad)
 
         sharded_tensor = _sharded_tensor.empty(spec, 10, 10, dtype=torch.double, init_rrefs=True)
-        sharded_tensor_metadata = sharded_tensor.metadata()
-        self.assertEqual(torch.double, sharded_tensor_metadata.dtype)
+        self.assertEqual(torch.double, sharded_tensor.dtype)
 
         # Need CPU for pin_memory
         spec = EnumerableShardingSpec([
@@ -821,8 +821,7 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase, MultiProcessTestCase):
         ])
 
         sharded_tensor = _sharded_tensor.empty(spec, 10, 10, pin_memory=True, init_rrefs=True)
-        sharded_tensor_metadata = sharded_tensor.metadata()
-        self.assertEqual(True, sharded_tensor_metadata.pin_memory)
+        self.assertTrue(sharded_tensor.is_pinned())
 
     @with_comms
     @skip_if_lt_x_gpu(4)
index 2b6720b..5f501b7 100644 (file)
@@ -551,6 +551,35 @@ class ShardedTensor(object):
         """
         return self._metadata.size
 
+    def is_pinned(self) -> bool:
+        """
+        Returns True if the sharded tensor (each local shard) resides in pinned memory.
+        """
+        return self._metadata.pin_memory
+
+    def is_contiguous(self) -> bool:
+        """
+        Returns True if the sharded tensor (each local shard) is contiguous in memory
+        in the order specified by memory format.
+        """
+        return self._metadata.memory_format == torch.contiguous_format
+
+    @property
+    def shape(self):
+        return self._metadata.size
+
+    @property
+    def requires_grad(self):
+        return self._metadata.requires_grad
+
+    @property
+    def dtype(self):
+        return self._metadata.dtype
+
+    @property
+    def layout(self):
+        return self._metadata.layout
+
     def _register_remote_shards(self, remote_shards: List[rpc.RRef[Shard]], rpc_rank: int):
         self._remote_shards[rpc_rank] = remote_shards