Re-apply: [nnc] Support thread level parallelism in fused kernels (#63776)
authorBert Maher <bertrand@fb.com>
Wed, 25 Aug 2021 01:52:29 +0000 (18:52 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 25 Aug 2021 01:56:55 +0000 (18:56 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63776

I reverted this out of an abundance of caution because some test
failures occurred, but they were all due to precision issues fixed lower in
this stack.  Let's try again.

I've rolled the elimination of the allow-parallelism-in-fusions toggle into
this diff since they're pretty tightly coupled.
ghstack-source-id: 136529847

Test Plan: CI

Reviewed By: huiguoo

Differential Revision: D30484555

fbshipit-source-id: 38fd33520f710585d1130c365a8c60c9ce794a59

12 files changed:
test/cpp/tensorexpr/test_kernel.cpp
test/cpp/tensorexpr/test_te_fuser_pass.cpp
test/jit/test_profiler.py
test/test_jit_fuser_te.py
test/test_tensorexpr.py
torch/csrc/jit/passes/tensorexpr_fuser.cpp
torch/csrc/jit/passes/tensorexpr_fuser.h
torch/csrc/jit/python/init.cpp
torch/csrc/jit/tensorexpr/kernel.cpp
torch/csrc/jit/tensorexpr/llvm_codegen.cpp
torch/csrc/jit/tensorexpr/llvm_jit.h
torch/csrc/jit/tensorexpr/loopnest.cpp

index e14282f..8cdf2ef 100644 (file)
@@ -198,6 +198,34 @@ TEST_F(Kernel, _3) {
   }
 }
 
+TEST_F(Kernel, ParallelStrided) {
+  const auto graph_string = R"IR(
+      graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu),
+            %1 : Float(5, 3, 40005, strides=[960120, 160020, 2], device=cpu)):
+        %2 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %1)
+        %3 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %2)
+        return (%3))IR";
+  auto graph = std::make_shared<Graph>();
+  parseIR(graph_string, &*graph);
+
+  auto a = at::rand({5, 3, 40005}, TensorOptions(kCPU).dtype(at::kFloat));
+  auto b = at::rand({10, 6, 80010}, TensorOptions(kCPU).dtype(at::kFloat))
+               .index(
+                   {Slice(None, None, 2),
+                    Slice(None, None, 2),
+                    Slice(None, None, 2)});
+  auto ref = a * (a * b);
+  auto o = at::zeros_like(ref);
+  TensorExprKernel k(graph);
+  std::vector<at::Tensor> inputs = {a, b};
+  std::vector<IValue> stack = fmap<IValue>(inputs);
+  k.run(stack);
+  o = stack[0].toTensor();
+  for (size_t i = 0; i < 5 * 3; i++) {
+    CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
+  }
+}
+
 TEST_F(Kernel, DISABLED_Shape_Inference) {
   // disabled: doesn't do stride propagation, and isn't being used currently
 
index 723a8fe..b82d383 100644 (file)
@@ -14,19 +14,15 @@ namespace jit {
 using namespace torch::jit::tensorexpr;
 
 struct WithCPUFuser {
-  WithCPUFuser(bool val = true)
-      : cpuFuserEnabled(canFuseOnCPU()), parallel(texprParallelCPUEnabled()) {
+  WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
     overrideCanFuseOnCPU(val);
-    setTexprParallelCPUEnabled(true);
   }
 
   ~WithCPUFuser() {
     overrideCanFuseOnCPU(cpuFuserEnabled);
-    setTexprParallelCPUEnabled(parallel);
   }
 
   bool cpuFuserEnabled;
-  bool parallel;
 };
 
 TEST(TEFuserPass, FuserPass_1) {
index aa8be05..b9ed9d0 100644 (file)
@@ -29,8 +29,6 @@ class TestProfiler(JitTestCase):
         torch._C._debug_set_fusion_group_inlining(False)
         self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
         torch._C._jit_set_te_must_use_llvm_cpu(False)
-        self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
-        torch._C._jit_set_texpr_parallel_cpu_enabled(True)
 
     def tearDown(self):
         torch._C._jit_set_profiling_executor(self.prev_exec)
@@ -42,7 +40,6 @@ class TestProfiler(JitTestCase):
         torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled)
         torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
         torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
-        torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
 
     def test_tensor_type_not_determined_by_inputs(self):
         @torch.jit.script
index f2dce12..014f142 100644 (file)
@@ -85,10 +85,6 @@ class TestTEFuser(JitTestCase):
         self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
         torch._C._jit_set_te_must_use_llvm_cpu(False)
 
-        # TODO: CPU fuser currently is disabled when multithreading.
-        self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
-        torch._C._jit_set_texpr_parallel_cpu_enabled(True)
-
         self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
         self.int_dtypes = [
             torch.int8,
@@ -116,7 +112,6 @@ class TestTEFuser(JitTestCase):
 
         torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
         torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
-        torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
 
     def assertLastGraphAllFused(self):
         self.assertAllFused(torch.jit.last_executed_optimized_graph())
index 6353113..47c7e68 100644 (file)
@@ -24,9 +24,6 @@ class BaseTestClass(JitTestCase):
         torch._C._debug_set_fusion_group_inlining(False)
         self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
         torch._C._jit_set_te_must_use_llvm_cpu(False)
-        # TODO: CPU fuser currently is disabled when multithreading.
-        self.old_fuse_parallel = torch._C._jit_texpr_parallel_cpu_enabled()
-        torch._C._jit_set_texpr_parallel_cpu_enabled(True)
 
         self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
 
@@ -39,7 +36,6 @@ class BaseTestClass(JitTestCase):
         torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
         torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
         torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
-        torch._C._jit_set_texpr_parallel_cpu_enabled(self.old_fuse_parallel)
 
     def assertLastGraphAllFused(self):
         self.assertAllFused(torch.jit.last_executed_optimized_graph())
index 3f0cd14..085291a 100644 (file)
@@ -1,6 +1,5 @@
 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
 
-#include <ATen/Parallel.h>
 #include <ATen/core/interned_strings.h>
 #include <ATen/record_function.h>
 #include <c10/util/FunctionRef.h>
@@ -250,15 +249,6 @@ bool isSupported(Node* node) {
 } // namespace tensorexpr
 
 static bool texpr_fuser_enabled_ = true;
-static bool texpr_parallel_cpu_enabled = false;
-
-bool texprParallelCPUEnabled() {
-  return texpr_parallel_cpu_enabled;
-}
-
-void setTexprParallelCPUEnabled(bool val) {
-  texpr_parallel_cpu_enabled = val;
-}
 
 void setTensorExprFuserEnabled(bool val) {
   texpr_fuser_enabled_ = val;
@@ -898,14 +888,7 @@ class TensorExprFuser {
       return false;
     }
     if (device->is_cpu()) {
-      // CPU fusion is only supported for single-thread.
-      if (!canFuseOnCPU()) {
-        return false;
-      }
-      if (at::get_num_threads() == 1 || texprParallelCPUEnabled()) {
-        return true;
-      }
-      return false;
+      return canFuseOnCPU();
     } else if (device->is_cuda()) {
       return canFuseOnGPU();
     } else if (device->is_xpu()) {
index 3f6538b..254aebd 100644 (file)
@@ -24,8 +24,6 @@ TORCH_API void setTensorExprFuserEnabled(bool val);
 TORCH_API bool tensorExprFuserEnabled();
 TORCH_API bool setTexprReductionsEnabled(bool value);
 TORCH_API bool texprReductionsEnabled();
-TORCH_API bool texprParallelCPUEnabled();
-TORCH_API void setTexprParallelCPUEnabled(bool val);
 
 TORCH_API void RemoveProfileNodesAndSpecializeTypes(
     std::shared_ptr<Graph>& graph);
index baea47d..645fea2 100644 (file)
@@ -714,8 +714,6 @@ void initJITBindings(PyObject* module) {
       .def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed)
       .def("_jit_set_texpr_reductions_enabled", &setTexprReductionsEnabled)
       .def("_jit_texpr_reductions_enabled", &texprReductionsEnabled)
-      .def("_jit_set_texpr_parallel_cpu_enabled", &setTexprParallelCPUEnabled)
-      .def("_jit_texpr_parallel_cpu_enabled", &texprParallelCPUEnabled)
       .def(
           "_jit_set_te_generate_block_code",
           [](bool gen_block_code) {
index fed5e1e..d53e857 100644 (file)
@@ -2,6 +2,7 @@
 #include <torch/csrc/jit/tensorexpr/kernel.h>
 
 #include <ATen/ExpandUtils.h>
+#include <ATen/Parallel.h>
 #include <ATen/TensorGeometry.h>
 #include <c10/util/irange.h>
 #include <c10/util/string_utils.h>
@@ -2487,6 +2488,86 @@ void fuseAllLoops(StmtPtr st) {
   }
 }
 
+// Compute the trip count of a loop if it is a constant.
+c10::optional<int64_t> tripCount(ForPtr loop) {
+  auto tc = IRSimplifier::simplify(
+      cast<int64_t>(ExprHandle(loop->stop()) - ExprHandle(loop->start())));
+  if (auto val = to<LongImm>(tc.node())) {
+    return val->value();
+  }
+  return c10::nullopt;
+}
+
+// Prune innermost loops until iterations satisfies a minimum grain size.
+static void pruneByGrainSize(std::vector<ForPtr>& loops) {
+  constexpr int64_t minGrainSize = 32768;
+  int64_t grainSize = 1;
+  for (int64_t i = loops.size(); i > 0; i--) {
+    auto tc = tripCount(loops[i - 1]);
+    if (!tc) {
+      break;
+    }
+    grainSize *= *tc;
+    if (grainSize < minGrainSize) {
+      loops.pop_back();
+    }
+  }
+}
+
+// Retain enough outermost loops to fill the number of threads.
+static void pruneByThreadCount(std::vector<ForPtr>& loops) {
+  int64_t trips = 1;
+  auto threads = at::get_num_threads();
+  auto it = loops.begin();
+  for (; it != loops.end(); it++) {
+    if (trips >= threads) {
+      break;
+    }
+    auto tc = tripCount(*it);
+    if (!tc) {
+      break;
+    }
+    trips *= *tc;
+  }
+  loops.erase(it, loops.end());
+}
+
+// Flatten and parallelize outer loops, subject to a minimum number of elements
+// in the inner loop, and a maximum level of thread-level parallelism in the
+// outer loops.
+template <typename Bufs>
+static void parallelizeOuterLoops(LoopNest& l, Bufs&& bufs) {
+  for (auto const& buf : bufs) {
+    auto loops = l.getLoopStmtsFor(buf);
+    pruneByGrainSize(loops);
+    pruneByThreadCount(loops);
+
+    // There are no loops to parallelize; give up.
+    if (loops.size() == 0) {
+      continue;
+    }
+    // The loop nest contains a reduction; give up.
+    auto reductions = NodeFinder<ReduceOp>::find(loops[0]);
+    if (reductions.size() > 0) {
+      continue;
+    }
+    // The loop nest has loop carried dependences; give up.
+    if (LoopNest::hasLoopCarriedDependence(loops[0])) {
+      continue;
+    }
+    // Try to flatten the outer loops and parallelize them if successful.
+    ForPtr flattened = nullptr;
+    if (loops.size() == 1) {
+      flattened = loops[0];
+    } else {
+      LoopNest::flatten(loops, &flattened);
+    }
+    if (flattened) {
+      flattened->set_parallel();
+    }
+  }
+}
+
 StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
   torch::jit::tensorexpr::LoopNest l(st, bufOutputs_);
   GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n");
@@ -2528,6 +2609,8 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
   if (backendType == kLLVMCodeGen) {
     fuseAllLoops(l.root_stmt());
     GRAPH_DEBUG("after fuse", *l.root_stmt());
+    parallelizeOuterLoops(l, bufOutputs_);
+    GRAPH_DEBUG("after parallelize", *l.root_stmt());
   }
 
   if (backendType == kCudaCodeGen) {
@@ -2602,9 +2685,13 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
   }
 
   l.prepareForCodegen();
+  GRAPH_DEBUG("after prepareForCodegen", *l.root_stmt());
+  l.simplify();
+  GRAPH_DEBUG("after simplification", *l.root_stmt());
 
   if (backendType == kLLVMCodeGen && !hasReduction) {
     l.vectorizeInnerLoops();
+    GRAPH_DEBUG("after vectorization", *l.root_stmt());
   }
 
   StmtPtr stmt = l.root_stmt();
index 4ab2d53..5346d36 100644 (file)
@@ -274,15 +274,24 @@ class LLVMCodeGenImpl : public IRVisitor {
   }
 };
 
+extern "C" {
 typedef void (*ParallelCallee)(int index, int8_t* packed_data);
-void DispatchParallel(int8_t* func, int start, int stop, int8_t* packed_data) {
+void DispatchParallel(
+    int8_t* func,
+    int start,
+    int stop,
+    int8_t* packed_data) noexcept {
   // TODO: preserve the func type.
-  ParallelCallee callee = reinterpret_cast<ParallelCallee>(func);
-  at::parallel_for(start, stop, 1, [&](int64_t f_begin, int64_t f_end) {
-    for (int index = f_begin; index < f_end; index++) {
-      callee(index, packed_data);
-    }
-  });
+  try {
+    ParallelCallee callee = reinterpret_cast<ParallelCallee>(func);
+    at::parallel_for(start, stop, 1, [&](int64_t f_begin, int64_t f_end) {
+      for (int index = f_begin; index < f_end; index++) {
+        callee(index, packed_data);
+      }
+    });
+  } catch (...) {
+  }
+}
 }
 
 } // namespace tensorexpr
@@ -1288,6 +1297,7 @@ void LLVMCodeGenImpl::processParallelFor(ForPtr v) {
       module_->getOrInsertFunction("DispatchParallel", dispatcher_fntype);
   llvm::Function* dispatcher =
       llvm::cast<llvm::Function>(dispatcher_callee.getCallee());
+  dispatcher->addFnAttr(llvm::Attribute::NoUnwind);
   irb_.CreateCall(
       dispatcher, {func_value, start, stop, packed_caller_args_ptr});
   value_ = llvm::ConstantInt::get(IntTy_, 0);
index 30ad531..8585900 100644 (file)
@@ -17,7 +17,13 @@ namespace torch {
 namespace jit {
 namespace tensorexpr {
 
-void DispatchParallel(int8_t* func, int start, int stop, int8_t* packed_data);
+extern "C" {
+void DispatchParallel(
+    int8_t* func,
+    int start,
+    int stop,
+    int8_t* packed_data) noexcept;
+}
 
 inline std::string formatError(llvm::Error&& err, const char* msg) {
   static constexpr char* defaultErrorMsg = "Unexpected failure in LLVM JIT";
index 1904999..d3a4b91 100644 (file)
@@ -179,6 +179,13 @@ class Vectorizer : public IRMutator {
     });
   }
 
+  ExprPtr mutate(ModPtr v) override {
+    std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
+    return try_vectorize(v, inputs, [&]() {
+      return ExprHandle(inputs[0]) % ExprHandle(inputs[1]);
+    });
+  }
+
   ExprPtr mutate(AndPtr v) override {
     std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
     return try_vectorize(v, inputs, [&]() {