from functools import wraps
+import math
import io
import sys
import torch
EnumerableShardingSpec,
ShardMetadata
)
+from torch.distributed._sharded_tensor.api import (
+ CreateOp,
+ TensorInitParams,
+ _create_tensor_from_params,
+)
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_nccl,
TEST_SKIPS,
)
from torch.testing._internal.common_utils import (
+ TestCase,
TEST_WITH_DEV_DBG_ASAN,
run_tests,
+ sandcastle_skip_if,
)
-
if TEST_WITH_DEV_DBG_ASAN:
print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)
sys.exit(0)
self.destroy_comms()
return wrapper
+class TestCreateTensorFromParams(TestCase):
+ @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed')
+ def test_empty(self):
+ tensor_init_params = TensorInitParams(
+ create_op=CreateOp.EMPTY,
+ dtype=torch.double,
+ layout=torch.strided,
+ requires_grad=False,
+ pin_memory=False,
+ memory_format=torch.contiguous_format, )
+ local_device = torch.device('cuda:0')
+ local_tensor = _create_tensor_from_params(
+ 5, 10, local_device=local_device, tensor_init_params=tensor_init_params)
+ self.assertEqual(local_device, local_tensor.device)
+ self.assertEqual(torch.double, local_tensor.dtype)
+ self.assertEqual(torch.strided, local_tensor.layout)
+ self.assertEqual(False, local_tensor.requires_grad)
+
+ @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed')
+ def test_ones(self):
+ tensor_init_params = TensorInitParams(
+ create_op=CreateOp.ONES,
+ dtype=torch.double,
+ layout=torch.strided,
+ requires_grad=False,
+ pin_memory=False,
+ memory_format=torch.contiguous_format, )
+ local_device = torch.device('cuda:0')
+ local_tensor = _create_tensor_from_params(
+ 5, 10, local_device=local_device, tensor_init_params=tensor_init_params)
+ expected_tensor = torch.ones(5, 10, device=local_device, dtype=torch.double)
+ self.assertEqual(expected_tensor, local_tensor)
class TestShardedTensorChunked(ShardedTensorTestBase, MultiProcessTestCase):
else:
self.assertEqual((3, 20), shard.tensor.size())
+
+ @with_comms
+ @skip_if_lt_x_gpu(4)
+ @requires_nccl()
+ def test_create_sharded_tensor_with_ones(self):
+ """ Test _sharded_tensor.ones(...) """
+
+ spec = ChunkShardingSpec(
+ dim=0,
+ placements=[
+ "rank:0/cuda:0",
+ "rank:1/cuda:1",
+ "rank:2/cuda:2",
+ "rank:3/cuda:3",
+ ],
+ )
+ h, w = 10, 20
+ sharded_tensor = _sharded_tensor.ones(spec, h, w)
+
+ # Validate local shard is initialized with torch.ones
+ local_shards = sharded_tensor.local_shards()
+ self.assertEqual(1, len(local_shards))
+ local_shard = local_shards[0].tensor
+ self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device)
+ # The split: for rank!=3 ceil(h/4)=3 for rank=3 1
+ expected_h = 1 if self.rank == 3 else math.ceil(h / 4)
+ self.assertEqual((expected_h, w), local_shard.size())
+ self.assertEqual(local_shard, torch.ones(expected_h, w))
+
@with_comms
@skip_if_lt_x_gpu(4)
@requires_nccl()
shard = remote_shard.to_here()
self.assertEqual((5, 5), shard.tensor.size())
+ @with_comms
+ @skip_if_lt_x_gpu(4)
+ @requires_nccl()
+ def test_create_sharded_tensor_with_ones(self):
+ """ Test _sharded_tensor.ones(...) """
+
+ spec = EnumerableShardingSpec([
+ ShardMetadata(
+ shard_offsets=[0, 0],
+ shard_lengths=[5, 5],
+ placement="rank:0/cuda:0",
+ ),
+ ShardMetadata(
+ shard_offsets=[0, 5],
+ shard_lengths=[5, 5],
+ placement="rank:1/cuda:1",
+ ),
+ ShardMetadata(
+ shard_offsets=[5, 0],
+ shard_lengths=[5, 5],
+ placement="rank:2/cuda:2",
+ ),
+ ShardMetadata(
+ shard_offsets=[5, 5],
+ shard_lengths=[5, 5],
+ placement="rank:3/cuda:3",
+ )
+ ])
+
+ sharded_tensor = _sharded_tensor.ones(spec, 10, 10, init_rrefs=True)
+ self.assertEqual((10, 10), sharded_tensor.size())
+ self.assertEqual(1, len(sharded_tensor.local_shards()))
+
+ # Verify local shard is initialized with torch.ones
+ local_shard = sharded_tensor.local_shards()[0]
+ self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device)
+ self.assertEqual((5, 5), local_shard.tensor.size())
+ self.assertEqual(local_shard.tensor, torch.ones(5, 5))
+
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_uneven_shards(self):
import torch
from torch.distributed._sharding_spec import ShardingSpec
from .api import (
+ CreateOp,
Shard,
ShardedTensor,
ShardedTensorMetadata,
+ TensorInitParams,
load_with_process_group,
)
+
def empty(
sharding_spec: ShardingSpec,
*size,
Returns:
A :class:`ShardedTensor` object on each rank
"""
+ tensor_init_params = TensorInitParams(create_op=CreateOp.EMPTY, dtype=dtype, layout=layout,
+ requires_grad=requires_grad,
+ pin_memory=pin_memory, memory_format=memory_format)
+ return ShardedTensor(
+ sharding_spec,
+ *size,
+ tensor_init_params=tensor_init_params,
+ process_group=process_group,
+ init_rrefs=init_rrefs,
+ )
+
+def ones(
+ sharding_spec: ShardingSpec,
+ *size,
+ dtype=None,
+ layout=torch.strided,
+ requires_grad=False,
+ pin_memory=False,
+ memory_format=torch.contiguous_format,
+ process_group=None,
+ init_rrefs=False):
+ """
+ Creates a ones :class:`ShardedTensor`. Needs to be called on all ranks in an SPMD fashion.
+
+ Args:
+ sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
+ describing how to shard the Tensor.
+ size (int...): a sequence of integers defining the shape of the output
+ tensor. Can be a variable number of arguments or a collection like a list or tuple.
+
+ Keyword args:
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+ Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
+ layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
+ Default: ``torch.strided``.
+ requires_grad (bool, optional): If autograd should record operations on the
+ returned tensor. Default: ``False``.
+ pin_memory (bool, optional): If set, returned tensor would be allocated in
+ the pinned memory. Works only for CPU tensors. Default: ``False``.
+ process_group (ProcessGroup, optional): The process group to work on. If None,
+ the default process group will be used.
+ init_rrefs (bool, optional): Whether or not to initialize
+ :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
+ Need to initialize the RPC Framework if specified as ``True``.
+ Default: ``False``.
+
+ Returns:
+ A :class:`ShardedTensor` object on each rank
+ """
+ tensor_init_params = TensorInitParams(create_op=CreateOp.ONES, dtype=dtype, layout=layout,
+ requires_grad=requires_grad,
+ pin_memory=pin_memory, memory_format=memory_format)
return ShardedTensor(
sharding_spec,
*size,
- dtype=dtype,
- layout=layout,
- requires_grad=requires_grad,
- pin_memory=pin_memory,
- memory_format=memory_format,
+ tensor_init_params=tensor_init_params,
process_group=process_group,
init_rrefs=init_rrefs,
)
import collections
from contextlib import contextmanager
from dataclasses import dataclass, field
+from enum import Enum
from typing import (
Dict,
List
validate_non_overlapping_shards_metadata
)
+
# Tracking for sharded tensor objects.
_sharded_tensor_lock = threading.Lock()
_sharded_tensor_current_id = 0
_sharded_tensor_map[sharded_tensor_id]._register_remote_shards(rrefs, rpc_rank)
+class CreateOp(Enum):
+ EMPTY = 0
+ ONES = 1
+
+
+@dataclass
+class TensorInitParams(object):
+ """ Container for list of common params to create new local tensor. """
+
+ __slots__ = ['create_op', 'dtype', 'layout', 'requires_grad', 'pin_memory',
+ 'memory_format']
+
+ create_op: CreateOp
+ dtype: torch.dtype
+ layout: torch.layout
+ requires_grad: bool
+ pin_memory: bool
+ memory_format: torch.memory_format
+
+
class ShardedTensor(object):
"""
ShardedTensor is an abstraction to represent Tensors that are sharded
ShardedTensor doesn't provide any Tensor like operations but is a wrapper
providing the Tensor representing the local shard and the global metadata.
Using these, users can build their custom distributed sharded computations
- on top of this primitive. The local shards are all initialized using
- :meth:`torch.empty`.
+ on top of this primitive. The local shards are all initialized using the
+ create_op specified by tensor_init_params.create_op, e.g., torch.ones, or
+ torch.empty
Args:
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
tensor. Can be a variable number of arguments or a collection like a list or tuple.
Keyword args:
- dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
- Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
- layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
- Default: ``torch.strided``.
- requires_grad (bool, optional): If autograd should record operations on the
- returned tensor. Default: ``False``.
- pin_memory (bool, optional): If set, returned tensor would be allocated in
- the pinned memory. Works only for CPU tensors. Default: ``False``.
- memory_format (:class:`torch.memory_format`, optional): the desired memory format of
- returned Tensor. Default: ``torch.contiguous_format``.
- process_group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used. If specified the ShardedTensor is only
- built on ranks that are part of this process group and the provided ``sharding_spec``
- is applied in the context of this process group.
+ tensor_init_params (:class: `TensorInitParams`): common params to create tensor.
init_rrefs (bool, optional): Whether or not to initialize
:class:`torch.distributed.rpc.RRef`s pointing to remote shards.
Need to initialize the RPC Framework if specified as ``True``.
self,
sharding_spec: ShardingSpec,
*size,
- dtype=None,
- layout=torch.strided,
- requires_grad=False,
- pin_memory=False,
- memory_format=torch.contiguous_format,
+ tensor_init_params: TensorInitParams,
process_group=None,
init_rrefs=False,
):
# _process_group, _local_shards, etc.
self._prepare_init(process_group=process_group, init_rrefs=init_rrefs)
- if dtype is None:
- dtype = torch.get_default_dtype()
+ if tensor_init_params.dtype is None:
+ tensor_init_params.dtype = torch.get_default_dtype()
- if layout != torch.strided:
+ if tensor_init_params.layout != torch.strided:
raise ValueError('Only torch.strided layout is currently supported')
- if memory_format != torch.contiguous_format:
+ if tensor_init_params.memory_format != torch.contiguous_format:
raise ValueError('Only torch.contiguous_format memory_format is currently supported')
if len(size) == 1 and isinstance(size[0], collections.Sequence):
self._sharding_spec = sharding_spec
if isinstance(self._sharding_spec, ChunkShardingSpec):
- self._init_chunked(
- dims,
- dtype,
- layout,
- requires_grad,
- pin_memory,
- memory_format,
- )
+ self._init_chunked(dims, tensor_init_params)
elif isinstance(self._sharding_spec, EnumerableShardingSpec):
- self._init_enumerable(
- dims,
- dtype,
- layout,
- requires_grad,
- pin_memory,
- memory_format,
- )
+ self._init_enumerable(dims, tensor_init_params)
else:
raise ValueError(f'Unsupported sharding_spec: {self._sharding_spec}')
sharded_tensor._post_init()
return sharded_tensor
- def _init_chunked(
- self,
- dims,
- dtype,
- layout,
- requires_grad,
- pin_memory,
- memory_format,
- ):
+ def _init_chunked(self, dims, tensor_init_params: TensorInitParams, ):
current_rank = dist.get_rank(self._process_group)
sharding_dim = self._sharding_spec.dim # type: ignore[attr-defined]
# Build the local shard for the current rank if it is involved in the sharding spec.
if current_rank == rank:
# Initialize the local shard.
- local_shard = torch.empty(
- *rank_dims,
- dtype=dtype,
- layout=layout,
- device=local_device,
- requires_grad=requires_grad,
- memory_format=memory_format,
- pin_memory=pin_memory,
- )
-
+ local_shard = _create_tensor_from_params(
+ *rank_dims, local_device=local_device, tensor_init_params=tensor_init_params)
self._local_shards.append(Shard(local_shard, shard_metadata))
# Build overall metadata
self._metadata = ShardedTensorMetadata(
shards_metadata,
dims,
- dtype,
- layout,
- requires_grad,
- memory_format,
- pin_memory,
+ tensor_init_params.dtype,
+ tensor_init_params.layout,
+ tensor_init_params.requires_grad,
+ tensor_init_params.memory_format,
+ tensor_init_params.pin_memory,
)
- def _init_enumerable(
- self,
- dims,
- dtype,
- layout,
- requires_grad,
- pin_memory,
- memory_format,
- ):
+ def _init_enumerable(self, dims, tensor_init_params: TensorInitParams):
# Validate the sharding spec is compatible with the tensor.
check_tensor(self._sharding_spec.shards, dims) # type: ignore[attr-defined]
if current_rank == rank:
# Initialize the local shard.
- local_shard = torch.empty(
- *shard_metadata.shard_lengths,
- dtype=dtype,
- layout=layout,
- device=local_device,
- requires_grad=requires_grad,
- memory_format=memory_format,
- pin_memory=pin_memory,
- )
-
+ local_shard = _create_tensor_from_params(
+ *shard_metadata.shard_lengths, local_device=local_device,
+ tensor_init_params=tensor_init_params)
self._local_shards.append(Shard(local_shard, shard_metadata))
# Build overall metadata
self._metadata = ShardedTensorMetadata(
shards_metadata,
dims,
- dtype,
- layout,
- requires_grad,
- memory_format,
- pin_memory,
+ tensor_init_params.dtype,
+ tensor_init_params.layout,
+ tensor_init_params.requires_grad,
+ tensor_init_params.memory_format,
+ tensor_init_params.pin_memory,
)
def _parse_and_validate_remote_device(self, remote_device: torch.distributed._remote_device):
f'but at load time was {global_world_size}')
self._post_init()
+
+
+def _create_tensor_from_params(*size, local_device, tensor_init_params: TensorInitParams):
+ """ Helper to construct tensor from size, device and common params. """
+
+ if tensor_init_params.create_op == CreateOp.ONES:
+ return torch.ones(*size,
+ dtype=tensor_init_params.dtype,
+ layout=tensor_init_params.layout,
+ device=local_device,
+ pin_memory=tensor_init_params.pin_memory,
+ requires_grad=tensor_init_params.requires_grad,)
+ elif tensor_init_params.create_op == CreateOp.EMPTY:
+ return torch.empty(*size,
+ dtype=tensor_init_params.dtype,
+ layout=tensor_init_params.layout,
+ device=local_device,
+ requires_grad=tensor_init_params.requires_grad,
+ # Note memory_format param is not accepted by torch.ones
+ memory_format=tensor_init_params.memory_format,
+ pin_memory=tensor_init_params.pin_memory,)
+ else:
+ raise ValueError(f'Unsupported create_op: {tensor_init_params.create_op}')