To add SequentialLR to PyTorch Core Schedulers (#64037)
authorIlqar Ramazanli <iramazanli@fb.com>
Thu, 9 Sep 2021 16:32:36 +0000 (09:32 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 16:36:32 +0000 (09:36 -0700)
Summary:
Partially resolves https://github.com/pytorch/vision/issues/4281

In this PR we are proposing a new scheduler --SequentialLR-- which enables list of different schedulers called in different periods of the training process.

The main motivation of this scheduler is recently gained popularity of warming up phase in the training time. It has been shown that having a small steps in initial stages of training can help convergence procedure get faster.

With the help of SequentialLR we mainly enable to call a small constant (or linearly increasing) learning rate followed by actual target learning rate scheduler.

```PyThon
scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2)
scheduler2 = ExponentialLR(optimizer, gamma=0.9)
scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[5])

for epoch in range(100):
    train(...)
    validate(...)
    scheduler.step()
```

which this code snippet will call `ConstantLR` in the first 5 epochs and will follow up with `ExponentialLR` in the following epochs.

This scheduler could be used to provide call of any group of schedulers next to each other. The main consideration we should make is every time we switch to a new scheduler we assume that new scheduler starts from the beginning- zeroth epoch.

We also add Chained Scheduler to `optim.rst` and `lr_scheduler.pyi` files here.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64037

Reviewed By: albanD

Differential Revision: D30841099

Pulled By: iramazanli

fbshipit-source-id: 94f7d352066ee108eef8cda5f0dcb07f4d371751

docs/source/optim.rst
test/test_optim.py
torch/optim/lr_scheduler.py
torch/optim/lr_scheduler.pyi

index 695f0a2..62a293d 100644 (file)
@@ -214,6 +214,8 @@ algorithms.
     lr_scheduler.LinearLR
     lr_scheduler.ExponentialLR
     lr_scheduler.CosineAnnealingLR
+    lr_scheduler.ChainedScheduler
+    lr_scheduler.SequentialLR
     lr_scheduler.ReduceLROnPlateau
     lr_scheduler.CyclicLR
     lr_scheduler.OneCycleLR
index d69e935..2d88d6f 100644 (file)
@@ -11,7 +11,7 @@ import torch.nn.functional as F
 from torch.optim import SGD
 from torch.autograd import Variable
 from torch import sparse
-from torch.optim.lr_scheduler import LambdaLR, MultiplicativeLR, StepLR, \
+from torch.optim.lr_scheduler import LambdaLR, MultiplicativeLR, SequentialLR, StepLR, \
     MultiStepLR, ConstantLR, LinearLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, \
     _LRScheduler, CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR, ChainedScheduler
 from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
@@ -1255,6 +1255,41 @@ class TestLRScheduler(TestCase):
                                       threshold=0.1, patience=5, cooldown=5)
         self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
 
+    def test_sequentiallr1(self):
+        epochs = 19
+        schedulers = [None] * 2
+        targets = [[0.05, 0.04, 0.032] + [0.05 for x in range(4)]
+                                       + [0.05 * 0.1 for x in range(4)]
+                                       + [0.05 * 0.01 for x in range(4)]
+                                       + [0.05 * 0.001 for x in range(4)]]
+        milestones = [3]
+        schedulers[0] = ExponentialLR(self.opt, gamma=0.8)
+        schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=4)
+        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
+        self._test(scheduler, targets, epochs)
+
+    def test_sequentiallr2(self):
+        epochs = 13
+        schedulers = [None] * 2
+        targets = [[0.005, 0.005, 0.005] + [0.05 * 0.9 ** x for x in range(10)]]
+        milestones = [3]
+        schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
+        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
+        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
+        self._test(scheduler, targets, epochs)
+
+    def test_sequentiallr3(self):
+        epochs = 12
+        schedulers = [None] * 3
+        targets = [[0.005, 0.005, 0.005] + [0.05, 0.04, 0.032]
+                                         + [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]]
+        milestones = [3, 6]
+        schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
+        schedulers[1] = ExponentialLR(self.opt, gamma=0.8)
+        schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2)
+        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
+        self._test(scheduler, targets, epochs)
+
     def test_chained_lr1(self):
         epochs = 10
         schedulers = [None] * 1
index 42f7b51..2043401 100644 (file)
@@ -582,6 +582,57 @@ class ExponentialLR(_LRScheduler):
                 for base_lr in self.base_lrs]
 
 
+class SequentialLR(_LRScheduler):
+    """Receives the list of schedulers that is expected to be called sequentially during
+    optimization process and milestone points that provides exact intervals to reflect
+    which scheduler is supposed to be called at a given epoch.
+
+    Args:
+        schedulers (list): List of chained schedulers.
+        milestones (list): List of integers that reflects milestone points.
+
+    Example:
+        >>> # Assuming optimizer uses lr = 1. for all groups
+        >>> # lr = 0.1     if epoch == 0
+        >>> # lr = 0.1     if epoch == 1
+        >>> # lr = 0.9     if epoch == 2
+        >>> # lr = 0.81    if epoch == 3
+        >>> # lr = 0.729   if epoch == 4
+        >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
+        >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
+        >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
+        >>> for epoch in range(100):
+        >>>     train(...)
+        >>>     validate(...)
+        >>>     scheduler.step()
+    """
+
+    def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False):
+        for scheduler_idx in range(1, len(schedulers)):
+            if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
+                raise ValueError(
+                    "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
+                    "got schedulers at index {} and {} to be different".format(0, scheduler_idx)
+                )
+        if (len(milestones) != len(schedulers) - 1):
+            raise ValueError(
+                "Sequential Schedulers expects number of schedulers provided to be one more "
+                "than the number of milestone points, but got number of schedulers {} and the "
+                "number of milestones to be equal to {}".format(len(schedulers), len(milestones))
+            )
+        self._schedulers = schedulers
+        self._milestones = milestones
+        self.last_epoch = last_epoch + 1
+
+    def step(self):
+        self.last_epoch += 1
+        idx = bisect_right(self._milestones, self.last_epoch)
+        if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
+            self._schedulers[idx].step(0)
+        else:
+            self._schedulers[idx].step()
+
+
 class CosineAnnealingLR(_LRScheduler):
     r"""Set the learning rate of each parameter group using a cosine annealing
     schedule, where :math:`\eta_{max}` is set to the initial lr and
index 9b1b8ea..9552e8e 100644 (file)
@@ -27,6 +27,12 @@ class LinearLR(_LRScheduler):
 class ExponentialLR(_LRScheduler):
     def __init__(self, optimizer: Optimizer, gamma: float, last_epoch: int=...) -> None: ...
 
+class ChainedScheduler(_LRScheduler):
+    def __init__(self, schedulers: List[_LRScheduler]) -> None: ...
+
+class SequentialLR(_LRScheduler):
+    def __init__(self, schedulers: List[_LRScheduler], milestones: List[int], last_epoch: int=...) -> None: ...
+
 class CosineAnnealingLR(_LRScheduler):
     def __init__(self, optimizer: Optimizer, T_max: int, eta_min: float=..., last_epoch: int=...) -> None: ...