Allow disabling cache in autocast (automatic mixed precision) (#63552)
authorleslie-fang-intel <leslie.fang@intel.com>
Wed, 8 Sep 2021 14:45:12 +0000 (07:45 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 14:47:18 +0000 (07:47 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63552

In this PR, we want to exclude these 2 cases in the `Autocast` weight cache usages:

- Using `torch.jit.trace` under the `Autocast`
As report in https://github.com/pytorch/pytorch/issues/50231 and several other discussions, using `torch.jit.trace` under the `Autocast`, the trace process would hit Autocast's weight cache and fails. So we should disable weight cache under the trace process.
- Using `Autocast` with `Grad mode`

  - Usually we are using `Grad mode` for training. Since in the training phase, the weight will change in every step. So we doesn't need to cache the weight.
  - For the recommended `Autocast` training case in the [doc](https://pytorch.org/docs/stable/amp.html), `Autocast` will clear the cache every step leaving the context. We should disable it to save the clear operations.
    ```
    model = Net().cuda()
    optimizer = optim.SGD(model.parameters(), ...)

    for input, target in data:
        optimizer.zero_grad()
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    ```

Test Plan: Imported from OSS

Reviewed By: mrshenli

Differential Revision: D30644913

Pulled By: ezyang

fbshipit-source-id: ad7bc87372e554e7aa1aa0795e9676871b3974e7

aten/src/ATen/autocast_mode.cpp
aten/src/ATen/autocast_mode.h
test/test_jit.py
test/test_public_bindings.py
torch/_C/__init__.pyi.in
torch/autocast_mode.py
torch/cpu/amp/autocast_mode.py
torch/csrc/autograd/init.cpp
torch/cuda/amp/autocast_mode.py
torch/overrides.py

index 9f5f486..4770d17 100644 (file)
@@ -58,6 +58,9 @@ thread_local int nesting = 0;
 // autocast_cpu_dtype is the lower_precision_fp used by AutocastCPU.
 thread_local at::ScalarType autocast_cpu_dtype = at::kBFloat16;
 
+// should we enabled the cache inside autocast.
+thread_local bool cache_enabled = true;
+
 // autocast_gpu_dtype is the lower_precision_fp used by AutocastGPU.
 thread_local at::ScalarType autocast_gpu_dtype = at::kHalf;
 }
@@ -93,6 +96,14 @@ void set_autocast_gpu_dtype(at::ScalarType dtype) {
   autocast_gpu_dtype = dtype;
 }
 
+bool is_autocast_cache_enabled() {
+  return cache_enabled;
+}
+
+void set_autocast_cache_enabled(bool enabled) {
+  cache_enabled = enabled;
+}
+
 // Overload to catch Tensor args
 // TODO (possible optimization):
 // Move cast_cache to an inline function in a header with cached_casts declared as
@@ -103,7 +114,7 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_
     // See cached_casts declaration above for detailed strategy.
     bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
                          arg.scalar_type() == at::kFloat && arg.requires_grad() &&
-                         arg.is_leaf() && !arg.is_view());
+                         arg.is_leaf() && !arg.is_view() && cache_enabled);
     if (can_try_cache) {
       auto it = cached_casts.find(arg.unsafeGetTensorImpl());
       if (it != cached_casts.end()) {
index 2e0c4e5..bede6cd 100644 (file)
@@ -14,6 +14,8 @@ TORCH_API at::ScalarType get_autocast_gpu_dtype();
 TORCH_API at::ScalarType get_autocast_cpu_dtype();
 TORCH_API void set_autocast_gpu_dtype(at::ScalarType dtype);
 TORCH_API void set_autocast_cpu_dtype(at::ScalarType dtype);
+TORCH_API bool is_autocast_cache_enabled();
+TORCH_API void set_autocast_cache_enabled(bool enabled);
 
 
 namespace {
index 7051d66..eb61fdb 100644 (file)
@@ -63,6 +63,7 @@ from jit.test_attr import TestGetDefaultAttr  # noqa: F401
 from jit.test_aten_pow import TestAtenPow  # noqa: F401
 from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo  # noqa: F401
 from jit.test_union import TestUnion  # noqa: F401
+from jit.test_models import MnistNet
 
 # Torch
 from torch import Tensor
@@ -15983,6 +15984,49 @@ class TestJitGeneratedModule(JitTestCase):
 class TestJitGeneratedFunctional(JitTestCase):
     pass
 
+class TestJitAutocast(JitTestCase):
+    def setUp(self):
+        super(TestJitAutocast, self).setUp()
+        self.models = [MnistNet()]
+        self.inputs = [torch.randn(5, 1, 28, 28, device='cpu')]
+
+    def tearDown(self):
+        super(TestJitAutocast, self).tearDown()
+
+    def test_generate_autocast_jit_trace_model(self):
+        def test_generate_autocast_jit_trace_model(model, x):
+            model.eval()
+            with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
+                traced_model = torch.jit.trace(model, x)
+        for i in range(self.models.__len__()):
+            test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i])
+
+    def test_nchw_autocast_jit_trace_model(self):
+        def test_nchw_autocast_jit_trace_model(model, x):
+            model.eval()
+            with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
+                traced_model = torch.jit.trace(model, x)
+            with torch.cpu.amp.autocast(), torch.no_grad():
+                y = traced_model(x.clone())
+                y2 = model(x.clone())
+            torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
+        for i in range(self.models.__len__()):
+            test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i])
+
+    def test_nhwc_autocast_jit_trace_model(self):
+        def test_nhwc_autocast_jit_trace_model(model, x):
+            model.eval()
+            with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
+                traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last))
+            with torch.cpu.amp.autocast(), torch.no_grad():
+                y = traced_model(x.clone().to(memory_format=torch.channels_last))
+                y2 = model(x.clone().to(memory_format=torch.channels_last))
+            torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
+        for i in range(self.models.__len__()):
+            if self.inputs[i].size().__len__() == 5:
+                # NHWC 3D case not support yet
+                continue
+            test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
 
 # UBSAN per-function exclusions don't seem to work with OpenMP pragmas,
 # and we have to disable the failing tests here instead.
index 9f8b79d..4df9062 100644 (file)
@@ -139,6 +139,7 @@ class TestPublicBindings(unittest.TestCase):
             "IntType",
             "IODescriptor",
             "is_anomaly_enabled",
+            "is_autocast_cache_enabled",
             "is_autocast_cpu_enabled",
             "is_autocast_enabled",
             "is_grad_enabled",
@@ -190,6 +191,7 @@ class TestPublicBindings(unittest.TestCase):
             "ScriptObjectProperty",
             "SerializationStorageContext",
             "set_anomaly_enabled",
+            "set_autocast_cache_enabled",
             "set_autocast_cpu_dtype",
             "set_autocast_cpu_enabled",
             "set_autocast_enabled",
index 091cb09..6f263c6 100644 (file)
@@ -642,6 +642,8 @@ def get_autocast_cpu_dtype() -> _dtype: ...
 def get_autocast_gpu_dtype() -> _dtype: ...
 def autocast_increment_nesting() -> _int: ...
 def autocast_decrement_nesting() -> _int: ...
+def is_autocast_cache_enabled() -> _bool: ...
+def set_autocast_cache_enabled(enabled: _bool) -> None: ...
 def set_anomaly_enabled(enabled: _bool) -> None: ...
 def is_anomaly_enabled() -> _bool: ...
 def _enter_dual_level() -> _int: ...
index 97d51b8..dcb69a8 100644 (file)
@@ -124,8 +124,9 @@ class autocast(object):
 
     Args:
         device_type(string, required):  Whether to use 'cuda' or 'cpu' device
-        enabled(bool, optional, default=True)":  Whether autocasting should be enabled in the region.
-        dtype(torch_dtype, optional):  Whether to use torch.float16 or torch.bfloat16
+        enabled(bool, optional, default=True):  Whether autocasting should be enabled in the region.
+        dtype(torch_dtype, optional):  Whether to use torch.float16 or torch.bfloat16.
+        cache_enabled(bool, optional, default=True):  Whether the weight cache inside autocast should be enabled.
     """
     def __init__(self, device_type, enabled=True, **kwargs):
         self.device = device_type
@@ -135,13 +136,16 @@ class autocast(object):
             self.fast_dtype = torch.get_autocast_cpu_dtype()
         else:
             raise RuntimeError('User specified autocast device_type must be \'cuda\' or \'cpu\'')
+        self._cache_enabled = torch.is_autocast_cache_enabled()
         if torch.cuda.amp.common.amp_definitely_not_available() and self.device == 'cuda':
             warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling')
             enabled = False
         for key, value in kwargs.items():
             if key == 'dtype':
                 self.fast_dtype = value
-            if not (key == 'dtype'):
+            if key == 'cache_enabled':
+                self._cache_enabled = value
+            if not ((key == 'dtype') or (key == 'cache_enabled')):
                 raise RuntimeError('Unrecognized optional argument supplied to autocast context manager: ' + str(key))
 
         if self.device == 'cpu':
@@ -157,6 +161,7 @@ class autocast(object):
         self._enabled = enabled
 
     def __enter__(self):
+        self.prev_cache_enabled = torch.is_autocast_cache_enabled()
         if self.device == 'cpu':
             self.prev = torch.is_autocast_cpu_enabled()
             self.prev_fastdtype = torch.get_autocast_cpu_dtype()
@@ -169,6 +174,7 @@ class autocast(object):
             torch.set_autocast_gpu_dtype(self.fast_dtype)
             torch.set_autocast_enabled(self._enabled)
             torch.autocast_increment_nesting()
+        torch.set_autocast_cache_enabled(self._cache_enabled)
 
     def __exit__(self, *args):
         # Drop the cache when we exit to a nesting level that's outside any instance of autocast.
@@ -182,6 +188,7 @@ class autocast(object):
                 torch.clear_autocast_cache()
             torch.set_autocast_enabled(self.prev)
             torch.set_autocast_gpu_dtype(self.prev_fastdtype)
+        torch.set_autocast_cache_enabled(self.prev_cache_enabled)
         return False
 
     def __call__(self, func):
index 8c65f72..7686928 100644 (file)
@@ -5,5 +5,5 @@ class autocast(torch.autocast_mode.autocast):
     See :class:`torch.autocast`.
     ``torch.cpu.amp.autocast(args...)`` is equivalent to ``torch.autocast("cpu", args...)``
     """
-    def __init__(self, enabled=True, dtype=torch.bfloat16):
-        super().__init__("cpu", enabled=enabled, dtype=dtype)
+    def __init__(self, enabled=True, dtype=torch.bfloat16, cache_enabled=True):
+        super().__init__("cpu", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
index 697ca87..d411849 100644 (file)
@@ -419,6 +419,26 @@ static PyObject * autocast_decrement_nesting(PyObject* _unused, PyObject *arg) {
   END_HANDLE_TH_ERRORS
 }
 
+static PyObject * is_autocast_cache_enabled(PyObject* _unused, PyObject *arg) {
+  HANDLE_TH_ERRORS
+  if (at::autocast::is_autocast_cache_enabled()) {
+    Py_RETURN_TRUE;
+  } else {
+    Py_RETURN_FALSE;
+  }
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject * set_autocast_cache_enabled(PyObject* _unused, PyObject *arg) {
+  HANDLE_TH_ERRORS
+  if (!PyBool_Check(arg)) {
+    throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
+  }
+  at::autocast::set_autocast_cache_enabled(arg == Py_True);
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
 static PyObject * set_grad_enabled(PyObject* _unused, PyObject *arg) {
   HANDLE_TH_ERRORS
   if (!PyBool_Check(arg)) {
@@ -510,6 +530,8 @@ static PyMethodDef methods[] = { // NOLINT
   {"get_autocast_gpu_dtype", get_autocast_gpu_dtype, METH_NOARGS, nullptr},
   {"autocast_increment_nesting", autocast_increment_nesting, METH_NOARGS, nullptr},
   {"autocast_decrement_nesting", autocast_decrement_nesting, METH_NOARGS, nullptr},
+  {"is_autocast_cache_enabled", is_autocast_cache_enabled, METH_NOARGS, nullptr},
+  {"set_autocast_cache_enabled", set_autocast_cache_enabled, METH_O, nullptr},
   {"set_anomaly_enabled", set_anomaly_mode_enabled, METH_O, nullptr},
   {"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr},
   {"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
index ca8a2fc..2aad0af 100644 (file)
@@ -13,8 +13,8 @@ class autocast(torch.autocast_mode.autocast):
     See :class:`torch.autocast`.
     ``torch.cuda.amp.autocast(args...)`` is equivalent to ``torch.autocast("cuda", args...)``
     """
-    def __init__(self, enabled=True, dtype=torch.float16):
-        super().__init__("cuda", enabled=enabled, dtype=dtype)
+    def __init__(self, enabled=True, dtype=torch.float16, cache_enabled=True):
+        super().__init__("cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
 
 
 # Casts Tensors and containers of Tensors.  Special-cases passthroughs for strings and np.ndarrays, which
index 049115e..3fd87b8 100644 (file)
@@ -192,6 +192,8 @@ def get_ignored_functions() -> Set[Callable]:
         torch.set_autocast_gpu_dtype,
         torch.autocast_increment_nesting,
         torch.autocast_decrement_nesting,
+        torch.is_autocast_cache_enabled,
+        torch.set_autocast_cache_enabled,
         torch.nn.functional.hardswish,
         torch.is_vulkan_available,
         torch.are_deterministic_algorithms_enabled,