Redefine scheduler to set learning rate using recursive formula (#14010)
authorChandler Zuo <chandlerzuo@fb.com>
Wed, 19 Dec 2018 00:40:23 +0000 (16:40 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 19 Dec 2018 00:44:31 +0000 (16:44 -0800)
Summary:
Modified step_lr for StepLR, MultiStepLR, ExponentialLR and CosineAnnealingLR. In this way, multiple schedulers can be used simultaneously to modify the learning rates.

Related issue: https://github.com/pytorch/pytorch/issues/13022

Added unit tests combining multiple schedulers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14010

Reviewed By: ezyang

Differential Revision: D13494941

Pulled By: chandlerzuo

fbshipit-source-id: 7561270245639ba1f2c00748f8e4a5f7dec7160c

test/test_optim.py
torch/optim/lr_scheduler.py

index 1f73996..f21ce9e 100644 (file)
@@ -2,6 +2,7 @@ import math
 import unittest
 import functools
 from copy import deepcopy
+from bisect import bisect_right
 import torch
 from torch._six import inf
 import torch.optim as optim
@@ -9,7 +10,8 @@ import torch.nn.functional as F
 from torch.optim import SGD
 from torch.autograd import Variable
 from torch import sparse
-from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau
+from torch.optim.lr_scheduler import LambdaLR, StepLR, MultiStepLR, \
+    ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau, _LRScheduler
 from common_utils import TestCase, run_tests, TEST_WITH_UBSAN, skipIfRocm, load_tests
 
 # load_tests from common_utils is used to automatically filter tests for
@@ -28,11 +30,17 @@ def drosenbrock(tensor):
 
 
 class TestOptim(TestCase):
-    def _test_rosenbrock_sparse(self, constructor, sparse_only=False):
+    def _test_rosenbrock_sparse(self, constructor, scheduler_constructors=None,
+                                sparse_only=False):
+        if scheduler_constructors is None:
+            scheduler_constructors = []
         params_t = torch.Tensor([1.5, 1.5])
 
         params = Variable(params_t, requires_grad=True)
         optimizer = constructor([params])
+        schedulers = []
+        for scheduler_constructor in scheduler_constructors:
+            schedulers.append(scheduler_constructor(optimizer))
 
         if not sparse_only:
             params_c = Variable(params_t.clone(), requires_grad=True)
@@ -68,17 +76,25 @@ class TestOptim(TestCase):
             # Do cyclic coordinate descent
             w = i % 2
             optimizer.step(functools.partial(eval, params, True, w))
+            for scheduler in schedulers:
+                if isinstance(scheduler, ReduceLROnPlateau):
+                    scheduler.step(rosenbrock(params))
+                else:
+                    scheduler.step()
             if not sparse_only:
                 optimizer_c.step(functools.partial(eval, params_c, False, w))
                 self.assertEqual(params.data, params_c.data)
 
         self.assertLessEqual(params.data.dist(solution), initial_dist)
 
-    def _test_basic_cases_template(self, weight, bias, input, constructor):
+    def _test_basic_cases_template(self, weight, bias, input, constructor, scheduler_constructors):
         weight = Variable(weight, requires_grad=True)
         bias = Variable(bias, requires_grad=True)
         input = Variable(input)
         optimizer = constructor(weight, bias)
+        schedulers = []
+        for scheduler_constructor in scheduler_constructors:
+            schedulers.append(scheduler_constructor(optimizer))
 
         # to check if the optimizer can be printed as a string
         optimizer.__repr__()
@@ -93,7 +109,13 @@ class TestOptim(TestCase):
             return loss
 
         initial_value = fn().item()
-        for i in range(200):
+        for _i in range(200):
+            for scheduler in schedulers:
+                if isinstance(scheduler, ReduceLROnPlateau):
+                    val_loss = fn()
+                    scheduler.step(val_loss)
+                else:
+                    scheduler.step()
             optimizer.step(fn)
         self.assertLess(fn().item(), initial_value)
 
@@ -113,7 +135,7 @@ class TestOptim(TestCase):
         fn = functools.partial(fn_base, optimizer, weight, bias)
 
         # Prime the optimizer
-        for i in range(20):
+        for _i in range(20):
             optimizer.step(fn)
         # Clone the weights and construct new optimizer for them
         weight_c = Variable(weight.data.clone(), requires_grad=True)
@@ -125,7 +147,7 @@ class TestOptim(TestCase):
         state_dict_c = deepcopy(optimizer.state_dict())
         optimizer_c.load_state_dict(state_dict_c)
         # Run both optimizations in parallel
-        for i in range(20):
+        for _i in range(20):
             optimizer.step(fn)
             optimizer_c.step(fn_c)
             self.assertEqual(weight, weight_c)
@@ -151,13 +173,16 @@ class TestOptim(TestCase):
         # Make sure state dict wasn't modified
         self.assertEqual(state_dict, state_dict_c)
 
-        for i in range(20):
+        for _i in range(20):
             optimizer.step(fn)
             optimizer_cuda.step(fn_cuda)
             self.assertEqual(weight, weight_cuda)
             self.assertEqual(bias, bias_cuda)
 
-    def _test_basic_cases(self, constructor, ignore_multidevice=False):
+    def _test_basic_cases(self, constructor, scheduler_constructors=None,
+                          ignore_multidevice=False):
+        if scheduler_constructors is None:
+            scheduler_constructors = []
         self._test_state_dict(
             torch.randn(10, 5),
             torch.randn(10),
@@ -168,14 +193,16 @@ class TestOptim(TestCase):
             torch.randn(10, 5),
             torch.randn(10),
             torch.randn(5),
-            constructor
+            constructor,
+            scheduler_constructors
         )
         # non-contiguous parameters
         self._test_basic_cases_template(
             torch.randn(10, 5, 2)[..., 0],
             torch.randn(10, 2)[..., 0],
             torch.randn(5),
-            constructor
+            constructor,
+            scheduler_constructors
         )
         # CUDA
         if not torch.cuda.is_available():
@@ -184,7 +211,8 @@ class TestOptim(TestCase):
             torch.randn(10, 5).cuda(),
             torch.randn(10).cuda(),
             torch.randn(5).cuda(),
-            constructor
+            constructor,
+            scheduler_constructors
         )
         # Multi-GPU
         if not torch.cuda.device_count() > 1 or ignore_multidevice:
@@ -193,11 +221,12 @@ class TestOptim(TestCase):
             torch.randn(10, 5).cuda(0),
             torch.randn(10).cuda(1),
             torch.randn(5).cuda(0),
-            constructor
+            constructor,
+            scheduler_constructors
         )
 
     def _build_params_dict(self, weight, bias, **kwargs):
-        return [dict(params=[weight]), dict(params=[bias], **kwargs)]
+        return [{'params': [weight]}, dict(params=[bias], **kwargs)]
 
     def _build_params_dict_single(self, weight, bias, **kwargs):
         return [dict(params=bias, **kwargs)]
@@ -220,6 +249,21 @@ class TestOptim(TestCase):
             lambda weight, bias: optim.SGD(
                 self._build_params_dict_single(weight, bias, lr=1e-2))
         )
+        self._test_basic_cases(
+            lambda weight, bias: optim.SGD([weight, bias], lr=1e-3),
+            [lambda opt: StepLR(opt, gamma=0.9, step_size=10)]
+        )
+        self._test_basic_cases(
+            lambda weight, bias: optim.SGD([weight, bias], lr=1e-3),
+            [lambda opt: StepLR(opt, gamma=0.9, step_size=10),
+             lambda opt: ReduceLROnPlateau(opt)]
+        )
+        self._test_basic_cases(
+            lambda weight, bias: optim.SGD([weight, bias], lr=1e-3),
+            [lambda opt: StepLR(opt, gamma=0.99, step_size=10),
+             lambda opt: ExponentialLR(opt, gamma=0.99),
+             lambda opt: ReduceLROnPlateau(opt)]
+        )
         with self.assertRaisesRegex(ValueError, "Invalid momentum value: -0.5"):
             optim.SGD(None, lr=1e-2, momentum=-0.5)
 
@@ -227,6 +271,10 @@ class TestOptim(TestCase):
         self._test_rosenbrock_sparse(
             lambda params: optim.SGD(params, lr=5e-3)
         )
+        self._test_rosenbrock_sparse(
+            lambda params: optim.SGD(params, lr=0.005),
+            [lambda opt: StepLR(opt, gamma=0.99999, step_size=300)]
+        )
 
     def test_adam(self):
         self._test_basic_cases(
@@ -246,12 +294,32 @@ class TestOptim(TestCase):
                 self._build_params_dict(weight, bias, lr=1e-2),
                 lr=1e-3, amsgrad=True)
         )
+        self._test_basic_cases(
+            lambda weight, bias: optim.Adam(
+                self._build_params_dict(weight, bias, lr=1e-2),
+                lr=1e-3),
+            [lambda opt: ExponentialLR(opt, gamma=0.9)]
+        )
+        self._test_basic_cases(
+            lambda weight, bias: optim.Adam([weight, bias], lr=1e-3,
+                                            amsgrad=True),
+            [lambda opt: ExponentialLR(opt, gamma=0.9),
+             lambda opt: ReduceLROnPlateau(opt)]
+        )
+        self._test_basic_cases(
+            lambda weight, bias: optim.Adam(
+                self._build_params_dict(weight, bias, lr=1e-2),
+                lr=1e-3, amsgrad=True),
+            [lambda opt: StepLR(opt, gamma=0.9, step_size=10),
+             lambda opt: ReduceLROnPlateau(opt)]
+        )
         with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"):
             optim.Adam(None, lr=1e-2, betas=(1.0, 0.0))
 
     def test_sparse_adam(self):
         self._test_rosenbrock_sparse(
             lambda params: optim.SparseAdam(params, lr=4e-2),
+            [],
             True
         )
         with self.assertRaisesRegex(ValueError, "Invalid beta parameter at index 0: 1.0"):
@@ -265,6 +333,12 @@ class TestOptim(TestCase):
             lambda weight, bias: optim.Adadelta(
                 self._build_params_dict(weight, bias, rho=0.95))
         )
+        self._test_basic_cases(
+            lambda weight, bias: optim.Adadelta(
+                self._build_params_dict(weight, bias, rho=0.95)),
+            [lambda opt: StepLR(opt, gamma=0.9, step_size=10),
+             lambda opt: ReduceLROnPlateau(opt)]
+        )
         with self.assertRaisesRegex(ValueError, "Invalid rho value: 1.1"):
             optim.Adadelta(None, lr=1e-2, rho=1.1)
 
@@ -281,6 +355,19 @@ class TestOptim(TestCase):
                 self._build_params_dict(weight, bias, lr=1e-2),
                 lr=1e-1)
         )
+        self._test_basic_cases(
+            lambda weight, bias: optim.Adagrad(
+                self._build_params_dict(weight, bias, lr=1e-2),
+                lr=1e-1),
+            [lambda opt: ReduceLROnPlateau(opt)]
+        )
+        self._test_basic_cases(
+            lambda weight, bias: optim.Adagrad(
+                self._build_params_dict(weight, bias, lr=1e-2),
+                lr=1e-1),
+            [lambda opt: ReduceLROnPlateau(opt),
+             lambda opt: ExponentialLR(opt, gamma=0.99)]
+        )
         with self.assertRaisesRegex(ValueError, "Invalid lr_decay value: -0.5"):
             optim.Adagrad(None, lr=1e-2, lr_decay=-0.5)
 
@@ -288,6 +375,11 @@ class TestOptim(TestCase):
         self._test_rosenbrock_sparse(
             lambda params: optim.Adagrad(params, lr=1e-1)
         )
+        self._test_rosenbrock_sparse(
+            lambda params: optim.Adagrad(params, lr=0.1),
+            [lambda opt: StepLR(opt, gamma=1 - 1e-5, step_size=500),
+             lambda opt: ReduceLROnPlateau(opt, threshold=1e-4)]
+        )
 
     @skipIfRocm
     def test_adamax(self):
@@ -387,6 +479,36 @@ class LambdaLRTestObject:
             return False
 
 
+class LegacyStepLR(StepLR):
+    def get_lr(self):
+        return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
+                for base_lr in self.base_lrs]
+
+
+class LegacyMultiStepLR(MultiStepLR):
+    def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1):
+        self.milestones = sorted(milestones)
+        self.gamma = gamma
+        super(MultiStepLR, self).__init__(optimizer, last_epoch)
+
+    def get_lr(self):
+        return [base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch)
+                for base_lr in self.base_lrs]
+
+
+class LegacyExponentialLR(ExponentialLR):
+    def get_lr(self):
+        return [base_lr * self.gamma ** self.last_epoch
+                for base_lr in self.base_lrs]
+
+
+class LegacyCosineAnnealingLR(CosineAnnealingLR):
+    def get_lr(self):
+        return [self.eta_min + (base_lr - self.eta_min) *
+                (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
+                for base_lr in self.base_lrs]
+
+
 class TestLRScheduler(TestCase):
     def setUp(self):
         self.net = SchedulerTestNet()
@@ -432,6 +554,28 @@ class TestLRScheduler(TestCase):
         scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
         self._test(scheduler, targets, epochs)
 
+    def test_legacy_step_lr(self):
+        scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
+        legacy_scheduler = LegacyStepLR(self.opt, gamma=0.1, step_size=3)
+        self._test_against_legacy(scheduler, legacy_scheduler, 20)
+
+    def test_legacy_multi_step_lr(self):
+        scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
+        legacy_scheduler = LegacyMultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
+        self._test_against_legacy(scheduler, legacy_scheduler, 20)
+
+    def test_legacy_exp_lr(self):
+        scheduler = ExponentialLR(self.opt, gamma=0.9)
+        legacy_scheduler = LegacyExponentialLR(self.opt, gamma=0.9)
+        self._test_against_legacy(scheduler, legacy_scheduler, 20)
+
+    def test_legacy_cos_anneal_lr(self):
+        eta_min = 1e-10
+        epochs = 20
+        scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
+        legacy_scheduler = LegacyCosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
+        self._test_against_legacy(scheduler, legacy_scheduler, epochs)
+
     def test_reduce_lr_on_plateau1(self):
         epochs = 10
         for param_group in self.opt.param_groups:
@@ -512,6 +656,141 @@ class TestLRScheduler(TestCase):
                                       threshold=0.1, patience=5, cooldown=5)
         self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)
 
+    def test_compound_step_and_multistep_lr(self):
+        epochs = 10
+        schedulers = [None] * 2
+        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
+        schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
+        targets = [[0.05] * 2 + [0.005] * 1 + [5e-4] * 2 + [5e-5] + [5e-6] * 3 + [5e-8]]
+        self._test(schedulers, targets, epochs)
+
+    def test_compound_step_and_exp_lr(self):
+        epochs = 10
+        schedulers = [None] * 2
+        single_targets = [0.05 * (0.9 ** x) for x in range(3)]
+        single_targets += [0.005 * (0.9 ** x) for x in range(3, 6)]
+        single_targets += [0.0005 * (0.9 ** x) for x in range(6, 9)]
+        single_targets += [0.00005 * (0.9 ** x) for x in range(9, 12)]
+        targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
+        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
+        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
+        self._test(schedulers, targets, epochs)
+
+    def test_compound_exp_and_multistep_lr(self):
+        epochs = 10
+        schedulers = [None] * 2
+        single_targets = [0.05 * (0.9 ** x) for x in range(2)]
+        single_targets += [0.005 * (0.9 ** x) for x in range(2, 5)]
+        single_targets += [0.0005 * (0.9 ** x) for x in range(5, 9)]
+        single_targets += [0.00005 * (0.9 ** x) for x in range(9, 11)]
+        targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
+        schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
+        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
+        self._test(schedulers, targets, epochs)
+
+    def test_compound_cosanneal_and_step_lr(self):
+        epochs = 10
+        eta_min = 1e-10
+        single_targets = [eta_min + (0.05 - eta_min) *
+                          (1 + math.cos(math.pi * x / epochs)) / 2
+                          for x in range(epochs)]
+        single_targets = [x * 0.1 ** (i // 3) for i, x in enumerate(single_targets)]
+        targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
+        schedulers = [None] * 2
+        schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
+        schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3)
+        self._test(schedulers, targets, epochs)
+
+    def test_compound_cosanneal_and_multistep_lr(self):
+        epochs = 10
+        eta_min = 1e-10
+        single_targets = [eta_min + (0.05 - eta_min) *
+                          (1 + math.cos(math.pi * x / epochs)) / 2
+                          for x in range(epochs)]
+        multipliers = [1] * 2 + [0.1] * 3 + [0.01] * 4 + [0.001]
+        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
+        targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
+        schedulers = [None] * 2
+        schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
+        schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
+        self._test(schedulers, targets, epochs)
+
+    def test_compound_cosanneal_and_exp_lr(self):
+        epochs = 10
+        eta_min = 1e-10
+        single_targets = [eta_min + (0.05 - eta_min) *
+                          (1 + math.cos(math.pi * x / epochs)) / 2
+                          for x in range(epochs)]
+        multipliers = [0.1 ** i for i in range(epochs)]
+        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
+        targets = [single_targets, list(map(lambda x: x * epochs, single_targets))]
+        schedulers = [None] * 2
+        schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
+        schedulers[1] = ExponentialLR(self.opt, gamma=0.1)
+        self._test(schedulers, targets, epochs)
+
+    def test_compound_reduce_lr_on_plateau1(self):
+        epochs = 10
+        for param_group in self.opt.param_groups:
+            param_group['lr'] = 0.5
+        single_targets = [0.5] * 20
+        multipliers = [0.1 ** (i // 3) for i in range(20)]
+        single_targets = [x * y for x, y in zip(multipliers, single_targets)]
+        targets = [single_targets]
+        metrics = [10 - i * 0.0167 for i in range(20)]
+        schedulers = [None, None]
+        schedulers[0] = ReduceLROnPlateau(self.opt, threshold_mode='abs', mode='min',
+                                          threshold=0.01, patience=5, cooldown=5)
+        schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3)
+        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
+
+    def test_compound_reduce_lr_on_plateau2(self):
+        epochs = 22
+        for param_group in self.opt.param_groups:
+            param_group['lr'] = 0.5
+        single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2
+        multipliers = [1] * 3 + [0.1] * 5 + [0.01] * 4 + [0.001] * 10
+        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
+        targets = [single_targets]
+        metrics = [10 - i * 0.0165 for i in range(22)]
+        schedulers = [None] * 2
+        schedulers[0] = ReduceLROnPlateau(self.opt, patience=5, cooldown=0, threshold_mode='abs',
+                                          mode='min', threshold=0.1)
+        schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[3, 8, 12])
+        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
+
+    def test_compound_reduce_lr_on_plateau3(self):
+        epochs = 22
+        for param_group in self.opt.param_groups:
+            param_group['lr'] = 0.5
+        single_targets = [0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4
+        multipliers = [0.1 ** i for i in range(epochs)]
+        single_targets = [x * y for x, y in zip(multipliers, single_targets)]
+        targets = [single_targets]
+        metrics = [-0.8] * 2 + [-0.234] * 20
+        schedulers = [None, None]
+        schedulers[0] = ReduceLROnPlateau(self.opt, mode='max', patience=5, cooldown=5,
+                                          threshold_mode='abs')
+        schedulers[1] = ExponentialLR(self.opt, gamma=0.1)
+        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
+
+    def test_compound_reduce_lr_on_plateau4(self):
+        epochs = 20
+        for param_group in self.opt.param_groups:
+            param_group['lr'] = 0.05
+        epochs = 10
+        eta_min = 1e-10
+        single_targets = [eta_min + (0.05 - eta_min) *
+                          (1 + math.cos(math.pi * x / epochs)) / 2
+                          for x in range(epochs)]
+        targets = [single_targets]
+        metrics = [1.5 * (1.025 ** i) for i in range(20)]  # 1.025 > 1.1**0.25
+        schedulers = [None, None]
+        schedulers[0] = ReduceLROnPlateau(self.opt, mode='max', patience=3,
+                                          threshold_mode='rel', threshold=0.1)
+        schedulers[1] = CosineAnnealingLR(self.opt, epochs, eta_min)
+        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)
+
     def test_lambda_lr(self):
         epochs = 10
         self.opt.param_groups[0]['lr'] = 0.05
@@ -587,17 +866,39 @@ class TestLRScheduler(TestCase):
                 self.assertAlmostEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
         self.assertAlmostEqual(scheduler.get_lr(), scheduler_copy.get_lr())
 
-    def _test(self, scheduler, targets, epochs=10):
+    def _test(self, schedulers, targets, epochs=10):
+        if isinstance(schedulers, _LRScheduler):
+            schedulers = [schedulers]
         for epoch in range(epochs):
-            scheduler.step(epoch)
+            [scheduler.step(epoch) for scheduler in schedulers]
             for param_group, target in zip(self.opt.param_groups, targets):
                 self.assertAlmostEqual(target[epoch], param_group['lr'],
                                        msg='LR is wrong in epoch {}: expected {}, got {}'.format(
                                            epoch, target[epoch], param_group['lr']), delta=1e-5)
 
-    def _test_reduce_lr_on_plateau(self, scheduler, targets, metrics, epochs=10, verbose=False):
+    def _test_against_legacy(self, scheduler, legacy_scheduler, epochs=10):
+        self.setUp()
+        targets = []
+        for epoch in range(epochs):
+            legacy_scheduler.step(epoch)
+            targets.append([group['lr'] for group in self.opt.param_groups])
+        self.setUp()
         for epoch in range(epochs):
-            scheduler.step(metrics[epoch])
+            scheduler.step(epoch)
+            for i, param_group in enumerate(self.opt.param_groups):
+                self.assertAlmostEqual(targets[epoch][i], param_group['lr'],
+                                       msg='LR is wrong in epoch {}: expected {}, got {}'.format(
+                                           epoch, targets[epoch][i], param_group['lr']), delta=1e-5)
+
+    def _test_reduce_lr_on_plateau(self, schedulers, targets, metrics, epochs=10, verbose=False):
+        if isinstance(schedulers, _LRScheduler) or isinstance(schedulers, ReduceLROnPlateau):
+            schedulers = [schedulers]
+        for epoch in range(epochs):
+            for scheduler in schedulers:
+                if isinstance(scheduler, ReduceLROnPlateau):
+                    scheduler.step(metrics[epoch])
+                else:
+                    scheduler.step(epoch)
             if verbose:
                 print('epoch{}:\tlr={}'.format(epoch, self.opt.param_groups[0]['lr']))
             for param_group, target in zip(self.opt.param_groups, targets):
@@ -605,6 +906,5 @@ class TestLRScheduler(TestCase):
                                        msg='LR is wrong in epoch {}: expected {}, got {}'.format(
                                            epoch, target[epoch], param_group['lr']), delta=1e-5)
 
-
 if __name__ == '__main__':
     run_tests()
index 96cfaff..200e2c6 100644 (file)
@@ -2,7 +2,7 @@ import types
 import math
 import torch
 from torch._six import inf
-from bisect import bisect_right
+from collections import Counter
 from functools import partial
 from .optimizer import Optimizer
 
@@ -124,9 +124,10 @@ class LambdaLR(_LRScheduler):
 
 
 class StepLR(_LRScheduler):
-    """Sets the learning rate of each parameter group to the initial lr
-    decayed by gamma every step_size epochs. When last_epoch=-1, sets
-    initial lr as lr.
+    """Decays the learning rate of each parameter group by gamma every
+    step_size epochs. Notice that such decay can happen simultaneously with
+    other changes to the learning rate from outside this scheduler. When
+    last_epoch=-1, sets initial lr as lr.
 
     Args:
         optimizer (Optimizer): Wrapped optimizer.
@@ -154,14 +155,17 @@ class StepLR(_LRScheduler):
         super(StepLR, self).__init__(optimizer, last_epoch)
 
     def get_lr(self):
-        return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
-                for base_lr in self.base_lrs]
+        if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
+            return [group['lr'] for group in self.optimizer.param_groups]
+        return [group['lr'] * self.gamma
+                for group in self.optimizer.param_groups]
 
 
 class MultiStepLR(_LRScheduler):
-    """Set the learning rate of each parameter group to the initial lr decayed
-    by gamma once the number of epoch reaches one of the milestones. When
-    last_epoch=-1, sets initial lr as lr.
+    """Decays the learning rate of each parameter group by gamma once the
+    number of epoch reaches one of the milestones. Notice that such decay can
+    happen simultaneously with other changes to the learning rate from outside
+    this scheduler. When last_epoch=-1, sets initial lr as lr.
 
     Args:
         optimizer (Optimizer): Wrapped optimizer.
@@ -183,21 +187,20 @@ class MultiStepLR(_LRScheduler):
     """
 
     def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1):
-        if not list(milestones) == sorted(milestones):
-            raise ValueError('Milestones should be a list of'
-                             ' increasing integers. Got {}', milestones)
-        self.milestones = milestones
+        self.milestones = Counter(milestones)
         self.gamma = gamma
         super(MultiStepLR, self).__init__(optimizer, last_epoch)
 
     def get_lr(self):
-        return [base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch)
-                for base_lr in self.base_lrs]
+        if self.last_epoch not in self.milestones:
+            return [group['lr'] for group in self.optimizer.param_groups]
+        return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
+                for group in self.optimizer.param_groups]
 
 
 class ExponentialLR(_LRScheduler):
-    """Set the learning rate of each parameter group to the initial lr decayed
-    by gamma every epoch. When last_epoch=-1, sets initial lr as lr.
+    """Decays the learning rate of each parameter group by gamma every epoch.
+    When last_epoch=-1, sets initial lr as lr.
 
     Args:
         optimizer (Optimizer): Wrapped optimizer.
@@ -210,8 +213,10 @@ class ExponentialLR(_LRScheduler):
         super(ExponentialLR, self).__init__(optimizer, last_epoch)
 
     def get_lr(self):
-        return [base_lr * self.gamma ** self.last_epoch
-                for base_lr in self.base_lrs]
+        if self.last_epoch == 0:
+            return self.base_lrs
+        return [group['lr'] * self.gamma
+                for group in self.optimizer.param_groups]
 
 
 class CosineAnnealingLR(_LRScheduler):
@@ -220,12 +225,18 @@ class CosineAnnealingLR(_LRScheduler):
     :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
 
     .. math::
+        \eta_{t+1} = \eta_{min} + (\eta_t - \eta_{min})\frac{1 +
+        \cos(\frac{T_{cur+1}}{T_{max}}\pi)}{1 + \cos(\frac{T_{cur}}{T_{max}}\pi)}
 
+    When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
+    is defined recursively, the learning rate can be simultaneously modified
+    outside this scheduler by other operators. If the learning rate is set
+    solely by this scheduler, the learning rate at each step becomes:
+
+    .. math::
         \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
         \cos(\frac{T_{cur}}{T_{max}}\pi))
 
-    When last_epoch=-1, sets initial lr as lr.
-
     It has been proposed in
     `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
     implements the cosine annealing part of SGDR, and not the restarts.
@@ -246,9 +257,12 @@ class CosineAnnealingLR(_LRScheduler):
         super(CosineAnnealingLR, self).__init__(optimizer, last_epoch)
 
     def get_lr(self):
-        return [self.eta_min + (base_lr - self.eta_min) *
-                (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
-                for base_lr in self.base_lrs]
+        if self.last_epoch == 0:
+            return self.base_lrs
+        return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
+                (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
+                (group['lr'] - self.eta_min) + self.eta_min
+                for group in self.optimizer.param_groups]
 
 
 class ReduceLROnPlateau(object):