To add state_dict and load_state_dict to SequentialLR (#65035)
authorIlqar Ramazanli <iramazanli@fb.com>
Wed, 15 Sep 2021 18:55:53 +0000 (11:55 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 15 Sep 2021 19:01:51 +0000 (12:01 -0700)
Summary:
To add state_dict() and load_state_dict() methods to SequentialLR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/65035

Reviewed By: prabhat00155, nateanl

Differential Revision: D30958204

Pulled By: datumbox

fbshipit-source-id: 65114e1b07146526ae2680233f5cd42b2534d67a

torch/optim/lr_scheduler.py

index 2043401..8bf880e 100644 (file)
@@ -632,6 +632,37 @@ class SequentialLR(_LRScheduler):
         else:
             self._schedulers[idx].step()
 
+    def state_dict(self):
+        """Returns the state of the scheduler as a :class:`dict`.
+
+        It contains an entry for every variable in self.__dict__ which
+        is not the optimizer.
+        The wrapped scheduler states will also be saved.
+        """
+        state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
+        state_dict['_schedulers'] = [None] * len(self._schedulers)
+
+        for idx, s in enumerate(self._schedulers):
+            state_dict['_schedulers'][idx] = s.state_dict()
+
+        return state_dict
+
+    def load_state_dict(self, state_dict):
+        """Loads the schedulers state.
+
+        Args:
+            state_dict (dict): scheduler state. Should be an object returned
+                from a call to :meth:`state_dict`.
+        """
+        _schedulers = state_dict.pop('_schedulers')
+        self.__dict__.update(state_dict)
+        # Restore state_dict keys in order to prevent side effects
+        # https://github.com/pytorch/pytorch/issues/32756
+        state_dict['_schedulers'] = _schedulers
+
+        for idx, s in enumerate(_schedulers):
+            self._schedulers[idx].load_state_dict(s)
+
 
 class CosineAnnealingLR(_LRScheduler):
     r"""Set the learning rate of each parameter group using a cosine annealing