Lowering of OpenMP Parallel operation to LLVM IR 1/n
authorKiran Chandramohan <kiran.chandramohan@arm.com>
Mon, 13 Jul 2020 22:13:04 +0000 (23:13 +0100)
committerKiran Chandramohan <kiran.chandramohan@arm.com>
Mon, 13 Jul 2020 22:55:45 +0000 (23:55 +0100)
This patch introduces lowering of the OpenMP parallel operation to LLVM
IR using the OpenMPIRBuilder.

Functions topologicalSort and connectPhiNodes are generalised so that
they work with operations also. connectPhiNodes is also made static.

Lowering works for a parallel region with multiple blocks. Clauses and
arguments of the OpenMP operation are not handled.

Reviewed By: rriddle, anchu-rajendran

Differential Revision: https://reviews.llvm.org/D81660

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Target/openmp-llvm.mlir

index 3be6c97..642282f 100644 (file)
@@ -24,7 +24,6 @@ def OpenMP_Dialect : Dialect {
 class OpenMP_Op<string mnemonic, list<OpTrait> traits = []> :
       Op<OpenMP_Dialect, mnemonic, traits>;
 
-
 //===----------------------------------------------------------------------===//
 // 2.6 parallel Construct
 //===----------------------------------------------------------------------===//
@@ -81,8 +80,8 @@ def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments]> {
     of the parallel region.
   }];
 
-  let arguments = (ins Optional<I1>:$if_expr_var,
-             Optional<AnyInteger>:$num_threads_var,
+  let arguments = (ins Optional<AnyType>:$if_expr_var,
+             Optional<AnyType>:$num_threads_var,
              OptionalAttr<ClauseDefault>:$default_val,
              Variadic<AnyType>:$private_vars,
              Variadic<AnyType>:$firstprivate_vars,
index 3a70101..e44ae97 100644 (file)
@@ -87,6 +87,8 @@ protected:
                                          llvm::IRBuilder<> &builder);
   virtual LogicalResult convertOmpOperation(Operation &op,
                                             llvm::IRBuilder<> &builder);
+  virtual LogicalResult convertOmpParallel(Operation &op,
+                                           llvm::IRBuilder<> &builder);
   static std::unique_ptr<llvm::Module> prepareLLVMModule(Operation *m);
 
   /// A helper to look up remapped operands in the value remapping table.
@@ -100,7 +102,6 @@ private:
   LogicalResult convertFunctions();
   LogicalResult convertGlobals();
   LogicalResult convertOneFunction(LLVMFuncOp func);
-  void connectPHINodes(LLVMFuncOp func);
   LogicalResult convertBlock(Block &bb, bool ignoreArguments);
 
   llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
index 657aa84..0defea6 100644 (file)
 #include "llvm/ADT/SetVector.h"
 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
 #include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/CFG.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Cloning.h"
 
 using namespace mlir;
@@ -304,7 +306,160 @@ ModuleTranslation::ModuleTranslation(Operation *module,
   assert(satisfiesLLVMModule(mlirModule) &&
          "mlirModule should honor LLVM's module semantics.");
 }
-ModuleTranslation::~ModuleTranslation() {}
+ModuleTranslation::~ModuleTranslation() {
+  if (ompBuilder)
+    ompBuilder->finalize();
+}
+
+/// Get the SSA value passed to the current block from the terminator operation
+/// of its predecessor.
+static Value getPHISourceValue(Block *current, Block *pred,
+                               unsigned numArguments, unsigned index) {
+  Operation &terminator = *pred->getTerminator();
+  if (isa<LLVM::BrOp>(terminator))
+    return terminator.getOperand(index);
+
+  // For conditional branches, we need to check if the current block is reached
+  // through the "true" or the "false" branch and take the relevant operands.
+  auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator);
+  assert(condBranchOp &&
+         "only branch operations can be terminators of a block that "
+         "has successors");
+  assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) &&
+         "successors with arguments in LLVM conditional branches must be "
+         "different blocks");
+
+  return condBranchOp.getSuccessor(0) == current
+             ? condBranchOp.trueDestOperands()[index]
+             : condBranchOp.falseDestOperands()[index];
+}
+
+/// Connect the PHI nodes to the results of preceding blocks.
+template <typename T>
+static void
+connectPHINodes(T &func, const DenseMap<Value, llvm::Value *> &valueMapping,
+                const DenseMap<Block *, llvm::BasicBlock *> &blockMapping) {
+  // Skip the first block, it cannot be branched to and its arguments correspond
+  // to the arguments of the LLVM function.
+  for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
+    Block *bb = &*it;
+    llvm::BasicBlock *llvmBB = blockMapping.lookup(bb);
+    auto phis = llvmBB->phis();
+    auto numArguments = bb->getNumArguments();
+    assert(numArguments == std::distance(phis.begin(), phis.end()));
+    for (auto &numberedPhiNode : llvm::enumerate(phis)) {
+      auto &phiNode = numberedPhiNode.value();
+      unsigned index = numberedPhiNode.index();
+      for (auto *pred : bb->getPredecessors()) {
+        phiNode.addIncoming(valueMapping.lookup(getPHISourceValue(
+                                bb, pred, numArguments, index)),
+                            blockMapping.lookup(pred));
+      }
+    }
+  }
+}
+
+// TODO: implement an iterative version
+static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
+  blocks.insert(b);
+  for (Block *bb : b->getSuccessors()) {
+    if (blocks.count(bb) == 0)
+      topologicalSortImpl(blocks, bb);
+  }
+}
+
+/// Sort function blocks topologically.
+template <typename T>
+static llvm::SetVector<Block *> topologicalSort(T &f) {
+  // For each blocks that has not been visited yet (i.e. that has no
+  // predecessors), add it to the list and traverse its successors in DFS
+  // preorder.
+  llvm::SetVector<Block *> blocks;
+  for (Block &b : f) {
+    if (blocks.count(&b) == 0)
+      topologicalSortImpl(blocks, &b);
+  }
+  assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted");
+
+  return blocks;
+}
+
+/// Convert the OpenMP parallel Operation to LLVM IR.
+LogicalResult
+ModuleTranslation::convertOmpParallel(Operation &opInst,
+                                      llvm::IRBuilder<> &builder) {
+  using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
+
+  auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
+                       llvm::BasicBlock &continuationIP) {
+    llvm::LLVMContext &llvmContext = llvmModule->getContext();
+
+    llvm::BasicBlock *codeGenIPBB = codeGenIP.getBlock();
+    llvm::Instruction *codeGenIPBBTI = codeGenIPBB->getTerminator();
+
+    builder.SetInsertPoint(codeGenIPBB);
+
+    for (auto &region : opInst.getRegions()) {
+      for (auto &bb : region) {
+        auto *llvmBB = llvm::BasicBlock::Create(
+            llvmContext, "omp.par.region", codeGenIP.getBlock()->getParent());
+        blockMapping[&bb] = llvmBB;
+      }
+
+      // Then, convert blocks one by one in topological order to ensure
+      // defs are converted before uses.
+      llvm::SetVector<Block *> blocks = topologicalSort(region);
+      for (auto indexedBB : llvm::enumerate(blocks)) {
+        Block *bb = indexedBB.value();
+        llvm::BasicBlock *curLLVMBB = blockMapping[bb];
+        if (bb->isEntryBlock())
+          codeGenIPBBTI->setSuccessor(0, curLLVMBB);
+
+        // TODO: Error not returned up the hierarchy
+        if (failed(
+                convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0)))
+          return;
+
+        // If this block has the terminator then add a jump to
+        // continuation bb
+        for (auto &op : *bb) {
+          if (isa<omp::TerminatorOp>(op)) {
+            builder.SetInsertPoint(curLLVMBB);
+            builder.CreateBr(&continuationIP);
+          }
+        }
+      }
+      // Finally, after all blocks have been traversed and values mapped,
+      // connect the PHI nodes to the results of preceding blocks.
+      connectPHINodes(region, valueMapping, blockMapping);
+    }
+  };
+
+  // TODO: Perform appropriate actions according to the data-sharing
+  // attribute (shared, private, firstprivate, ...) of variables.
+  // Currently defaults to shared.
+  auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
+                    llvm::Value &vPtr,
+                    llvm::Value *&replacementValue) -> InsertPointTy {
+    replacementValue = &vPtr;
+
+    return codeGenIP;
+  };
+
+  // TODO: Perform finalization actions for variables. This has to be
+  // called for variables which have destructors/finalizers.
+  auto finiCB = [&](InsertPointTy codeGenIP) {};
+
+  // TODO: The various operands of parallel operation are not handled.
+  // Parallel operation is created with some default options for now.
+  llvm::Value *ifCond = nullptr;
+  llvm::Value *numThreads = nullptr;
+  bool isCancellable = false;
+  builder.restoreIP(ompBuilder->CreateParallel(
+      builder, bodyGenCB, privCB, finiCB, ifCond, numThreads,
+      llvm::omp::OMP_PROC_BIND_default, isCancellable));
+  return success();
+}
 
 /// Given an OpenMP MLIR operation, create the corresponding LLVM IR
 /// (including OpenMP runtime calls).
@@ -340,6 +495,9 @@ ModuleTranslation::convertOmpOperation(Operation &opInst,
         ompBuilder->CreateFlush(builder.saveIP());
         return success();
       })
+      .Case([&](omp::TerminatorOp) { return success(); })
+      .Case(
+          [&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); })
       .Default([&](Operation *inst) {
         return inst->emitError("unsupported OpenMP operation: ")
                << inst->getName();
@@ -556,75 +714,6 @@ LogicalResult ModuleTranslation::convertGlobals() {
   return success();
 }
 
-/// Get the SSA value passed to the current block from the terminator operation
-/// of its predecessor.
-static Value getPHISourceValue(Block *current, Block *pred,
-                               unsigned numArguments, unsigned index) {
-  auto &terminator = *pred->getTerminator();
-  if (isa<LLVM::BrOp>(terminator)) {
-    return terminator.getOperand(index);
-  }
-
-  // For conditional branches, we need to check if the current block is reached
-  // through the "true" or the "false" branch and take the relevant operands.
-  auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator);
-  assert(condBranchOp &&
-         "only branch operations can be terminators of a block that "
-         "has successors");
-  assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) &&
-         "successors with arguments in LLVM conditional branches must be "
-         "different blocks");
-
-  return condBranchOp.getSuccessor(0) == current
-             ? condBranchOp.trueDestOperands()[index]
-             : condBranchOp.falseDestOperands()[index];
-}
-
-void ModuleTranslation::connectPHINodes(LLVMFuncOp func) {
-  // Skip the first block, it cannot be branched to and its arguments correspond
-  // to the arguments of the LLVM function.
-  for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
-    Block *bb = &*it;
-    llvm::BasicBlock *llvmBB = blockMapping.lookup(bb);
-    auto phis = llvmBB->phis();
-    auto numArguments = bb->getNumArguments();
-    assert(numArguments == std::distance(phis.begin(), phis.end()));
-    for (auto &numberedPhiNode : llvm::enumerate(phis)) {
-      auto &phiNode = numberedPhiNode.value();
-      unsigned index = numberedPhiNode.index();
-      for (auto *pred : bb->getPredecessors()) {
-        phiNode.addIncoming(valueMapping.lookup(getPHISourceValue(
-                                bb, pred, numArguments, index)),
-                            blockMapping.lookup(pred));
-      }
-    }
-  }
-}
-
-// TODO: implement an iterative version
-static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
-  blocks.insert(b);
-  for (Block *bb : b->getSuccessors()) {
-    if (blocks.count(bb) == 0)
-      topologicalSortImpl(blocks, bb);
-  }
-}
-
-/// Sort function blocks topologically.
-static llvm::SetVector<Block *> topologicalSort(LLVMFuncOp f) {
-  // For each blocks that has not been visited yet (i.e. that has no
-  // predecessors), add it to the list and traverse its successors in DFS
-  // preorder.
-  llvm::SetVector<Block *> blocks;
-  for (Block &b : f) {
-    if (blocks.count(&b) == 0)
-      topologicalSortImpl(blocks, &b);
-  }
-  assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted");
-
-  return blocks;
-}
-
 /// Attempts to add an attribute identified by `key`, optionally with the given
 /// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
 /// attribute has a kind known to LLVM IR, create the attribute of this kind,
@@ -772,7 +861,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
 
   // Finally, after all blocks have been traversed and values mapped, connect
   // the PHI nodes to the results of preceding blocks.
-  connectPHINodes(func);
+  connectPHINodes(func, valueMapping, blockMapping);
   return success();
 }
 
index ddfc2a4..c8acd80 100644 (file)
@@ -32,3 +32,49 @@ llvm.func @test_flush_construct(%arg0: !llvm.i32) {
   // CHECK-NEXT:    ret void
   llvm.return
 }
+
+// CHECK-LABEL: define void @test_omp_parallel_1()
+llvm.func @test_omp_parallel_1() -> () {
+  // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_1:.*]] to {{.*}}
+  omp.parallel {
+    omp.barrier
+    omp.terminator
+  }
+
+  llvm.return
+}
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_1]]
+  // CHECK: call void @__kmpc_barrier
+
+llvm.func @body(!llvm.i64)
+
+// CHECK-LABEL: define void @test_omp_parallel_2()
+llvm.func @test_omp_parallel_2() -> () {
+  // CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_2:.*]] to {{.*}}
+  omp.parallel {
+    ^bb0:
+      %0 = llvm.mlir.constant(1 : index) : !llvm.i64
+      %1 = llvm.mlir.constant(42 : index) : !llvm.i64
+      llvm.call @body(%0) : (!llvm.i64) -> ()
+      llvm.call @body(%1) : (!llvm.i64) -> ()
+      llvm.br ^bb1
+
+    ^bb1:
+      %2 = llvm.add %0, %1 : !llvm.i64
+      llvm.call @body(%2) : (!llvm.i64) -> ()
+      omp.terminator
+  }
+  llvm.return
+}
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_2]]
+  // CHECK-LABEL: omp.par.region:
+  // CHECK: br label %omp.par.region1
+  // CHECK-LABEL: omp.par.region1:
+  // CHECK: call void @body(i64 1)
+  // CHECK: call void @body(i64 42)
+  // CHECK: br label %omp.par.region2
+  // CHECK-LABEL: omp.par.region2:
+  // CHECK: call void @body(i64 43)
+  // CHECK: br label %omp.par.pre_finalize