Adds Cyclical Learning Rate and Momentum (#18001)
authorSam Pepose <sampepose@fb.com>
Thu, 28 Mar 2019 02:47:43 +0000 (19:47 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 28 Mar 2019 02:56:04 +0000 (19:56 -0700)
Summary:
This implements a cyclical learning rate (CLR) schedule with an optional inverse cyclical momentum. More info about CLR: https://github.com/bckenstler/CLR

This is finishing what #2016 started. Resolves #1909.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18001

Differential Revision: D14451845

Pulled By: sampepose

fbshipit-source-id: 8f682e0c3dee3a73bd2b14cc93fcf5f0e836b8c9

docs/source/optim.rst
test/test_optim.py
torch/optim/lr_scheduler.py

index 9ba6c39..db8af0f 100644 (file)
@@ -145,3 +145,5 @@ allows dynamic learning rate reducing based on some validation measurements.
     :members:
 .. autoclass:: torch.optim.lr_scheduler.ReduceLROnPlateau
     :members:
+.. autoclass:: torch.optim.lr_scheduler.CyclicLR
+    :members:
index 82fee18..c6c51e4 100644 (file)
@@ -11,7 +11,8 @@ from torch.optim import SGD
 from torch.autograd import Variable
 from torch import sparse
 from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, \
-    ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, _LRScheduler
+    ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, _LRScheduler, \
+    CyclicLR
 from common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests
 
 # load_tests from common_utils is used to automatically filter tests for
@@ -790,6 +791,165 @@ class TestLRScheduler(TestCase):
         schedulers[1] = CosineAnnealingLR(self.opt, epochs, eta_min)
         self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
 
+    def test_cycle_lr_invalid_mode(self):
+        with self.assertRaises(ValueError):
+            scheduler = CyclicLR(self.opt, base_lr=0, max_lr=0, mode="CATS")
+
+    def test_cycle_lr_triangular_mode_one_lr(self):
+        lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
+        momentum_target = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3]
+        lr_targets = [lr_target, lr_target]
+        momentum_targets = [momentum_target, momentum_target]
+        scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4,
+                             cycle_momentum=True, base_momentum=1, max_momentum=5,
+                             mode='triangular')
+        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
+
+    def test_cycle_lr_triangular_mode_one_lr_no_momentum(self):
+        lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
+        lr_targets = [lr_target, lr_target]
+        momentum_target = [self.opt.defaults['momentum']] * len(lr_target)
+        momentum_targets = [momentum_target, momentum_target]
+        scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4,
+                             cycle_momentum=False, mode='triangular')
+        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
+
+    def test_cycle_lr_triangular2_mode_one_lr(self):
+        lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 1.5, 2.0, 2.5, 3.0, 2.5, 2.0, 1.5,
+                     1, 1.25, 1.50, 1.75, 2.00, 1.75]
+        momentum_target = [5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.5, 4.0,
+                           3.5, 3.0, 3.5, 4.0, 4.5, 5.0, 4.75, 4.5, 4.25, 4.0, 4.25]
+        lr_targets = [lr_target, lr_target]
+        momentum_targets = [momentum_target, momentum_target]
+        scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5, step_size_up=4,
+                             cycle_momentum=True, base_momentum=1, max_momentum=5,
+                             mode='triangular2')
+        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
+
+    def test_cycle_lr_exp_range_mode_one_lr(self):
+        base_lr, max_lr = 1, 5
+        diff_lr = max_lr - base_lr
+        gamma = 0.9
+        xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1]
+        lr_target = list(map(lambda x: base_lr + x[1] * diff_lr * gamma**x[0], enumerate(xs)))
+        momentum_target = list(map(lambda x: max_lr - x[1] * diff_lr * gamma**x[0], enumerate(xs)))
+        lr_targets = [lr_target, lr_target]
+        momentum_targets = [momentum_target, momentum_target]
+        scheduler = CyclicLR(self.opt, base_lr=base_lr,
+                             max_lr=max_lr, step_size_up=4,
+                             cycle_momentum=True, base_momentum=base_lr, max_momentum=max_lr,
+                             mode='exp_range', gamma=gamma)
+        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
+
+    def test_cycle_lr_triangular_mode(self):
+        lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
+        lr_target_2 = list(map(lambda x: x + 1, lr_target_1))
+        lr_targets = [lr_target_1, lr_target_2]
+        momentum_target_1 = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3]
+        momentum_target_2 = list(map(lambda x: x + 1, momentum_target_1))
+        momentum_targets = [momentum_target_1, momentum_target_2]
+        scheduler = CyclicLR(self.opt, base_lr=[1, 2], max_lr=[5, 6], step_size_up=4,
+                             cycle_momentum=True, base_momentum=[1, 2], max_momentum=[5, 6],
+                             mode='triangular')
+        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))
+
+    def test_cycle_lr_triangular2_mode(self):
+        lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 1.5, 2.0, 2.5, 3.0, 2.5, 2.0, 1.5, 1,
+                       1.25, 1.50, 1.75, 2.00, 1.75]
+        lr_target_2 = list(map(lambda x: x + 2, lr_target_1))
+        lr_targets = [lr_target_1, lr_target_2]
+        momentum_target_1 = [5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.5, 4.0, 3.5,
+                             3.0, 3.5, 4.0, 4.5, 5.0, 4.75, 4.5, 4.25, 4.0, 4.25]
+        momentum_target_2 = list(map(lambda x: x + 2, momentum_target_1))
+        momentum_targets = [momentum_target_1, momentum_target_2]
+        scheduler = CyclicLR(self.opt, base_lr=[1, 3], max_lr=[5, 7], step_size_up=4,
+                             cycle_momentum=True, base_momentum=[1, 3], max_momentum=[5, 7],
+                             mode='triangular2')
+        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))
+
+    def test_cycle_lr_exp_range_mode(self):
+        base_lr_1, max_lr_1 = 1, 5
+        base_lr_2, max_lr_2 = 5, 12
+
+        diff_lr_1 = max_lr_1 - base_lr_1
+        diff_lr_2 = max_lr_2 - base_lr_2
+
+        gamma = 0.9
+        xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1]
+        lr_target_1 = list(map(lambda x: base_lr_1 + x[1] * diff_lr_1 * gamma**x[0], enumerate(xs)))
+        lr_target_2 = list(map(lambda x: base_lr_2 + x[1] * diff_lr_2 * gamma**x[0], enumerate(xs)))
+        lr_targets = [lr_target_1, lr_target_2]
+        momentum_target_1 = list(map(lambda x: max_lr_1 - x[1] * diff_lr_1 * gamma**x[0], enumerate(xs)))
+        momentum_target_2 = list(map(lambda x: max_lr_2 - x[1] * diff_lr_2 * gamma**x[0], enumerate(xs)))
+        momentum_targets = [momentum_target_1, momentum_target_2]
+        scheduler = CyclicLR(self.opt, base_lr=[base_lr_1, base_lr_2],
+                             max_lr=[max_lr_1, max_lr_2], step_size_up=4,
+                             cycle_momentum=True, base_momentum=[base_lr_1, base_lr_2],
+                             max_momentum=[max_lr_1, max_lr_2],
+                             mode='exp_range', gamma=gamma)
+        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))
+
+    def test_cycle_lr_triangular_mode_step_size_up_down(self):
+        lr_target = [1.0, 2.0, 3.0, 4.0, 5.0, 13.0 / 3, 11.0 / 3, 9.0 / 3, 7.0 / 3, 5.0 / 3, 1.0]
+        lr_targets = [lr_target, lr_target]
+        momentum_target = [5.0, 4.0, 3.0, 2.0, 1.0, 5.0 / 3, 7.0 / 3, 3.0, 11.0 / 3, 13.0 / 3, 5.0]
+        momentum_targets = [momentum_target, momentum_target]
+
+        scheduler = CyclicLR(self.opt, base_lr=1, max_lr=5,
+                             step_size_up=4,
+                             step_size_down=6,
+                             cycle_momentum=True,
+                             base_momentum=1, max_momentum=5,
+                             mode='triangular')
+        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
+
+    def test_cycle_lr_triangular2_mode_step_size_up_down(self):
+        lr_base_target = ([
+            1.0, 3.0, 5.0, 13.0 / 3, 11.0 / 3, 9.0 / 3, 7.0 / 3, 5.0 / 3, 1.0, 2.0, 3.0, 8.0 / 3,
+            7.0 / 3, 6.0 / 3, 5.0 / 3, 4.0 / 3, 1.0, 3.0 / 2, 2.0, 11.0 / 6, 10.0 / 6, 9.0 / 6,
+            8.0 / 6, 7.0 / 6
+        ])
+        momentum_base_target = ([
+            5.0, 3.0, 1.0, 5.0 / 3, 7.0 / 3, 3.0, 11.0 / 3, 13.0 / 3, 5.0, 4.0, 3.0, 10.0 / 3,
+            11.0 / 3, 4.0, 13.0 / 3, 14.0 / 3, 5.0, 4.5, 4.0, 25.0 / 6, 13.0 / 3, 4.5, 14.0 / 3,
+            29.0 / 6
+        ])
+        deltas = [2 * i for i in range(0, 2)]
+        base_lrs = [1 + delta for delta in deltas]
+        max_lrs = [5 + delta for delta in deltas]
+        lr_targets = [[x + delta for x in lr_base_target] for delta in deltas]
+        momentum_targets = [[x + delta for x in momentum_base_target] for delta in deltas]
+        scheduler = CyclicLR(
+            self.opt,
+            base_lr=base_lrs,
+            max_lr=max_lrs,
+            step_size_up=2,
+            step_size_down=6,
+            cycle_momentum=True,
+            base_momentum=base_lrs,
+            max_momentum=max_lrs,
+            mode='triangular2')
+        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_base_target))
+
+    def test_cycle_lr_exp_range_mode_step_size_up_down(self):
+        base_lr, max_lr = 1, 5
+        diff_lr = max_lr - base_lr
+        gamma = 0.9
+        xs = ([
+            0.0, 0.5, 1.0, 5.0 / 6, 4.0 / 6, 3.0 / 6, 2.0 / 6, 1.0 / 6, 0.0, 0.5, 1.0, 5.0 / 6,
+            4.0 / 6
+        ])
+        lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)]
+        lr_targets = [lr_target, lr_target]
+        momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)]
+        momentum_targets = [momentum_target, momentum_target]
+        scheduler = CyclicLR(self.opt, base_lr=base_lr, max_lr=max_lr,
+                             step_size_up=2, step_size_down=6,
+                             cycle_momentum=True, base_momentum=base_lr,
+                             max_momentum=max_lr,
+                             mode='exp_range', gamma=gamma)
+        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))
+
     def test_lambda_lr(self):
         epochs = 10
         self.opt.param_groups[0]['lr'] = 0.05
@@ -905,5 +1065,21 @@ class TestLRScheduler(TestCase):
                                        msg='LR is wrong in epoch {}: expected {}, got {}'.format(
                                            epoch, target[epoch], param_group['lr']), delta=1e-5)
 
+    def _test_cycle_lr(self, scheduler, lr_targets, momentum_targets, batch_iterations, verbose=False):
+        for batch_num in range(batch_iterations):
+            scheduler.step(batch_num)
+            if verbose:
+                print('batch{}:\tlr={},momentum={}'.format(batch_num, self.opt.param_groups[0]['lr'],
+                                                           self.opt.param_groups[0]['momentum']))
+            for param_group, lr_target, momentum_target in zip(self.opt.param_groups, lr_targets, momentum_targets):
+                self.assertAlmostEqual(
+                    lr_target[batch_num], param_group['lr'],
+                    msg='LR is wrong in batch_num {}: expected {}, got {}'.format(
+                        batch_num, lr_target[batch_num], param_group['lr']), delta=1e-5)
+                self.assertAlmostEqual(
+                    momentum_target[batch_num], param_group['momentum'],
+                    msg='Momentum is wrong in batch_num {}: expected {}, got {}'.format(
+                        batch_num, momentum_target[batch_num], param_group['momentum']), delta=1e-5)
+
 if __name__ == '__main__':
     run_tests()
index 200e2c6..3650794 100644 (file)
@@ -4,6 +4,7 @@ import torch
 from torch._six import inf
 from collections import Counter
 from functools import partial
+
 from .optimizer import Optimizer
 
 
@@ -427,3 +428,216 @@ class ReduceLROnPlateau(object):
     def load_state_dict(self, state_dict):
         self.__dict__.update(state_dict)
         self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
+
+
+class CyclicLR(_LRScheduler):
+    """Sets the learning rate of each parameter group according to
+    cyclical learning rate policy (CLR). The policy cycles the learning
+    rate between two boundaries with a constant frequency, as detailed in
+    the paper `Cyclical Learning Rates for Training Neural Networks`_.
+    The distance between the two boundaries can be scaled on a per-iteration
+    or per-cycle basis.
+
+    Cyclical learning rate policy changes the learning rate after every batch.
+    `step` should be called after a batch has been used for training.
+
+    This class has three built-in policies, as put forth in the paper:
+    "triangular":
+        A basic triangular cycle w/ no amplitude scaling.
+    "triangular2":
+        A basic triangular cycle that scales initial amplitude by half each cycle.
+    "exp_range":
+        A cycle that scales initial amplitude by gamma**(cycle iterations) at each
+        cycle iteration.
+
+    This implementation was adapted from the github repo: `bckenstler/CLR`_
+
+    Args:
+        optimizer (Optimizer): Wrapped optimizer.
+        base_lr (float or list): Initial learning rate which is the
+            lower boundary in the cycle for each parameter group.
+        max_lr (float or list): Upper learning rate boundaries in the cycle
+            for each parameter group. Functionally,
+            it defines the cycle amplitude (max_lr - base_lr).
+            The lr at any cycle is the sum of base_lr
+            and some scaling of the amplitude; therefore
+            max_lr may not actually be reached depending on
+            scaling function.
+        step_size_up (int): Number of training iterations in the
+            increasing half of a cycle. Default: 2000
+        step_size_down (int): Number of training iterations in the
+            decreasing half of a cycle. If step_size_down is None,
+            it is set to step_size_up. Default: None
+        mode (str): One of {triangular, triangular2, exp_range}.
+            Values correspond to policies detailed above.
+            If scale_fn is not None, this argument is ignored.
+            Default: 'triangular'
+        gamma (float): Constant in 'exp_range' scaling function:
+            gamma**(cycle iterations)
+            Default: 1.0
+        scale_fn (function): Custom scaling policy defined by a single
+            argument lambda function, where
+            0 <= scale_fn(x) <= 1 for all x >= 0.
+            If specified, then 'mode' is ignored.
+            Default: None
+        scale_mode (str): {'cycle', 'iterations'}.
+            Defines whether scale_fn is evaluated on
+            cycle number or cycle iterations (training
+            iterations since start of cycle).
+            Default: 'cycle'
+        cycle_momentum (bool): If ``True``, momentum is cycled inversely
+            to learning rate between 'base_momentum' and 'max_momentum'.
+            Default: True
+        base_momentum (float or list): Initial momentum which is the
+            lower boundary in the cycle for each parameter group.
+            Default: 0.8
+        max_momentum (float or list): Upper momentum boundaries in the cycle
+            for each parameter group. Functionally,
+            it defines the cycle amplitude (max_momentum - base_momentum).
+            The momentum at any cycle is the difference of max_momentum
+            and some scaling of the amplitude; therefore
+            base_momentum may not actually be reached depending on
+            scaling function. Default: 0.9
+        last_epoch (int): The index of the last batch. This parameter is used when
+            resuming a training job. Since `step()` should be invoked after each
+            batch instead of after each epoch, this number represents the total
+            number of *batches* computed, not the total number of epochs computed.
+            When last_epoch=-1, the schedule is started from the beginning.
+            Default: -1
+
+    Example:
+        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
+        >>> scheduler = torch.optim.CyclicLR(optimizer)
+        >>> data_loader = torch.utils.data.DataLoader(...)
+        >>> for epoch in range(10):
+        >>>     for batch in data_loader:
+        >>>         train_batch(...)
+        >>>         scheduler.step()
+
+
+    .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
+    .. _bckenstler/CLR: https://github.com/bckenstler/CLR
+    """
+
+    def __init__(self,
+                 optimizer,
+                 base_lr,
+                 max_lr,
+                 step_size_up=2000,
+                 step_size_down=None,
+                 mode='triangular',
+                 gamma=1.,
+                 scale_fn=None,
+                 scale_mode='cycle',
+                 cycle_momentum=True,
+                 base_momentum=0.8,
+                 max_momentum=0.9,
+                 last_epoch=-1):
+
+        if not isinstance(optimizer, Optimizer):
+            raise TypeError('{} is not an Optimizer'.format(
+                type(optimizer).__name__))
+        self.optimizer = optimizer
+
+        base_lrs = self._format_param('base_lr', optimizer, base_lr)
+        if last_epoch == -1:
+            for lr, group in zip(base_lrs, optimizer.param_groups):
+                group['lr'] = lr
+
+        self.max_lrs = self._format_param('max_lr', optimizer, max_lr)
+
+        step_size_up = float(step_size_up)
+        step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
+        self.total_size = step_size_up + step_size_down
+        self.step_ratio = step_size_up / self.total_size
+
+        if mode not in ['triangular', 'triangular2', 'exp_range'] \
+                and scale_fn is None:
+            raise ValueError('mode is invalid and scale_fn is None')
+
+        self.mode = mode
+        self.gamma = gamma
+
+        if scale_fn is None:
+            if self.mode == 'triangular':
+                self.scale_fn = self._triangular_scale_fn
+                self.scale_mode = 'cycle'
+            elif self.mode == 'triangular2':
+                self.scale_fn = self._triangular2_scale_fn
+                self.scale_mode = 'cycle'
+            elif self.mode == 'exp_range':
+                self.scale_fn = self._exp_range_scale_fn
+                self.scale_mode = 'iterations'
+        else:
+            self.scale_fn = scale_fn
+            self.scale_mode = scale_mode
+
+        self.cycle_momentum = cycle_momentum
+        if cycle_momentum:
+            if 'momentum' not in optimizer.defaults:
+                raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
+
+            base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
+            if last_epoch == -1:
+                for momentum, group in zip(base_momentums, optimizer.param_groups):
+                    group['momentum'] = momentum
+        self.base_momentums = list(map(lambda group: group['momentum'], optimizer.param_groups))
+        self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
+
+        super(CyclicLR, self).__init__(optimizer, last_epoch)
+
+    def _format_param(self, name, optimizer, param):
+        """Return correctly formatted lr/momentum for each param group."""
+        if isinstance(param, (list, tuple)):
+            if len(param) != len(optimizer.param_groups):
+                raise ValueError("expected {} values for {}, got {}".format(
+                    len(optimizer.param_groups), name, len(param)))
+            return param
+        else:
+            return [param] * len(optimizer.param_groups)
+
+    def _triangular_scale_fn(self, x):
+        return 1.
+
+    def _triangular2_scale_fn(self, x):
+        return 1 / (2. ** (x - 1))
+
+    def _exp_range_scale_fn(self, x):
+        return self.gamma**(x)
+
+    def get_lr(self):
+        """Calculates the learning rate at batch index. This function treats
+        `self.last_epoch` as the last batch index.
+
+        If `self.cycle_momentum` is ``True``, this function has a side effect of
+        updating the optimizer's momentum.
+        """
+        cycle = math.floor(1 + self.last_epoch / self.total_size)
+        x = 1. + self.last_epoch / self.total_size - cycle
+        if x <= self.step_ratio:
+            scale_factor = x / self.step_ratio
+        else:
+            scale_factor = (x - 1) / (self.step_ratio - 1)
+
+        lrs = []
+        for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
+            base_height = (max_lr - base_lr) * scale_factor
+            if self.scale_mode == 'cycle':
+                lr = base_lr + base_height * self.scale_fn(cycle)
+            else:
+                lr = base_lr + base_height * self.scale_fn(self.last_epoch)
+            lrs.append(lr)
+
+        if self.cycle_momentum:
+            momentums = []
+            for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
+                base_height = (max_momentum - base_momentum) * scale_factor
+                if self.scale_mode == 'cycle':
+                    momentum = max_momentum - base_height * self.scale_fn(cycle)
+                else:
+                    momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
+                momentums.append(momentum)
+            for param_group, momentum in zip(self.optimizer.param_groups, momentums):
+                param_group['momentum'] = momentum
+
+        return lrs