Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63189
Add main --> run_tests func in test file which is needed to launch the real test cases in OSS flow.
Test Plan:
b/f:
$ python test/distributed/_sharding_spec/test_sharding_spec.py --v ==> nothing happened
$ python test/distributed/_sharded_tensor/test_sharded_tensor.py --v ==> nothing happened
after:
$ python test/distributed/_sharding_spec/test_sharding_spec.py --v ==>
test_chunked_sharding_spec (__main__.TestShardingSpec) ... ok
test_device_placement (__main__.TestShardingSpec) ... ok
test_enumerable_sharding_spec (__main__.TestShardingSpec) ... ok
$ python test/distributed/_sharded_tensor/test_sharded_tensor.py --v
test_complete_world_size (__main__.TestShardedTensorChunked) ... ok
test_insufficient_sharding_dims (__main__.TestShardedTensorChunked) ... ok
test_invalid_pg_rpc_ranks (__main__.TestShardedTensorChunked) ... [W tensorpipe_agent.cpp:699] RPC agent for worker2 encountered error when reading incoming request from worker0: eof (this error originated at tensorpipe/transport/shm/connection_impl.cc:259)
ok
test_invalid_sharding (__main__.TestShardedTensorChunked) ... ok
test_load_state_dict_errors (__main__.TestShardedTensorChunked) ... ok
test_multiple_local_shards (__main__.TestShardedTensorChunked) ... ok
test_new_group (__main__.TestShardedTensorChunked) ... ok
test_partial_world_size (__main__.TestShardedTensorChunked) ... ok
test_sharded_tensor_metadata (__main__.TestShardedTensorChunked) ... ok
test_sharded_tensor_sizes (__main__.TestShardedTensorChunked) ... ok
test_sharding_columns (__main__.TestShardedTensorChunked) ... ok
test_state_dict (__main__.TestShardedTensorChunked) ... ok
test_state_dict_new_group (__main__.TestShardedTensorChunked) ... ok
test_state_dict_no_sharded_tensors (__main__.TestShardedTensorChunked) ... ok
test_grid_sharding (__main__.TestShardedTensorEnumerable) ... ok
test_multiple_local_shards (__main__.TestShardedTensorEnumerable) ... ok
test_new_group (__main__.TestShardedTensorEnumerable) ... ok
test_partial_world_size (__main__.TestShardedTensorEnumerable) ... ok
test_sharded_tensor_metadata (__main__.TestShardedTensorEnumerable) ... ok
test_uneven_shards (__main__.TestShardedTensorEnumerable) ... ok
test_with_rpc_names (__main__.TestShardedTensorEnumerable) ... ok
test_init_from_local_shards (__main__.TestShardedTensorFromLocalShards) ... ok
test_init_from_local_shards_invalid_shards (__main__.TestShardedTensorFromLocalShards) ... ok
test_init_from_local_shards_invalid_shards_gaps (__main__.TestShardedTensorFromLocalShards) ...
Imported from OSS
Reviewed By: VitalyFedyunin
Differential Revision:
D30294094
fbshipit-source-id:
08f0431a12ea854abe00dc920205b10ba43ae6b6
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
+ run_tests,
)
if TEST_WITH_DEV_DBG_ASAN:
with self.assertRaisesRegex(ValueError, "does not match tensor volume"):
sharded_tensor = _sharded_tensor.init_from_local_shards(local_shards, sharded_tensor_metadata, init_rrefs=True)
+
+if __name__ == '__main__':
+ run_tests()
)
from torch.distributed._sharding_spec._internals import check_tensor
+from torch.testing._internal.common_utils import (
+ run_tests,
+ sandcastle_skip_if,
+)
+
class TestShardingSpec(TestCase):
+ @sandcastle_skip_if(torch.cuda.device_count() < 2, '2 CUDA GPUs are needed')
def test_device_placement(self):
# valid devices
DevicePlacementSpec("cuda:0")
with self.assertRaisesRegex(RuntimeError, "Invalid device string"):
DevicePlacementSpec("rank:0/cpu2")
+ @sandcastle_skip_if(torch.cuda.device_count() < 2, '2 CUDA GPUs are needed')
def test_chunked_sharding_spec(self):
# Test valid specs.
ChunkShardingSpec(0, [torch.device(0), torch.device(1)])
with self.assertRaisesRegex(RuntimeError, "Invalid device string"):
ChunkShardingSpec(0, ["rank:0/cuda:foo", "cuda:1"])
+ @sandcastle_skip_if(torch.cuda.device_count() < 2, '2 CUDA GPUs are needed')
def test_enumerable_sharding_spec(self):
# test valid specs
with self.assertRaisesRegex(ValueError, 'does not match tensor volume'):
check_tensor(spec.shards, torch.rand(10, 10).size())
+
+if __name__ == '__main__':
+ run_tests()