// autocast_cpu_dtype is the lower_precision_fp used by AutocastCPU.
thread_local at::ScalarType autocast_cpu_dtype = at::kBFloat16;
+// should we enabled the cache inside autocast.
+thread_local bool cache_enabled = true;
+
// autocast_gpu_dtype is the lower_precision_fp used by AutocastGPU.
thread_local at::ScalarType autocast_gpu_dtype = at::kHalf;
}
autocast_gpu_dtype = dtype;
}
+bool is_autocast_cache_enabled() {
+ return cache_enabled;
+}
+
+void set_autocast_cache_enabled(bool enabled) {
+ cache_enabled = enabled;
+}
+
// Overload to catch Tensor args
// TODO (possible optimization):
// Move cast_cache to an inline function in a header with cached_casts declared as
// See cached_casts declaration above for detailed strategy.
bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) &&
arg.scalar_type() == at::kFloat && arg.requires_grad() &&
- arg.is_leaf() && !arg.is_view());
+ arg.is_leaf() && !arg.is_view() && cache_enabled);
if (can_try_cache) {
auto it = cached_casts.find(arg.unsafeGetTensorImpl());
if (it != cached_casts.end()) {
TORCH_API at::ScalarType get_autocast_cpu_dtype();
TORCH_API void set_autocast_gpu_dtype(at::ScalarType dtype);
TORCH_API void set_autocast_cpu_dtype(at::ScalarType dtype);
+TORCH_API bool is_autocast_cache_enabled();
+TORCH_API void set_autocast_cache_enabled(bool enabled);
namespace {
from jit.test_aten_pow import TestAtenPow # noqa: F401
from jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401
from jit.test_union import TestUnion # noqa: F401
+from jit.test_models import MnistNet
# Torch
from torch import Tensor
class TestJitGeneratedFunctional(JitTestCase):
pass
+class TestJitAutocast(JitTestCase):
+ def setUp(self):
+ super(TestJitAutocast, self).setUp()
+ self.models = [MnistNet()]
+ self.inputs = [torch.randn(5, 1, 28, 28, device='cpu')]
+
+ def tearDown(self):
+ super(TestJitAutocast, self).tearDown()
+
+ def test_generate_autocast_jit_trace_model(self):
+ def test_generate_autocast_jit_trace_model(model, x):
+ model.eval()
+ with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
+ traced_model = torch.jit.trace(model, x)
+ for i in range(self.models.__len__()):
+ test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i])
+
+ def test_nchw_autocast_jit_trace_model(self):
+ def test_nchw_autocast_jit_trace_model(model, x):
+ model.eval()
+ with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
+ traced_model = torch.jit.trace(model, x)
+ with torch.cpu.amp.autocast(), torch.no_grad():
+ y = traced_model(x.clone())
+ y2 = model(x.clone())
+ torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
+ for i in range(self.models.__len__()):
+ test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i])
+
+ def test_nhwc_autocast_jit_trace_model(self):
+ def test_nhwc_autocast_jit_trace_model(model, x):
+ model.eval()
+ with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
+ traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last))
+ with torch.cpu.amp.autocast(), torch.no_grad():
+ y = traced_model(x.clone().to(memory_format=torch.channels_last))
+ y2 = model(x.clone().to(memory_format=torch.channels_last))
+ torch.testing.assert_allclose(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
+ for i in range(self.models.__len__()):
+ if self.inputs[i].size().__len__() == 5:
+ # NHWC 3D case not support yet
+ continue
+ test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
# UBSAN per-function exclusions don't seem to work with OpenMP pragmas,
# and we have to disable the failing tests here instead.
"IntType",
"IODescriptor",
"is_anomaly_enabled",
+ "is_autocast_cache_enabled",
"is_autocast_cpu_enabled",
"is_autocast_enabled",
"is_grad_enabled",
"ScriptObjectProperty",
"SerializationStorageContext",
"set_anomaly_enabled",
+ "set_autocast_cache_enabled",
"set_autocast_cpu_dtype",
"set_autocast_cpu_enabled",
"set_autocast_enabled",
def get_autocast_gpu_dtype() -> _dtype: ...
def autocast_increment_nesting() -> _int: ...
def autocast_decrement_nesting() -> _int: ...
+def is_autocast_cache_enabled() -> _bool: ...
+def set_autocast_cache_enabled(enabled: _bool) -> None: ...
def set_anomaly_enabled(enabled: _bool) -> None: ...
def is_anomaly_enabled() -> _bool: ...
def _enter_dual_level() -> _int: ...
Args:
device_type(string, required): Whether to use 'cuda' or 'cpu' device
- enabled(bool, optional, default=True)": Whether autocasting should be enabled in the region.
- dtype(torch_dtype, optional): Whether to use torch.float16 or torch.bfloat16
+ enabled(bool, optional, default=True): Whether autocasting should be enabled in the region.
+ dtype(torch_dtype, optional): Whether to use torch.float16 or torch.bfloat16.
+ cache_enabled(bool, optional, default=True): Whether the weight cache inside autocast should be enabled.
"""
def __init__(self, device_type, enabled=True, **kwargs):
self.device = device_type
self.fast_dtype = torch.get_autocast_cpu_dtype()
else:
raise RuntimeError('User specified autocast device_type must be \'cuda\' or \'cpu\'')
+ self._cache_enabled = torch.is_autocast_cache_enabled()
if torch.cuda.amp.common.amp_definitely_not_available() and self.device == 'cuda':
warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling')
enabled = False
for key, value in kwargs.items():
if key == 'dtype':
self.fast_dtype = value
- if not (key == 'dtype'):
+ if key == 'cache_enabled':
+ self._cache_enabled = value
+ if not ((key == 'dtype') or (key == 'cache_enabled')):
raise RuntimeError('Unrecognized optional argument supplied to autocast context manager: ' + str(key))
if self.device == 'cpu':
self._enabled = enabled
def __enter__(self):
+ self.prev_cache_enabled = torch.is_autocast_cache_enabled()
if self.device == 'cpu':
self.prev = torch.is_autocast_cpu_enabled()
self.prev_fastdtype = torch.get_autocast_cpu_dtype()
torch.set_autocast_gpu_dtype(self.fast_dtype)
torch.set_autocast_enabled(self._enabled)
torch.autocast_increment_nesting()
+ torch.set_autocast_cache_enabled(self._cache_enabled)
def __exit__(self, *args):
# Drop the cache when we exit to a nesting level that's outside any instance of autocast.
torch.clear_autocast_cache()
torch.set_autocast_enabled(self.prev)
torch.set_autocast_gpu_dtype(self.prev_fastdtype)
+ torch.set_autocast_cache_enabled(self.prev_cache_enabled)
return False
def __call__(self, func):
See :class:`torch.autocast`.
``torch.cpu.amp.autocast(args...)`` is equivalent to ``torch.autocast("cpu", args...)``
"""
- def __init__(self, enabled=True, dtype=torch.bfloat16):
- super().__init__("cpu", enabled=enabled, dtype=dtype)
+ def __init__(self, enabled=True, dtype=torch.bfloat16, cache_enabled=True):
+ super().__init__("cpu", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
END_HANDLE_TH_ERRORS
}
+static PyObject * is_autocast_cache_enabled(PyObject* _unused, PyObject *arg) {
+ HANDLE_TH_ERRORS
+ if (at::autocast::is_autocast_cache_enabled()) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject * set_autocast_cache_enabled(PyObject* _unused, PyObject *arg) {
+ HANDLE_TH_ERRORS
+ if (!PyBool_Check(arg)) {
+ throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
+ }
+ at::autocast::set_autocast_cache_enabled(arg == Py_True);
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
static PyObject * set_grad_enabled(PyObject* _unused, PyObject *arg) {
HANDLE_TH_ERRORS
if (!PyBool_Check(arg)) {
{"get_autocast_gpu_dtype", get_autocast_gpu_dtype, METH_NOARGS, nullptr},
{"autocast_increment_nesting", autocast_increment_nesting, METH_NOARGS, nullptr},
{"autocast_decrement_nesting", autocast_decrement_nesting, METH_NOARGS, nullptr},
+ {"is_autocast_cache_enabled", is_autocast_cache_enabled, METH_NOARGS, nullptr},
+ {"set_autocast_cache_enabled", set_autocast_cache_enabled, METH_O, nullptr},
{"set_anomaly_enabled", set_anomaly_mode_enabled, METH_O, nullptr},
{"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr},
{"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
See :class:`torch.autocast`.
``torch.cuda.amp.autocast(args...)`` is equivalent to ``torch.autocast("cuda", args...)``
"""
- def __init__(self, enabled=True, dtype=torch.float16):
- super().__init__("cuda", enabled=enabled, dtype=dtype)
+ def __init__(self, enabled=True, dtype=torch.float16, cache_enabled=True):
+ super().__init__("cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which
torch.set_autocast_gpu_dtype,
torch.autocast_increment_nesting,
torch.autocast_decrement_nesting,
+ torch.is_autocast_cache_enabled,
+ torch.set_autocast_cache_enabled,
torch.nn.functional.hardswish,
torch.is_vulkan_available,
torch.are_deterministic_algorithms_enabled,