def _jit_check_alias_annotation(g: Graph, args: Tuple[Any, ...], unqualified_op_name: str): ...
def _jit_can_fuse_on_cpu() -> _bool: ...
def _jit_can_fuse_on_gpu() -> _bool: ...
+def _jit_can_fuse_on_cpu_legacy() -> _bool: ...
def _debug_get_fusion_group_inlining() -> _bool: ...
def _debug_set_fusion_group_inlining(enable: _bool): ...
def _jit_texpr_fuser_enabled() -> _bool: ...
def _llvm_enabled() -> _bool: ...
def _jit_override_can_fuse_on_cpu(override: _bool): ...
def _jit_override_can_fuse_on_gpu(override: _bool): ...
+def _jit_override_can_fuse_on_cpu_legacy(override: _bool): ...
def _jit_set_symbolic_shapes_test_mode(override: _bool): ...
def _jit_symbolic_shapes_test_mode_enabled() -> _bool: ...
def _jit_set_texpr_fuser_enabled(enable: _bool): ...
#include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
#include <torch/csrc/jit/codegen/fuser/kernel_spec.h>
#include <torch/csrc/jit/codegen/fuser/tensor_info.h>
+#include <torch/csrc/jit/passes/graph_fuser.h>
#include <algorithm>
#include <iostream> // TODO: remove, debugging only
bool runFusion(const int64_t key, Stack& stack, std::string* code_out) {
// Short-circuits if fusion isn't enabled
- if (!canFuseOnCPU() && !canFuseOnGPU())
+ if (!canFuseOnCPULegacy() && !canFuseOnGPU())
return false;
// Acquires the FusionSpec
// Attempts to run fallback if device fusion is disabled
if (device.is_cuda() && !canFuseOnGPU())
return false;
- if (device.is_cpu() && !canFuseOnCPU())
+ if (device.is_cpu() && !canFuseOnCPULegacy())
return false;
if (device.is_xpu())
return false;
#include <c10/util/Flags.h>
#include <stdexcept>
-C10_DEFINE_bool(torch_jit_enable_cpu_fusion, false, "enable cpu fusion");
-
namespace torch {
namespace jit {
namespace detail {
-// Note: CPU fusion is currently disabled due to test flakiness
-#if defined(FBCODE_CAFFE2)
+#ifdef TORCH_ENABLE_LLVM
bool cpu_fuser_enabled = true;
#else
bool cpu_fuser_enabled = false;
}
bool canFuseOnCPU() {
- return fuser::hasFusionBackend(DeviceType::CPU) &&
- (detail::cpu_fuser_enabled || FLAGS_torch_jit_enable_cpu_fusion);
+ return fuser::hasFusionBackend(DeviceType::CPU) && detail::cpu_fuser_enabled;
}
bool canFuseOnGPU() {
return !strict_fuser_check;
}
if ((*device).is_cpu()) {
- return canFuseOnCPU();
+ return canFuseOnCPULegacy();
} else if ((*device).is_cuda()) {
return canFuseOnGPU();
} else if ((*device).is_xpu()) {
} // anonymous namespace
+static bool cpu_fuser_enabled_legacy = false;
+
+bool canFuseOnCPULegacy() {
+ return cpu_fuser_enabled_legacy;
+}
+
+void overrideCanFuseOnCPULegacy(bool value) {
+ cpu_fuser_enabled_legacy = value;
+}
+
void FuseGraph(std::shared_ptr<Graph>& graph, bool strict_fuser_check) {
AliasDb db(graph);
GraphFuser(&db, graph->block(), strict_fuser_check).run();
namespace torch {
namespace jit {
+TORCH_API bool canFuseOnCPULegacy();
+TORCH_API void overrideCanFuseOnCPULegacy(bool value);
+
// NB: Be sure to run DCE before fusion, because dead instructions
// can prevent fusion opportunities from being exploited.
// On Windows will noop, NYI
.def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU)
.def("_jit_can_fuse_on_cpu", &canFuseOnCPU)
.def("_jit_can_fuse_on_gpu", &canFuseOnGPU)
+ .def("_jit_can_fuse_on_cpu_legacy", &canFuseOnCPULegacy)
+ .def("_jit_override_can_fuse_on_cpu_legacy", &overrideCanFuseOnCPULegacy)
.def(
"_jit_differentiate",
[](Graph& g) {
def enable_cpu_fuser(fn):
def wrapper(*args, **kwargs):
+ torch._C._jit_override_can_fuse_on_cpu_legacy(True)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_set_te_must_use_llvm_cpu(False)
try:
fn(*args, **kwargs)
finally:
+ torch._C._jit_override_can_fuse_on_cpu_legacy(False)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_set_te_must_use_llvm_cpu(True)
return wrapper