[nnc] Support thread level parallelism in fused kernels (#63386)
authorBert Maher <bertrand@fb.com>
Fri, 20 Aug 2021 18:11:49 +0000 (11:11 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 20 Aug 2021 18:18:17 +0000 (11:18 -0700)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63386

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D30360382

Pulled By: bertmaher

fbshipit-source-id: 29acf4e932c669ce0f35823faea9099bcd8119b6

test/cpp/tensorexpr/test_kernel.cpp
torch/csrc/jit/tensorexpr/kernel.cpp
torch/csrc/jit/tensorexpr/llvm_codegen.cpp
torch/csrc/jit/tensorexpr/llvm_jit.h
torch/csrc/jit/tensorexpr/loopnest.cpp

index 8f36f54395f495664e5e890c6fbee155c7c0f0f1..8d4e48c4a0bff323aa6a5c95b6ee1159ee337cab 100644 (file)
@@ -206,6 +206,36 @@ TEST_F(Kernel, _3) {
   }
 }
 
+TEST_F(Kernel, ParallelStrided) {
+  KernelScope kernel_scope;
+
+  const auto graph_string = R"IR(
+      graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu),
+            %1 : Float(5, 3, 40005, strides=[960120, 160020, 2], device=cpu)):
+        %2 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %1)
+        %3 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %2)
+        return (%3))IR";
+  auto graph = std::make_shared<Graph>();
+  parseIR(graph_string, &*graph);
+
+  auto a = at::rand({5, 3, 40005}, TensorOptions(kCPU).dtype(at::kFloat));
+  auto b = at::rand({10, 6, 80010}, TensorOptions(kCPU).dtype(at::kFloat))
+               .index(
+                   {Slice(None, None, 2),
+                    Slice(None, None, 2),
+                    Slice(None, None, 2)});
+  auto ref = a * (a * b);
+  auto o = at::zeros_like(ref);
+  TensorExprKernel k(graph);
+  std::vector<at::Tensor> inputs = {a, b};
+  std::vector<IValue> stack = fmap<IValue>(inputs);
+  k.run(stack);
+  o = stack[0].toTensor();
+  for (size_t i = 0; i < 5 * 3; i++) {
+    CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
+  }
+}
+
 TEST_F(Kernel, DISABLED_Shape_Inference) {
   // disabled: doesn't do stride propagation, and isn't being used currently
 
index faacd022e7e0b2dbc6a67e537423b9072893068a..c5333b20106103c5a7bde1d159531f0073ec1de0 100644 (file)
@@ -2,6 +2,7 @@
 #include <torch/csrc/jit/tensorexpr/kernel.h>
 
 #include <ATen/ExpandUtils.h>
+#include <ATen/Parallel.h>
 #include <ATen/TensorGeometry.h>
 #include <c10/util/irange.h>
 #include <c10/util/string_utils.h>
@@ -2487,6 +2488,86 @@ void fuseAllLoops(StmtPtr st) {
   }
 }
 
+// Compute the trip count of a loop if it is a constant.
+c10::optional<int64_t> tripCount(ForPtr loop) {
+  auto tc = IRSimplifier::simplify(
+      cast<int64_t>(ExprHandle(loop->stop()) - ExprHandle(loop->start())));
+  if (auto val = to<LongImm>(tc.node())) {
+    return val->value();
+  }
+  return c10::nullopt;
+}
+
+// Prune innermost loops until iterations satisfies a minimum grain size.
+static void pruneByGrainSize(std::vector<ForPtr>& loops) {
+  constexpr int64_t minGrainSize = 32768;
+  int64_t grainSize = 1;
+  for (int64_t i = loops.size(); i > 0; i--) {
+    auto tc = tripCount(loops[i - 1]);
+    if (!tc) {
+      break;
+    }
+    grainSize *= *tc;
+    if (grainSize < minGrainSize) {
+      loops.pop_back();
+    }
+  }
+}
+
+// Retain enough outermost loops to fill the number of threads.
+static void pruneByThreadCount(std::vector<ForPtr>& loops) {
+  int64_t trips = 1;
+  auto threads = at::get_num_threads();
+  auto it = loops.begin();
+  for (; it != loops.end(); it++) {
+    if (trips >= threads) {
+      break;
+    }
+    auto tc = tripCount(*it);
+    if (!tc) {
+      break;
+    }
+    trips *= *tc;
+  }
+  loops.erase(it, loops.end());
+}
+
+// Flatten and parallelize outer loops, subject to a minimum number of elements
+// in the inner loop, and a maximum level of thread-level parallelism in the
+// outer loops.
+template <typename Bufs>
+static void parallelizeOuterLoops(LoopNest& l, Bufs&& bufs) {
+  for (auto const& buf : bufs) {
+    auto loops = l.getLoopStmtsFor(buf);
+    pruneByGrainSize(loops);
+    pruneByThreadCount(loops);
+
+    // There are no loops to parallelize; give up.
+    if (loops.size() == 0) {
+      continue;
+    }
+    // The loop nest contains a reduction; give up.
+    auto reductions = NodeFinder<ReduceOp>::find(loops[0]);
+    if (reductions.size() > 0) {
+      continue;
+    }
+    // The loop nest has loop carried dependences; give up.
+    if (LoopNest::hasLoopCarriedDependence(loops[0])) {
+      continue;
+    }
+    // Try to flatten the outer loops and parallelize them if successful.
+    ForPtr flattened = nullptr;
+    if (loops.size() == 1) {
+      flattened = loops[0];
+    } else {
+      LoopNest::flatten(loops, &flattened);
+    }
+    if (flattened) {
+      flattened->set_parallel();
+    }
+  }
+}
+
 StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
   torch::jit::tensorexpr::LoopNest l(st, bufOutputs_);
   GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n");
@@ -2528,6 +2609,8 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
   if (backendType == kLLVMCodeGen) {
     fuseAllLoops(l.root_stmt());
     GRAPH_DEBUG("after fuse", *l.root_stmt());
+    parallelizeOuterLoops(l, bufOutputs_);
+    GRAPH_DEBUG("after parallelize", *l.root_stmt());
   }
 
   if (backendType == kCudaCodeGen) {
@@ -2602,9 +2685,13 @@ StmtPtr TensorExprKernel::transformLoops(BackendType backendType, StmtPtr st) {
   }
 
   l.prepareForCodegen();
+  GRAPH_DEBUG("after prepareForCodegen", *l.root_stmt());
+  l.simplify();
+  GRAPH_DEBUG("after simplification", *l.root_stmt());
 
   if (backendType == kLLVMCodeGen && !hasReduction) {
     l.vectorizeInnerLoops();
+    GRAPH_DEBUG("after vectorization", *l.root_stmt());
   }
 
   StmtPtr stmt = l.root_stmt();
index eac1f82f25c4bbd2f74eef2d554ac39534665a53..d5a95bc4cf886cbfcef93672fc5052a56c9ed0bb 100644 (file)
@@ -274,15 +274,24 @@ class LLVMCodeGenImpl : public IRVisitor {
   }
 };
 
+extern "C" {
 typedef void (*ParallelCallee)(int index, int8_t* packed_data);
-void DispatchParallel(int8_t* func, int start, int stop, int8_t* packed_data) {
+void DispatchParallel(
+    int8_t* func,
+    int start,
+    int stop,
+    int8_t* packed_data) noexcept {
   // TODO: preserve the func type.
-  ParallelCallee callee = reinterpret_cast<ParallelCallee>(func);
-  at::parallel_for(start, stop, 1, [&](int64_t f_begin, int64_t f_end) {
-    for (int index = f_begin; index < f_end; index++) {
-      callee(index, packed_data);
-    }
-  });
+  try {
+    ParallelCallee callee = reinterpret_cast<ParallelCallee>(func);
+    at::parallel_for(start, stop, 1, [&](int64_t f_begin, int64_t f_end) {
+      for (int index = f_begin; index < f_end; index++) {
+        callee(index, packed_data);
+      }
+    });
+  } catch (...) {
+  }
+}
 }
 
 } // namespace tensorexpr
@@ -1287,6 +1296,7 @@ void LLVMCodeGenImpl::processParallelFor(ForPtr v) {
       module_->getOrInsertFunction("DispatchParallel", dispatcher_fntype);
   llvm::Function* dispatcher =
       llvm::cast<llvm::Function>(dispatcher_callee.getCallee());
+  dispatcher->addFnAttr(llvm::Attribute::NoUnwind);
   irb_.CreateCall(
       dispatcher, {func_value, start, stop, packed_caller_args_ptr});
   value_ = llvm::ConstantInt::get(IntTy_, 0);
index 30ad5317a1b3c82bd1eb79f9341066f824c5d4ba..8585900abc8d6a9455364a7f7bfd66b206776d6c 100644 (file)
@@ -17,7 +17,13 @@ namespace torch {
 namespace jit {
 namespace tensorexpr {
 
-void DispatchParallel(int8_t* func, int start, int stop, int8_t* packed_data);
+extern "C" {
+void DispatchParallel(
+    int8_t* func,
+    int start,
+    int stop,
+    int8_t* packed_data) noexcept;
+}
 
 inline std::string formatError(llvm::Error&& err, const char* msg) {
   static constexpr char* defaultErrorMsg = "Unexpected failure in LLVM JIT";
index a296d8c7af79b3053dc937fbf2c7f4a276797ad7..7bcdd1a666f7b580e1b370126238f580f7eafdab 100644 (file)
@@ -179,6 +179,13 @@ class Vectorizer : public IRMutator {
     });
   }
 
+  ExprPtr mutate(ModPtr v) override {
+    std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
+    return try_vectorize(v, inputs, [&]() {
+      return ExprHandle(inputs[0]) % ExprHandle(inputs[1]);
+    });
+  }
+
   ExprPtr mutate(AndPtr v) override {
     std::vector<ExprPtr> inputs = {v->lhs(), v->rhs()};
     return try_vectorize(v, inputs, [&]() {