To add Chained Scheduler to the list of PyTorch schedulers. (#63491)
authorIlqar Ramazanli <iramazanli@fb.com>
Thu, 26 Aug 2021 20:29:03 +0000 (13:29 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 26 Aug 2021 20:30:21 +0000 (13:30 -0700)
Summary:
In this PR we are introducing ChainedScheduler which initially proposed in the discussion https://github.com/pytorch/pytorch/pull/26423#discussion_r329976246 .

The idea is to provide a user friendly chaining method for schedulers, especially for the cases many of them are involved and we want to have a clean and easy to read interface for schedulers. This method will be even more crucial once CompositeSchedulers and Schedulers for different type of parameters are involved.

The immediate application of Chained Scheduler is expected to happen in TorchVision Library to combine WarmUpLR and  MultiStepLR https://github.com/pytorch/vision/blob/master/references/video_classification/scheduler.py#L5 . However, it can be expected that in many other use cases also this method could be applied.

### Example
The usage is as simple as below:

```python
sched=ChainedScheduler([ExponentialLR(self.opt, gamma=0.9),
                        WarmUpLR(self.opt, warmup_factor=0.2, warmup_iters=4, warmup_method="constant"),
                        StepLR(self.opt, gamma=0.1, step_size=3)])
```

Then calling
```python
sched.step()
```
would trigger step function for all three schedulers consecutively

Partially resolves https://github.com/pytorch/vision/issues/4281

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63491

Reviewed By: datumbox, mruberry

Differential Revision: D30576180

Pulled By: iramazanli

fbshipit-source-id: b43f0749f55faab25079641b7d91c21a891a87e4

test/test_optim.py
torch/optim/lr_scheduler.py

index 01ec43b..fe282ef 100644 (file)
@@ -13,7 +13,7 @@ from torch.autograd import Variable
 from torch import sparse
 from torch.optim.lr_scheduler import LambdaLR, MultiplicativeLR, StepLR, \
     MultiStepLR, WarmUpLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, \
-    _LRScheduler, CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR
+    _LRScheduler, CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR, ChainedScheduler
 from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
 from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \
     skipIfRocm
@@ -1253,6 +1253,44 @@ class TestLRScheduler(TestCase):
                                       threshold=0.1, patience=5, cooldown=5)
         self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
 
+    def test_chained_lr1(self):
+        epochs = 10
+        schedulers = [None] * 1
+        targets = [[0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3]
+        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
+        scheduler = ChainedScheduler(schedulers)
+        self._test([scheduler], targets, epochs)
+
+    def test_chained_lr2(self):
+        epochs = 10
+        schedulers = [None] * 1
+        targets = [[0.02, 0.03, 0.04] + [0.05] * 9]
+        schedulers[0] = WarmUpLR(self.opt, warmup_factor=0.4, warmup_iters=3, warmup_method="linear")
+        scheduler = ChainedScheduler(schedulers)
+        self._test([scheduler], targets, epochs)
+
+    def test_chained_lr3(self):
+        epochs = 10
+        schedulers = [None] * 2
+        targets = [[0.02, 0.03, 0.04, 0.05] + [0.005] * 4 + [0.0005] * 3 + [0.00005] * 3]
+        schedulers[0] = WarmUpLR(self.opt, warmup_factor=0.4, warmup_iters=3, warmup_method="linear")
+        schedulers[1] = MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1)
+        scheduler = ChainedScheduler(schedulers)
+        self._test([scheduler], targets, epochs)
+
+    def test_chained_lr4(self):
+        epochs = 9
+        schedulers = [None] * 3
+        targets = [[0.05 * 0.2 * 0.9 ** x for x in range(3)]
+                   + [0.05 * 0.2 * 0.9 ** 3 * 0.1]
+                   + [0.05 * 0.9 ** x * 0.1 for x in range(4, 6)]
+                   + [0.05 * 0.9 ** x * 0.01 for x in range(6, 9)]]
+        schedulers[0] = ExponentialLR(self.opt, gamma=0.9)
+        schedulers[1] = WarmUpLR(self.opt, warmup_factor=0.2, warmup_iters=4, warmup_method="constant")
+        schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=3)
+        scheduler = ChainedScheduler(schedulers)
+        self._test([scheduler], targets, epochs)
+
     def test_compound_step_and_multistep_lr(self):
         epochs = 10
         schedulers = [None] * 2
index 657a35a..761a404 100644 (file)
@@ -603,6 +603,44 @@ class CosineAnnealingLR(_LRScheduler):
                 for base_lr in self.base_lrs]
 
 
+class ChainedScheduler(_LRScheduler):
+    """Chains list of learning rate schedulers. It takes a list of chainable learning
+    rate schedulers and performs consecutive step() functions belong to them by just
+    one call.
+
+    Args:
+        schedulers (list): List of chained schedulers.
+
+    Example:
+        >>> # Assuming optimizer uses lr = 1. for all groups
+        >>> # lr = 0.09     if epoch == 0
+        >>> # lr = 0.081    if epoch == 1
+        >>> # lr = 0.729    if epoch == 2
+        >>> # lr = 0.6561   if epoch == 3
+        >>> # lr = 0.59049  if epoch >= 4
+        >>> scheduler1 = WarmUpLR(self.opt, warmup_factor=0.1, warmup_iters=2, warmup_method="constant")
+        >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
+        >>> scheduler = ChainedScheduler([scheduler1, scheduler2])
+        >>> for epoch in range(100):
+        >>>     train(...)
+        >>>     validate(...)
+        >>>     scheduler.step()
+    """
+
+    def __init__(self, schedulers):
+        for scheduler_idx in range(1, len(schedulers)):
+            if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
+                raise ValueError(
+                    "ChainedScheduler expects all schedulers to belong to the same optimizer, but "
+                    "got schedulers at index {} and {} to be different".format(0, scheduler_idx)
+                )
+        self.schedulers = list(schedulers)
+
+    def step(self):
+        for scheduler in self.schedulers:
+            scheduler.step()
+
+
 class ReduceLROnPlateau(object):
     """Reduce learning rate when a metric has stopped improving.
     Models often benefit from reducing the learning rate by a factor