[PT/ShardedTensor]Allow zero size local shard (#65007)
authorXing Liu <xingl@fb.com>
Tue, 21 Sep 2021 16:38:04 +0000 (09:38 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 21 Sep 2021 16:58:54 +0000 (09:58 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65007

Relax shard size check in ShardMetadata to allow zero size local shard.

When sharding a tensor on N ranks, some ranks may have empty shard allocated. As we are assuming SPMD, the ranks w/ empty shard still need to participate in all collectives, and we need to allow this in ShardMetadata.

Test Plan: Unit tests and CLI

Reviewed By: jiaqizhai, wanchaol

Differential Revision: D30926566

fbshipit-source-id: afa562c94ffa8f8d91d65ddb4c348156d871dc36

test/distributed/_sharding_spec/test_sharding_spec.py
torch/distributed/_sharding_spec/_internals.py

index 409e7bd..4709ff3 100644 (file)
@@ -148,8 +148,8 @@ class TestShardingSpec(TestCase):
         with self.assertRaisesRegex(ValueError, 'shard_offsets should be >=0'):
             ShardMetadata(shard_offsets=[-1, 0], shard_lengths=[1, 1], placement="cuda:0")
 
-        with self.assertRaisesRegex(ValueError, 'shard_lengths should be > 0'):
-            ShardMetadata(shard_offsets=[0, 0], shard_lengths=[0, 1], placement="cuda:0")
+        with self.assertRaisesRegex(ValueError, 'shard_lengths should be >= 0'):
+            ShardMetadata(shard_offsets=[0, 0], shard_lengths=[-1, 1], placement="cuda:0")
 
         with self.assertRaisesRegex(ValueError, 'Empty shard list provided'):
             EnumerableShardingSpec([])
index 568d11c..afeeaeb 100644 (file)
@@ -40,8 +40,8 @@ class ShardMetadata(object):
         for i in range(len(self.shard_offsets)):
             if self.shard_offsets[i] < 0:
                 raise ValueError('shard_offsets should be >=0')
-            if self.shard_lengths[i] <= 0:
-                raise ValueError('shard_lengths should be > 0')
+            if self.shard_lengths[i] < 0:
+                raise ValueError('shard_lengths should be >= 0')