To change WarmUp Scheduler with ConstantLR and LinearLR (#64395)
authorIlqar Ramazanli <iramazanli@fb.com>
Tue, 7 Sep 2021 15:41:09 +0000 (08:41 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 7 Sep 2021 15:42:31 +0000 (08:42 -0700)
Summary:
Partially unblocks https://github.com/pytorch/vision/issues/4281

Previously we have added WarmUp Schedulers to PyTorch Core in the PR : https://github.com/pytorch/pytorch/pull/60836 which had two mode of execution - linear and constant depending on warming up function.

In this PR we are changing this interface to more direct form, as separating linear and constant modes to separate Schedulers. In particular

```Python
scheduler1 = WarmUpLR(optimizer, warmup_factor=0.1, warmup_iters=5, warmup_method="constant")
scheduler2 = WarmUpLR(optimizer, warmup_factor=0.1, warmup_iters=5, warmup_method="linear")
```

will look like

```Python
scheduler1 = ConstantLR(optimizer, warmup_factor=0.1, warmup_iters=5)
scheduler2 = LinearLR(optimizer, warmup_factor=0.1, warmup_iters=5)
```

correspondingly.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64395

Reviewed By: datumbox

Differential Revision: D30753688

Pulled By: iramazanli

fbshipit-source-id: e47f86d12033f80982ddf1faf5b46873adb4f324

docs/source/optim.rst
test/test_optim.py
torch/optim/lr_scheduler.py
torch/optim/lr_scheduler.pyi

index 2ded57f..695f0a2 100644 (file)
@@ -210,7 +210,8 @@ algorithms.
     lr_scheduler.MultiplicativeLR
     lr_scheduler.StepLR
     lr_scheduler.MultiStepLR
-    lr_scheduler.WarmUpLR
+    lr_scheduler.ConstantLR
+    lr_scheduler.LinearLR
     lr_scheduler.ExponentialLR
     lr_scheduler.CosineAnnealingLR
     lr_scheduler.ReduceLROnPlateau
index fe282ef..d69e935 100644 (file)
@@ -12,7 +12,7 @@ from torch.optim import SGD
 from torch.autograd import Variable
 from torch import sparse
 from torch.optim.lr_scheduler import LambdaLR, MultiplicativeLR, StepLR, \
-    MultiStepLR, WarmUpLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, \
+    MultiStepLR, ConstantLR, LinearLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, \
     _LRScheduler, CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR, ChainedScheduler
 from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
 from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \
@@ -274,16 +274,16 @@ class TestOptim(TestCase):
             )
             self._test_basic_cases(
                 lambda weight, bias: optimizer([weight, bias], lr=1e-3),
-                [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="linear")]
+                [lambda opt: LinearLR(opt, start_factor=0.4, end_factor=0.8, total_iters=4)]
             )
             self._test_basic_cases(
                 lambda weight, bias: optimizer([weight, bias], lr=1e-3),
-                [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant")]
+                [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)]
             )
             self._test_basic_cases(
                 lambda weight, bias: optimizer([weight, bias], lr=1e-3),
                 [lambda opt: StepLR(opt, gamma=0.9, step_size=10),
-                 lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4)]
+                 lambda opt: LinearLR(opt, start_factor=0.4, end_factor=0.6, total_iters=4)]
             )
             self._test_basic_cases(
                 lambda weight, bias: optimizer([weight, bias], lr=1e-3),
@@ -430,17 +430,17 @@ class TestOptim(TestCase):
                 lambda weight, bias: optimizer(
                     self._build_params_dict(weight, bias, lr=1e-2),
                     lr=1e-3),
-                [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="linear")]
+                [lambda opt: LinearLR(opt, start_factor=0.4, total_iters=4)]
             )
             self._test_basic_cases(
                 lambda weight, bias: optimizer(
                     self._build_params_dict(weight, bias, lr=1e-2),
                     lr=1e-3),
-                [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant")]
+                [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4)]
             )
             self._test_basic_cases(
                 lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True),
-                [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant"),
+                [lambda opt: ConstantLR(opt, factor=0.4, total_iters=4),
                  lambda opt: ExponentialLR(opt, gamma=0.9)]
             )
             self._test_basic_cases(
@@ -992,12 +992,12 @@ class TestLRScheduler(TestCase):
         scheduler = ExponentialLR(self.opt, gamma=0.9)
         self._test_lr_is_constant_for_constant_epoch(scheduler)
 
-    def test_constant_warmup_lr_is_constant_for_constant_epoch(self):
-        scheduler = WarmUpLR(self.opt, warmup_method="constant")
+    def test_constantlr_is_constant_for_constant_epoch(self):
+        scheduler = ConstantLR(self.opt)
         self._test_lr_is_constant_for_constant_epoch(scheduler)
 
-    def test_linear_warmup_lr_is_constant_for_constant_epoch(self):
-        scheduler = WarmUpLR(self.opt, warmup_method="linear")
+    def test_linear_linearlr_is_constant_for_constant_epoch(self):
+        scheduler = LinearLR(self.opt)
         self._test_lr_is_constant_for_constant_epoch(scheduler)
 
     def test_step_lr(self):
@@ -1051,76 +1051,78 @@ class TestLRScheduler(TestCase):
         scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
         self._test_with_epoch(scheduler, targets, epochs)
 
-    def test__get_last_lr_constant_warmup_lr(self):
+    def test_get_last_lr_constantlr(self):
         # lr = 0.025     if epoch < 5
         # lr = 0.005    if 5 <= epoch
         epochs = 10
         single_targets = [0.025] * 5 + [0.05] * 5
         targets = [single_targets, [x * epochs for x in single_targets]]
-        scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 2, warmup_iters=5, warmup_method="constant")
+        scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
         self._test_get_last_lr(scheduler, targets, epochs)
 
-    def test__get_last_lr_linear_warmup_lr(self):
+    def test_get_last_lr_linearlr(self):
         # lr = 0.025     if epoch == 0
         # lr = 0.03125   if epoch == 1
         # lr = 0.0375    if epoch == 2
         # lr = 0.04375   if epoch == 3
         # lr = 0.005     if 4 <= epoch
         epochs = 10
-        factor = 1.0 / 2
+        start_factor = 1.0 / 4
+        end_factor = 3. / 5
         iters = 4
-        interpolation = [factor + i * (1 - factor) / iters for i in range(iters)]
-        single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters)
+        interpolation = [start_factor + i * (end_factor - start_factor) / iters for i in range(iters)]
+        single_targets = [x * 0.05 for x in interpolation] + [0.05 * end_factor] * (epochs - iters)
         targets = [single_targets, [x * epochs for x in single_targets]]
-        scheduler = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear")
+        scheduler = LinearLR(self.opt, start_factor=start_factor, end_factor=end_factor, total_iters=iters)
         self._test_get_last_lr(scheduler, targets, epochs)
 
-    def test__constant_warmup_lr(self):
+    def test_constantlr(self):
         # lr = 0.025     if epoch < 5
         # lr = 0.005    if 5 <= epoch
         epochs = 10
         single_targets = [0.025] * 5 + [0.05] * 5
         targets = [single_targets, [x * epochs for x in single_targets]]
-        scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 2, warmup_iters=5, warmup_method="constant")
+        scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
         self._test(scheduler, targets, epochs)
 
-    def test__linear_warmup_lr(self):
+    def test_linearlr(self):
         # lr = 0.025     if epoch == 0
         # lr = 0.03125   if epoch == 1
         # lr = 0.0375    if epoch == 2
         # lr = 0.04375   if epoch == 3
         # lr = 0.005     if 4 <= epoch
         epochs = 10
-        factor = 1.0 / 2
+        start_factor = 1.0 / 2
         iters = 4
-        interpolation = [factor + i * (1 - factor) / iters for i in range(iters)]
+        interpolation = [start_factor + i * (1 - start_factor) / iters for i in range(iters)]
         single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters)
         targets = [single_targets, [x * epochs for x in single_targets]]
-        scheduler = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear")
+        scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
         self._test(scheduler, targets, epochs)
 
-    def test_constant_warmup_with_epoch(self):
+    def test_constantlr_with_epoch(self):
         # lr = 0.025     if epoch < 5
         # lr = 0.005    if 5 <= epoch
         epochs = 10
         single_targets = [0.025] * 5 + [0.05] * 5
         targets = [single_targets, [x * epochs for x in single_targets]]
-        scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 2, warmup_iters=5, warmup_method="constant")
+        scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
         self._test_with_epoch(scheduler, targets, epochs)
 
-    def test_linear_warmup_with_epoch(self):
+    def test_linearlr_with_epoch(self):
         # lr = 0.025     if epoch == 0
         # lr = 0.03125   if epoch == 1
         # lr = 0.0375    if epoch == 2
         # lr = 0.04375   if epoch == 3
         # lr = 0.005     if 4 <= epoch
         epochs = 10
-        factor = 1.0 / 2
+        start_factor = 1.0 / 2
+        end_factor = 1.
         iters = 4
-        interpolation = [factor + i * (1 - factor) / iters for i in range(iters)]
+        interpolation = [start_factor + i * (end_factor - start_factor) / iters for i in range(iters)]
         single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters)
         targets = [single_targets, [x * epochs for x in single_targets]]
-        scheduler = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear")
+        scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
         self._test_with_epoch(scheduler, targets, epochs)
 
     def test_exp_lr(self):
@@ -1145,14 +1147,14 @@ class TestLRScheduler(TestCase):
         closed_form_scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
         self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
 
-    def test_closed_form_linear_warmup_lr(self):
-        scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 3, warmup_iters=4, warmup_method="linear")
-        closed_form_scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 3, warmup_iters=4, warmup_method="linear")
+    def test_closed_form_linearlr(self):
+        scheduler = LinearLR(self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4)
+        closed_form_scheduler = LinearLR(self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4)
         self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
 
-    def test_closed_form_constant_warmup_lr(self):
-        scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 3, warmup_iters=4, warmup_method="constant")
-        closed_form_scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 3, warmup_iters=4, warmup_method="constant")
+    def test_closed_form_constantlr(self):
+        scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4)
+        closed_form_scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4)
         self._test_against_closed_form(scheduler, closed_form_scheduler, 20)
 
     def test_closed_form_multi_step_lr(self):
@@ -1265,7 +1267,7 @@ class TestLRScheduler(TestCase):
         epochs = 10
         schedulers = [None] * 1
         targets = [[0.02, 0.03, 0.04] + [0.05] * 9]
-        schedulers[0] = WarmUpLR(self.opt, warmup_factor=0.4, warmup_iters=3, warmup_method="linear")
+        schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3)
         scheduler = ChainedScheduler(schedulers)
         self._test([scheduler], targets, epochs)
 
@@ -1273,7 +1275,7 @@ class TestLRScheduler(TestCase):
         epochs = 10
         schedulers = [None] * 2
         targets = [[0.02, 0.03, 0.04, 0.05] + [0.005] * 4 + [0.0005] * 3 + [0.00005] * 3]
-        schedulers[0] = WarmUpLR(self.opt, warmup_factor=0.4, warmup_iters=3, warmup_method="linear")
+        schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3)
         schedulers[1] = MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1)
         scheduler = ChainedScheduler(schedulers)
         self._test([scheduler], targets, epochs)
@@ -1286,7 +1288,7 @@ class TestLRScheduler(TestCase):
                    + [0.05 * 0.9 ** x * 0.1 for x in range(4, 6)]
                    + [0.05 * 0.9 ** x * 0.01 for x in range(6, 9)]]
         schedulers[0] = ExponentialLR(self.opt, gamma=0.9)
-        schedulers[1] = WarmUpLR(self.opt, warmup_factor=0.2, warmup_iters=4, warmup_method="constant")
+        schedulers[1] = ConstantLR(self.opt, factor=0.2, total_iters=4)
         schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=3)
         scheduler = ChainedScheduler(schedulers)
         self._test([scheduler], targets, epochs)
@@ -1323,20 +1325,23 @@ class TestLRScheduler(TestCase):
         schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
         self._test(schedulers, targets, epochs)
 
-    def test_compound_exp_and_linear_warmup_lr(self):
+    def test_compound_exp_and_linearlr(self):
         epochs = 10
         iters = 4
-        factor = 0.4
+        start_factor = 0.4
+        end_factor = 0.9
         schedulers = [None] * 2
         single_targets = [0.05 * (0.9 ** x) for x in range(11)]
         for i in range(iters):
-            single_targets[i] *= factor + i / iters * (1 - factor)
+            single_targets[i] *= start_factor + i / iters * (end_factor - start_factor)
+        for i in range(iters, 11):
+            single_targets[i] *= end_factor
         targets = [single_targets, [x * epochs for x in single_targets]]
-        schedulers[0] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear")
+        schedulers[0] = LinearLR(self.opt, start_factor=start_factor, end_factor=end_factor, total_iters=iters)
         schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
         self._test(schedulers, targets, epochs)
 
-    def test_compound_step_and_constant_warmup(self):
+    def test_compound_step_and_constantlr(self):
         epochs = 10
         iters = 4
         factor = 0.4
@@ -1344,20 +1349,20 @@ class TestLRScheduler(TestCase):
         single_targets = [0.05 * 0.4] * 3 + [0.005 * 0.4] + [0.005] * 2 + [0.0005] * 3 + [0.00005] * 3
         targets = [single_targets, [x * epochs for x in single_targets]]
         schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
-        schedulers[1] = WarmUpLR(self.opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant")
+        schedulers[1] = ConstantLR(self.opt, factor=0.4, total_iters=4)
         self._test(schedulers, targets, epochs)
 
-    def test_compound_linear_warmup_and_multistep_lr(self):
+    def test_compound_linearlr_and_multistep_lr(self):
         epochs = 10
         iters = 4
-        factor = 0.4
+        start_factor = 0.4
         schedulers = [None] * 2
         single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 2
         for i in range(iters):
-            single_targets[i] *= factor + i / iters * (1 - factor)
+            single_targets[i] *= start_factor + i / iters * (1 - start_factor)
         targets = [single_targets, [x * epochs for x in single_targets]]
         schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
-        schedulers[1] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear")
+        schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
         self._test(schedulers, targets, epochs)
 
     def test_compound_cosanneal_and_step_lr(self):
@@ -1387,19 +1392,19 @@ class TestLRScheduler(TestCase):
         schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
         self._test(schedulers, targets, epochs)
 
-    def test_compound_cosanneal_and_linear_warmup_lr(self):
+    def test_compound_cosanneal_and_linearlr(self):
         epochs = 10
         iters = 4
-        factor = 0.4
+        start_factor = 0.4
         eta_min = 1e-10
         schedulers = [None] * 2
         single_targets = [eta_min + (0.05 - eta_min) *
                           (1 + math.cos(math.pi * x / epochs)) / 2
                           for x in range(epochs)]
         for i in range(iters):
-            single_targets[i] *= factor + i / iters * (1 - factor)
+            single_targets[i] *= start_factor + i / iters * (1 - start_factor)
         targets = [single_targets, [x * epochs for x in single_targets]]
-        schedulers[0] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear")
+        schedulers[0] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
         schedulers[1] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
         self._test(schedulers, targets, epochs)
 
@@ -1485,14 +1490,14 @@ class TestLRScheduler(TestCase):
 
     def test_compound_reduce_lr_on_plateau5(self):
         iters = 4
-        factor = 0.4
+        start_factor = 0.4
         epochs = 22
         for param_group in self.opt.param_groups:
             param_group['lr'] = 0.5
         single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2
         multipliers = [1] * 22
         for i in range(iters):
-            multipliers[i] *= factor + i / iters * (1 - factor)
+            multipliers[i] *= start_factor + i / iters * (1 - start_factor)
         single_targets = [x * y for x, y in zip(single_targets, multipliers)]
         targets = [single_targets]
         targets = targets[1:]  # test runs step before checking lr
@@ -1500,7 +1505,7 @@ class TestLRScheduler(TestCase):
         schedulers = [None] * 2
         schedulers[0] = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs',
                                           mode='min', threshold=0.1)
-        schedulers[1] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear")
+        schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
         self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
 
     def test_cycle_lr_invalid_mode(self):
index 761a404..42f7b51 100644 (file)
@@ -427,25 +427,78 @@ class MultiStepLR(_LRScheduler):
                 for base_lr in self.base_lrs]
 
 
-class WarmUpLR(_LRScheduler):
-    """Decays the learning rate of each parameter group by either a small constant
-    or linearly increasing small warmup factor until the number of epoch reaches a
-    pre-defined milestone: warmup_iters. Notice that such decay can happen
-    simultaneously with other changes to the learning rate from outside this scheduler.
+class ConstantLR(_LRScheduler):
+    """Decays the learning rate of each parameter group by a small constant factor until the
+    number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can
+    happen simultaneously with other changes to the learning rate from outside this scheduler.
     When last_epoch=-1, sets initial lr as lr.
 
     Args:
         optimizer (Optimizer): Wrapped optimizer.
-        warmup_factor (float): The number we multiply learning rate in the first epoch.
-            If the warming up method is constant, the multiplication factor of the
-            learning rate stays the same in all epochs, but, in the linear case, it
-            starts increasing in the following epochs. Default: 1./3.
-        warmup_iters (int): The number of warming up steps. Default: 5.
-        warmup_method (str): One of `constant` and `linear`. In `constant` mode, the
-            learning rate will be multiplied with a small constant until a milestone
-            defined in warmup_iters. In the `linear` case, the multiplication factor
-            starts with warmup_factor in the first epoch then linearly increases to
-            reach 1. in the epoch number warmup_iters. Default: `linear`.
+        factor (float): The number we multiply learning rate until the milestone. Default: 1./3.
+        total_iters (int): The number of steps that the scheduler decays the learning rate.
+            Default: 5.
+        last_epoch (int): The index of the last epoch. Default: -1.
+        verbose (bool): If ``True``, prints a message to stdout for
+            each update. Default: ``False``.
+
+    Example:
+        >>> # Assuming optimizer uses lr = 0.05 for all groups
+        >>> # lr = 0.025   if epoch == 0
+        >>> # lr = 0.025   if epoch == 1
+        >>> # lr = 0.025   if epoch == 2
+        >>> # lr = 0.025   if epoch == 3
+        >>> # lr = 0.05    if epoch >= 4
+        >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4)
+        >>> for epoch in range(100):
+        >>>     train(...)
+        >>>     validate(...)
+        >>>     scheduler.step()
+    """
+
+    def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False):
+        if factor > 1.0 or factor < 0:
+            raise ValueError('Constant multiplicative factor expected to be between 0 and 1.')
+
+        self.factor = factor
+        self.total_iters = total_iters
+        super(ConstantLR, self).__init__(optimizer, last_epoch, verbose)
+
+    def get_lr(self):
+        if not self._get_lr_called_within_step:
+            warnings.warn("To get the last learning rate computed by the scheduler, "
+                          "please use `get_last_lr()`.", UserWarning)
+
+        if self.last_epoch == 0:
+            return [group['lr'] * self.factor for group in self.optimizer.param_groups]
+
+        if (self.last_epoch > self.total_iters or
+                (self.last_epoch != self.total_iters)):
+            return [group['lr'] for group in self.optimizer.param_groups]
+
+        if (self.last_epoch == self.total_iters):
+            return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups]
+
+    def _get_closed_form_lr(self):
+        return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
+                for base_lr in self.base_lrs]
+
+
+class LinearLR(_LRScheduler):
+    """Decays the learning rate of each parameter group by linearly changing small
+    multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters.
+    Notice that such decay can happen simultaneously with other changes to the learning rate
+    from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
+
+    Args:
+        optimizer (Optimizer): Wrapped optimizer.
+        start_factor (float): The number we multiply learning rate in the first epoch.
+            The multiplication factor changes towards end_factor in the following epochs.
+            Default: 1./3.
+        end_factor (float): The number we multiply learning rate at the end of linear changing
+            process. Default: 1.0.
+        total_iters (int): The number of iterations that multiplicative factor reaches to 1.
+            Default: 5.
         last_epoch (int): The index of the last epoch. Default: -1.
         verbose (bool): If ``True``, prints a message to stdout for
             each update. Default: ``False``.
@@ -457,24 +510,25 @@ class WarmUpLR(_LRScheduler):
         >>> # lr = 0.0375   if epoch == 2
         >>> # lr = 0.04375  if epoch == 3
         >>> # lr = 0.005    if epoch >= 4
-        >>> scheduler = WarmUpLR(self.opt, warmup_factor=0.5, warmup_iters=4, warmup_method="linear")
+        >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4)
         >>> for epoch in range(100):
         >>>     train(...)
         >>>     validate(...)
         >>>     scheduler.step()
     """
 
-    def __init__(self, optimizer, warmup_factor=1.0 / 3, warmup_iters=5, warmup_method="linear",
-                 last_epoch=-1, verbose=False):
-        if warmup_method not in ("constant", "linear"):
-            raise ValueError(
-                "Only 'constant' or 'linear' warmup_method accepted, but "
-                "got {}".format(warmup_method)
-            )
-        self.warmup_factor = warmup_factor
-        self.warmup_iters = warmup_iters
-        self.warmup_method = warmup_method
-        super(WarmUpLR, self).__init__(optimizer, last_epoch, verbose)
+    def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1,
+                 verbose=False):
+        if start_factor > 1.0 or start_factor < 0:
+            raise ValueError('Starting multiplicative factor expected to be between 0 and 1.')
+
+        if end_factor > 1.0 or end_factor < 0:
+            raise ValueError('Ending multiplicative factor expected to be between 0 and 1.')
+
+        self.start_factor = start_factor
+        self.end_factor = end_factor
+        self.total_iters = total_iters
+        super(LinearLR, self).__init__(optimizer, last_epoch, verbose)
 
     def get_lr(self):
         if not self._get_lr_called_within_step:
@@ -482,25 +536,18 @@ class WarmUpLR(_LRScheduler):
                           "please use `get_last_lr()`.", UserWarning)
 
         if self.last_epoch == 0:
-            return [group['lr'] * self.warmup_factor for group in self.optimizer.param_groups]
+            return [group['lr'] * self.start_factor for group in self.optimizer.param_groups]
 
-        if (self.last_epoch > self.warmup_iters or
-                (self.warmup_method == "constant" and self.last_epoch != self.warmup_iters)):
+        if (self.last_epoch > self.total_iters):
             return [group['lr'] for group in self.optimizer.param_groups]
 
-        if (self.warmup_method == "constant" and self.last_epoch == self.warmup_iters):
-            return [group['lr'] * (1.0 / self.warmup_factor) for group in self.optimizer.param_groups]
-
-        return [group['lr'] * (1. + (1.0 - self.warmup_factor) /
-                (self.warmup_iters * self.warmup_factor + (self.last_epoch - 1) * (1 - self.warmup_factor)))
+        return [group['lr'] * (1. + (self.end_factor - self.start_factor) /
+                (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor)))
                 for group in self.optimizer.param_groups]
 
     def _get_closed_form_lr(self):
-        return [base_lr * (self.warmup_factor +
-                (1 - self.warmup_factor) * min(self.warmup_iters, self.last_epoch) /
-                self.warmup_iters * (self.warmup_method == "linear") +
-                (self.last_epoch >= self.warmup_iters) * (1 - self.warmup_factor) *
-                (self.warmup_method == "constant"))
+        return [base_lr * (self.start_factor +
+                (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters)
                 for base_lr in self.base_lrs]
 
 
@@ -618,7 +665,7 @@ class ChainedScheduler(_LRScheduler):
         >>> # lr = 0.729    if epoch == 2
         >>> # lr = 0.6561   if epoch == 3
         >>> # lr = 0.59049  if epoch >= 4
-        >>> scheduler1 = WarmUpLR(self.opt, warmup_factor=0.1, warmup_iters=2, warmup_method="constant")
+        >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
         >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
         >>> scheduler = ChainedScheduler([scheduler1, scheduler2])
         >>> for epoch in range(100):
index 821407e..9b1b8ea 100644 (file)
@@ -18,8 +18,11 @@ class StepLR(_LRScheduler):
 class MultiStepLR(_LRScheduler):
     def __init__(self, optimizer: Optimizer, milestones: Iterable[int], gamma: float=..., last_epoch: int=...) -> None: ...
 
-class WarmUpLR(_LRScheduler):
-    def __init__(self, optimizer: Optimizer, warmup_factor: float=..., warmup_iters: int=..., warmup_method: str=..., last_epoch: int=...) -> None: ...
+class ConstantLR(_LRScheduler):
+    def __init__(self, optimizer: Optimizer, factor: float=..., total_iters: int=..., last_epoch: int=...) -> None: ...
+
+class LinearLR(_LRScheduler):
+    def __init__(self, optimizer: Optimizer, start_factor: float=..., end_factor: float=..., total_iters: int=..., last_epoch: int=...) -> None: ...
 
 class ExponentialLR(_LRScheduler):
     def __init__(self, optimizer: Optimizer, gamma: float, last_epoch: int=...) -> None: ...