[OpenMP][MLIR] Fix for nested parallel regions
authorKiran Chandramohan <kiran.chandramohan@arm.com>
Sat, 17 Oct 2020 21:24:27 +0000 (22:24 +0100)
committerKiran Chandramohan <kiran.chandramohan@arm.com>
Mon, 19 Oct 2020 07:45:50 +0000 (08:45 +0100)
Usage of nested parallel regions were not working correctly and leading
to assertion failures. Fix contains the following changes,
1) Don't set the insertion point in the body callback.
2) Save the continuation IP in a stack and set the branch to
continuationIP at the terminator.

Reviewed By: SouraVX, jdoerfert, ftynse

Differential Revision: https://reviews.llvm.org/D88720

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/test/Target/openmp-llvm.mlir

index d28c796..7fc66f1 100644 (file)
@@ -127,6 +127,10 @@ private:
   /// OpenMP dialect hasn't been loaded (it is always loaded if there are OpenMP
   /// operations in the module though).
   const Dialect *ompDialect;
+  /// Stack which stores the target block to which a branch a must be added when
+  /// a terminator is seen. A stack is required to handle nested OpenMP parallel
+  /// regions.
+  SmallVector<llvm::BasicBlock *, 4> ompContinuationIPStack;
 
   /// Mappings between llvm.mlir.global definitions and corresponding globals.
   DenseMap<Operation *, llvm::GlobalValue *> globalsMapping;
index 23f5698..3d5e091 100644 (file)
@@ -391,8 +391,8 @@ ModuleTranslation::convertOmpParallel(Operation &opInst,
 
     llvm::BasicBlock *codeGenIPBB = codeGenIP.getBlock();
     llvm::Instruction *codeGenIPBBTI = codeGenIPBB->getTerminator();
+    ompContinuationIPStack.push_back(&continuationIP);
 
-    builder.SetInsertPoint(codeGenIPBB);
     // ParallelOp has only `1` region associated with it.
     auto &region = cast<omp::ParallelOp>(opInst).getRegion();
     for (auto &bb : region) {
@@ -407,22 +407,22 @@ ModuleTranslation::convertOmpParallel(Operation &opInst,
     for (auto indexedBB : llvm::enumerate(blocks)) {
       Block *bb = indexedBB.value();
       llvm::BasicBlock *curLLVMBB = blockMapping[bb];
-      if (bb->isEntryBlock())
+      if (bb->isEntryBlock()) {
+        assert(codeGenIPBBTI->getNumSuccessors() == 1 &&
+               "OpenMPIRBuilder provided entry block has multiple successors");
+        assert(codeGenIPBBTI->getSuccessor(0) == &continuationIP &&
+               "ContinuationIP is not the successor of OpenMPIRBuilder "
+               "provided entry block");
         codeGenIPBBTI->setSuccessor(0, curLLVMBB);
+      }
 
       // TODO: Error not returned up the hierarchy
       if (failed(convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0)))
         return;
-
-      // If this block has the terminator then add a jump to
-      // continuation bb
-      for (auto &op : *bb) {
-        if (isa<omp::TerminatorOp>(op)) {
-          builder.SetInsertPoint(curLLVMBB);
-          builder.CreateBr(&continuationIP);
-        }
-      }
     }
+
+    ompContinuationIPStack.pop_back();
+
     // Finally, after all blocks have been traversed and values mapped,
     // connect the PHI nodes to the results of preceding blocks.
     connectPHINodes(region, valueMapping, blockMapping);
@@ -498,7 +498,10 @@ ModuleTranslation::convertOmpOperation(Operation &opInst,
         ompBuilder->CreateFlush(builder.saveIP());
         return success();
       })
-      .Case([&](omp::TerminatorOp) { return success(); })
+      .Case([&](omp::TerminatorOp) {
+        builder.CreateBr(ompContinuationIPStack.back());
+        return success();
+      })
       .Case(
           [&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); })
       .Default([&](Operation *inst) {
index 518fa6f..4b3de0b 100644 (file)
@@ -214,3 +214,53 @@ llvm.func @test_omp_parallel_3() -> () {
 // CHECK: define internal void @[[OMP_OUTLINED_FN_3_3]]
 // CHECK: define internal void @[[OMP_OUTLINED_FN_3_2]]
 // CHECK: define internal void @[[OMP_OUTLINED_FN_3_1]]
+
+// CHECK-LABEL: define void @test_omp_parallel_4()
+llvm.func @test_omp_parallel_4() -> () {
+// CHECK: call void {{.*}}@__kmpc_fork_call{{.*}} @[[OMP_OUTLINED_FN_4_1:.*]] to
+// CHECK: define internal void @[[OMP_OUTLINED_FN_4_1]]
+// CHECK: call void @__kmpc_barrier
+// CHECK: call void {{.*}}@__kmpc_fork_call{{.*}} @[[OMP_OUTLINED_FN_4_1_1:.*]] to
+// CHECK: call void @__kmpc_barrier
+  omp.parallel {
+    omp.barrier
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_4_1_1]]
+// CHECK: call void @__kmpc_barrier
+    omp.parallel {
+      omp.barrier
+      omp.terminator
+    }
+
+    omp.barrier
+    omp.terminator
+  }
+  llvm.return
+}
+
+llvm.func @test_omp_parallel_5() -> () {
+// CHECK: call void {{.*}}@__kmpc_fork_call{{.*}} @[[OMP_OUTLINED_FN_5_1:.*]] to
+// CHECK: define internal void @[[OMP_OUTLINED_FN_5_1]]
+// CHECK: call void @__kmpc_barrier
+// CHECK: call void {{.*}}@__kmpc_fork_call{{.*}} @[[OMP_OUTLINED_FN_5_1_1:.*]] to
+// CHECK: call void @__kmpc_barrier
+  omp.parallel {
+    omp.barrier
+
+// CHECK: define internal void @[[OMP_OUTLINED_FN_5_1_1]]
+    omp.parallel {
+// CHECK: call void {{.*}}@__kmpc_fork_call{{.*}} @[[OMP_OUTLINED_FN_5_1_1_1:.*]] to
+// CHECK: define internal void @[[OMP_OUTLINED_FN_5_1_1_1]]
+// CHECK: call void @__kmpc_barrier
+      omp.parallel {
+        omp.barrier
+        omp.terminator
+      }
+      omp.terminator
+    }
+
+    omp.barrier
+    omp.terminator
+  }
+  llvm.return
+}