else:
self._schedulers[idx].step()
+ def state_dict(self):
+ """Returns the state of the scheduler as a :class:`dict`.
+
+ It contains an entry for every variable in self.__dict__ which
+ is not the optimizer.
+ The wrapped scheduler states will also be saved.
+ """
+ state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
+ state_dict['_schedulers'] = [None] * len(self._schedulers)
+
+ for idx, s in enumerate(self._schedulers):
+ state_dict['_schedulers'][idx] = s.state_dict()
+
+ return state_dict
+
+ def load_state_dict(self, state_dict):
+ """Loads the schedulers state.
+
+ Args:
+ state_dict (dict): scheduler state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ _schedulers = state_dict.pop('_schedulers')
+ self.__dict__.update(state_dict)
+ # Restore state_dict keys in order to prevent side effects
+ # https://github.com/pytorch/pytorch/issues/32756
+ state_dict['_schedulers'] = _schedulers
+
+ for idx, s in enumerate(_schedulers):
+ self._schedulers[idx].load_state_dict(s)
+
class CosineAnnealingLR(_LRScheduler):
r"""Set the learning rate of each parameter group using a cosine annealing