From cec08e7032cca31c643e826a88733c95d066ace9 Mon Sep 17 00:00:00 2001 From: Ilqar Ramazanli Date: Sun, 15 Aug 2021 12:30:18 -0700 Subject: [PATCH] To add warm-up scheduler to optim (#60836) Summary: Warm up of learning rate scheduling has initially been discussed by Priya et. al. in the paper: https://arxiv.org/pdf/1706.02677.pdf . In the section 2.2 of the paper they discussed and proposed idea of warming up learning schedulers in order to prevent big variance / noise in the learning rate. Then idea has been further discussed in the following papers: * Akilesh Gotmare et al. https://arxiv.org/abs/1810.13243 * Bernstein et al http://proceedings.mlr.press/v80/bernstein18a/bernstein18a.pdf * Liyuan Liu et al: https://arxiv.org/pdf/1908.03265.pdf There are two type of popularly used learning rate warm up ideas * Constant warmup (start with very small constant learning rate) * Linear Warmup ( start with small learning rate and gradually increase) In this PR we are adding warm up as learning rate scheduler. Note that learning rates are chainable, which means that we can merge warmup scheduler with any other learning rate scheduler to make more sophisticated learning rate scheduler. ## Linear Warmup Linear Warmup is multiplying learning rate with pre-defined constant - warmup_factor in the first epoch (epoch 0). Then targeting to increase this multiplication constant to one in warmup_iters many epochs. Hence we can derive the formula at i-th step to have multiplication constant equal to: warmup_factor + (1-warmup_factor) * i / warmup_iters Moreover, the fraction of this quantity at point i to point i-1 will give us 1 + (1.0 - warmup_factor) / [warmup_iters*warmup_factor+(i-1)*(1-warmup_factor)] which is used in get_lr() method in our implementation. Below we provide an example how to use linear warmup scheduler and to give an example to show how does it works. ```python import torch from torch.nn import Parameter from torch.optim import SGD from torch.optim.lr_scheduler import WarmUpLR model = [Parameter(torch.randn(2, 2, requires_grad=True))] optimizer = SGD(model, 0.1) scheduler = WarmUpLR(optimizer, warmup_factor=0.1, warmup_iters=10, warmup_method="linear") for epoch in range(15): print(epoch, scheduler.get_last_lr()[0]) optimizer.step() scheduler.step() ``` ``` 0 0.010000000000000002 1 0.019000000000000003 2 0.028000000000000008 3 0.03700000000000001 4 0.04600000000000001 5 0.055000000000000014 6 0.06400000000000002 7 0.07300000000000002 8 0.08200000000000003 9 0.09100000000000004 10 0.10000000000000005 11 0.10000000000000005 12 0.10000000000000005 13 0.10000000000000005 14 0.10000000000000005 ``` ## Constant Warmup Constant warmup has straightforward idea, to multiply learning rate by warmup_factor until we reach to epoch warmup_factor, then do nothing for following epochs ```python import torch from torch.nn import Parameter from torch.optim import SGD from torch.optim.lr_scheduler import WarmUpLR model = [Parameter(torch.randn(2, 2, requires_grad=True))] optimizer = SGD(model, 0.1) scheduler = WarmUpLR(optimizer, warmup_factor=0.1, warmup_iters=5, warmup_method="constant") for epoch in range(10): print(epoch, scheduler.get_last_lr()[0]) optimizer.step() scheduler.step() ``` ``` 0 0.010000000000000002 1 0.010000000000000002 2 0.010000000000000002 3 0.010000000000000002 4 0.010000000000000002 5 0.10000000000000002 6 0.10000000000000002 7 0.10000000000000002 8 0.10000000000000002 9 0.10000000000000002 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/60836 Reviewed By: saketh-are Differential Revision: D29537615 Pulled By: iramazanli fbshipit-source-id: d910946027acc52663b301f9c56ade686e62cb69 --- docs/source/optim.rst | 1 + test/test_optim.py | 195 ++++++++++++++++++++++++++++++++++++++++++- torch/optim/lr_scheduler.py | 77 +++++++++++++++++ torch/optim/lr_scheduler.pyi | 3 + 4 files changed, 275 insertions(+), 1 deletion(-) diff --git a/docs/source/optim.rst b/docs/source/optim.rst index b7252f0..2ded57f 100644 --- a/docs/source/optim.rst +++ b/docs/source/optim.rst @@ -210,6 +210,7 @@ algorithms. lr_scheduler.MultiplicativeLR lr_scheduler.StepLR lr_scheduler.MultiStepLR + lr_scheduler.WarmUpLR lr_scheduler.ExponentialLR lr_scheduler.CosineAnnealingLR lr_scheduler.ReduceLROnPlateau diff --git a/test/test_optim.py b/test/test_optim.py index a3c1374..20b8e5c 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -12,7 +12,7 @@ from torch.optim import SGD from torch.autograd import Variable from torch import sparse from torch.optim.lr_scheduler import LambdaLR, MultiplicativeLR, StepLR, \ - MultiStepLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, \ + MultiStepLR, WarmUpLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, \ _LRScheduler, CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR from torch.optim.swa_utils import AveragedModel, SWALR, update_bn from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests, \ @@ -274,6 +274,19 @@ class TestOptim(TestCase): ) self._test_basic_cases( lambda weight, bias: optimizer([weight, bias], lr=1e-3), + [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="linear")] + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3), + [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant")] + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3), + [lambda opt: StepLR(opt, gamma=0.9, step_size=10), + lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4)] + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3), [lambda opt: StepLR(opt, gamma=0.9, step_size=10), lambda opt: ReduceLROnPlateau(opt)] ) @@ -414,6 +427,23 @@ class TestOptim(TestCase): [lambda opt: ExponentialLR(opt, gamma=0.9)] ) self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-2), + lr=1e-3), + [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="linear")] + ) + self._test_basic_cases( + lambda weight, bias: optimizer( + self._build_params_dict(weight, bias, lr=1e-2), + lr=1e-3), + [lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant")] + ) + self._test_basic_cases( + lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True), + [lambda opt: ExponentialLR(opt, gamma=0.9), + lambda opt: WarmUpLR(opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant")] + ) + self._test_basic_cases( lambda weight, bias: optimizer([weight, bias], lr=1e-3, amsgrad=True), [lambda opt: ExponentialLR(opt, gamma=0.9), lambda opt: ReduceLROnPlateau(opt)] @@ -962,6 +992,14 @@ class TestLRScheduler(TestCase): scheduler = ExponentialLR(self.opt, gamma=0.9) self._test_lr_is_constant_for_constant_epoch(scheduler) + def test_constant_warmup_lr_is_constant_for_constant_epoch(self): + scheduler = WarmUpLR(self.opt, warmup_method="constant") + self._test_lr_is_constant_for_constant_epoch(scheduler) + + def test_linear_warmup_lr_is_constant_for_constant_epoch(self): + scheduler = WarmUpLR(self.opt, warmup_method="linear") + self._test_lr_is_constant_for_constant_epoch(scheduler) + def test_step_lr(self): # lr = 0.05 if epoch < 3 # lr = 0.005 if 30 <= epoch < 6 @@ -1013,6 +1051,78 @@ class TestLRScheduler(TestCase): scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) self._test_with_epoch(scheduler, targets, epochs) + def test__get_last_lr_constant_warmup_lr(self): + # lr = 0.025 if epoch < 5 + # lr = 0.005 if 5 <= epoch + epochs = 10 + single_targets = [0.025] * 5 + [0.05] * 5 + targets = [single_targets, [x * epochs for x in single_targets]] + scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 2, warmup_iters=5, warmup_method="constant") + self._test_get_last_lr(scheduler, targets, epochs) + + def test__get_last_lr_linear_warmup_lr(self): + # lr = 0.025 if epoch == 0 + # lr = 0.03125 if epoch == 1 + # lr = 0.0375 if epoch == 2 + # lr = 0.04375 if epoch == 3 + # lr = 0.005 if 4 <= epoch + epochs = 10 + factor = 1.0 / 2 + iters = 4 + interpolation = [factor + i * (1 - factor) / iters for i in range(iters)] + single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) + targets = [single_targets, [x * epochs for x in single_targets]] + scheduler = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear") + self._test_get_last_lr(scheduler, targets, epochs) + + def test__constant_warmup_lr(self): + # lr = 0.025 if epoch < 5 + # lr = 0.005 if 5 <= epoch + epochs = 10 + single_targets = [0.025] * 5 + [0.05] * 5 + targets = [single_targets, [x * epochs for x in single_targets]] + scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 2, warmup_iters=5, warmup_method="constant") + self._test(scheduler, targets, epochs) + + def test__linear_warmup_lr(self): + # lr = 0.025 if epoch == 0 + # lr = 0.03125 if epoch == 1 + # lr = 0.0375 if epoch == 2 + # lr = 0.04375 if epoch == 3 + # lr = 0.005 if 4 <= epoch + epochs = 10 + factor = 1.0 / 2 + iters = 4 + interpolation = [factor + i * (1 - factor) / iters for i in range(iters)] + single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) + targets = [single_targets, [x * epochs for x in single_targets]] + scheduler = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear") + self._test(scheduler, targets, epochs) + + def test_constant_warmup_with_epoch(self): + # lr = 0.025 if epoch < 5 + # lr = 0.005 if 5 <= epoch + epochs = 10 + single_targets = [0.025] * 5 + [0.05] * 5 + targets = [single_targets, [x * epochs for x in single_targets]] + scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 2, warmup_iters=5, warmup_method="constant") + self._test_with_epoch(scheduler, targets, epochs) + + def test_linear_warmup_with_epoch(self): + # lr = 0.025 if epoch == 0 + # lr = 0.03125 if epoch == 1 + # lr = 0.0375 if epoch == 2 + # lr = 0.04375 if epoch == 3 + # lr = 0.005 if 4 <= epoch + epochs = 10 + factor = 1.0 / 2 + iters = 4 + interpolation = [factor + i * (1 - factor) / iters for i in range(iters)] + single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters) + targets = [single_targets, [x * epochs for x in single_targets]] + scheduler = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear") + self._test_with_epoch(scheduler, targets, epochs) + def test_exp_lr(self): epochs = 10 single_targets = [0.05 * (0.9 ** x) for x in range(epochs)] @@ -1035,6 +1145,16 @@ class TestLRScheduler(TestCase): closed_form_scheduler = StepLR(self.opt, gamma=0.1, step_size=3) self._test_against_closed_form(scheduler, closed_form_scheduler, 20) + def test_closed_form_linear_warmup_lr(self): + scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 3, warmup_iters=4, warmup_method="linear") + closed_form_scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 3, warmup_iters=4, warmup_method="linear") + self._test_against_closed_form(scheduler, closed_form_scheduler, 20) + + def test_closed_form_constant_warmup_lr(self): + scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 3, warmup_iters=4, warmup_method="constant") + closed_form_scheduler = WarmUpLR(self.opt, warmup_factor=1.0 / 3, warmup_iters=4, warmup_method="constant") + self._test_against_closed_form(scheduler, closed_form_scheduler, 20) + def test_closed_form_multi_step_lr(self): scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) closed_form_scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) @@ -1165,6 +1285,43 @@ class TestLRScheduler(TestCase): schedulers[1] = ExponentialLR(self.opt, gamma=0.9) self._test(schedulers, targets, epochs) + def test_compound_exp_and_linear_warmup_lr(self): + epochs = 10 + iters = 4 + factor = 0.4 + schedulers = [None] * 2 + single_targets = [0.05 * (0.9 ** x) for x in range(11)] + for i in range(iters): + single_targets[i] *= factor + i / iters * (1 - factor) + targets = [single_targets, [x * epochs for x in single_targets]] + schedulers[0] = ExponentialLR(self.opt, gamma=0.9) + schedulers[1] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear") + self._test(schedulers, targets, epochs) + + def test_compound_step_and_constant_warmup(self): + epochs = 10 + iters = 4 + factor = 0.4 + schedulers = [None] * 2 + single_targets = [0.05 * 0.4] * 3 + [0.005 * 0.4] + [0.005] * 2 + [0.0005] * 3 + [0.00005] * 3 + targets = [single_targets, [x * epochs for x in single_targets]] + schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3) + schedulers[1] = WarmUpLR(self.opt, warmup_factor=0.4, warmup_iters=4, warmup_method="constant") + self._test(schedulers, targets, epochs) + + def test_compound_linear_warmup_and_multistep_lr(self): + epochs = 10 + iters = 4 + factor = 0.4 + schedulers = [None] * 2 + single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 2 + for i in range(iters): + single_targets[i] *= factor + i / iters * (1 - factor) + targets = [single_targets, [x * epochs for x in single_targets]] + schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) + schedulers[1] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear") + self._test(schedulers, targets, epochs) + def test_compound_cosanneal_and_step_lr(self): epochs = 10 eta_min = 1e-10 @@ -1192,6 +1349,22 @@ class TestLRScheduler(TestCase): schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]) self._test(schedulers, targets, epochs) + def test_compound_cosanneal_and_linear_warmup_lr(self): + epochs = 10 + iters = 4 + factor = 0.4 + eta_min = 1e-10 + schedulers = [None] * 2 + single_targets = [eta_min + (0.05 - eta_min) * + (1 + math.cos(math.pi * x / epochs)) / 2 + for x in range(epochs)] + for i in range(iters): + single_targets[i] *= factor + i / iters * (1 - factor) + targets = [single_targets, [x * epochs for x in single_targets]] + schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min) + schedulers[1] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear") + self._test(schedulers, targets, epochs) + def test_compound_cosanneal_and_exp_lr(self): epochs = 10 eta_min = 1e-10 @@ -1272,6 +1445,26 @@ class TestLRScheduler(TestCase): schedulers[1] = CosineAnnealingLR(self.opt, epochs, eta_min) self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) + def test_compound_reduce_lr_on_plateau5(self): + iters = 4 + factor = 0.4 + epochs = 22 + for param_group in self.opt.param_groups: + param_group['lr'] = 0.5 + single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2 + multipliers = [1] * 22 + for i in range(iters): + multipliers[i] *= factor + i / iters * (1 - factor) + single_targets = [x * y for x, y in zip(single_targets, multipliers)] + targets = [single_targets] + targets = targets[1:] # test runs step before checking lr + metrics = [10 - i * 0.0165 for i in range(22)] + schedulers = [None] * 2 + schedulers[0] = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs', + mode='min', threshold=0.1) + schedulers[1] = WarmUpLR(self.opt, warmup_factor=factor, warmup_iters=iters, warmup_method="linear") + self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs) + def test_cycle_lr_invalid_mode(self): with self.assertRaises(ValueError): scheduler = CyclicLR(self.opt, base_lr=0, max_lr=0, mode="CATS") diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 043a821..78a8cfa 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -427,6 +427,83 @@ class MultiStepLR(_LRScheduler): for base_lr in self.base_lrs] +class WarmUpLR(_LRScheduler): + """Decays the learning rate of each parameter group by either a small constant + or linearly increasing small warmup factor until the number of epoch reaches a + pre-defined milestone: warmup_iters. Notice that such decay can happen + simultaneously with other changes to the learning rate from outside this scheduler. + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup_factor (float): The number we multiply learning rate in the first epoch. + If the warming up method is constant, the multiplication factor of the + learning rate stays the same in all epochs, but, in the linear case, it + starts increasing in the following epochs. Default: 1./3. + warmup_iters (int): The number of warming up steps. Default: 5. + warmup_method (str): One of `constant` and `linear`. In `constant` mode, the + learning rate will be multiplied with a small constant until a milestone + defined in warmup_iters. In the `linear` case, the multiplication factor + starts with warmup_factor in the first epoch then linearly increases to + reach 1. in the epoch number warmup_iters. Default: `linear`. + last_epoch (int): The index of the last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.03125 if epoch == 1 + >>> # lr = 0.0375 if epoch == 2 + >>> # lr = 0.04375 if epoch == 3 + >>> # lr = 0.005 if epoch >= 4 + >>> scheduler = WarmUpLR(self.opt, warmup_factor=0.5, warmup_iters=4, warmup_method="linear") + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, warmup_factor=1.0 / 3, warmup_iters=5, warmup_method="linear", + last_epoch=-1, verbose=False): + if warmup_method not in ("constant", "linear"): + raise ValueError( + "Only 'constant' or 'linear' warmup_method accepted, but " + "got {}".format(warmup_method) + ) + self.warmup_factor = warmup_factor + self.warmup_iters = warmup_iters + self.warmup_method = warmup_method + super(WarmUpLR, self).__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == 0: + return [group['lr'] * self.warmup_factor for group in self.optimizer.param_groups] + + if (self.last_epoch > self.warmup_iters or + (self.warmup_method == "constant" and self.last_epoch != self.warmup_iters)): + return [group['lr'] for group in self.optimizer.param_groups] + + if (self.warmup_method == "constant" and self.last_epoch == self.warmup_iters): + return [group['lr'] * (1.0 / self.warmup_factor) for group in self.optimizer.param_groups] + + return [group['lr'] * (1. + (1.0 - self.warmup_factor) / + (self.warmup_iters * self.warmup_factor + (self.last_epoch - 1) * (1 - self.warmup_factor))) + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * (self.warmup_factor + + (1 - self.warmup_factor) * min(self.warmup_iters, self.last_epoch) / + self.warmup_iters * (self.warmup_method == "linear") + + (self.last_epoch >= self.warmup_iters) * (1 - self.warmup_factor) * + (self.warmup_method == "constant")) + for base_lr in self.base_lrs] + + class ExponentialLR(_LRScheduler): """Decays the learning rate of each parameter group by gamma every epoch. When last_epoch=-1, sets initial lr as lr. diff --git a/torch/optim/lr_scheduler.pyi b/torch/optim/lr_scheduler.pyi index 1c49c6e..821407e 100644 --- a/torch/optim/lr_scheduler.pyi +++ b/torch/optim/lr_scheduler.pyi @@ -18,6 +18,9 @@ class StepLR(_LRScheduler): class MultiStepLR(_LRScheduler): def __init__(self, optimizer: Optimizer, milestones: Iterable[int], gamma: float=..., last_epoch: int=...) -> None: ... +class WarmUpLR(_LRScheduler): + def __init__(self, optimizer: Optimizer, warmup_factor: float=..., warmup_iters: int=..., warmup_method: str=..., last_epoch: int=...) -> None: ... + class ExponentialLR(_LRScheduler): def __init__(self, optimizer: Optimizer, gamma: float, last_epoch: int=...) -> None: ... -- 2.7.4