Add driver function to run test_sharded_tensor.py and test_sharding_spec.py (#63189)
authorBo Wang <bowangbj@fb.com>
Mon, 16 Aug 2021 22:18:01 +0000 (15:18 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 16 Aug 2021 22:25:32 +0000 (15:25 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63189

Add main --> run_tests func in test file which is needed to launch the real test cases in OSS flow.

Test Plan:
b/f:
$ python test/distributed/_sharding_spec/test_sharding_spec.py --v   ==> nothing happened
$ python test/distributed/_sharded_tensor/test_sharded_tensor.py --v ==> nothing happened

after:

$ python test/distributed/_sharding_spec/test_sharding_spec.py --v   ==>

test_chunked_sharding_spec (__main__.TestShardingSpec) ... ok
test_device_placement (__main__.TestShardingSpec) ... ok
test_enumerable_sharding_spec (__main__.TestShardingSpec) ... ok

$ python test/distributed/_sharded_tensor/test_sharded_tensor.py --v

test_complete_world_size (__main__.TestShardedTensorChunked) ... ok
test_insufficient_sharding_dims (__main__.TestShardedTensorChunked) ... ok
test_invalid_pg_rpc_ranks (__main__.TestShardedTensorChunked) ... [W tensorpipe_agent.cpp:699] RPC agent for worker2 encountered error when reading incoming request from worker0: eof (this error originated at tensorpipe/transport/shm/connection_impl.cc:259)
ok
test_invalid_sharding (__main__.TestShardedTensorChunked) ... ok
test_load_state_dict_errors (__main__.TestShardedTensorChunked) ... ok
test_multiple_local_shards (__main__.TestShardedTensorChunked) ... ok
test_new_group (__main__.TestShardedTensorChunked) ... ok
test_partial_world_size (__main__.TestShardedTensorChunked) ... ok
test_sharded_tensor_metadata (__main__.TestShardedTensorChunked) ... ok
test_sharded_tensor_sizes (__main__.TestShardedTensorChunked) ... ok
test_sharding_columns (__main__.TestShardedTensorChunked) ... ok
test_state_dict (__main__.TestShardedTensorChunked) ... ok
test_state_dict_new_group (__main__.TestShardedTensorChunked) ... ok
test_state_dict_no_sharded_tensors (__main__.TestShardedTensorChunked) ... ok
test_grid_sharding (__main__.TestShardedTensorEnumerable) ... ok
test_multiple_local_shards (__main__.TestShardedTensorEnumerable) ... ok
test_new_group (__main__.TestShardedTensorEnumerable) ... ok
test_partial_world_size (__main__.TestShardedTensorEnumerable) ... ok
test_sharded_tensor_metadata (__main__.TestShardedTensorEnumerable) ... ok
test_uneven_shards (__main__.TestShardedTensorEnumerable) ... ok
test_with_rpc_names (__main__.TestShardedTensorEnumerable) ... ok
test_init_from_local_shards (__main__.TestShardedTensorFromLocalShards) ... ok
test_init_from_local_shards_invalid_shards (__main__.TestShardedTensorFromLocalShards) ... ok
test_init_from_local_shards_invalid_shards_gaps (__main__.TestShardedTensorFromLocalShards) ...

Imported from OSS

Reviewed By: VitalyFedyunin

Differential Revision: D30294094

fbshipit-source-id: 08f0431a12ea854abe00dc920205b10ba43ae6b6

test/distributed/_sharded_tensor/test_sharded_tensor.py
test/distributed/_sharding_spec/test_sharding_spec.py

index 593ceba..829855f 100644 (file)
@@ -23,6 +23,7 @@ from torch.testing._internal.common_distributed import (
 )
 from torch.testing._internal.common_utils import (
     TEST_WITH_DEV_DBG_ASAN,
+    run_tests,
 )
 
 if TEST_WITH_DEV_DBG_ASAN:
@@ -1440,3 +1441,6 @@ class TestShardedTensorFromLocalShards(ShardedTensorTestBase, MultiProcessTestCa
 
         with self.assertRaisesRegex(ValueError, "does not match tensor volume"):
             sharded_tensor = _sharded_tensor.init_from_local_shards(local_shards, sharded_tensor_metadata, init_rrefs=True)
+
+if __name__ == '__main__':
+    run_tests()
index 652fcb0..77d862b 100644 (file)
@@ -8,8 +8,14 @@ from torch.distributed._sharding_spec import (
 )
 from torch.distributed._sharding_spec._internals import check_tensor
 
+from torch.testing._internal.common_utils import (
+    run_tests,
+    sandcastle_skip_if,
+)
+
 class TestShardingSpec(TestCase):
 
+    @sandcastle_skip_if(torch.cuda.device_count() < 2, '2 CUDA GPUs are needed')
     def test_device_placement(self):
         # valid devices
         DevicePlacementSpec("cuda:0")
@@ -29,6 +35,7 @@ class TestShardingSpec(TestCase):
         with self.assertRaisesRegex(RuntimeError, "Invalid device string"):
             DevicePlacementSpec("rank:0/cpu2")
 
+    @sandcastle_skip_if(torch.cuda.device_count() < 2, '2 CUDA GPUs are needed')
     def test_chunked_sharding_spec(self):
         # Test valid specs.
         ChunkShardingSpec(0, [torch.device(0), torch.device(1)])
@@ -58,6 +65,7 @@ class TestShardingSpec(TestCase):
         with self.assertRaisesRegex(RuntimeError, "Invalid device string"):
             ChunkShardingSpec(0, ["rank:0/cuda:foo", "cuda:1"])
 
+    @sandcastle_skip_if(torch.cuda.device_count() < 2, '2 CUDA GPUs are needed')
     def test_enumerable_sharding_spec(self):
         # test valid specs
 
@@ -217,3 +225,6 @@ class TestShardingSpec(TestCase):
 
         with self.assertRaisesRegex(ValueError, 'does not match tensor volume'):
             check_tensor(spec.shards, torch.rand(10, 10).size())
+
+if __name__ == '__main__':
+    run_tests()