To add SequentialLR to PyTorch Core Schedulers (#64037)
authorIlqar Ramazanli <iramazanli@fb.com>
Thu, 9 Sep 2021 16:32:36 +0000 (09:32 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 16:36:32 +0000 (09:36 -0700)
commit2b41bf40c5888c5f44fe2b1d20fdb25e9c7502a8
tree42500b0f58a3fe54cfd4137175ec47c29203bbce
parentc3203efe808c5d042e1acd20d19467bcff597487
To add SequentialLR to PyTorch Core Schedulers (#64037)

Summary:
Partially resolves https://github.com/pytorch/vision/issues/4281

In this PR we are proposing a new scheduler --SequentialLR-- which enables list of different schedulers called in different periods of the training process.

The main motivation of this scheduler is recently gained popularity of warming up phase in the training time. It has been shown that having a small steps in initial stages of training can help convergence procedure get faster.

With the help of SequentialLR we mainly enable to call a small constant (or linearly increasing) learning rate followed by actual target learning rate scheduler.

```PyThon
scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=2)
scheduler2 = ExponentialLR(optimizer, gamma=0.9)
scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[5])

for epoch in range(100):
    train(...)
    validate(...)
    scheduler.step()
```

which this code snippet will call `ConstantLR` in the first 5 epochs and will follow up with `ExponentialLR` in the following epochs.

This scheduler could be used to provide call of any group of schedulers next to each other. The main consideration we should make is every time we switch to a new scheduler we assume that new scheduler starts from the beginning- zeroth epoch.

We also add Chained Scheduler to `optim.rst` and `lr_scheduler.pyi` files here.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64037

Reviewed By: albanD

Differential Revision: D30841099

Pulled By: iramazanli

fbshipit-source-id: 94f7d352066ee108eef8cda5f0dcb07f4d371751
docs/source/optim.rst
test/test_optim.py
torch/optim/lr_scheduler.py
torch/optim/lr_scheduler.pyi