More sharded_tensor creation ops: harded_tensor.zeros, sharded_tensor.full, sharded_t...
authorBo Wang <bowangbj@fb.com>
Thu, 26 Aug 2021 23:00:16 +0000 (16:00 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 23:01:38 +0000 (16:01 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63732

Test Plan:
$ python test/distributed/_sharded_tensor/test_sharded_tensor.py  --v

$ python test/distributed/_sharded_tensor/test_sharded_tensor.py TestCreateTensorFromParams --v
$ python test/distributed/_sharded_tensor/test_sharded_tensor.py TestShardedTensorChunked --v

Imported from OSS

Differential Revision:
D30472621
D30472621

Reviewed By: pritamdamania87

Pulled By: bowangbj

fbshipit-source-id: fd8ebf9b815fdc292ad1aad521f9f4f454163d0e

test/distributed/_sharded_tensor/test_sharded_tensor.py
torch/distributed/_sharded_tensor/__init__.py
torch/distributed/_sharded_tensor/api.py

index 6c03d9f..718b594 100644 (file)
@@ -126,8 +126,9 @@ def with_comms(func):
 class TestCreateTensorFromParams(TestCase):
     @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed')
     def test_empty(self):
+        expected_dtype = torch.double
         tensor_properties = TensorProperties(
-            dtype=torch.double,
+            dtype=expected_dtype,
             layout=torch.strided,
             requires_grad=False,
             pin_memory=False,
@@ -138,14 +139,15 @@ class TestCreateTensorFromParams(TestCase):
         local_tensor = _create_tensor_from_params(
             5, 10, local_device=local_device, tensor_init_params=tensor_init_params)
         self.assertEqual(local_device, local_tensor.device)
-        self.assertEqual(torch.double, local_tensor.dtype)
+        self.assertEqual(expected_dtype, local_tensor.dtype)
         self.assertEqual(torch.strided, local_tensor.layout)
         self.assertEqual(False, local_tensor.requires_grad)
 
     @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed')
     def test_ones(self):
+        expected_dtype = torch.double
         tensor_properties = TensorProperties(
-            dtype=torch.double,
+            dtype=expected_dtype,
             layout=torch.strided,
             requires_grad=False,
             pin_memory=False,
@@ -153,9 +155,98 @@ class TestCreateTensorFromParams(TestCase):
         tensor_init_params = TensorInitParams(
             create_op=CreateOp.ONES, tensor_properties=tensor_properties)
         local_device = torch.device('cuda:0')
+        h, w = 5, 10
         local_tensor = _create_tensor_from_params(
-            5, 10, local_device=local_device, tensor_init_params=tensor_init_params)
-        expected_tensor = torch.ones(5, 10, device=local_device, dtype=torch.double)
+            h, w, local_device=local_device, tensor_init_params=tensor_init_params)
+        expected_tensor = torch.ones(h, w, device=local_device, dtype=expected_dtype)
+        self.assertEqual(expected_tensor, local_tensor)
+
+    @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed')
+    def test_zeros(self):
+        expected_dtype = torch.int32
+        tensor_properties = TensorProperties(
+            dtype=expected_dtype,
+            layout=torch.strided,
+            requires_grad=False,
+            pin_memory=False,
+            memory_format=torch.contiguous_format,
+        )
+        tensor_init_params = TensorInitParams(create_op=CreateOp.ZEROS, tensor_properties=tensor_properties, )
+        local_device = torch.device('cuda:0')
+        h, w = 5, 10
+        local_tensor = _create_tensor_from_params(
+            h, w, local_device=local_device, tensor_init_params=tensor_init_params)
+        expected_tensor = torch.zeros(h, w, device=local_device, dtype=expected_dtype)
+        self.assertEqual(expected_tensor, local_tensor)
+
+    @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed')
+    def test_rand(self):
+        expected_dtype = torch.double
+        tensor_properties = TensorProperties(
+            dtype=expected_dtype,
+            layout=torch.strided,
+            requires_grad=False,
+            pin_memory=False,
+            memory_format=torch.contiguous_format,
+        )
+        tensor_init_params = TensorInitParams(create_op=CreateOp.RAND, tensor_properties=tensor_properties, )
+        local_device = torch.device('cuda:0')
+        h, w = 5, 10
+        seed = 13
+        torch.cuda.manual_seed(seed)
+        local_tensor = _create_tensor_from_params(
+            h, w, local_device=local_device, tensor_init_params=tensor_init_params)
+        # reset seed to ensure same random numbers are generated
+        torch.cuda.manual_seed(seed)
+        expected_tensor = torch.rand(h, w, device=local_device, dtype=expected_dtype)
+        self.assertEqual(expected_tensor, local_tensor)
+
+    @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed')
+    def test_full_with_dtype_inferred(self):
+        fill_value = 23.5
+        tensor_properties = TensorProperties(
+            # tensor's dtype can be inferred from fill_value
+            dtype=None,
+            layout=torch.strided,
+            requires_grad=False,
+            pin_memory=False,
+            memory_format=torch.contiguous_format,
+        )
+        tensor_init_params = TensorInitParams(
+            create_op=CreateOp.FULL,
+            fill_value=fill_value,
+            tensor_properties=tensor_properties, )
+        local_device = torch.device('cuda:0')
+        h, w = 5, 10
+        local_tensor = _create_tensor_from_params(
+            h, w, local_device=local_device, tensor_init_params=tensor_init_params)
+        # local_tensor.dtype is inferred from fill_value (float32).
+        self.assertEqual(torch.float32, local_tensor.dtype)
+        expected_tensor = torch.full((h, w), fill_value=fill_value, device=local_device)
+        self.assertEqual(expected_tensor, local_tensor)
+
+    @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed')
+    def test_full_with_dtype_overridden(self):
+        fill_value = 23.5
+        tensor_properties = TensorProperties(
+            # tensor's dtype can be inferred from fill_value
+            dtype=torch.double,
+            layout=torch.strided,
+            requires_grad=False,
+            pin_memory=False,
+            memory_format=torch.contiguous_format,
+        )
+        tensor_init_params = TensorInitParams(
+            create_op=CreateOp.FULL,
+            fill_value=fill_value,
+            tensor_properties=tensor_properties, )
+        local_device = torch.device('cuda:0')
+        h, w = 5, 10
+        local_tensor = _create_tensor_from_params(
+            h, w, local_device=local_device, tensor_init_params=tensor_init_params)
+        # local_tensor.dtype is overridden.
+        self.assertEqual(torch.double, local_tensor.dtype)
+        expected_tensor = torch.full((h, w), fill_value=fill_value, device=local_device, dtype=torch.double)
         self.assertEqual(expected_tensor, local_tensor)
 
 class TestShardedTensorChunked(ShardedTensorTestBase, MultiProcessTestCase):
@@ -295,6 +386,102 @@ class TestShardedTensorChunked(ShardedTensorTestBase, MultiProcessTestCase):
     @with_comms
     @skip_if_lt_x_gpu(4)
     @requires_nccl()
+    def test_create_sharded_tensor_with_zeros(self):
+        """ Test _sharded_tensor.zeros(...) """
+
+        spec = ChunkShardingSpec(
+            dim=0,
+            placements=[
+                "rank:0/cuda:0",
+                "rank:1/cuda:1",
+                "rank:2/cuda:2",
+                "rank:3/cuda:3",
+            ],
+        )
+        h, w = 10, 20
+        sharded_tensor = _sharded_tensor.zeros(spec, h, w)
+
+        # Validate local shard is initialized with torch.zeros
+        local_shards = sharded_tensor.local_shards()
+        self.assertEqual(1, len(local_shards))
+        local_shard = local_shards[0].tensor
+        self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device)
+        # The split: for rank!=3 ceil(h/4)=3  for rank=3 1
+        expected_h = 1 if self.rank == 3 else math.ceil(h / 4)
+        self.assertEqual((expected_h, w), local_shard.size())
+        self.assertEqual(local_shard, torch.zeros(expected_h, w))
+
+
+    @with_comms
+    @skip_if_lt_x_gpu(4)
+    @requires_nccl()
+    def test_create_sharded_tensor_with_rand(self):
+        """ Test _sharded_tensor.rand(...) """
+
+        spec = ChunkShardingSpec(
+            dim=0,
+            placements=[
+                "rank:0/cuda:0",
+                "rank:1/cuda:1",
+                "rank:2/cuda:2",
+                "rank:3/cuda:3",
+            ],
+        )
+        h, w = 8, 2
+        seed = 1234
+
+        expected_h = 2
+        expected_device = torch.device(f"cuda:{self.rank}")
+        dtype = torch.double
+        torch.manual_seed(seed)
+        expected = torch.rand(expected_h, w, device=expected_device, dtype=dtype)
+        # reset seed to ensure the same random numbers are generated
+        torch.manual_seed(seed)
+        sharded_tensor = _sharded_tensor.rand(spec, h, w, dtype=dtype)
+
+        # Validate local shard is initialized with torch.rand
+        local_shards = sharded_tensor.local_shards()
+        self.assertEqual(1, len(local_shards))
+        local_shard = local_shards[0].tensor
+        self.assertEqual(expected_device, local_shard.device)
+        self.assertEqual((expected_h, w), local_shard.size())
+        self.assertEqual(expected, local_shard)
+
+
+    @with_comms
+    @skip_if_lt_x_gpu(4)
+    @requires_nccl()
+    def test_create_sharded_tensor_with_full(self):
+        """ Test _sharded_tensor.full(...) """
+
+        spec = ChunkShardingSpec(
+            dim=0,
+            placements=[
+                "rank:0/cuda:0",
+                "rank:1/cuda:1",
+                "rank:2/cuda:2",
+                "rank:3/cuda:3",
+            ],
+        )
+        h, w = 10, 20
+        fill_value = 1234
+        sharded_tensor = _sharded_tensor.full(spec, size=(h, w), fill_value=fill_value, dtype=torch.int32)
+
+        # Validate local shard is initialized with torch.full
+        local_shards = sharded_tensor.local_shards()
+        self.assertEqual(1, len(local_shards))
+        local_shard = local_shards[0].tensor
+        self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device)
+        # The split: for rank!=3 ceil(h/4)=3  for rank=3 1
+        expected_h = 1 if self.rank == 3 else math.ceil(h / 4)
+        self.assertEqual((expected_h, w), local_shard.size())
+        self.assertEqual(local_shard,
+                         torch.full(size=(expected_h, w), fill_value=fill_value, dtype=torch.int32))
+
+
+    @with_comms
+    @skip_if_lt_x_gpu(4)
+    @requires_nccl()
     def test_partial_world_size(self):
 
         spec = ChunkShardingSpec(
index 4cbdded..4f8646d 100644 (file)
@@ -1,7 +1,5 @@
-from typing import List
+# coding=utf-8
 
-import torch
-from torch.distributed._sharding_spec import ShardingSpec
 from .api import (
     CreateOp,
     Shard,
@@ -11,6 +9,9 @@ from .api import (
     TensorProperties,
     load_with_process_group,
 )
+from torch.distributed._sharding_spec import ShardingSpec
+from typing import List
+import torch
 
 
 def empty(sharding_spec: ShardingSpec,
@@ -23,7 +24,8 @@ def empty(sharding_spec: ShardingSpec,
           process_group=None,
           init_rrefs=False):
     """
-    Creates an empty :class:`ShardedTensor`. Needs to be called on all ranks in an SPMD fashion.
+    Returns a :class:`ShardedTensor` filled with uninitialized data.
+        Needs to be called on all ranks in an SPMD fashion.
 
     Args:
         sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
@@ -74,7 +76,8 @@ def ones(sharding_spec: ShardingSpec,
          process_group=None,
          init_rrefs=False):
     """
-    Creates a ones :class:`ShardedTensor`. Needs to be called on all ranks in an SPMD fashion.
+    Returns a :class:`ShardedTensor` with the scalar value 1.
+        Needs to be called on all ranks in an SPMD fashion.
 
     Args:
         sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
@@ -113,10 +116,172 @@ def ones(sharding_spec: ShardingSpec,
         init_rrefs=init_rrefs,
     )
 
-def init_from_local_shards(local_shards: List[Shard],
-                           sharded_tensor_metadata: ShardedTensorMetadata,
-                           process_group=None,
-                           init_rrefs=False):
+
+def rand(sharding_spec: ShardingSpec,
+         *size,
+         dtype=None,
+         layout=torch.strided,
+         requires_grad=False,
+         pin_memory=False,
+         memory_format=torch.contiguous_format,
+         process_group=None,
+         init_rrefs=False):
+    """
+    Returns a :class:`ShardedTensor` filled with random numbers from a uniform distribution on the
+        interval :math:`[0, 1)`. Needs to be called on all ranks in an SPMD fashion.
+
+    Args:
+        sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
+            describing how to shard the Tensor.
+        size (int...): a sequence of integers defining the shape of the output
+            tensor. Can be a variable number of arguments or a collection like a list or tuple.
+
+    Keyword args:
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+            Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
+        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
+            Default: ``torch.strided``.
+        requires_grad (bool, optional): If autograd should record operations on the
+            returned tensor. Default: ``False``.
+        pin_memory (bool, optional): If set, returned tensor would be allocated in
+            the pinned memory. Works only for CPU tensors. Default: ``False``.
+        process_group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        init_rrefs (bool, optional): Whether or not to initialize
+            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
+            Need to initialize the RPC Framework if specified as ``True``.
+            Default: ``False``.
+
+    Returns:
+        A :class:`ShardedTensor` object on each rank
+    """
+    tensor_properties = TensorProperties(
+        dtype=dtype, layout=layout, requires_grad=requires_grad,
+        pin_memory=pin_memory, memory_format=memory_format
+    )
+    tensor_init_params = TensorInitParams(create_op=CreateOp.RAND, tensor_properties=tensor_properties, )
+    return ShardedTensor(
+        sharding_spec,
+        *size,
+        tensor_init_params=tensor_init_params,
+        process_group=process_group,
+        init_rrefs=init_rrefs,
+    )
+
+
+def zeros(sharding_spec: ShardingSpec,
+          *size,
+          dtype=None,
+          layout=torch.strided,
+          requires_grad=False,
+          pin_memory=False,
+          memory_format=torch.contiguous_format,
+          process_group=None,
+          init_rrefs=False):
+    """
+    Returns a :class:`ShardedTensor` filled with the scalar value 0.
+        Needs to be called on all ranks in an SPMD fashion.
+
+    Args:
+        sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
+            describing how to shard the Tensor.
+        size (int...): a sequence of integers defining the shape of the output
+            tensor. Can be a variable number of arguments or a collection like a list or tuple.
+
+    Keyword args:
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+            Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
+        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
+            Default: ``torch.strided``.
+        requires_grad (bool, optional): If autograd should record operations on the
+            returned tensor. Default: ``False``.
+        pin_memory (bool, optional): If set, returned tensor would be allocated in
+            the pinned memory. Works only for CPU tensors. Default: ``False``.
+        process_group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        init_rrefs (bool, optional): Whether or not to initialize
+            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
+            Need to initialize the RPC Framework if specified as ``True``.
+            Default: ``False``.
+
+    Returns:
+        A :class:`ShardedTensor` object on each rank
+    """
+    tensor_properties = TensorProperties(
+        dtype=dtype, layout=layout, requires_grad=requires_grad,
+        pin_memory=pin_memory, memory_format=memory_format,
+    )
+    tensor_init_params = TensorInitParams(create_op=CreateOp.ZEROS, tensor_properties=tensor_properties, )
+    return ShardedTensor(
+        sharding_spec,
+        *size,
+        tensor_init_params=tensor_init_params,
+        process_group=process_group,
+        init_rrefs=init_rrefs,
+    )
+
+
+def full(sharding_spec: ShardingSpec,
+         size,
+         fill_value=torch.types.Number,
+         dtype=None,
+         layout=torch.strided,
+         requires_grad=False,
+         pin_memory=False,
+         memory_format=torch.contiguous_format,
+         process_group=None,
+         init_rrefs=False):
+    """
+    Creates a :class:`ShardedTensor` filled with fill_value. The tensor’s dtype
+        is inferred from fill_value. If dtype is specified, it will override the
+        inferred type from fill_value. Needs to be called on all ranks in an SPMD fashion.
+
+    Args:
+        sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
+            describing how to shard the Tensor.
+        size (int...):  a list, tuple, or `torch.Size` of integers defining the shape of the
+            output tensor.
+        fill_value (Scalar) – the value to fill the output tensor with.
+
+    Keyword args:
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+            Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
+        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
+            Default: ``torch.strided``.
+        requires_grad (bool, optional): If autograd should record operations on the
+            returned tensor. Default: ``False``.
+        pin_memory (bool, optional): If set, returned tensor would be allocated in
+            the pinned memory. Works only for CPU tensors. Default: ``False``.
+        process_group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        init_rrefs (bool, optional): Whether or not to initialize
+            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
+            Need to initialize the RPC Framework if specified as ``True``.
+            Default: ``False``.
+
+    Returns:
+        A :class:`ShardedTensor` object on each rank
+    """
+    tensor_properties = TensorProperties(
+        dtype=dtype, layout=layout, requires_grad=requires_grad,
+        pin_memory=pin_memory, memory_format=memory_format,
+    )
+    tensor_init_params = TensorInitParams(
+        create_op=CreateOp.FULL, fill_value=fill_value, tensor_properties=tensor_properties)
+    return ShardedTensor(
+        sharding_spec,
+        *size,
+        tensor_init_params=tensor_init_params,
+        process_group=process_group,
+        init_rrefs=init_rrefs,
+    )
+
+
+def init_from_local_shards(
+        local_shards: List[Shard],
+        sharded_tensor_metadata: ShardedTensorMetadata,
+        process_group=None,
+        init_rrefs=False):
     """
     Creates an :class:`ShardedTensor` from local shards and the global metadata.
     Needs to be called on all ranks in an SPMD fashion.
index ae1a3a9..3b7476d 100644 (file)
@@ -22,7 +22,7 @@ from torch.distributed._sharding_spec._internals import (
     check_tensor,
     validate_non_overlapping_shards_metadata
 )
-
+from torch.types import Number
 
 # Tracking for sharded tensor objects.
 _sharded_tensor_lock = threading.Lock()
@@ -143,17 +143,28 @@ def _register_remote_shards(sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]]
 
 class CreateOp(Enum):
     EMPTY = 0
-    ONES = 1
+    FULL = 1
+    ONES = 2
+    RAND = 3
+    ZEROS = 4
 
 
 @dataclass
 class TensorInitParams(object):
     """ Container for list of common params to create new local tensor. """
 
-    __slots__ = ['create_op', 'tensor_properties']
-
     create_op: CreateOp
-    tensor_properties: TensorProperties
+
+    # needed when create_op is FULL
+    # default set to False (not None) since None is incompatible with Number.
+    fill_value: Number = field(default=False)
+
+    tensor_properties: TensorProperties = field(
+        default=TensorProperties(dtype=torch.get_default_dtype(),
+                                 layout=torch.strided,
+                                 requires_grad=False,
+                                 memory_format=torch.contiguous_format,
+                                 pin_memory=False))
 
 
 class ShardedTensor(object):
@@ -684,5 +695,26 @@ def _create_tensor_from_params(*size, local_device, tensor_init_params: TensorIn
                            device=local_device, requires_grad=requires_grad,
                            # NB: memory_format param is not accepted by torch.ones
                            memory_format=memory_format, pin_memory=pin_memory,)
+    elif tensor_init_params.create_op == CreateOp.ZEROS:
+        return torch.zeros(*size,
+                           dtype=dtype,
+                           layout=layout,
+                           device=local_device,
+                           pin_memory=pin_memory,
+                           requires_grad=requires_grad,)
+    elif tensor_init_params.create_op == CreateOp.RAND:
+        return torch.rand(*size,
+                          dtype=dtype,
+                          layout=layout,
+                          device=local_device,
+                          pin_memory=pin_memory,
+                          requires_grad=requires_grad,)
+    elif tensor_init_params.create_op == CreateOp.FULL:
+        return torch.full(size=size,
+                          fill_value=tensor_init_params.fill_value,
+                          layout=layout,
+                          dtype=dtype,
+                          requires_grad=requires_grad,
+                          device=local_device, )
     else:
         raise ValueError(f'Unsupported create_op: {tensor_init_params.create_op}')