Allow disabling cache in autocast (automatic mixed precision) (#63552)
authorleslie-fang-intel <leslie.fang@intel.com>
Wed, 8 Sep 2021 14:45:12 +0000 (07:45 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 14:47:18 +0000 (07:47 -0700)
commit768014b3e69090347dbecf81e292ae97028067fe
tree0cee4f56811ba23a004fc925b54e1d46442f0320
parentb616132403697a043fd9691693e40f407e77063a
Allow disabling cache in autocast (automatic mixed precision) (#63552)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63552

In this PR, we want to exclude these 2 cases in the `Autocast` weight cache usages:

- Using `torch.jit.trace` under the `Autocast`
As report in https://github.com/pytorch/pytorch/issues/50231 and several other discussions, using `torch.jit.trace` under the `Autocast`, the trace process would hit Autocast's weight cache and fails. So we should disable weight cache under the trace process.
- Using `Autocast` with `Grad mode`

  - Usually we are using `Grad mode` for training. Since in the training phase, the weight will change in every step. So we doesn't need to cache the weight.
  - For the recommended `Autocast` training case in the [doc](https://pytorch.org/docs/stable/amp.html), `Autocast` will clear the cache every step leaving the context. We should disable it to save the clear operations.
    ```
    model = Net().cuda()
    optimizer = optim.SGD(model.parameters(), ...)

    for input, target in data:
        optimizer.zero_grad()
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    ```

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D30644913

Pulled By: ezyang

fbshipit-source-id: ad7bc87372e554e7aa1aa0795e9676871b3974e7
aten/src/ATen/autocast_mode.cpp
aten/src/ATen/autocast_mode.h
test/test_jit.py
test/test_public_bindings.py
torch/_C/__init__.pyi.in
torch/autocast_mode.py
torch/cpu/amp/autocast_mode.py
torch/csrc/autograd/init.cpp
torch/cuda/amp/autocast_mode.py
torch/overrides.py