[nnc] Re-enable CPU fusion" (#63665)
authorBert Maher <bertrand@fb.com>
Mon, 23 Aug 2021 19:41:32 +0000 (12:41 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Mon, 23 Aug 2021 19:42:42 +0000 (12:42 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63665

This reverts commit 125e2d02e575612eb427104e7c67f1c28f090db8.

Test Plan: Imported from OSS

Reviewed By: ZolotukhinM

Differential Revision: D30471646

Pulled By: bertmaher

fbshipit-source-id: 4189869566f03b5f9ada78d78830f6a34946eed6

torch/_C/__init__.pyi.in
torch/csrc/jit/codegen/fuser/executor.cpp
torch/csrc/jit/codegen/fuser/interface.cpp
torch/csrc/jit/passes/graph_fuser.cpp
torch/csrc/jit/passes/graph_fuser.h
torch/csrc/jit/python/init.cpp
torch/testing/_internal/jit_utils.py

index 4d0245c..0b6bb6b 100644 (file)
@@ -208,6 +208,7 @@ def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ...
 def _jit_check_alias_annotation(g: Graph, args: Tuple[Any, ...], unqualified_op_name: str): ...
 def _jit_can_fuse_on_cpu() -> _bool: ...
 def _jit_can_fuse_on_gpu() -> _bool: ...
+def _jit_can_fuse_on_cpu_legacy() -> _bool: ...
 def _debug_get_fusion_group_inlining() -> _bool: ...
 def _debug_set_fusion_group_inlining(enable: _bool): ...
 def _jit_texpr_fuser_enabled() -> _bool: ...
@@ -215,6 +216,7 @@ def _jit_nvfuser_enabled() -> _bool: ...
 def _llvm_enabled() -> _bool: ...
 def _jit_override_can_fuse_on_cpu(override: _bool): ...
 def _jit_override_can_fuse_on_gpu(override: _bool): ...
+def _jit_override_can_fuse_on_cpu_legacy(override: _bool): ...
 def _jit_set_symbolic_shapes_test_mode(override: _bool): ...
 def _jit_symbolic_shapes_test_mode_enabled() -> _bool: ...
 def _jit_set_texpr_fuser_enabled(enable: _bool): ...
index b260e48..46f2f41 100644 (file)
@@ -11,6 +11,7 @@
 #include <torch/csrc/jit/codegen/fuser/kernel_cache.h>
 #include <torch/csrc/jit/codegen/fuser/kernel_spec.h>
 #include <torch/csrc/jit/codegen/fuser/tensor_info.h>
+#include <torch/csrc/jit/passes/graph_fuser.h>
 
 #include <algorithm>
 #include <iostream> // TODO: remove, debugging only
@@ -327,7 +328,7 @@ void launchFusion(
 
 bool runFusion(const int64_t key, Stack& stack, std::string* code_out) {
   // Short-circuits if fusion isn't enabled
-  if (!canFuseOnCPU() && !canFuseOnGPU())
+  if (!canFuseOnCPULegacy() && !canFuseOnGPU())
     return false;
 
   // Acquires the FusionSpec
@@ -362,7 +363,7 @@ bool runFusion(const int64_t key, Stack& stack, std::string* code_out) {
   // Attempts to run fallback if device fusion is disabled
   if (device.is_cuda() && !canFuseOnGPU())
     return false;
-  if (device.is_cpu() && !canFuseOnCPU())
+  if (device.is_cpu() && !canFuseOnCPULegacy())
     return false;
   if (device.is_xpu())
     return false;
index ec67c4b..ef7e9e0 100644 (file)
@@ -8,15 +8,12 @@
 #include <c10/util/Flags.h>
 #include <stdexcept>
 
-C10_DEFINE_bool(torch_jit_enable_cpu_fusion, false, "enable cpu fusion");
-
 namespace torch {
 namespace jit {
 
 namespace detail {
 
-// Note: CPU fusion is currently disabled due to test flakiness
-#if defined(FBCODE_CAFFE2)
+#ifdef TORCH_ENABLE_LLVM
 bool cpu_fuser_enabled = true;
 #else
 bool cpu_fuser_enabled = false;
@@ -37,8 +34,7 @@ void runFusion(const int64_t key, Stack& stack) {
 }
 
 bool canFuseOnCPU() {
-  return fuser::hasFusionBackend(DeviceType::CPU) &&
-      (detail::cpu_fuser_enabled || FLAGS_torch_jit_enable_cpu_fusion);
+  return fuser::hasFusionBackend(DeviceType::CPU) && detail::cpu_fuser_enabled;
 }
 
 bool canFuseOnGPU() {
index f7dd466..653f9fe 100644 (file)
@@ -183,7 +183,7 @@ struct GraphFuser {
       return !strict_fuser_check;
     }
     if ((*device).is_cpu()) {
-      return canFuseOnCPU();
+      return canFuseOnCPULegacy();
     } else if ((*device).is_cuda()) {
       return canFuseOnGPU();
     } else if ((*device).is_xpu()) {
@@ -1244,6 +1244,16 @@ void PeepholeOptimizeShapeExpressions(Block* block, AliasDb* db) {
 
 } // anonymous namespace
 
+static bool cpu_fuser_enabled_legacy = false;
+
+bool canFuseOnCPULegacy() {
+  return cpu_fuser_enabled_legacy;
+}
+
+void overrideCanFuseOnCPULegacy(bool value) {
+  cpu_fuser_enabled_legacy = value;
+}
+
 void FuseGraph(std::shared_ptr<Graph>& graph, bool strict_fuser_check) {
   AliasDb db(graph);
   GraphFuser(&db, graph->block(), strict_fuser_check).run();
index 0cdcc2e..aafb442 100644 (file)
@@ -5,6 +5,9 @@
 namespace torch {
 namespace jit {
 
+TORCH_API bool canFuseOnCPULegacy();
+TORCH_API void overrideCanFuseOnCPULegacy(bool value);
+
 // NB: Be sure to run DCE before fusion, because dead instructions
 // can prevent fusion opportunities from being exploited.
 // On Windows will noop, NYI
index c92ab1b..baea47d 100644 (file)
@@ -590,6 +590,8 @@ void initJITBindings(PyObject* module) {
       .def("_jit_override_can_fuse_on_gpu", &overrideCanFuseOnGPU)
       .def("_jit_can_fuse_on_cpu", &canFuseOnCPU)
       .def("_jit_can_fuse_on_gpu", &canFuseOnGPU)
+      .def("_jit_can_fuse_on_cpu_legacy", &canFuseOnCPULegacy)
+      .def("_jit_override_can_fuse_on_cpu_legacy", &overrideCanFuseOnCPULegacy)
       .def(
           "_jit_differentiate",
           [](Graph& g) {
index 6086572..7f9fb97 100644 (file)
@@ -668,11 +668,13 @@ def _trace(*args, **kwargs):
 
 def enable_cpu_fuser(fn):
     def wrapper(*args, **kwargs):
+        torch._C._jit_override_can_fuse_on_cpu_legacy(True)
         torch._C._jit_override_can_fuse_on_cpu(True)
         torch._C._jit_set_te_must_use_llvm_cpu(False)
         try:
             fn(*args, **kwargs)
         finally:
+            torch._C._jit_override_can_fuse_on_cpu_legacy(False)
             torch._C._jit_override_can_fuse_on_cpu(False)
             torch._C._jit_set_te_must_use_llvm_cpu(True)
     return wrapper