Extend _sharded_tensor constructor to support other ops like torch.ones (#63378)
authorBo Wang <bowangbj@fb.com>
Sat, 21 Aug 2021 00:09:35 +0000 (17:09 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sat, 21 Aug 2021 00:11:34 +0000 (17:11 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63378

a) Introduce InitCommonParams to wrap tensor creation params
b) Factor local tensor initiation into common_params so that tensor value is not hard specified in ShardedTensor constructor
c) Add _sharded_tensor.ones(...) to exemplify - Note memory_format arg is not provided to be consistent as torch.ones
d) Follow up: more ops like torch.full, torch.zero, torch.rand,

Test:
$ python test/distributed/_sharded_tensor/test_sharded_tensor.py TestCreateTensorFromParams --v
$ python test/distributed/_sharded_tensor/test_sharded_tensor.py TestShardedTensorChunked.test_create_sharded_tensor_with_ones --v
$ python test/distributed/_sharded_tensor/test_sharded_tensor.py TestShardedTensorEnumerable.test_create_sharded_tensor_with_ones --v

Test Plan: Imported from OSS

Reviewed By: pritamdamania87, wanchaol

Differential Revision: D30359245

Pulled By: bowangbj

fbshipit-source-id: 85768fcb36e9d9d40213036884b1266930a91701

test/distributed/_sharded_tensor/test_sharded_tensor.py
torch/distributed/_sharded_tensor/__init__.py
torch/distributed/_sharded_tensor/api.py

index 829855f6be2c507591eb79d6c4eaccafb6211849..5067f301b559537460d3edd993aa30e845b1eecd 100644 (file)
@@ -1,4 +1,5 @@
 from functools import wraps
+import math
 import io
 import sys
 import torch
@@ -15,6 +16,11 @@ from torch.distributed._sharding_spec import (
     EnumerableShardingSpec,
     ShardMetadata
 )
+from torch.distributed._sharded_tensor.api import (
+    CreateOp,
+    TensorInitParams,
+    _create_tensor_from_params,
+)
 from torch.testing._internal.common_distributed import (
     MultiProcessTestCase,
     requires_nccl,
@@ -22,10 +28,11 @@ from torch.testing._internal.common_distributed import (
     TEST_SKIPS,
 )
 from torch.testing._internal.common_utils import (
+    TestCase,
     TEST_WITH_DEV_DBG_ASAN,
     run_tests,
+    sandcastle_skip_if,
 )
-
 if TEST_WITH_DEV_DBG_ASAN:
     print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)
     sys.exit(0)
@@ -115,6 +122,38 @@ def with_comms(func):
         self.destroy_comms()
     return wrapper
 
+class TestCreateTensorFromParams(TestCase):
+    @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed')
+    def test_empty(self):
+        tensor_init_params = TensorInitParams(
+            create_op=CreateOp.EMPTY,
+            dtype=torch.double,
+            layout=torch.strided,
+            requires_grad=False,
+            pin_memory=False,
+            memory_format=torch.contiguous_format, )
+        local_device = torch.device('cuda:0')
+        local_tensor = _create_tensor_from_params(
+            5, 10, local_device=local_device, tensor_init_params=tensor_init_params)
+        self.assertEqual(local_device, local_tensor.device)
+        self.assertEqual(torch.double, local_tensor.dtype)
+        self.assertEqual(torch.strided, local_tensor.layout)
+        self.assertEqual(False, local_tensor.requires_grad)
+
+    @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed')
+    def test_ones(self):
+        tensor_init_params = TensorInitParams(
+            create_op=CreateOp.ONES,
+            dtype=torch.double,
+            layout=torch.strided,
+            requires_grad=False,
+            pin_memory=False,
+            memory_format=torch.contiguous_format, )
+        local_device = torch.device('cuda:0')
+        local_tensor = _create_tensor_from_params(
+            5, 10, local_device=local_device, tensor_init_params=tensor_init_params)
+        expected_tensor = torch.ones(5, 10, device=local_device, dtype=torch.double)
+        self.assertEqual(expected_tensor, local_tensor)
 
 class TestShardedTensorChunked(ShardedTensorTestBase, MultiProcessTestCase):
 
@@ -219,6 +258,35 @@ class TestShardedTensorChunked(ShardedTensorTestBase, MultiProcessTestCase):
                     else:
                         self.assertEqual((3, 20), shard.tensor.size())
 
+
+    @with_comms
+    @skip_if_lt_x_gpu(4)
+    @requires_nccl()
+    def test_create_sharded_tensor_with_ones(self):
+        """ Test _sharded_tensor.ones(...) """
+
+        spec = ChunkShardingSpec(
+            dim=0,
+            placements=[
+                "rank:0/cuda:0",
+                "rank:1/cuda:1",
+                "rank:2/cuda:2",
+                "rank:3/cuda:3",
+            ],
+        )
+        h, w = 10, 20
+        sharded_tensor = _sharded_tensor.ones(spec, h, w)
+
+        # Validate local shard is initialized with torch.ones
+        local_shards = sharded_tensor.local_shards()
+        self.assertEqual(1, len(local_shards))
+        local_shard = local_shards[0].tensor
+        self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device)
+        # The split: for rank!=3 ceil(h/4)=3  for rank=3 1
+        expected_h = 1 if self.rank == 3 else math.ceil(h / 4)
+        self.assertEqual((expected_h, w), local_shard.size())
+        self.assertEqual(local_shard, torch.ones(expected_h, w))
+
     @with_comms
     @skip_if_lt_x_gpu(4)
     @requires_nccl()
@@ -818,6 +886,45 @@ class TestShardedTensorEnumerable(ShardedTensorTestBase, MultiProcessTestCase):
                 shard = remote_shard.to_here()
                 self.assertEqual((5, 5), shard.tensor.size())
 
+    @with_comms
+    @skip_if_lt_x_gpu(4)
+    @requires_nccl()
+    def test_create_sharded_tensor_with_ones(self):
+        """ Test _sharded_tensor.ones(...) """
+
+        spec = EnumerableShardingSpec([
+            ShardMetadata(
+                shard_offsets=[0, 0],
+                shard_lengths=[5, 5],
+                placement="rank:0/cuda:0",
+            ),
+            ShardMetadata(
+                shard_offsets=[0, 5],
+                shard_lengths=[5, 5],
+                placement="rank:1/cuda:1",
+            ),
+            ShardMetadata(
+                shard_offsets=[5, 0],
+                shard_lengths=[5, 5],
+                placement="rank:2/cuda:2",
+            ),
+            ShardMetadata(
+                shard_offsets=[5, 5],
+                shard_lengths=[5, 5],
+                placement="rank:3/cuda:3",
+            )
+        ])
+
+        sharded_tensor = _sharded_tensor.ones(spec, 10, 10, init_rrefs=True)
+        self.assertEqual((10, 10), sharded_tensor.size())
+        self.assertEqual(1, len(sharded_tensor.local_shards()))
+
+        # Verify local shard is initialized with torch.ones
+        local_shard = sharded_tensor.local_shards()[0]
+        self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device)
+        self.assertEqual((5, 5), local_shard.tensor.size())
+        self.assertEqual(local_shard.tensor, torch.ones(5, 5))
+
     @skip_if_lt_x_gpu(4)
     @requires_nccl()
     def test_uneven_shards(self):
index d9833159dc9de568ecc1f96795a9cd822080408d..ecb7ea1fed8c61141db03e70e31a58841464abb3 100644 (file)
@@ -3,12 +3,15 @@ from typing import List
 import torch
 from torch.distributed._sharding_spec import ShardingSpec
 from .api import (
+    CreateOp,
     Shard,
     ShardedTensor,
     ShardedTensorMetadata,
+    TensorInitParams,
     load_with_process_group,
 )
 
+
 def empty(
         sharding_spec: ShardingSpec,
         *size,
@@ -49,14 +52,62 @@ def empty(
     Returns:
         A :class:`ShardedTensor` object on each rank
     """
+    tensor_init_params = TensorInitParams(create_op=CreateOp.EMPTY, dtype=dtype, layout=layout,
+                                          requires_grad=requires_grad,
+                                          pin_memory=pin_memory, memory_format=memory_format)
+    return ShardedTensor(
+        sharding_spec,
+        *size,
+        tensor_init_params=tensor_init_params,
+        process_group=process_group,
+        init_rrefs=init_rrefs,
+    )
+
+def ones(
+        sharding_spec: ShardingSpec,
+        *size,
+        dtype=None,
+        layout=torch.strided,
+        requires_grad=False,
+        pin_memory=False,
+        memory_format=torch.contiguous_format,
+        process_group=None,
+        init_rrefs=False):
+    """
+    Creates a ones :class:`ShardedTensor`. Needs to be called on all ranks in an SPMD fashion.
+
+    Args:
+        sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
+            describing how to shard the Tensor.
+        size (int...): a sequence of integers defining the shape of the output
+            tensor. Can be a variable number of arguments or a collection like a list or tuple.
+
+    Keyword args:
+        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+            Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
+        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
+            Default: ``torch.strided``.
+        requires_grad (bool, optional): If autograd should record operations on the
+            returned tensor. Default: ``False``.
+        pin_memory (bool, optional): If set, returned tensor would be allocated in
+            the pinned memory. Works only for CPU tensors. Default: ``False``.
+        process_group (ProcessGroup, optional): The process group to work on. If None,
+            the default process group will be used.
+        init_rrefs (bool, optional): Whether or not to initialize
+            :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
+            Need to initialize the RPC Framework if specified as ``True``.
+            Default: ``False``.
+
+    Returns:
+        A :class:`ShardedTensor` object on each rank
+    """
+    tensor_init_params = TensorInitParams(create_op=CreateOp.ONES, dtype=dtype, layout=layout,
+                                          requires_grad=requires_grad,
+                                          pin_memory=pin_memory, memory_format=memory_format)
     return ShardedTensor(
         sharding_spec,
         *size,
-        dtype=dtype,
-        layout=layout,
-        requires_grad=requires_grad,
-        pin_memory=pin_memory,
-        memory_format=memory_format,
+        tensor_init_params=tensor_init_params,
         process_group=process_group,
         init_rrefs=init_rrefs,
     )
index ca9a05abffa0610821695d57e764116a06174468..2b6720b059a851eac5ac06323a4ed17a87ffe184 100644 (file)
@@ -1,6 +1,7 @@
 import collections
 from contextlib import contextmanager
 from dataclasses import dataclass, field
+from enum import Enum
 from typing import (
     Dict,
     List
@@ -22,6 +23,7 @@ from torch.distributed._sharding_spec._internals import (
     validate_non_overlapping_shards_metadata
 )
 
+
 # Tracking for sharded tensor objects.
 _sharded_tensor_lock = threading.Lock()
 _sharded_tensor_current_id = 0
@@ -123,6 +125,26 @@ def _register_remote_shards(sharded_tensor_id: int, rrefs: List[rpc.RRef[Shard]]
         _sharded_tensor_map[sharded_tensor_id]._register_remote_shards(rrefs, rpc_rank)
 
 
+class CreateOp(Enum):
+    EMPTY = 0
+    ONES = 1
+
+
+@dataclass
+class TensorInitParams(object):
+    """ Container for list of common params to create new local tensor. """
+
+    __slots__ = ['create_op', 'dtype', 'layout', 'requires_grad', 'pin_memory',
+                 'memory_format']
+
+    create_op: CreateOp
+    dtype: torch.dtype
+    layout: torch.layout
+    requires_grad: bool
+    pin_memory: bool
+    memory_format: torch.memory_format
+
+
 class ShardedTensor(object):
     """
     ShardedTensor is an abstraction to represent Tensors that are sharded
@@ -136,8 +158,9 @@ class ShardedTensor(object):
     ShardedTensor doesn't provide any Tensor like operations but is a wrapper
     providing the Tensor representing the local shard and the global metadata.
     Using these, users can build their custom distributed sharded computations
-    on top of this primitive. The local shards are all initialized using
-    :meth:`torch.empty`.
+    on top of this primitive. The local shards are all initialized using the
+    create_op specified by tensor_init_params.create_op, e.g., torch.ones, or
+    torch.empty
 
     Args:
         sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
@@ -146,20 +169,7 @@ class ShardedTensor(object):
             tensor. Can be a variable number of arguments or a collection like a list or tuple.
 
     Keyword args:
-        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
-            Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
-        layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
-            Default: ``torch.strided``.
-        requires_grad (bool, optional): If autograd should record operations on the
-            returned tensor. Default: ``False``.
-        pin_memory (bool, optional): If set, returned tensor would be allocated in
-            the pinned memory. Works only for CPU tensors. Default: ``False``.
-        memory_format (:class:`torch.memory_format`, optional): the desired memory format of
-            returned Tensor. Default: ``torch.contiguous_format``.
-        process_group (ProcessGroup, optional): The process group to work on. If None,
-            the default process group will be used. If specified the ShardedTensor is only
-            built on ranks that are part of this process group and the provided ``sharding_spec``
-            is applied in the context of this process group.
+        tensor_init_params (:class: `TensorInitParams`): common params to create tensor.
         init_rrefs (bool, optional): Whether or not to initialize
             :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
             Need to initialize the RPC Framework if specified as ``True``.
@@ -170,11 +180,7 @@ class ShardedTensor(object):
         self,
         sharding_spec: ShardingSpec,
         *size,
-        dtype=None,
-        layout=torch.strided,
-        requires_grad=False,
-        pin_memory=False,
-        memory_format=torch.contiguous_format,
+        tensor_init_params: TensorInitParams,
         process_group=None,
         init_rrefs=False,
     ):
@@ -182,13 +188,13 @@ class ShardedTensor(object):
         # _process_group, _local_shards, etc.
         self._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
 
-        if dtype is None:
-            dtype = torch.get_default_dtype()
+        if tensor_init_params.dtype is None:
+            tensor_init_params.dtype = torch.get_default_dtype()
 
-        if layout != torch.strided:
+        if tensor_init_params.layout != torch.strided:
             raise ValueError('Only torch.strided layout is currently supported')
 
-        if memory_format != torch.contiguous_format:
+        if tensor_init_params.memory_format != torch.contiguous_format:
             raise ValueError('Only torch.contiguous_format memory_format is currently supported')
 
         if len(size) == 1 and isinstance(size[0], collections.Sequence):
@@ -203,23 +209,9 @@ class ShardedTensor(object):
         self._sharding_spec = sharding_spec
 
         if isinstance(self._sharding_spec, ChunkShardingSpec):
-            self._init_chunked(
-                dims,
-                dtype,
-                layout,
-                requires_grad,
-                pin_memory,
-                memory_format,
-            )
+            self._init_chunked(dims, tensor_init_params)
         elif isinstance(self._sharding_spec, EnumerableShardingSpec):
-            self._init_enumerable(
-                dims,
-                dtype,
-                layout,
-                requires_grad,
-                pin_memory,
-                memory_format,
-            )
+            self._init_enumerable(dims, tensor_init_params)
         else:
             raise ValueError(f'Unsupported sharding_spec: {self._sharding_spec}')
 
@@ -420,15 +412,7 @@ class ShardedTensor(object):
         sharded_tensor._post_init()
         return sharded_tensor
 
-    def _init_chunked(
-        self,
-        dims,
-        dtype,
-        layout,
-        requires_grad,
-        pin_memory,
-        memory_format,
-    ):
+    def _init_chunked(self, dims, tensor_init_params: TensorInitParams, ):
         current_rank = dist.get_rank(self._process_group)
         sharding_dim = self._sharding_spec.dim  # type: ignore[attr-defined]
 
@@ -469,38 +453,22 @@ class ShardedTensor(object):
                 # Build the local shard for the current rank if it is involved in the sharding spec.
                 if current_rank == rank:
                     # Initialize the local shard.
-                    local_shard = torch.empty(
-                        *rank_dims,
-                        dtype=dtype,
-                        layout=layout,
-                        device=local_device,
-                        requires_grad=requires_grad,
-                        memory_format=memory_format,
-                        pin_memory=pin_memory,
-                    )
-
+                    local_shard = _create_tensor_from_params(
+                        *rank_dims, local_device=local_device, tensor_init_params=tensor_init_params)
                     self._local_shards.append(Shard(local_shard, shard_metadata))
 
         # Build overall metadata
         self._metadata = ShardedTensorMetadata(
             shards_metadata,
             dims,
-            dtype,
-            layout,
-            requires_grad,
-            memory_format,
-            pin_memory,
+            tensor_init_params.dtype,
+            tensor_init_params.layout,
+            tensor_init_params.requires_grad,
+            tensor_init_params.memory_format,
+            tensor_init_params.pin_memory,
         )
 
-    def _init_enumerable(
-        self,
-        dims,
-        dtype,
-        layout,
-        requires_grad,
-        pin_memory,
-        memory_format,
-    ):
+    def _init_enumerable(self, dims, tensor_init_params: TensorInitParams):
         # Validate the sharding spec is compatible with the tensor.
         check_tensor(self._sharding_spec.shards, dims)  # type: ignore[attr-defined]
 
@@ -513,27 +481,20 @@ class ShardedTensor(object):
 
             if current_rank == rank:
                 # Initialize the local shard.
-                local_shard = torch.empty(
-                    *shard_metadata.shard_lengths,
-                    dtype=dtype,
-                    layout=layout,
-                    device=local_device,
-                    requires_grad=requires_grad,
-                    memory_format=memory_format,
-                    pin_memory=pin_memory,
-                )
-
+                local_shard = _create_tensor_from_params(
+                    *shard_metadata.shard_lengths, local_device=local_device,
+                    tensor_init_params=tensor_init_params)
                 self._local_shards.append(Shard(local_shard, shard_metadata))
 
         # Build overall metadata
         self._metadata = ShardedTensorMetadata(
             shards_metadata,
             dims,
-            dtype,
-            layout,
-            requires_grad,
-            memory_format,
-            pin_memory,
+            tensor_init_params.dtype,
+            tensor_init_params.layout,
+            tensor_init_params.requires_grad,
+            tensor_init_params.memory_format,
+            tensor_init_params.pin_memory,
         )
 
     def _parse_and_validate_remote_device(self, remote_device: torch.distributed._remote_device):
@@ -672,3 +633,26 @@ class ShardedTensor(object):
                 f'but at load time was {global_world_size}')
 
         self._post_init()
+
+
+def _create_tensor_from_params(*size, local_device, tensor_init_params: TensorInitParams):
+    """ Helper to construct tensor from size, device and common params. """
+
+    if tensor_init_params.create_op == CreateOp.ONES:
+        return torch.ones(*size,
+                          dtype=tensor_init_params.dtype,
+                          layout=tensor_init_params.layout,
+                          device=local_device,
+                          pin_memory=tensor_init_params.pin_memory,
+                          requires_grad=tensor_init_params.requires_grad,)
+    elif tensor_init_params.create_op == CreateOp.EMPTY:
+        return torch.empty(*size,
+                           dtype=tensor_init_params.dtype,
+                           layout=tensor_init_params.layout,
+                           device=local_device,
+                           requires_grad=tensor_init_params.requires_grad,
+                           # Note memory_format param is not accepted by torch.ones
+                           memory_format=tensor_init_params.memory_format,
+                           pin_memory=tensor_init_params.pin_memory,)
+    else:
+        raise ValueError(f'Unsupported create_op: {tensor_init_params.create_op}')