From 768014b3e69090347dbecf81e292ae97028067fe Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Wed, 8 Sep 2021 07:45:12 -0700 Subject: [PATCH] Allow disabling cache in autocast (automatic mixed precision) (#63552) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63552 In this PR, we want to exclude these 2 cases in the `Autocast` weight cache usages: - Using `torch.jit.trace` under the `Autocast` As report in https://github.com/pytorch/pytorch/issues/50231 and several other discussions, using `torch.jit.trace` under the `Autocast`, the trace process would hit Autocast's weight cache and fails. So we should disable weight cache under the trace process. - Using `Autocast` with `Grad mode` - Usually we are using `Grad mode` for training. Since in the training phase, the weight will change in every step. So we doesn't need to cache the weight. - For the recommended `Autocast` training case in the [doc](https://pytorch.org/docs/stable/amp.html), `Autocast` will clear the cache every step leaving the context. We should disable it to save the clear operations. ``` model = Net().cuda() optimizer = optim.SGD(model.parameters(), ...) for input, target in data: optimizer.zero_grad() with autocast(): output = model(input) loss = loss_fn(output, target) loss.backward() optimizer.step() ``` Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D30644913 Pulled By: ezyang fbshipit-source-id: ad7bc87372e554e7aa1aa0795e9676871b3974e7 --- aten/src/ATen/autocast_mode.cpp | 13 +++++++++++- aten/src/ATen/autocast_mode.h | 2 ++ test/test_jit.py | 44 +++++++++++++++++++++++++++++++++++++++++ test/test_public_bindings.py | 2 ++ torch/_C/__init__.pyi.in | 2 ++ torch/autocast_mode.py | 13 +++++++++--- torch/cpu/amp/autocast_mode.py | 4 ++-- torch/csrc/autograd/init.cpp | 22 +++++++++++++++++++++ torch/cuda/amp/autocast_mode.py | 4 ++-- torch/overrides.py | 2 ++ 10 files changed, 100 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 9f5f486..4770d17 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -58,6 +58,9 @@ thread_local int nesting = 0; // autocast_cpu_dtype is the lower_precision_fp used by AutocastCPU. thread_local at::ScalarType autocast_cpu_dtype = at::kBFloat16; +// should we enabled the cache inside autocast. +thread_local bool cache_enabled = true; + // autocast_gpu_dtype is the lower_precision_fp used by AutocastGPU. thread_local at::ScalarType autocast_gpu_dtype = at::kHalf; } @@ -93,6 +96,14 @@ void set_autocast_gpu_dtype(at::ScalarType dtype) { autocast_gpu_dtype = dtype; } +bool is_autocast_cache_enabled() { + return cache_enabled; +} + +void set_autocast_cache_enabled(bool enabled) { + cache_enabled = enabled; +} + // Overload to catch Tensor args // TODO (possible optimization): // Move cast_cache to an inline function in a header with cached_casts declared as @@ -103,7 +114,7 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_ // See cached_casts declaration above for detailed strategy. bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) && arg.scalar_type() == at::kFloat && arg.requires_grad() && - arg.is_leaf() && !arg.is_view()); + arg.is_leaf() && !arg.is_view() && cache_enabled); if (can_try_cache) { auto it = cached_casts.find(arg.unsafeGetTensorImpl()); if (it != cached_casts.end()) { diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index 2e0c4e5..bede6cd 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -14,6 +14,8 @@ TORCH_API at::ScalarType get_autocast_gpu_dtype(); TORCH_API at::ScalarType get_autocast_cpu_dtype(); TORCH_API void set_autocast_gpu_dtype(at::ScalarType dtype); TORCH_API void set_autocast_cpu_dtype(at::ScalarType dtype); +TORCH_API bool is_autocast_cache_enabled(); +TORCH_API void set_autocast_cache_enabled(bool enabled); namespace { diff --git a/test/test_jit.py b/test/test_jit.py index 7051d66..eb61fdb 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -63,6 +63,7 @@ from jit.test_attr import TestGetDefaultAttr # noqa: F401 from jit.test_aten_pow import TestAtenPow # noqa: F401 from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401 from jit.test_union import TestUnion # noqa: F401 +from jit.test_models import MnistNet # Torch from torch import Tensor @@ -15983,6 +15984,49 @@ class TestJitGeneratedModule(JitTestCase): class TestJitGeneratedFunctional(JitTestCase): pass +class TestJitAutocast(JitTestCase): + def setUp(self): + super(TestJitAutocast, self).setUp() + self.models = [MnistNet()] + self.inputs = [torch.randn(5, 1, 28, 28, device='cpu')] + + def tearDown(self): + super(TestJitAutocast, self).tearDown() + + def test_generate_autocast_jit_trace_model(self): + def test_generate_autocast_jit_trace_model(model, x): + model.eval() + with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): + traced_model = torch.jit.trace(model, x) + for i in range(self.models.__len__()): + test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i]) + + def test_nchw_autocast_jit_trace_model(self): + def test_nchw_autocast_jit_trace_model(model, x): + model.eval() + with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): + traced_model = torch.jit.trace(model, x) + with torch.cpu.amp.autocast(), torch.no_grad(): + y = traced_model(x.clone()) + y2 = model(x.clone()) + torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03) + for i in range(self.models.__len__()): + test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i]) + + def test_nhwc_autocast_jit_trace_model(self): + def test_nhwc_autocast_jit_trace_model(model, x): + model.eval() + with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad(): + traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last)) + with torch.cpu.amp.autocast(), torch.no_grad(): + y = traced_model(x.clone().to(memory_format=torch.channels_last)) + y2 = model(x.clone().to(memory_format=torch.channels_last)) + torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03) + for i in range(self.models.__len__()): + if self.inputs[i].size().__len__() == 5: + # NHWC 3D case not support yet + continue + test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i]) # UBSAN per-function exclusions don't seem to work with OpenMP pragmas, # and we have to disable the failing tests here instead. diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 9f8b79d..4df9062 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -139,6 +139,7 @@ class TestPublicBindings(unittest.TestCase): "IntType", "IODescriptor", "is_anomaly_enabled", + "is_autocast_cache_enabled", "is_autocast_cpu_enabled", "is_autocast_enabled", "is_grad_enabled", @@ -190,6 +191,7 @@ class TestPublicBindings(unittest.TestCase): "ScriptObjectProperty", "SerializationStorageContext", "set_anomaly_enabled", + "set_autocast_cache_enabled", "set_autocast_cpu_dtype", "set_autocast_cpu_enabled", "set_autocast_enabled", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 091cb09..6f263c6 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -642,6 +642,8 @@ def get_autocast_cpu_dtype() -> _dtype: ... def get_autocast_gpu_dtype() -> _dtype: ... def autocast_increment_nesting() -> _int: ... def autocast_decrement_nesting() -> _int: ... +def is_autocast_cache_enabled() -> _bool: ... +def set_autocast_cache_enabled(enabled: _bool) -> None: ... def set_anomaly_enabled(enabled: _bool) -> None: ... def is_anomaly_enabled() -> _bool: ... def _enter_dual_level() -> _int: ... diff --git a/torch/autocast_mode.py b/torch/autocast_mode.py index 97d51b8..dcb69a8 100644 --- a/torch/autocast_mode.py +++ b/torch/autocast_mode.py @@ -124,8 +124,9 @@ class autocast(object): Args: device_type(string, required): Whether to use 'cuda' or 'cpu' device - enabled(bool, optional, default=True)": Whether autocasting should be enabled in the region. - dtype(torch_dtype, optional): Whether to use torch.float16 or torch.bfloat16 + enabled(bool, optional, default=True): Whether autocasting should be enabled in the region. + dtype(torch_dtype, optional): Whether to use torch.float16 or torch.bfloat16. + cache_enabled(bool, optional, default=True): Whether the weight cache inside autocast should be enabled. """ def __init__(self, device_type, enabled=True, **kwargs): self.device = device_type @@ -135,13 +136,16 @@ class autocast(object): self.fast_dtype = torch.get_autocast_cpu_dtype() else: raise RuntimeError('User specified autocast device_type must be \'cuda\' or \'cpu\'') + self._cache_enabled = torch.is_autocast_cache_enabled() if torch.cuda.amp.common.amp_definitely_not_available() and self.device == 'cuda': warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling') enabled = False for key, value in kwargs.items(): if key == 'dtype': self.fast_dtype = value - if not (key == 'dtype'): + if key == 'cache_enabled': + self._cache_enabled = value + if not ((key == 'dtype') or (key == 'cache_enabled')): raise RuntimeError('Unrecognized optional argument supplied to autocast context manager: ' + str(key)) if self.device == 'cpu': @@ -157,6 +161,7 @@ class autocast(object): self._enabled = enabled def __enter__(self): + self.prev_cache_enabled = torch.is_autocast_cache_enabled() if self.device == 'cpu': self.prev = torch.is_autocast_cpu_enabled() self.prev_fastdtype = torch.get_autocast_cpu_dtype() @@ -169,6 +174,7 @@ class autocast(object): torch.set_autocast_gpu_dtype(self.fast_dtype) torch.set_autocast_enabled(self._enabled) torch.autocast_increment_nesting() + torch.set_autocast_cache_enabled(self._cache_enabled) def __exit__(self, *args): # Drop the cache when we exit to a nesting level that's outside any instance of autocast. @@ -182,6 +188,7 @@ class autocast(object): torch.clear_autocast_cache() torch.set_autocast_enabled(self.prev) torch.set_autocast_gpu_dtype(self.prev_fastdtype) + torch.set_autocast_cache_enabled(self.prev_cache_enabled) return False def __call__(self, func): diff --git a/torch/cpu/amp/autocast_mode.py b/torch/cpu/amp/autocast_mode.py index 8c65f72..7686928 100644 --- a/torch/cpu/amp/autocast_mode.py +++ b/torch/cpu/amp/autocast_mode.py @@ -5,5 +5,5 @@ class autocast(torch.autocast_mode.autocast): See :class:`torch.autocast`. ``torch.cpu.amp.autocast(args...)`` is equivalent to ``torch.autocast("cpu", args...)`` """ - def __init__(self, enabled=True, dtype=torch.bfloat16): - super().__init__("cpu", enabled=enabled, dtype=dtype) + def __init__(self, enabled=True, dtype=torch.bfloat16, cache_enabled=True): + super().__init__("cpu", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 697ca87..d411849 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -419,6 +419,26 @@ static PyObject * autocast_decrement_nesting(PyObject* _unused, PyObject *arg) { END_HANDLE_TH_ERRORS } +static PyObject * is_autocast_cache_enabled(PyObject* _unused, PyObject *arg) { + HANDLE_TH_ERRORS + if (at::autocast::is_autocast_cache_enabled()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +static PyObject * set_autocast_cache_enabled(PyObject* _unused, PyObject *arg) { + HANDLE_TH_ERRORS + if (!PyBool_Check(arg)) { + throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name); + } + at::autocast::set_autocast_cache_enabled(arg == Py_True); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + static PyObject * set_grad_enabled(PyObject* _unused, PyObject *arg) { HANDLE_TH_ERRORS if (!PyBool_Check(arg)) { @@ -510,6 +530,8 @@ static PyMethodDef methods[] = { // NOLINT {"get_autocast_gpu_dtype", get_autocast_gpu_dtype, METH_NOARGS, nullptr}, {"autocast_increment_nesting", autocast_increment_nesting, METH_NOARGS, nullptr}, {"autocast_decrement_nesting", autocast_decrement_nesting, METH_NOARGS, nullptr}, + {"is_autocast_cache_enabled", is_autocast_cache_enabled, METH_NOARGS, nullptr}, + {"set_autocast_cache_enabled", set_autocast_cache_enabled, METH_O, nullptr}, {"set_anomaly_enabled", set_anomaly_mode_enabled, METH_O, nullptr}, {"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr}, {"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr}, diff --git a/torch/cuda/amp/autocast_mode.py b/torch/cuda/amp/autocast_mode.py index ca8a2fc..2aad0af 100644 --- a/torch/cuda/amp/autocast_mode.py +++ b/torch/cuda/amp/autocast_mode.py @@ -13,8 +13,8 @@ class autocast(torch.autocast_mode.autocast): See :class:`torch.autocast`. ``torch.cuda.amp.autocast(args...)`` is equivalent to ``torch.autocast("cuda", args...)`` """ - def __init__(self, enabled=True, dtype=torch.float16): - super().__init__("cuda", enabled=enabled, dtype=dtype) + def __init__(self, enabled=True, dtype=torch.float16, cache_enabled=True): + super().__init__("cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled) # Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which diff --git a/torch/overrides.py b/torch/overrides.py index 049115e..3fd87b8 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -192,6 +192,8 @@ def get_ignored_functions() -> Set[Callable]: torch.set_autocast_gpu_dtype, torch.autocast_increment_nesting, torch.autocast_decrement_nesting, + torch.is_autocast_cache_enabled, + torch.set_autocast_cache_enabled, torch.nn.functional.hardswish, torch.is_vulkan_available, torch.are_deterministic_algorithms_enabled, -- 2.7.4