Discover new tests in run_tests.py (#64246)
authorNikita Shulga <nshulga@fb.com>
Wed, 1 Sep 2021 00:19:11 +0000 (17:19 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 00:32:55 +0000 (17:32 -0700)
Summary:
Introduce `discover_tests` function that globs for all Python files
starting with `test_` in test folder excluding subfolders which are
executed differently

Fixes https://github.com/pytorch/pytorch/issues/64178

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64246

Reviewed By: walterddr, seemethere

Differential Revision: D30661652

Pulled By: malfet

fbshipit-source-id: a52e78ec717b6846add267579dd8d9ae75326bf9

test/run_test.py [changed mode: 0755->0644]

old mode 100755 (executable)
new mode 100644 (file)
index 55b2f38..5953919
@@ -50,145 +50,87 @@ except ImportError:
     )
 
 
-TESTS = [
-    "test_import_time",
-    "test_public_bindings",
-    "test_type_hints",
-    "test_ao_sparsity",
-    "test_autograd",
-    "benchmark_utils/test_benchmark_utils",
-    "test_binary_ufuncs",
-    "test_buffer_protocol",
-    "test_bundled_inputs",
-    "test_complex",
-    "test_cpp_api_parity",
-    "test_cpp_extensions_aot_no_ninja",
-    "test_cpp_extensions_aot_ninja",
-    "test_cpp_extensions_jit",
-    "distributed/test_c10d_common",
-    "distributed/test_c10d_gloo",
-    "distributed/test_c10d_nccl",
-    "distributed/test_jit_c10d",
-    "distributed/test_c10d_spawn_gloo",
-    "distributed/test_c10d_spawn_nccl",
-    "distributed/test_store",
-    "distributed/test_pg_wrapper",
-    "distributed/algorithms/test_join",
-    "test_cuda",
-    "test_autocast",
-    "test_jit_cuda_fuser",
-    "test_cuda_primary_ctx",
-    "test_dataloader",
-    "test_datapipe",
-    "distributed/test_data_parallel",
-    "distributed/test_distributed_spawn",
-    "distributions/test_constraints",
-    "distributions/test_distributions",
-    "test_dispatch",
-    "test_foreach",
-    "test_indexing",
-    "test_jit",
-    "test_linalg",
-    "test_logging",
-    "test_mkldnn",
-    "test_model_dump",
-    "test_module_init",
-    "test_modules",
-    "test_multiprocessing",
-    "test_multiprocessing_spawn",
-    "distributed/test_nccl",
-    "test_native_functions",
-    "test_numba_integration",
-    "test_nn",
-    "test_ops",
-    "test_optim",
-    "test_functional_optim",
-    "test_pytree",
-    "test_mobile_optimizer",
-    "test_set_default_mobile_cpu_allocator",
-    "test_xnnpack_integration",
-    "test_vulkan",
-    "test_sparse",
-    "test_sparse_csr",
-    "test_quantization",
-    "test_pruning_op",
-    "test_spectral_ops",
-    "test_serialization",
-    "test_shape_ops",
-    "test_show_pickle",
-    "test_sort_and_select",
-    "test_tensor_creation_ops",
-    "test_testing",
-    "test_torch",
-    "test_type_info",
-    "test_unary_ufuncs",
-    "test_utils",
-    "test_view_ops",
-    "test_vmap",
-    "test_namedtuple_return_api",
-    "test_numpy_interop",
-    "test_jit_profiling",
-    "test_jit_legacy",
-    "test_jit_fuser_legacy",
-    "test_tensorboard",
-    "test_namedtensor",
-    "test_reductions",
-    "test_type_promotion",
-    "test_jit_disabled",
-    "test_function_schema",
-    "test_overrides",
-    "test_jit_fuser_te",
-    "test_tensorexpr",
-    "test_tensorexpr_pybind",
-    "test_openmp",
-    "test_profiler",
-    "distributed/test_launcher",
-    "distributed/nn/jit/test_instantiator",
-    "distributed/rpc/test_faulty_agent",
-    "distributed/rpc/test_tensorpipe_agent",
-    "distributed/rpc/cuda/test_tensorpipe_agent",
-    "test_determination",
-    "test_futures",
-    "test_fx",
-    "test_fx_experimental",
-    "test_functional_autograd_benchmark",
-    "test_package",
-    "test_license",
-    "distributed/pipeline/sync/skip/test_api",
-    "distributed/pipeline/sync/skip/test_gpipe",
-    "distributed/pipeline/sync/skip/test_inspect_skip_layout",
-    "distributed/pipeline/sync/skip/test_leak",
-    "distributed/pipeline/sync/skip/test_portal",
-    "distributed/pipeline/sync/skip/test_stash_pop",
-    "distributed/pipeline/sync/skip/test_tracker",
-    "distributed/pipeline/sync/skip/test_verify_skippables",
-    "distributed/pipeline/sync/test_balance",
-    "distributed/pipeline/sync/test_bugs",
-    "distributed/pipeline/sync/test_checkpoint",
-    "distributed/pipeline/sync/test_copy",
-    "distributed/pipeline/sync/test_deferred_batch_norm",
-    "distributed/pipeline/sync/test_dependency",
-    "distributed/pipeline/sync/test_inplace",
-    "distributed/pipeline/sync/test_microbatch",
-    "distributed/pipeline/sync/test_phony",
-    "distributed/pipeline/sync/test_pipe",
-    "distributed/pipeline/sync/test_pipeline",
-    "distributed/pipeline/sync/test_stream",
-    "distributed/pipeline/sync/test_transparency",
-    "distributed/pipeline/sync/test_worker",
-    "distributed/optim/test_zero_redundancy_optimizer",
-    "distributed/elastic/timer/api_test",
-    "distributed/elastic/timer/local_timer_example",
-    "distributed/elastic/timer/local_timer_test",
-    "distributed/elastic/events/lib_test",
-    "distributed/elastic/metrics/api_test",
-    "distributed/elastic/utils/logging_test",
-    "distributed/elastic/utils/util_test",
-    "distributed/elastic/utils/distributed_test",
-    "distributed/elastic/multiprocessing/api_test",
-    "distributed/_sharding_spec/test_sharding_spec",
-    "distributed/_sharded_tensor/test_sharded_tensor",
-]
+def discover_tests(
+        base_dir: Optional[pathlib.Path] = None,
+        blocklisted_patterns: Optional[List[str]] = None,
+        blocklisted_tests: Optional[List[str]] = None,
+        extra_tests: Optional[List[str]] = None) -> List[str]:
+    """
+    Searches for all python files starting with test_ excluding one specified by patterns
+    """
+    def skip_test_p(name: str) -> bool:
+        rc = False
+        if blocklisted_patterns is not None:
+            rc |= any(name.startswith(pattern) for pattern in blocklisted_patterns)
+        if blocklisted_tests is not None:
+            rc |= name in blocklisted_tests
+        return rc
+    cwd = pathlib.Path(__file__).resolve().parent if base_dir is None else base_dir
+    all_py_files = list(cwd.glob('**/test_*.py'))
+    rc = [str(fname.relative_to(cwd))[:-3] for fname in all_py_files]
+    # Invert slashes on Windows
+    if sys.platform == "win32":
+        rc = [name.replace('\\', '/') for name in rc]
+    rc = [test for test in rc if not skip_test_p(test)]
+    if extra_tests is not None:
+        rc += extra_tests
+    return sorted(rc)
+
+
+TESTS = discover_tests(
+    blocklisted_patterns=[
+        'ao',
+        'bottleneck_test',
+        'custom_backend',
+        'custom_operator',
+        'fx',        # executed by test_fx.py
+        'jit',      # executed by test_jit.py
+        'mobile',
+        'onnx',
+        'package',  # executed by test_package.py
+        'quantization',  # executed by test_quantization.py
+    ],
+    blocklisted_tests=[
+        'test_bundled_images',
+        'test_cpp_extensions_aot',
+        'test_gen_backend_stubs',
+        'test_jit_fuser',
+        'test_jit_simple',
+        'test_jit_string',
+        'test_kernel_launch_checks',
+        'test_metal',
+        'test_nnapi',
+        'test_python_dispatch',
+        'test_segment_reductions',
+        'test_static_runtime',
+        'test_throughput_benchmark',
+        'test_typing',
+        "distributed/algorithms/ddp_comm_hooks/test_ddp_hooks",
+        "distributed/algorithms/quantization/test_quantization",
+        "distributed/bin/test_script",
+        "distributed/elastic/multiprocessing/bin/test_script",
+        "distributed/launcher/bin/test_script",
+        "distributed/launcher/bin/test_script_init_method",
+        "distributed/launcher/bin/test_script_is_torchelastic_launched",
+        "distributed/launcher/bin/test_script_local_rank",
+        "distributed/test_c10d_spawn",
+        'distributions/test_transforms',
+        'distributions/test_utils',
+    ],
+    extra_tests=[
+        "test_cpp_extensions_aot_ninja",
+        "test_cpp_extensions_aot_no_ninja",
+        "distributed/elastic/timer/api_test",
+        "distributed/elastic/timer/local_timer_example",
+        "distributed/elastic/timer/local_timer_test",
+        "distributed/elastic/events/lib_test",
+        "distributed/elastic/metrics/api_test",
+        "distributed/elastic/utils/logging_test",
+        "distributed/elastic/utils/util_test",
+        "distributed/elastic/utils/distributed_test",
+        "distributed/elastic/multiprocessing/api_test",
+    ]
+)
 
 # Tests need to be run with pytest.
 USE_PYTEST_LIST = [