sandcastle_skip_if,
)
from torch.utils.checkpoint import checkpoint
+from torch.distributed.optim import functional_optim_map
if not IS_WINDOWS:
from torch.distributed.optim.functional_sgd import _FunctionalSGD
from torch.distributed.optim.functional_adam import _FunctionalAdam
from torch.distributed.optim.functional_adamw import _FunctionalAdamW
- _SUPPORTED_OPTIM_MAPPING = {
- _FunctionalSGD: torch.optim.SGD,
- _FunctionalAdam: torch.optim.Adam,
- _FunctionalAdamW: torch.optim.AdamW,
- }
if TEST_WITH_TSAN:
print(
gpu_model_allreduce = self._gpu_model_with_ddp_comm_hook(
process_group, default.allreduce_hook, gradient_as_bucket_view, hook_state
)
- sgd = _SUPPORTED_OPTIM_MAPPING.get(functional_optim_cls)(
+ mapping = {v: k for k, v in functional_optim_map.items()}
+ sgd = mapping.get(functional_optim_cls)(
gpu_model_allreduce.parameters(),
*functional_optim_args,
**functional_optim_kwargs,
import torch.nn.functional as F
from torch.optim import SGD, Adam, AdamW
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
-
-if not IS_WINDOWS:
- from torch.distributed.optim.functional_sgd import _FunctionalSGD
- from torch.distributed.optim.functional_adam import _FunctionalAdam
- from torch.distributed.optim.functional_adamw import _FunctionalAdamW
- _SUPPORTED_OPTIM_MAPPING = {
- SGD: _FunctionalSGD,
- Adam: _FunctionalAdam,
- AdamW: _FunctionalAdamW,
- }
-
+from torch.distributed.optim import functional_optim_map
class MyModule(torch.nn.Module):
def __init__(self):
optim_params = module_optim.parameters()
functional_params = module_functional.parameters()
optim = optim_cls(optim_params, *args, **kwargs)
- functional_optim_cls = _SUPPORTED_OPTIM_MAPPING.get(optim_cls, None)
+ functional_optim_cls = functional_optim_map.get(optim_cls, None)
if not functional_optim_cls:
raise ValueError(f"Functional optimizer not implemented for {optim_cls}")
optim_functional = functional_optim_cls(
sandcastle_skip_if,
)
+from torch.distributed.optim import functional_optim_map
+
if not IS_WINDOWS:
import torch.distributed.optim.post_localSGD_optimizer as post_localSGD_optimizer
from torch.distributed.optim.functional_sgd import _FunctionalSGD
from torch.distributed.optim.functional_adam import _FunctionalAdam
from torch.distributed.optim.functional_adamw import _FunctionalAdamW
- _SUPPORTED_OPTIM_MAPPING = {
- _FunctionalSGD: torch.optim.SGD,
- _FunctionalAdam: torch.optim.Adam,
- _FunctionalAdamW: torch.optim.AdamW,
- }
from torch.utils.data.distributed import DistributedSampler
if static_graph:
ddp_model_with_no_hook._set_static_graph()
- optimizer_no_hook = _SUPPORTED_OPTIM_MAPPING.get(functional_optim_cls)(
+ mapping = {v: k for k, v in functional_optim_map.items()}
+ optimizer_no_hook = mapping.get(functional_optim_cls)(
ddp_model_with_no_hook.parameters(),
*functional_optim_args,
**functional_optim_kwargs,