From 28f9e108b10cde1979b29547ccd87fd11a411bce Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Fri, 13 Aug 2021 08:19:23 -0700 Subject: [PATCH] Pass `_allow_empty_param_list` into func opt ctor (#63163) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63163 Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D30284615 Pulled By: andwgu fbshipit-source-id: 4857f5b618ec5b007648737ab532ce605e5d70dc --- test/distributed/optim/test_zero_redundancy_optimizer.py | 1 - torch/distributed/optim/zero_redundancy_optimizer.py | 12 +++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py index c60c7de..e6259ad 100644 --- a/test/distributed/optim/test_zero_redundancy_optimizer.py +++ b/test/distributed/optim/test_zero_redundancy_optimizer.py @@ -967,7 +967,6 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer): lr=SGD_LR, momentum=SGD_MOMENTUM, weight_decay=SGD_WEIGHT_DECAY, - _allow_empty_param_list=True ) ddp_model_overlap.register_comm_hook( None, diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py index bba71e4..2454e75 100644 --- a/torch/distributed/optim/zero_redundancy_optimizer.py +++ b/torch/distributed/optim/zero_redundancy_optimizer.py @@ -6,6 +6,7 @@ import collections import copy import enum +import inspect import io import logging from itertools import chain @@ -1375,7 +1376,16 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): assert len(param_groups) == 1, "Initializing the local " \ "functional optimizer with more than one parameter group" params = param_groups[0]["params"] - self.optim: Any = self._optim_constructor(params, **self._optim_defaults) + # Try to pass `_allow_empty_param_list=True` to avoid erroring + if "_allow_empty_param_list" in inspect.signature(self._optim_constructor).parameters: + self.optim: Any = self._optim_constructor(params, **self._optim_defaults, _allow_empty_param_list=True) + else: + logging.warning( + f"{self._optim_constructor} does not support the argument " + "`_allow_empty_param_list`; ZeroRedundancyOptimizer may " + "error due to an empty parameter list" + ) + self.optim: Any = self._optim_constructor(params, **self._optim_defaults) # Log information about the DDP and ZeRO bucketing if dist._get_debug_mode() != dist._DistributedDebugLevel.OFF: -- 2.7.4