[OpenMP][IRBuilder] Added if clause to task
authorShraiysh Vaishay <shraiysh@gmail.com>
Fri, 23 Sep 2022 01:24:01 +0000 (01:24 +0000)
committerShraiysh <shraiysh@gmail.com>
Fri, 23 Sep 2022 01:39:41 +0000 (01:39 +0000)
This patch adds support for if clause to task construct in OpenMP
IRBuilder.

Reviewed By: raghavendhra

Differential Revision: https://reviews.llvm.org/D130615

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

index 369392d..ba63353 100644 (file)
@@ -647,9 +647,16 @@ public:
   /// \param Tied True if the task is tied, false if the task is untied.
   /// \param Final i1 value which is `true` if the task is final, `false` if the
   ///              task is not final.
+  /// \param IfCondition i1 value. If it evaluates to `false`, an undeferred
+  ///                    task is generated, and the encountering thread must
+  ///                    suspend the current task region, for which execution
+  ///                    cannot be resumed until execution of the structured
+  ///                    block that is associated with the generated task is
+  ///                    completed.
   InsertPointTy createTask(const LocationDescription &Loc,
                            InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
-                           bool Tied = true, Value *Final = nullptr);
+                           bool Tied = true, Value *Final = nullptr,
+                           Value *IfCondition = nullptr);
 
   /// Generator for the taskgroup construct
   ///
index 92f6289..de2ac6e 100644 (file)
@@ -24,6 +24,7 @@
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DebugInfoMetadata.h"
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/MDBuilder.h"
@@ -1289,7 +1290,7 @@ void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
 OpenMPIRBuilder::InsertPointTy
 OpenMPIRBuilder::createTask(const LocationDescription &Loc,
                             InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
-                            bool Tied, Value *Final) {
+                            bool Tied, Value *Final, Value *IfCondition) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
 
@@ -1321,7 +1322,8 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
   OI.EntryBB = TaskAllocaBB;
   OI.OuterAllocaBB = AllocaIP.getBlock();
   OI.ExitBB = TaskExitBB;
-  OI.PostOutlineCB = [this, Ident, Tied, Final](Function &OutlinedFn) {
+  OI.PostOutlineCB = [this, Ident, Tied, Final,
+                      IfCondition](Function &OutlinedFn) {
     // The input IR here looks like the following-
     // ```
     // func @current_fn() {
@@ -1431,6 +1433,44 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
                            TaskSize);
     }
 
+    // In the presence of the `if` clause, the following IR is generated:
+    //    ...
+    //    %data = call @__kmpc_omp_task_alloc(...)
+    //    br i1 %if_condition, label %then, label %else
+    //  then:
+    //    call @__kmpc_omp_task(...)
+    //    br label %exit
+    //  else:
+    //    call @__kmpc_omp_task_begin_if0(...)
+    //    call @wrapper_fn(...)
+    //    call @__kmpc_omp_task_complete_if0(...)
+    //    br label %exit
+    //  exit:
+    //    ...
+    if (IfCondition) {
+      // `SplitBlockAndInsertIfThenElse` requires the block to have a
+      // terminator.
+      BasicBlock *NewBasicBlock =
+          splitBB(Builder, /*CreateBranch=*/true, "if.end");
+      Instruction *IfTerminator =
+          NewBasicBlock->getSinglePredecessor()->getTerminator();
+      Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
+      Builder.SetInsertPoint(IfTerminator);
+      SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI,
+                                    &ElseTI);
+      Builder.SetInsertPoint(ElseTI);
+      Function *TaskBeginFn =
+          getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
+      Function *TaskCompleteFn =
+          getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
+      Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, NewTaskData});
+      if (HasTaskData)
+        Builder.CreateCall(WrapperFunc, {ThreadID, NewTaskData});
+      else
+        Builder.CreateCall(WrapperFunc, {ThreadID});
+      Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, NewTaskData});
+      Builder.SetInsertPoint(ThenTI);
+    }
     // Emit the @__kmpc_omp_task runtime call to spawn the task
     Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
     Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData});
index aa120c1..92a118b 100644 (file)
@@ -5040,6 +5040,71 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) {
   EXPECT_FALSE(verifyModule(*M, &errs()));
 }
 
+TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
+  IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
+  BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+  Builder.SetInsertPoint(BodyBB);
+  Value *IfCondition = Builder.CreateICmp(
+      CmpInst::Predicate::ICMP_EQ, F->getArg(0),
+      ConstantInt::get(Type::getInt32Ty(M->getContext()), 0U));
+  OpenMPIRBuilder::LocationDescription Loc(Builder.saveIP(), DL);
+  Builder.restoreIP(OMPBuilder.createTask(Loc, AllocaIP, BodyGenCB,
+                                          /*Tied=*/false, /*Final=*/nullptr,
+                                          IfCondition));
+  OMPBuilder.finalize();
+  Builder.CreateRetVoid();
+
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+
+  CallInst *TaskAllocCall = dyn_cast<CallInst>(
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
+          ->user_back());
+  ASSERT_NE(TaskAllocCall, nullptr);
+
+  // Check the branching is based on the if condition argument.
+  BranchInst *IfConditionBranchInst =
+      dyn_cast<BranchInst>(TaskAllocCall->getParent()->getTerminator());
+  ASSERT_NE(IfConditionBranchInst, nullptr);
+  ASSERT_TRUE(IfConditionBranchInst->isConditional());
+  EXPECT_EQ(IfConditionBranchInst->getCondition(), IfCondition);
+
+  // Check that the `__kmpc_omp_task` executes only in the then branch.
+  CallInst *TaskCall = dyn_cast<CallInst>(
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task)
+          ->user_back());
+  ASSERT_NE(TaskCall, nullptr);
+  EXPECT_EQ(TaskCall->getParent(), IfConditionBranchInst->getSuccessor(0));
+
+  // Check that the OpenMP Runtime Functions specific to `if` clause execute
+  // only in the else branch. Also check that the function call is between the
+  // `__kmpc_omp_task_begin_if0` and `__kmpc_omp_task_complete_if0` calls.
+  CallInst *TaskBeginIfCall = dyn_cast<CallInst>(
+      OMPBuilder
+          .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0)
+          ->user_back());
+  CallInst *TaskCompleteCall = dyn_cast<CallInst>(
+      OMPBuilder
+          .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0)
+          ->user_back());
+  ASSERT_NE(TaskBeginIfCall, nullptr);
+  ASSERT_NE(TaskCompleteCall, nullptr);
+  Function *WrapperFunc =
+      dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
+  ASSERT_NE(WrapperFunc, nullptr);
+  CallInst *WrapperFuncCall = dyn_cast<CallInst>(WrapperFunc->user_back());
+  ASSERT_NE(WrapperFuncCall, nullptr);
+  EXPECT_EQ(TaskBeginIfCall->getParent(),
+            IfConditionBranchInst->getSuccessor(1));
+  EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), WrapperFuncCall);
+  EXPECT_EQ(WrapperFuncCall->getNextNonDebugInstruction(), TaskCompleteCall);
+}
+
 TEST_F(OpenMPIRBuilderTest, CreateTaskgroup) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);