From 614958912784a13737720de39b2da40fe6f26e75 Mon Sep 17 00:00:00 2001 From: Prabhdeep Singh Soni Date: Fri, 7 Oct 2022 16:55:13 -0400 Subject: [PATCH] [OMPIRBuilder] Support depend clause for task This patch adds support for the `depend` clause for the `task` construct. Reviewed By: jdoerfert Differential Revision: https://reviews.llvm.org/D135695 --- clang/lib/CodeGen/CGOpenMPRuntime.cpp | 55 +++++++++-------- llvm/include/llvm/Frontend/OpenMP/OMPConstants.h | 13 ++++ llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h | 14 ++++- llvm/include/llvm/Frontend/OpenMP/OMPKinds.def | 1 + llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 68 +++++++++++++++++++-- llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp | 75 ++++++++++++++++++++++++ 6 files changed, 191 insertions(+), 35 deletions(-) diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp index 8453636..7570974 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -4377,39 +4377,26 @@ CGOpenMPRuntime::emitTaskInit(CodeGenFunction &CGF, SourceLocation Loc, return Result; } -namespace { -/// Dependence kind for RTL. -enum RTLDependenceKindTy { - DepIn = 0x01, - DepInOut = 0x3, - DepMutexInOutSet = 0x4, - DepInOutSet = 0x8, - DepOmpAllMem = 0x80, -}; -/// Fields ids in kmp_depend_info record. -enum RTLDependInfoFieldsTy { BaseAddr, Len, Flags }; -} // namespace - /// Translates internal dependency kind into the runtime kind. static RTLDependenceKindTy translateDependencyKind(OpenMPDependClauseKind K) { RTLDependenceKindTy DepKind; switch (K) { case OMPC_DEPEND_in: - DepKind = DepIn; + DepKind = RTLDependenceKindTy::DepIn; break; // Out and InOut dependencies must use the same code. case OMPC_DEPEND_out: case OMPC_DEPEND_inout: - DepKind = DepInOut; + DepKind = RTLDependenceKindTy::DepInOut; break; case OMPC_DEPEND_mutexinoutset: - DepKind = DepMutexInOutSet; + DepKind = RTLDependenceKindTy::DepMutexInOutSet; break; case OMPC_DEPEND_inoutset: - DepKind = DepInOutSet; + DepKind = RTLDependenceKindTy::DepInOutSet; break; case OMPC_DEPEND_outallmemory: - DepKind = DepOmpAllMem; + DepKind = RTLDependenceKindTy::DepOmpAllMem; break; case OMPC_DEPEND_source: case OMPC_DEPEND_sink: @@ -4457,7 +4444,9 @@ CGOpenMPRuntime::getDepobjElements(CodeGenFunction &CGF, LValue DepobjLVal, DepObjAddr, KmpDependInfoTy, Base.getBaseInfo(), Base.getTBAAInfo()); // NumDeps = deps[i].base_addr; LValue BaseAddrLVal = CGF.EmitLValueForField( - NumDepsBase, *std::next(KmpDependInfoRD->field_begin(), BaseAddr)); + NumDepsBase, + *std::next(KmpDependInfoRD->field_begin(), + static_cast(RTLDependInfoFields::BaseAddr))); llvm::Value *NumDeps = CGF.EmitLoadOfScalar(BaseAddrLVal, Loc); return std::make_pair(NumDeps, Base); } @@ -4503,18 +4492,24 @@ static void emitDependData(CodeGenFunction &CGF, QualType &KmpDependInfoTy, } // deps[i].base_addr = &; LValue BaseAddrLVal = CGF.EmitLValueForField( - Base, *std::next(KmpDependInfoRD->field_begin(), BaseAddr)); + Base, + *std::next(KmpDependInfoRD->field_begin(), + static_cast(RTLDependInfoFields::BaseAddr))); CGF.EmitStoreOfScalar(Addr, BaseAddrLVal); // deps[i].len = sizeof(); LValue LenLVal = CGF.EmitLValueForField( - Base, *std::next(KmpDependInfoRD->field_begin(), Len)); + Base, *std::next(KmpDependInfoRD->field_begin(), + static_cast(RTLDependInfoFields::Len))); CGF.EmitStoreOfScalar(Size, LenLVal); // deps[i].flags = ; RTLDependenceKindTy DepKind = translateDependencyKind(Data.DepKind); LValue FlagsLVal = CGF.EmitLValueForField( - Base, *std::next(KmpDependInfoRD->field_begin(), Flags)); - CGF.EmitStoreOfScalar(llvm::ConstantInt::get(LLVMFlagsTy, DepKind), - FlagsLVal); + Base, + *std::next(KmpDependInfoRD->field_begin(), + static_cast(RTLDependInfoFields::Flags))); + CGF.EmitStoreOfScalar( + llvm::ConstantInt::get(LLVMFlagsTy, static_cast(DepKind)), + FlagsLVal); if (unsigned *P = Pos.dyn_cast()) { ++(*P); } else { @@ -4790,7 +4785,9 @@ Address CGOpenMPRuntime::emitDepobjDependClause( LValue Base = CGF.MakeAddrLValue(DependenciesArray, KmpDependInfoTy); // deps[i].base_addr = NumDependencies; LValue BaseAddrLVal = CGF.EmitLValueForField( - Base, *std::next(KmpDependInfoRD->field_begin(), BaseAddr)); + Base, + *std::next(KmpDependInfoRD->field_begin(), + static_cast(RTLDependInfoFields::BaseAddr))); CGF.EmitStoreOfScalar(NumDepsVal, BaseAddrLVal); llvm::PointerUnion Pos; unsigned Idx = 1; @@ -4870,9 +4867,11 @@ void CGOpenMPRuntime::emitUpdateClause(CodeGenFunction &CGF, LValue DepobjLVal, // deps[i].flags = NewDepKind; RTLDependenceKindTy DepKind = translateDependencyKind(NewDepKind); LValue FlagsLVal = CGF.EmitLValueForField( - Base, *std::next(KmpDependInfoRD->field_begin(), Flags)); - CGF.EmitStoreOfScalar(llvm::ConstantInt::get(LLVMFlagsTy, DepKind), - FlagsLVal); + Base, *std::next(KmpDependInfoRD->field_begin(), + static_cast(RTLDependInfoFields::Flags))); + CGF.EmitStoreOfScalar( + llvm::ConstantInt::get(LLVMFlagsTy, static_cast(DepKind)), + FlagsLVal); // Shift the address forward by one element. Address ElementNext = diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h index 76104f6..b0e9c53 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h @@ -207,6 +207,19 @@ enum class OMPInteropType { Unknown, Target, TargetSync }; /// Atomic compare operations. Currently OpenMP only supports ==, >, and <. enum class OMPAtomicCompareOp : unsigned { EQ, MIN, MAX }; +/// Fields ids in kmp_depend_info record. +enum class RTLDependInfoFields { BaseAddr, Len, Flags }; + +/// Dependence kind for RTL. +enum class RTLDependenceKindTy { + DepUnknown = 0x0, + DepIn = 0x01, + DepInOut = 0x3, + DepMutexInOutSet = 0x4, + DepInOutSet = 0x8, + DepOmpAllMem = 0x80, +}; + } // end namespace omp } // end namespace llvm diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index c16230f..c59adc7 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -645,6 +645,17 @@ public: /// \param Loc The location where the taskyield directive was encountered. void createTaskyield(const LocationDescription &Loc); + /// A struct to pack the relevant information for an OpenMP depend clause. + struct DependData { + omp::RTLDependenceKindTy DepKind = omp::RTLDependenceKindTy::DepUnknown; + Type *DepValueType; + Value *DepVal; + explicit DependData() = default; + DependData(omp::RTLDependenceKindTy DepKind, Type *DepValueType, + Value *DepVal) + : DepKind(DepKind), DepValueType(DepValueType), DepVal(DepVal) {} + }; + /// Generator for `#omp task` /// /// \param Loc The location where the task construct was encountered. @@ -662,7 +673,8 @@ public: InsertPointTy createTask(const LocationDescription &Loc, InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, bool Tied = true, Value *Final = nullptr, - Value *IfCondition = nullptr); + Value *IfCondition = nullptr, + ArrayRef Dependencies = {}); /// Generator for the taskgroup construct /// diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def index 6e7c0d3..71abc88 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -92,6 +92,7 @@ __OMP_STRUCT_TYPE(OffloadEntry, __tgt_offload_entry, Int8Ptr, Int8Ptr, SizeTy, __OMP_STRUCT_TYPE(KernelArgs, __tgt_kernel_arguments, Int32, Int32, VoidPtrPtr, VoidPtrPtr, Int64Ptr, Int64Ptr, VoidPtrPtr, VoidPtrPtr, Int64) __OMP_STRUCT_TYPE(AsyncInfo, __tgt_async_info, Int8Ptr) +__OMP_STRUCT_TYPE(DependInfo, kmp_dep_info, SizeTy, SizeTy, Int8) #undef __OMP_STRUCT_TYPE #undef OMP_STRUCT_TYPE diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index adc5316..91bd2fe 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1290,7 +1290,8 @@ void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) { OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTask(const LocationDescription &Loc, InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB, - bool Tied, Value *Final, Value *IfCondition) { + bool Tied, Value *Final, Value *IfCondition, + ArrayRef Dependencies) { if (!updateToLocation(Loc)) return InsertPointTy(); @@ -1322,8 +1323,8 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, OI.EntryBB = TaskAllocaBB; OI.OuterAllocaBB = AllocaIP.getBlock(); OI.ExitBB = TaskExitBB; - OI.PostOutlineCB = [this, Ident, Tied, Final, - IfCondition](Function &OutlinedFn) { + OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, + Dependencies](Function &OutlinedFn) { // The input IR here looks like the following- // ``` // func @current_fn() { @@ -1433,6 +1434,49 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, TaskSize); } + Value *DepArrayPtr = nullptr; + if (Dependencies.size()) { + InsertPointTy OldIP = Builder.saveIP(); + Builder.SetInsertPoint( + &OldIP.getBlock()->getParent()->getEntryBlock().back()); + + Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size()); + Value *DepArray = + Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr"); + + unsigned P = 0; + for (DependData *Dep : Dependencies) { + Value *Base = + Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, P); + // Store the pointer to the variable + Value *Addr = Builder.CreateStructGEP( + DependInfo, Base, + static_cast(RTLDependInfoFields::BaseAddr)); + Value *DepValPtr = + Builder.CreatePtrToInt(Dep->DepVal, Builder.getInt64Ty()); + Builder.CreateStore(DepValPtr, Addr); + // Store the size of the variable + Value *Size = Builder.CreateStructGEP( + DependInfo, Base, + static_cast(RTLDependInfoFields::Len)); + Builder.CreateStore(Builder.getInt64(M.getDataLayout().getTypeStoreSize( + Dep->DepValueType)), + Size); + // Store the dependency kind + Value *Flags = Builder.CreateStructGEP( + DependInfo, Base, + static_cast(RTLDependInfoFields::Flags)); + Builder.CreateStore( + ConstantInt::get(Builder.getInt8Ty(), + static_cast(Dep->DepKind)), + Flags); + ++P; + } + + DepArrayPtr = Builder.CreateBitCast(DepArray, Builder.getInt8PtrTy()); + Builder.restoreIP(OldIP); + } + // In the presence of the `if` clause, the following IR is generated: // ... // %data = call @__kmpc_omp_task_alloc(...) @@ -1471,9 +1515,21 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc, Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, NewTaskData}); Builder.SetInsertPoint(ThenTI); } - // Emit the @__kmpc_omp_task runtime call to spawn the task - Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task); - Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData}); + + if (Dependencies.size()) { + Function *TaskFn = + getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps); + Builder.CreateCall( + TaskFn, + {Ident, ThreadID, NewTaskData, Builder.getInt32(Dependencies.size()), + DepArrayPtr, ConstantInt::get(Builder.getInt32Ty(), 0), + ConstantPointerNull::get(Type::getInt8PtrTy(M.getContext()))}); + + } else { + // Emit the @__kmpc_omp_task runtime call to spawn the task + Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task); + Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData}); + } StaleCI->eraseFromParent(); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index af96ac2..7ae13a5 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -5092,6 +5092,81 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskUntied) { EXPECT_FALSE(verifyModule(*M, &errs())); } +TEST_F(OpenMPIRBuilderTest, CreateTaskDepend) { + using InsertPointTy = OpenMPIRBuilder::InsertPointTy; + OpenMPIRBuilder OMPBuilder(*M); + OMPBuilder.initialize(); + F->setName("func"); + IRBuilder<> Builder(BB); + auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {}; + BasicBlock *AllocaBB = Builder.GetInsertBlock(); + BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split"); + OpenMPIRBuilder::LocationDescription Loc( + InsertPointTy(BodyBB, BodyBB->getFirstInsertionPt()), DL); + AllocaInst *InDep = Builder.CreateAlloca(Type::getInt32Ty(M->getContext())); + OpenMPIRBuilder::DependData DDIn(RTLDependenceKindTy::DepIn, + Type::getInt32Ty(M->getContext()), InDep); + SmallVector DDS; + DDS.push_back(&DDIn); + Builder.restoreIP(OMPBuilder.createTask( + Loc, InsertPointTy(AllocaBB, AllocaBB->getFirstInsertionPt()), BodyGenCB, + /*Tied=*/false, /*Final*/ nullptr, /*IfCondition*/ nullptr, DDS)); + OMPBuilder.finalize(); + Builder.CreateRetVoid(); + + // Check for the `NumDeps` argument + CallInst *TaskAllocCall = dyn_cast( + OMPBuilder + .getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps) + ->user_back()); + ASSERT_NE(TaskAllocCall, nullptr); + ConstantInt *NumDeps = dyn_cast(TaskAllocCall->getArgOperand(3)); + ASSERT_NE(NumDeps, nullptr); + EXPECT_EQ(NumDeps->getZExtValue(), 1U); + + // Check for the `DepInfo` array argument + BitCastInst *DepArrayPtr = + dyn_cast(TaskAllocCall->getOperand(4)); + ASSERT_NE(DepArrayPtr, nullptr); + AllocaInst *DepArray = dyn_cast(DepArrayPtr->getOperand(0)); + ASSERT_NE(DepArray, nullptr); + Value::user_iterator DepArrayI = DepArray->user_begin(); + EXPECT_EQ(*DepArrayI, DepArrayPtr); + ++DepArrayI; + Value::user_iterator DepInfoI = DepArrayI->user_begin(); + // Check for the `DependKind` flag in the `DepInfo` array + Value *Flag = findStoredValue(*DepInfoI); + ASSERT_NE(Flag, nullptr); + ConstantInt *FlagInt = dyn_cast(Flag); + ASSERT_NE(FlagInt, nullptr); + EXPECT_EQ(FlagInt->getZExtValue(), + static_cast(RTLDependenceKindTy::DepIn)); + ++DepInfoI; + // Check for the size in the `DepInfo` array + Value *Size = findStoredValue(*DepInfoI); + ASSERT_NE(Size, nullptr); + ConstantInt *SizeInt = dyn_cast(Size); + ASSERT_NE(SizeInt, nullptr); + EXPECT_EQ(SizeInt->getZExtValue(), 4U); + ++DepInfoI; + // Check for the variable address in the `DepInfo` array + Value *AddrStored = findStoredValue(*DepInfoI); + ASSERT_NE(AddrStored, nullptr); + PtrToIntInst *AddrInt = dyn_cast(AddrStored); + ASSERT_NE(AddrInt, nullptr); + Value *Addr = AddrInt->getPointerOperand(); + EXPECT_EQ(Addr, InDep); + + ConstantInt *NumDepsNoAlias = + dyn_cast(TaskAllocCall->getArgOperand(5)); + ASSERT_NE(NumDepsNoAlias, nullptr); + EXPECT_EQ(NumDepsNoAlias->getZExtValue(), 0U); + EXPECT_EQ(TaskAllocCall->getOperand(6), + ConstantPointerNull::get(Type::getInt8PtrTy(M->getContext()))); + + EXPECT_FALSE(verifyModule(*M, &errs())); +} + TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) { using InsertPointTy = OpenMPIRBuilder::InsertPointTy; OpenMPIRBuilder OMPBuilder(*M); -- 2.7.4