[BE] remove _SUPPORTED_OPTIM_MAP from tests (#63383)
authorRohan Varma <rvarm1@fb.com>
Wed, 18 Aug 2021 00:12:32 +0000 (17:12 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 18 Aug 2021 00:17:25 +0000 (17:17 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63383

Per title
ghstack-source-id: 135966157

Test Plan: CI

Reviewed By: SciPioneer

Differential Revision: D30358921

fbshipit-source-id: 965e054e525194b1ee55980340df275bab355c9b

test/distributed/test_c10d_nccl.py
test/test_functional_optim.py
torch/testing/_internal/distributed/distributed_test.py

index 285053d..f7f6681 100644 (file)
@@ -49,16 +49,12 @@ from torch.testing._internal.common_utils import (
     sandcastle_skip_if,
 )
 from torch.utils.checkpoint import checkpoint
+from torch.distributed.optim import functional_optim_map
 
 if not IS_WINDOWS:
     from torch.distributed.optim.functional_sgd import _FunctionalSGD
     from torch.distributed.optim.functional_adam import _FunctionalAdam
     from torch.distributed.optim.functional_adamw import _FunctionalAdamW
-    _SUPPORTED_OPTIM_MAPPING = {
-        _FunctionalSGD: torch.optim.SGD,
-        _FunctionalAdam: torch.optim.Adam,
-        _FunctionalAdamW: torch.optim.AdamW,
-    }
 
 if TEST_WITH_TSAN:
     print(
@@ -1639,7 +1635,8 @@ class DistributedDataParallelTest(
         gpu_model_allreduce = self._gpu_model_with_ddp_comm_hook(
             process_group, default.allreduce_hook, gradient_as_bucket_view, hook_state
         )
-        sgd = _SUPPORTED_OPTIM_MAPPING.get(functional_optim_cls)(
+        mapping = {v: k for k, v in functional_optim_map.items()}
+        sgd = mapping.get(functional_optim_cls)(
             gpu_model_allreduce.parameters(),
             *functional_optim_args,
             **functional_optim_kwargs,
index 59af691..98a3f06 100644 (file)
@@ -5,17 +5,7 @@ import torch.nn as nn
 import torch.nn.functional as F
 from torch.optim import SGD, Adam, AdamW
 from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
-
-if not IS_WINDOWS:
-    from torch.distributed.optim.functional_sgd import _FunctionalSGD
-    from torch.distributed.optim.functional_adam import _FunctionalAdam
-    from torch.distributed.optim.functional_adamw import _FunctionalAdamW
-    _SUPPORTED_OPTIM_MAPPING = {
-        SGD: _FunctionalSGD,
-        Adam: _FunctionalAdam,
-        AdamW: _FunctionalAdamW,
-    }
-
+from torch.distributed.optim import functional_optim_map
 
 class MyModule(torch.nn.Module):
     def __init__(self):
@@ -39,7 +29,7 @@ class TestFunctionalOptimParity(TestCase):
         optim_params = module_optim.parameters()
         functional_params = module_functional.parameters()
         optim = optim_cls(optim_params, *args, **kwargs)
-        functional_optim_cls = _SUPPORTED_OPTIM_MAPPING.get(optim_cls, None)
+        functional_optim_cls = functional_optim_map.get(optim_cls, None)
         if not functional_optim_cls:
             raise ValueError(f"Functional optimizer not implemented for {optim_cls}")
         optim_functional = functional_optim_cls(
index 6ef94c9..2a126ab 100644 (file)
@@ -66,16 +66,13 @@ from torch.testing._internal.common_utils import (
     sandcastle_skip_if,
 )
 
+from torch.distributed.optim import functional_optim_map
+
 if not IS_WINDOWS:
     import torch.distributed.optim.post_localSGD_optimizer as post_localSGD_optimizer
     from torch.distributed.optim.functional_sgd import _FunctionalSGD
     from torch.distributed.optim.functional_adam import _FunctionalAdam
     from torch.distributed.optim.functional_adamw import _FunctionalAdamW
-    _SUPPORTED_OPTIM_MAPPING = {
-        _FunctionalSGD: torch.optim.SGD,
-        _FunctionalAdam: torch.optim.Adam,
-        _FunctionalAdamW: torch.optim.AdamW,
-    }
 
 from torch.utils.data.distributed import DistributedSampler
 
@@ -3949,7 +3946,8 @@ class DistributedTest:
                     if static_graph:
                         ddp_model_with_no_hook._set_static_graph()
 
-                    optimizer_no_hook = _SUPPORTED_OPTIM_MAPPING.get(functional_optim_cls)(
+                    mapping = {v: k for k, v in functional_optim_map.items()}
+                    optimizer_no_hook = mapping.get(functional_optim_cls)(
                         ddp_model_with_no_hook.parameters(),
                         *functional_optim_args,
                         **functional_optim_kwargs,