Remove flag to toggle CPU fusion in the presence of parallelism (#63514)
authorBert Maher <bertrand@fb.com>
Fri, 20 Aug 2021 18:11:49 +0000 (11:11 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 20 Aug 2021 18:18:19 +0000 (11:18 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63514

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D30417127

Pulled By: bertmaher

fbshipit-source-id: b77d7c68364f2af73570740540f3b1152313016e

test/cpp/tensorexpr/test_te_fuser_pass.cpp
test/jit/test_profiler.py
test/test_jit_fuser_te.py
test/test_tensorexpr.py
torch/csrc/jit/passes/tensorexpr_fuser.cpp
torch/csrc/jit/passes/tensorexpr_fuser.h
torch/csrc/jit/python/init.cpp

index 8dd616453362b574f4a56c136c5bf3c934ca16c7..91fb4c2b7582cc05cfd85fed407758f2ab82deda 100644 (file)
@@ -15,19 +15,15 @@ namespace jit {
 using namespace torch::jit::tensorexpr;
 
 struct WithCPUFuser {
-  WithCPUFuser(bool val = true)
-      : cpuFuserEnabled(canFuseOnCPU()), parallel(texprParallelCPUEnabled()) {
+  WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
     overrideCanFuseOnCPU(val);
-    setTexprParallelCPUEnabled(true);
   }
 
   ~WithCPUFuser() {
     overrideCanFuseOnCPU(cpuFuserEnabled);
-    setTexprParallelCPUEnabled(parallel);
   }
 
   bool cpuFuserEnabled;
-  bool parallel;
 };
 
 TEST(TEFuserPass, FuserPass_1) {
index aa8be0518385f74ea3a25085c880b4987041fbb7..b9ed9d0b78eb513aa0914652647f7b26284f4bfe 100644 (file)
@@ -29,8 +29,6 @@ class TestProfiler(JitTestCase):
         torch._C._debug_set_fusion_group_inlining(False)
         self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
         torch._C._jit_set_te_must_use_llvm_cpu(False)
-        self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
-        torch._C._jit_set_texpr_parallel_cpu_enabled(True)
 
     def tearDown(self):
         torch._C._jit_set_profiling_executor(self.prev_exec)
@@ -42,7 +40,6 @@ class TestProfiler(JitTestCase):
         torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled)
         torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
         torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
-        torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
 
     def test_tensor_type_not_determined_by_inputs(self):
         @torch.jit.script
index 64c26b7936b549d241f4d460cf6571a2e6990123..614226ff871baa0400f92142ca7a2fe89c7c2f65 100644 (file)
@@ -85,10 +85,6 @@ class TestTEFuser(JitTestCase):
         self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
         torch._C._jit_set_te_must_use_llvm_cpu(False)
 
-        # TODO: CPU fuser currently is disabled when multithreading.
-        self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
-        torch._C._jit_set_texpr_parallel_cpu_enabled(True)
-
         self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
         self.int_dtypes = [
             torch.int8,
@@ -116,7 +112,6 @@ class TestTEFuser(JitTestCase):
 
         torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
         torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
-        torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
 
     def assertLastGraphAllFused(self):
         self.assertAllFused(torch.jit.last_executed_optimized_graph())
index 6353113a1ec4c4e352b8c04188698f5b9bcc0dd0..47c7e689aa6a4cf9d516a402d617f8babb36517a 100644 (file)
@@ -24,9 +24,6 @@ class BaseTestClass(JitTestCase):
         torch._C._debug_set_fusion_group_inlining(False)
         self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
         torch._C._jit_set_te_must_use_llvm_cpu(False)
-        # TODO: CPU fuser currently is disabled when multithreading.
-        self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
-        torch._C._jit_set_texpr_parallel_cpu_enabled(True)
 
         self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
 
@@ -39,7 +36,6 @@ class BaseTestClass(JitTestCase):
         torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
         torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
         torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
-        torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
 
     def assertLastGraphAllFused(self):
         self.assertAllFused(torch.jit.last_executed_optimized_graph())
index d4add03506c4f5e20171ee84d0a47417ce3f8e85..52bf4539479dfcc6049617d18394e925bd42799c 100644 (file)
@@ -1,6 +1,5 @@
 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
 
-#include <ATen/Parallel.h>
 #include <ATen/core/interned_strings.h>
 #include <ATen/record_function.h>
 #include <c10/util/FunctionRef.h>
@@ -250,15 +249,6 @@ bool isSupported(Node* node) {
 } // namespace tensorexpr
 
 static bool texpr_fuser_enabled_ = true;
-static bool texpr_parallel_cpu_enabled = false;
-
-bool texprParallelCPUEnabled() {
-  return texpr_parallel_cpu_enabled;
-}
-
-void setTexprParallelCPUEnabled(bool val) {
-  texpr_parallel_cpu_enabled = val;
-}
 
 void setTensorExprFuserEnabled(bool val) {
   texpr_fuser_enabled_ = val;
@@ -898,14 +888,7 @@ class TensorExprFuser {
       return false;
     }
     if (device->is_cpu()) {
-      // CPU fusion is only supported for single-thread.
-      if (!canFuseOnCPU()) {
-        return false;
-      }
-      if (at::get_num_threads() == 1 || texprParallelCPUEnabled()) {
-        return true;
-      }
-      return false;
+      return canFuseOnCPU();
     } else if (device->is_cuda()) {
       return canFuseOnGPU();
     } else if (device->is_xpu()) {
index 3f6538b7e587a322c546ab96875485d0086b3acc..254aebd91d12f1be579effba0623cf69c198054c 100644 (file)
@@ -24,8 +24,6 @@ TORCH_API void setTensorExprFuserEnabled(bool val);
 TORCH_API bool tensorExprFuserEnabled();
 TORCH_API bool setTexprReductionsEnabled(bool value);
 TORCH_API bool texprReductionsEnabled();
-TORCH_API bool texprParallelCPUEnabled();
-TORCH_API void setTexprParallelCPUEnabled(bool val);
 
 TORCH_API void RemoveProfileNodesAndSpecializeTypes(
     std::shared_ptr<Graph>& graph);
index 5fca5755935510dfefa42a837ecaa3720d5e2b9c..992e60edd7d198207baf89ad60b36a121a02ac73 100644 (file)
@@ -711,8 +711,6 @@ void initJITBindings(PyObject* module) {
       .def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed)
       .def("_jit_set_texpr_reductions_enabled", &setTexprReductionsEnabled)
       .def("_jit_texpr_reductions_enabled", &texprReductionsEnabled)
-      .def("_jit_set_texpr_parallel_cpu_enabled", &setTexprParallelCPUEnabled)
-      .def("_jit_texpr_parallel_cpu_enabled", &texprParallelCPUEnabled)
       .def(
           "_jit_set_te_generate_block_code",
           [](bool gen_block_code) {