From b72f1ec9fbb14cd7d2f5112d2c52ef5cdd1aa94a Mon Sep 17 00:00:00 2001 From: David Truby Date: Tue, 22 Nov 2022 13:32:47 +0000 Subject: [PATCH] [openmp][mlir] Lower parallel if to new fork_call_if function. This patch adds a new runtime function `fork_call_if` and uses that to lower parallel if statements when going through OpenMPIRBuilder. This fixes an issue where the OpenMPIRBuilder passes all arguments to fork_call as a struct but this struct is not filled corretly in the non-if branch by handling the fork inside the runtime. Differential Revision: https://reviews.llvm.org/D138495 --- llvm/include/llvm/Frontend/OpenMP/OMPKinds.def | 2 + llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 86 +++++++++---------------- llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp | 32 +++------ mlir/test/Target/LLVMIR/openmp-llvm.mlir | 74 ++------------------- openmp/runtime/src/kmp.h | 3 + openmp/runtime/src/kmp_csupport.cpp | 31 +++++++++ openmp/runtime/test/lit.cfg | 4 ++ openmp/runtime/test/parallel/omp_parallel_if.c | 1 + 8 files changed, 85 insertions(+), 148 deletions(-) diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def index 87ac4d5..bef838a 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def +++ b/llvm/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -203,6 +203,8 @@ __OMP_RTL(__kmpc_flush, false, Void, IdentPtr) __OMP_RTL(__kmpc_global_thread_num, false, Int32, IdentPtr) __OMP_RTL(__kmpc_get_hardware_thread_id_in_block, false, Int32, ) __OMP_RTL(__kmpc_fork_call, true, Void, IdentPtr, Int32, ParallelTaskPtr) +__OMP_RTL(__kmpc_fork_call_if, false, Void, IdentPtr, Int32, ParallelTaskPtr, + Int32, VoidPtr) __OMP_RTL(__kmpc_omp_taskwait, false, Int32, IdentPtr, Int32) __OMP_RTL(__kmpc_omp_taskwait_51, false, Int32, IdentPtr, Int32, Int32) __OMP_RTL(__kmpc_omp_taskyield, false, Int32, IdentPtr, Int32, /* Int */ Int32) diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 5e08c29..f002644 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -914,34 +914,21 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel( AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr"); AllocaInst *ZeroAddr = Builder.CreateAlloca(Int32, nullptr, "zero.addr"); - // If there is an if condition we actually use the TIDAddr and ZeroAddr in the - // program, otherwise we only need them for modeling purposes to get the - // associated arguments in the outlined function. In the former case, - // initialize the allocas properly, in the latter case, delete them later. - if (IfCondition) { - Builder.CreateStore(Constant::getNullValue(Int32), TIDAddr); - Builder.CreateStore(Constant::getNullValue(Int32), ZeroAddr); - } else { - ToBeDeleted.push_back(TIDAddr); - ToBeDeleted.push_back(ZeroAddr); - } + // We only need TIDAddr and ZeroAddr for modeling purposes to get the + // associated arguments in the outlined function, so we delete them later. + ToBeDeleted.push_back(TIDAddr); + ToBeDeleted.push_back(ZeroAddr); // Create an artificial insertion point that will also ensure the blocks we // are about to split are not degenerated. auto *UI = new UnreachableInst(Builder.getContext(), InsertBB); - Instruction *ThenTI = UI, *ElseTI = nullptr; - if (IfCondition) - SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI); - - BasicBlock *ThenBB = ThenTI->getParent(); - BasicBlock *PRegEntryBB = ThenBB->splitBasicBlock(ThenTI, "omp.par.entry"); - BasicBlock *PRegBodyBB = - PRegEntryBB->splitBasicBlock(ThenTI, "omp.par.region"); + BasicBlock *EntryBB = UI->getParent(); + BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(UI, "omp.par.entry"); + BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(UI, "omp.par.region"); BasicBlock *PRegPreFiniBB = - PRegBodyBB->splitBasicBlock(ThenTI, "omp.par.pre_finalize"); - BasicBlock *PRegExitBB = - PRegPreFiniBB->splitBasicBlock(ThenTI, "omp.par.exit"); + PRegBodyBB->splitBasicBlock(UI, "omp.par.pre_finalize"); + BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(UI, "omp.par.exit"); auto FiniCBWrapper = [&](InsertPointTy IP) { // Hide "open-ended" blocks from the given FiniCB by setting the right jump @@ -975,7 +962,7 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel( Builder.CreateLoad(Int32, ZeroAddr, "zero.addr.use"); ToBeDeleted.push_back(ZeroAddrUse); - // ThenBB + // EntryBB // | // V // PRegionEntryBB <- Privatization allocas are placed here. @@ -998,8 +985,12 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel( BodyGenCB(InnerAllocaIP, CodeGenIP); LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n"); + FunctionCallee RTLFn; + if (IfCondition) + RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if); + else + RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call); - FunctionCallee RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call); if (auto *F = dyn_cast(RTLFn.getCallee())) { if (!F->hasMetadata(llvm::LLVMContext::MD_callback)) { llvm::LLVMContext &Ctx = F->getContext(); @@ -1034,15 +1025,30 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel( CI->getParent()->setName("omp_parallel"); Builder.SetInsertPoint(CI); - // Build call __kmpc_fork_call(Ident, n, microtask, var1, .., varn); + // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn); Value *ForkCallArgs[] = { Ident, Builder.getInt32(NumCapturedVars), Builder.CreateBitCast(&OutlinedFn, ParallelTaskPtr)}; SmallVector RealArgs; RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs)); + if (IfCondition) { + Value *Cond = Builder.CreateSExtOrTrunc(IfCondition, + Type::getInt32Ty(M.getContext())); + RealArgs.push_back(Cond); + } RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end()); + // __kmpc_fork_call_if always expects a void ptr as the last argument + // If there are no arguments, pass a null pointer. + auto PtrTy = Type::getInt8PtrTy(M.getContext()); + if (IfCondition && NumCapturedVars == 0) { + llvm::Value *Void = ConstantPointerNull::get(PtrTy); + RealArgs.push_back(Void); + } + if (IfCondition && RealArgs.back()->getType() != PtrTy) + RealArgs.back() = Builder.CreateBitCast(RealArgs.back(), PtrTy); + Builder.CreateCall(RTLFn, RealArgs); LLVM_DEBUG(dbgs() << "With fork_call placed: " @@ -1055,35 +1061,7 @@ IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel( Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin(); Builder.CreateStore(Builder.CreateLoad(Int32, OutlinedAI), PrivTIDAddr); - // If no "if" clause was present we do not need the call created during - // outlining, otherwise we reuse it in the serialized parallel region. - if (!ElseTI) { - CI->eraseFromParent(); - } else { - - // If an "if" clause was present we are now generating the serialized - // version into the "else" branch. - Builder.SetInsertPoint(ElseTI); - - // Build calls __kmpc_serialized_parallel(&Ident, GTid); - Value *SerializedParallelCallArgs[] = {Ident, ThreadID}; - Builder.CreateCall( - getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_serialized_parallel), - SerializedParallelCallArgs); - - // OutlinedFn(>id, &zero, CapturedStruct); - CI->removeFromParent(); - Builder.Insert(CI); - - // __kmpc_end_serialized_parallel(&Ident, GTid); - Value *EndArgs[] = {Ident, ThreadID}; - Builder.CreateCall( - getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_serialized_parallel), - EndArgs); - - LLVM_DEBUG(dbgs() << "With serialized parallel region: " - << *Builder.GetInsertBlock()->getParent() << "\n"); - } + CI->eraseFromParent(); for (Instruction *I : ToBeDeleted) I->eraseFromParent(); diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp index c5a9e6e..30d8aee 100644 --- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp +++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp @@ -986,38 +986,22 @@ TEST_F(OpenMPIRBuilderTest, ParallelIfCond) { EXPECT_EQ(OutlinedFn->arg_size(), 3U); EXPECT_EQ(&OutlinedFn->getEntryBlock(), PrivAI->getParent()); - ASSERT_EQ(OutlinedFn->getNumUses(), 2U); + ASSERT_EQ(OutlinedFn->getNumUses(), 1U); - CallInst *DirectCI = nullptr; CallInst *ForkCI = nullptr; for (User *Usr : OutlinedFn->users()) { - if (isa(Usr)) { - ASSERT_EQ(DirectCI, nullptr); - DirectCI = cast(Usr); - } else { - ASSERT_TRUE(isa(Usr)); - ASSERT_EQ(Usr->getNumUses(), 1U); - ASSERT_TRUE(isa(Usr->user_back())); - ForkCI = cast(Usr->user_back()); - } + ASSERT_TRUE(isa(Usr)); + ASSERT_EQ(Usr->getNumUses(), 1U); + ASSERT_TRUE(isa(Usr->user_back())); + ForkCI = cast(Usr->user_back()); } - EXPECT_EQ(ForkCI->getCalledFunction()->getName(), "__kmpc_fork_call"); - EXPECT_EQ(ForkCI->arg_size(), 4U); + EXPECT_EQ(ForkCI->getCalledFunction()->getName(), "__kmpc_fork_call_if"); + EXPECT_EQ(ForkCI->arg_size(), 5U); EXPECT_TRUE(isa(ForkCI->getArgOperand(0))); EXPECT_EQ(ForkCI->getArgOperand(1), ConstantInt::get(Type::getInt32Ty(Ctx), 1)); - Value *StoredForkArg = - findStoredValueInAggregateAt(Ctx, ForkCI->getArgOperand(3), 0); - EXPECT_EQ(StoredForkArg, F->arg_begin()); - - EXPECT_EQ(DirectCI->getCalledFunction(), OutlinedFn); - EXPECT_EQ(DirectCI->arg_size(), 3U); - EXPECT_TRUE(isa(DirectCI->getArgOperand(0))); - EXPECT_TRUE(isa(DirectCI->getArgOperand(1))); - Value *StoredDirectArg = - findStoredValueInAggregateAt(Ctx, DirectCI->getArgOperand(2), 0); - EXPECT_EQ(StoredDirectArg, F->arg_begin()); + EXPECT_EQ(ForkCI->getArgOperand(3)->getType(), Type::getInt32Ty(Ctx)); } TEST_F(OpenMPIRBuilderTest, ParallelCancelBarrier) { diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir index 702b8f3..94e37c1 100644 --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -151,33 +151,19 @@ llvm.func @test_omp_parallel_num_threads_3() -> () { // CHECK: define void @test_omp_parallel_if_1(i32 %[[IF_VAR_1:.*]]) llvm.func @test_omp_parallel_if_1(%arg0: i32) -> () { -// Check that the allocas are emitted by the OpenMPIRBuilder at the top of the -// function, before the condition. Allocas are only emitted by the builder when -// the `if` clause is present. We match specific SSA value names since LLVM -// actually produces those names. -// CHECK: %tid.addr{{.*}} = alloca i32 -// CHECK: %zero.addr{{.*}} = alloca i32 - -// CHECK: %[[IF_COND_VAR_1:.*]] = icmp slt i32 %[[IF_VAR_1]], 0 %0 = llvm.mlir.constant(0 : index) : i32 %1 = llvm.icmp "slt" %arg0, %0 : i32 +// CHECK: %[[IF_COND_VAR_1:.*]] = icmp slt i32 %[[IF_VAR_1]], 0 + // CHECK: %[[GTN_IF_1:.*]] = call i32 @__kmpc_global_thread_num(ptr @[[SI_VAR_IF_1:.*]]) -// CHECK: br i1 %[[IF_COND_VAR_1]], label %[[IF_COND_TRUE_BLOCK_1:.*]], label %[[IF_COND_FALSE_BLOCK_1:.*]] -// CHECK: [[IF_COND_TRUE_BLOCK_1]]: // CHECK: br label %[[OUTLINED_CALL_IF_BLOCK_1:.*]] // CHECK: [[OUTLINED_CALL_IF_BLOCK_1]]: -// CHECK: call void {{.*}} @__kmpc_fork_call(ptr @[[SI_VAR_IF_1]], {{.*}} @[[OMP_OUTLINED_FN_IF_1:.*]]) +// CHECK: %[[I32_IF_COND_VAR_1:.*]] = sext i1 %[[IF_COND_VAR_1]] to i32 +// CHECK: call void @__kmpc_fork_call_if(ptr @[[SI_VAR_IF_1]], i32 0, ptr @[[OMP_OUTLINED_FN_IF_1:.*]], i32 %[[I32_IF_COND_VAR_1]], ptr null) // CHECK: br label %[[OUTLINED_EXIT_IF_1:.*]] // CHECK: [[OUTLINED_EXIT_IF_1]]: -// CHECK: br label %[[OUTLINED_EXIT_IF_2:.*]] -// CHECK: [[OUTLINED_EXIT_IF_2]]: // CHECK: br label %[[RETURN_BLOCK_IF_1:.*]] -// CHECK: [[IF_COND_FALSE_BLOCK_1]]: -// CHECK: call void @__kmpc_serialized_parallel(ptr @[[SI_VAR_IF_1]], i32 %[[GTN_IF_1]]) -// CHECK: call void @[[OMP_OUTLINED_FN_IF_1]] -// CHECK: call void @__kmpc_end_serialized_parallel(ptr @[[SI_VAR_IF_1]], i32 %[[GTN_IF_1]]) -// CHECK: br label %[[RETURN_BLOCK_IF_1]] omp.parallel if(%1 : i1) { omp.barrier omp.terminator @@ -193,58 +179,6 @@ llvm.func @test_omp_parallel_if_1(%arg0: i32) -> () { // ----- -// CHECK-LABEL: @test_nested_alloca_ip -llvm.func @test_nested_alloca_ip(%arg0: i32) -> () { - - // Check that the allocas are emitted by the OpenMPIRBuilder at the top of - // the function, before the condition. Allocas are only emitted by the - // builder when the `if` clause is present. We match specific SSA value names - // since LLVM actually produces those names and ensure they come before the - // "icmp" that is the first operation we emit. - // CHECK: %tid.addr{{.*}} = alloca i32 - // CHECK: %zero.addr{{.*}} = alloca i32 - // CHECK: icmp slt i32 %{{.*}}, 0 - %0 = llvm.mlir.constant(0 : index) : i32 - %1 = llvm.icmp "slt" %arg0, %0 : i32 - - omp.parallel if(%1 : i1) { - // The "parallel" operation will be outlined, check the the function is - // produced. Inside that function, further allocas should be placed before - // another "icmp". - // CHECK: define - // CHECK: %tid.addr{{.*}} = alloca i32 - // CHECK: %zero.addr{{.*}} = alloca i32 - // CHECK: icmp slt i32 %{{.*}}, 1 - %2 = llvm.mlir.constant(1 : index) : i32 - %3 = llvm.icmp "slt" %arg0, %2 : i32 - - omp.parallel if(%3 : i1) { - // One more nesting level. - // CHECK: define - // CHECK: %tid.addr{{.*}} = alloca i32 - // CHECK: %zero.addr{{.*}} = alloca i32 - // CHECK: icmp slt i32 %{{.*}}, 2 - - %4 = llvm.mlir.constant(2 : index) : i32 - %5 = llvm.icmp "slt" %arg0, %4 : i32 - - omp.parallel if(%5 : i1) { - omp.barrier - omp.terminator - } - - omp.barrier - omp.terminator - } - omp.barrier - omp.terminator - } - - llvm.return -} - -// ----- - // CHECK-LABEL: define void @test_omp_parallel_3() llvm.func @test_omp_parallel_3() -> () { // CHECK: [[OMP_THREAD_3_1:%.*]] = call i32 @__kmpc_global_thread_num(ptr @{{[0-9]+}}) diff --git a/openmp/runtime/src/kmp.h b/openmp/runtime/src/kmp.h index f2d030f..99bd39b 100644 --- a/openmp/runtime/src/kmp.h +++ b/openmp/runtime/src/kmp.h @@ -3901,6 +3901,9 @@ KMP_EXPORT kmp_int32 __kmpc_bound_num_threads(ident_t *); KMP_EXPORT kmp_int32 __kmpc_ok_to_fork(ident_t *); KMP_EXPORT void __kmpc_fork_call(ident_t *, kmp_int32 nargs, kmpc_micro microtask, ...); +KMP_EXPORT void __kmpc_fork_call_if(ident_t *loc, kmp_int32 nargs, + kmpc_micro microtask, kmp_int32 cond, + void *args); KMP_EXPORT void __kmpc_serialized_parallel(ident_t *, kmp_int32 global_tid); KMP_EXPORT void __kmpc_end_serialized_parallel(ident_t *, kmp_int32 global_tid); diff --git a/openmp/runtime/src/kmp_csupport.cpp b/openmp/runtime/src/kmp_csupport.cpp index 97b15be..64b9d16 100644 --- a/openmp/runtime/src/kmp_csupport.cpp +++ b/openmp/runtime/src/kmp_csupport.cpp @@ -332,6 +332,37 @@ void __kmpc_fork_call(ident_t *loc, kmp_int32 argc, kmpc_micro microtask, ...) { /*! @ingroup PARALLEL +@param loc source location information +@param microtask pointer to callback routine consisting of outlined parallel +construct +@param cond condition for running in parallel +@param args struct of pointers to shared variables that aren't global + +Perform a fork only if the condition is true. +*/ +void __kmpc_fork_call_if(ident_t *loc, kmp_int32 argc, kmpc_micro microtask, + kmp_int32 cond, void *args) { + int gtid = __kmp_entry_gtid(); + int zero = 0; + if (cond) { + if (args) + __kmpc_fork_call(loc, argc, microtask, args); + else + __kmpc_fork_call(loc, argc, microtask); + } else { + __kmpc_serialized_parallel(loc, gtid); + + if (args) + microtask(>id, &zero, args); + else + microtask(>id, &zero); + + __kmpc_end_serialized_parallel(loc, gtid); + } +} + +/*! +@ingroup PARALLEL @param loc source location information @param global_tid global thread number @param num_teams number of teams requested for the teams construct diff --git a/openmp/runtime/test/lit.cfg b/openmp/runtime/test/lit.cfg index c1cf24a..f49f39a 100644 --- a/openmp/runtime/test/lit.cfg +++ b/openmp/runtime/test/lit.cfg @@ -133,6 +133,8 @@ if 'INTEL_LICENSE_FILE' in os.environ: # substitutions config.substitutions.append(("%libomp-compile-and-run", \ "%libomp-compile && %libomp-run")) +config.substitutions.append(("%libomp-irbuilder-compile-and-run", \ + "%libomp-irbuilder-compile && %libomp-run")) config.substitutions.append(("%libomp-c99-compile-and-run", \ "%libomp-c99-compile && %libomp-run")) config.substitutions.append(("%libomp-cxx-compile-and-run", \ @@ -143,6 +145,8 @@ config.substitutions.append(("%libomp-cxx-compile", \ "%clangXX %openmp_flags %flags -std=c++17 %s -o %t" + libs)) config.substitutions.append(("%libomp-compile", \ "%clang %openmp_flags %flags %s -o %t" + libs)) +config.substitutions.append(("%libomp-irbuilder-compile", \ + "%clang %openmp_flags %flags -fopenmp-enable-irbuilder %s -o %t" + libs)) config.substitutions.append(("%libomp-c99-compile", \ "%clang %openmp_flags %flags -std=c99 %s -o %t" + libs)) config.substitutions.append(("%libomp-run", "%t")) diff --git a/openmp/runtime/test/parallel/omp_parallel_if.c b/openmp/runtime/test/parallel/omp_parallel_if.c index abbf3cd..7a92402 100644 --- a/openmp/runtime/test/parallel/omp_parallel_if.c +++ b/openmp/runtime/test/parallel/omp_parallel_if.c @@ -1,4 +1,5 @@ // RUN: %libomp-compile-and-run +// RUN: %libomp-irbuilder-compile-and-run #include #include "omp_testsuite.h" -- 2.7.4