class TestCreateTensorFromParams(TestCase):
@sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed')
def test_empty(self):
+ expected_dtype = torch.double
tensor_properties = TensorProperties(
- dtype=torch.double,
+ dtype=expected_dtype,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
local_tensor = _create_tensor_from_params(
5, 10, local_device=local_device, tensor_init_params=tensor_init_params)
self.assertEqual(local_device, local_tensor.device)
- self.assertEqual(torch.double, local_tensor.dtype)
+ self.assertEqual(expected_dtype, local_tensor.dtype)
self.assertEqual(torch.strided, local_tensor.layout)
self.assertEqual(False, local_tensor.requires_grad)
@sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed')
def test_ones(self):
+ expected_dtype = torch.double
tensor_properties = TensorProperties(
- dtype=torch.double,
+ dtype=expected_dtype,
layout=torch.strided,
requires_grad=False,
pin_memory=False,
tensor_init_params = TensorInitParams(
create_op=CreateOp.ONES, tensor_properties=tensor_properties)
local_device = torch.device('cuda:0')
+ h, w = 5, 10
local_tensor = _create_tensor_from_params(
- 5, 10, local_device=local_device, tensor_init_params=tensor_init_params)
- expected_tensor = torch.ones(5, 10, device=local_device, dtype=torch.double)
+ h, w, local_device=local_device, tensor_init_params=tensor_init_params)
+ expected_tensor = torch.ones(h, w, device=local_device, dtype=expected_dtype)
+ self.assertEqual(expected_tensor, local_tensor)
+
+ @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed')
+ def test_zeros(self):
+ expected_dtype = torch.int32
+ tensor_properties = TensorProperties(
+ dtype=expected_dtype,
+ layout=torch.strided,
+ requires_grad=False,
+ pin_memory=False,
+ memory_format=torch.contiguous_format,
+ )
+ tensor_init_params = TensorInitParams(create_op=CreateOp.ZEROS, tensor_properties=tensor_properties, )
+ local_device = torch.device('cuda:0')
+ h, w = 5, 10
+ local_tensor = _create_tensor_from_params(
+ h, w, local_device=local_device, tensor_init_params=tensor_init_params)
+ expected_tensor = torch.zeros(h, w, device=local_device, dtype=expected_dtype)
+ self.assertEqual(expected_tensor, local_tensor)
+
+ @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed')
+ def test_rand(self):
+ expected_dtype = torch.double
+ tensor_properties = TensorProperties(
+ dtype=expected_dtype,
+ layout=torch.strided,
+ requires_grad=False,
+ pin_memory=False,
+ memory_format=torch.contiguous_format,
+ )
+ tensor_init_params = TensorInitParams(create_op=CreateOp.RAND, tensor_properties=tensor_properties, )
+ local_device = torch.device('cuda:0')
+ h, w = 5, 10
+ seed = 13
+ torch.cuda.manual_seed(seed)
+ local_tensor = _create_tensor_from_params(
+ h, w, local_device=local_device, tensor_init_params=tensor_init_params)
+ # reset seed to ensure same random numbers are generated
+ torch.cuda.manual_seed(seed)
+ expected_tensor = torch.rand(h, w, device=local_device, dtype=expected_dtype)
+ self.assertEqual(expected_tensor, local_tensor)
+
+ @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed')
+ def test_full_with_dtype_inferred(self):
+ fill_value = 23.5
+ tensor_properties = TensorProperties(
+ # tensor's dtype can be inferred from fill_value
+ dtype=None,
+ layout=torch.strided,
+ requires_grad=False,
+ pin_memory=False,
+ memory_format=torch.contiguous_format,
+ )
+ tensor_init_params = TensorInitParams(
+ create_op=CreateOp.FULL,
+ fill_value=fill_value,
+ tensor_properties=tensor_properties, )
+ local_device = torch.device('cuda:0')
+ h, w = 5, 10
+ local_tensor = _create_tensor_from_params(
+ h, w, local_device=local_device, tensor_init_params=tensor_init_params)
+ # local_tensor.dtype is inferred from fill_value (float32).
+ self.assertEqual(torch.float32, local_tensor.dtype)
+ expected_tensor = torch.full((h, w), fill_value=fill_value, device=local_device)
+ self.assertEqual(expected_tensor, local_tensor)
+
+ @sandcastle_skip_if(torch.cuda.device_count() < 1, 'CUDA GPU is needed')
+ def test_full_with_dtype_overridden(self):
+ fill_value = 23.5
+ tensor_properties = TensorProperties(
+ # tensor's dtype can be inferred from fill_value
+ dtype=torch.double,
+ layout=torch.strided,
+ requires_grad=False,
+ pin_memory=False,
+ memory_format=torch.contiguous_format,
+ )
+ tensor_init_params = TensorInitParams(
+ create_op=CreateOp.FULL,
+ fill_value=fill_value,
+ tensor_properties=tensor_properties, )
+ local_device = torch.device('cuda:0')
+ h, w = 5, 10
+ local_tensor = _create_tensor_from_params(
+ h, w, local_device=local_device, tensor_init_params=tensor_init_params)
+ # local_tensor.dtype is overridden.
+ self.assertEqual(torch.double, local_tensor.dtype)
+ expected_tensor = torch.full((h, w), fill_value=fill_value, device=local_device, dtype=torch.double)
self.assertEqual(expected_tensor, local_tensor)
class TestShardedTensorChunked(ShardedTensorTestBase, MultiProcessTestCase):
@with_comms
@skip_if_lt_x_gpu(4)
@requires_nccl()
+ def test_create_sharded_tensor_with_zeros(self):
+ """ Test _sharded_tensor.zeros(...) """
+
+ spec = ChunkShardingSpec(
+ dim=0,
+ placements=[
+ "rank:0/cuda:0",
+ "rank:1/cuda:1",
+ "rank:2/cuda:2",
+ "rank:3/cuda:3",
+ ],
+ )
+ h, w = 10, 20
+ sharded_tensor = _sharded_tensor.zeros(spec, h, w)
+
+ # Validate local shard is initialized with torch.zeros
+ local_shards = sharded_tensor.local_shards()
+ self.assertEqual(1, len(local_shards))
+ local_shard = local_shards[0].tensor
+ self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device)
+ # The split: for rank!=3 ceil(h/4)=3 for rank=3 1
+ expected_h = 1 if self.rank == 3 else math.ceil(h / 4)
+ self.assertEqual((expected_h, w), local_shard.size())
+ self.assertEqual(local_shard, torch.zeros(expected_h, w))
+
+
+ @with_comms
+ @skip_if_lt_x_gpu(4)
+ @requires_nccl()
+ def test_create_sharded_tensor_with_rand(self):
+ """ Test _sharded_tensor.rand(...) """
+
+ spec = ChunkShardingSpec(
+ dim=0,
+ placements=[
+ "rank:0/cuda:0",
+ "rank:1/cuda:1",
+ "rank:2/cuda:2",
+ "rank:3/cuda:3",
+ ],
+ )
+ h, w = 8, 2
+ seed = 1234
+
+ expected_h = 2
+ expected_device = torch.device(f"cuda:{self.rank}")
+ dtype = torch.double
+ torch.manual_seed(seed)
+ expected = torch.rand(expected_h, w, device=expected_device, dtype=dtype)
+ # reset seed to ensure the same random numbers are generated
+ torch.manual_seed(seed)
+ sharded_tensor = _sharded_tensor.rand(spec, h, w, dtype=dtype)
+
+ # Validate local shard is initialized with torch.rand
+ local_shards = sharded_tensor.local_shards()
+ self.assertEqual(1, len(local_shards))
+ local_shard = local_shards[0].tensor
+ self.assertEqual(expected_device, local_shard.device)
+ self.assertEqual((expected_h, w), local_shard.size())
+ self.assertEqual(expected, local_shard)
+
+
+ @with_comms
+ @skip_if_lt_x_gpu(4)
+ @requires_nccl()
+ def test_create_sharded_tensor_with_full(self):
+ """ Test _sharded_tensor.full(...) """
+
+ spec = ChunkShardingSpec(
+ dim=0,
+ placements=[
+ "rank:0/cuda:0",
+ "rank:1/cuda:1",
+ "rank:2/cuda:2",
+ "rank:3/cuda:3",
+ ],
+ )
+ h, w = 10, 20
+ fill_value = 1234
+ sharded_tensor = _sharded_tensor.full(spec, size=(h, w), fill_value=fill_value, dtype=torch.int32)
+
+ # Validate local shard is initialized with torch.full
+ local_shards = sharded_tensor.local_shards()
+ self.assertEqual(1, len(local_shards))
+ local_shard = local_shards[0].tensor
+ self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device)
+ # The split: for rank!=3 ceil(h/4)=3 for rank=3 1
+ expected_h = 1 if self.rank == 3 else math.ceil(h / 4)
+ self.assertEqual((expected_h, w), local_shard.size())
+ self.assertEqual(local_shard,
+ torch.full(size=(expected_h, w), fill_value=fill_value, dtype=torch.int32))
+
+
+ @with_comms
+ @skip_if_lt_x_gpu(4)
+ @requires_nccl()
def test_partial_world_size(self):
spec = ChunkShardingSpec(
-from typing import List
+# coding=utf-8
-import torch
-from torch.distributed._sharding_spec import ShardingSpec
from .api import (
CreateOp,
Shard,
TensorProperties,
load_with_process_group,
)
+from torch.distributed._sharding_spec import ShardingSpec
+from typing import List
+import torch
def empty(sharding_spec: ShardingSpec,
process_group=None,
init_rrefs=False):
"""
- Creates an empty :class:`ShardedTensor`. Needs to be called on all ranks in an SPMD fashion.
+ Returns a :class:`ShardedTensor` filled with uninitialized data.
+ Needs to be called on all ranks in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
process_group=None,
init_rrefs=False):
"""
- Creates a ones :class:`ShardedTensor`. Needs to be called on all ranks in an SPMD fashion.
+ Returns a :class:`ShardedTensor` with the scalar value 1.
+ Needs to be called on all ranks in an SPMD fashion.
Args:
sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
init_rrefs=init_rrefs,
)
-def init_from_local_shards(local_shards: List[Shard],
- sharded_tensor_metadata: ShardedTensorMetadata,
- process_group=None,
- init_rrefs=False):
+
+def rand(sharding_spec: ShardingSpec,
+ *size,
+ dtype=None,
+ layout=torch.strided,
+ requires_grad=False,
+ pin_memory=False,
+ memory_format=torch.contiguous_format,
+ process_group=None,
+ init_rrefs=False):
+ """
+ Returns a :class:`ShardedTensor` filled with random numbers from a uniform distribution on the
+ interval :math:`[0, 1)`. Needs to be called on all ranks in an SPMD fashion.
+
+ Args:
+ sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
+ describing how to shard the Tensor.
+ size (int...): a sequence of integers defining the shape of the output
+ tensor. Can be a variable number of arguments or a collection like a list or tuple.
+
+ Keyword args:
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+ Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
+ layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
+ Default: ``torch.strided``.
+ requires_grad (bool, optional): If autograd should record operations on the
+ returned tensor. Default: ``False``.
+ pin_memory (bool, optional): If set, returned tensor would be allocated in
+ the pinned memory. Works only for CPU tensors. Default: ``False``.
+ process_group (ProcessGroup, optional): The process group to work on. If None,
+ the default process group will be used.
+ init_rrefs (bool, optional): Whether or not to initialize
+ :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
+ Need to initialize the RPC Framework if specified as ``True``.
+ Default: ``False``.
+
+ Returns:
+ A :class:`ShardedTensor` object on each rank
+ """
+ tensor_properties = TensorProperties(
+ dtype=dtype, layout=layout, requires_grad=requires_grad,
+ pin_memory=pin_memory, memory_format=memory_format
+ )
+ tensor_init_params = TensorInitParams(create_op=CreateOp.RAND, tensor_properties=tensor_properties, )
+ return ShardedTensor(
+ sharding_spec,
+ *size,
+ tensor_init_params=tensor_init_params,
+ process_group=process_group,
+ init_rrefs=init_rrefs,
+ )
+
+
+def zeros(sharding_spec: ShardingSpec,
+ *size,
+ dtype=None,
+ layout=torch.strided,
+ requires_grad=False,
+ pin_memory=False,
+ memory_format=torch.contiguous_format,
+ process_group=None,
+ init_rrefs=False):
+ """
+ Returns a :class:`ShardedTensor` filled with the scalar value 0.
+ Needs to be called on all ranks in an SPMD fashion.
+
+ Args:
+ sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
+ describing how to shard the Tensor.
+ size (int...): a sequence of integers defining the shape of the output
+ tensor. Can be a variable number of arguments or a collection like a list or tuple.
+
+ Keyword args:
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+ Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
+ layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
+ Default: ``torch.strided``.
+ requires_grad (bool, optional): If autograd should record operations on the
+ returned tensor. Default: ``False``.
+ pin_memory (bool, optional): If set, returned tensor would be allocated in
+ the pinned memory. Works only for CPU tensors. Default: ``False``.
+ process_group (ProcessGroup, optional): The process group to work on. If None,
+ the default process group will be used.
+ init_rrefs (bool, optional): Whether or not to initialize
+ :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
+ Need to initialize the RPC Framework if specified as ``True``.
+ Default: ``False``.
+
+ Returns:
+ A :class:`ShardedTensor` object on each rank
+ """
+ tensor_properties = TensorProperties(
+ dtype=dtype, layout=layout, requires_grad=requires_grad,
+ pin_memory=pin_memory, memory_format=memory_format,
+ )
+ tensor_init_params = TensorInitParams(create_op=CreateOp.ZEROS, tensor_properties=tensor_properties, )
+ return ShardedTensor(
+ sharding_spec,
+ *size,
+ tensor_init_params=tensor_init_params,
+ process_group=process_group,
+ init_rrefs=init_rrefs,
+ )
+
+
+def full(sharding_spec: ShardingSpec,
+ size,
+ fill_value=torch.types.Number,
+ dtype=None,
+ layout=torch.strided,
+ requires_grad=False,
+ pin_memory=False,
+ memory_format=torch.contiguous_format,
+ process_group=None,
+ init_rrefs=False):
+ """
+ Creates a :class:`ShardedTensor` filled with fill_value. The tensor’s dtype
+ is inferred from fill_value. If dtype is specified, it will override the
+ inferred type from fill_value. Needs to be called on all ranks in an SPMD fashion.
+
+ Args:
+ sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
+ describing how to shard the Tensor.
+ size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the
+ output tensor.
+ fill_value (Scalar) – the value to fill the output tensor with.
+
+ Keyword args:
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+ Default: if ``None``, uses a global default (see :func:`torch.set_default_tensor_type`).
+ layout (:class:`torch.layout`, optional): the desired layout of returned Tensor.
+ Default: ``torch.strided``.
+ requires_grad (bool, optional): If autograd should record operations on the
+ returned tensor. Default: ``False``.
+ pin_memory (bool, optional): If set, returned tensor would be allocated in
+ the pinned memory. Works only for CPU tensors. Default: ``False``.
+ process_group (ProcessGroup, optional): The process group to work on. If None,
+ the default process group will be used.
+ init_rrefs (bool, optional): Whether or not to initialize
+ :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
+ Need to initialize the RPC Framework if specified as ``True``.
+ Default: ``False``.
+
+ Returns:
+ A :class:`ShardedTensor` object on each rank
+ """
+ tensor_properties = TensorProperties(
+ dtype=dtype, layout=layout, requires_grad=requires_grad,
+ pin_memory=pin_memory, memory_format=memory_format,
+ )
+ tensor_init_params = TensorInitParams(
+ create_op=CreateOp.FULL, fill_value=fill_value, tensor_properties=tensor_properties)
+ return ShardedTensor(
+ sharding_spec,
+ *size,
+ tensor_init_params=tensor_init_params,
+ process_group=process_group,
+ init_rrefs=init_rrefs,
+ )
+
+
+def init_from_local_shards(
+ local_shards: List[Shard],
+ sharded_tensor_metadata: ShardedTensorMetadata,
+ process_group=None,
+ init_rrefs=False):
"""
Creates an :class:`ShardedTensor` from local shards and the global metadata.
Needs to be called on all ranks in an SPMD fashion.
check_tensor,
validate_non_overlapping_shards_metadata
)
-
+from torch.types import Number
# Tracking for sharded tensor objects.
_sharded_tensor_lock = threading.Lock()
class CreateOp(Enum):
EMPTY = 0
- ONES = 1
+ FULL = 1
+ ONES = 2
+ RAND = 3
+ ZEROS = 4
@dataclass
class TensorInitParams(object):
""" Container for list of common params to create new local tensor. """
- __slots__ = ['create_op', 'tensor_properties']
-
create_op: CreateOp
- tensor_properties: TensorProperties
+
+ # needed when create_op is FULL
+ # default set to False (not None) since None is incompatible with Number.
+ fill_value: Number = field(default=False)
+
+ tensor_properties: TensorProperties = field(
+ default=TensorProperties(dtype=torch.get_default_dtype(),
+ layout=torch.strided,
+ requires_grad=False,
+ memory_format=torch.contiguous_format,
+ pin_memory=False))
class ShardedTensor(object):
device=local_device, requires_grad=requires_grad,
# NB: memory_format param is not accepted by torch.ones
memory_format=memory_format, pin_memory=pin_memory,)
+ elif tensor_init_params.create_op == CreateOp.ZEROS:
+ return torch.zeros(*size,
+ dtype=dtype,
+ layout=layout,
+ device=local_device,
+ pin_memory=pin_memory,
+ requires_grad=requires_grad,)
+ elif tensor_init_params.create_op == CreateOp.RAND:
+ return torch.rand(*size,
+ dtype=dtype,
+ layout=layout,
+ device=local_device,
+ pin_memory=pin_memory,
+ requires_grad=requires_grad,)
+ elif tensor_init_params.create_op == CreateOp.FULL:
+ return torch.full(size=size,
+ fill_value=tensor_init_params.fill_value,
+ layout=layout,
+ dtype=dtype,
+ requires_grad=requires_grad,
+ device=local_device, )
else:
raise ValueError(f'Unsupported create_op: {tensor_init_params.create_op}')