From cefa5cefdce2d5090002c3116403f7e5ca5700b9 Mon Sep 17 00:00:00 2001 From: Johannes Doerfert Date: Wed, 11 Jan 2023 01:22:45 -0800 Subject: [PATCH] [OpenMP] Replace ExternalizationRAII with virtual uses The externalization was always a stopgap solution. One of the drawbacks is that it is very conservative no matter if we actually require the functions at the end of the pass. The new concept is more generic and properly integrates into the dependence graph. Whenever we might need a function, it has a "virtual use" that cannot be analyzed. If we do not because of some AA state, there will be a dependence to ensure state changes trigger revisits of uses, including a potentially new virtual use. --- llvm/include/llvm/Transforms/IPO/Attributor.h | 10 ++ llvm/lib/Transforms/IPO/Attributor.cpp | 9 ++ llvm/lib/Transforms/IPO/OpenMPOpt.cpp | 138 ++++++++++++--------- .../get_hardware_num_threads_in_block_fold.ll | 8 +- 4 files changed, 100 insertions(+), 65 deletions(-) diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h index 41b555b..98af41d 100644 --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -1885,11 +1885,21 @@ struct Attributor { return SimplificationCallbacks.count(IRP); } + using VirtualUseCallbackTy = + std::function; + void registerVirtualUseCallback(const Value &V, + const VirtualUseCallbackTy &CB) { + VirtualUseCallbacks[&V].emplace_back(CB); + } + private: /// The vector with all simplification callbacks registered by outside AAs. DenseMap> SimplificationCallbacks; + DenseMap> + VirtualUseCallbacks; + public: /// Translate \p V from the callee context into the call site context. std::optional diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp index c3cb22d..30450c3 100644 --- a/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/llvm/lib/Transforms/IPO/Attributor.cpp @@ -1472,6 +1472,11 @@ bool Attributor::checkForAllUses( bool IgnoreDroppableUses, function_ref EquivalentUseCB) { + // Check virtual uses first. + for (VirtualUseCallbackTy &CB : VirtualUseCallbacks.lookup(&V)) + if (!CB(*this, &QueryingAA)) + return false; + // Check the trivial case first as it catches void values. if (V.use_empty()) return true; @@ -1611,6 +1616,10 @@ bool Attributor::checkForAllCallSites(function_ref Pred, << " has no internal linkage, hence not all call sites are known\n"); return false; } + // Check virtual uses first. + for (VirtualUseCallbackTy &CB : VirtualUseCallbacks.lookup(&Fn)) + if (!CB(*this, QueryingAA)) + return false; SmallVector Uses(make_pointer_range(Fn.uses())); for (unsigned u = 0; u < Uses.size(); ++u) { diff --git a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp index 5ea3d2b..89723a3 100644 --- a/llvm/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/llvm/lib/Transforms/IPO/OpenMPOpt.cpp @@ -2087,30 +2087,6 @@ private: [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); }); } - /// RAII struct to temporarily change an RTL function's linkage to external. - /// This prevents it from being mistakenly removed by other optimizations. - struct ExternalizationRAII { - ExternalizationRAII(OMPInformationCache &OMPInfoCache, - RuntimeFunction RFKind) - : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) { - if (!Declaration) - return; - - LinkageType = Declaration->getLinkage(); - Declaration->setLinkage(GlobalValue::ExternalLinkage); - } - - ~ExternalizationRAII() { - if (!Declaration) - return; - - Declaration->setLinkage(LinkageType); - } - - Function *Declaration; - GlobalValue::LinkageTypes LinkageType; - }; - /// The underlying module. Module &M; @@ -2135,21 +2111,6 @@ private: if (SCC.empty()) return false; - // Temporarily make these function have external linkage so the Attributor - // doesn't remove them when we try to look them up later. - ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel); - ExternalizationRAII EndParallel(OMPInfoCache, - OMPRTL___kmpc_kernel_end_parallel); - ExternalizationRAII BarrierSPMD(OMPInfoCache, - OMPRTL___kmpc_barrier_simple_spmd); - ExternalizationRAII BarrierGeneric(OMPInfoCache, - OMPRTL___kmpc_barrier_simple_generic); - ExternalizationRAII ThreadId(OMPInfoCache, - OMPRTL___kmpc_get_hardware_thread_id_in_block); - ExternalizationRAII NumThreads( - OMPInfoCache, OMPRTL___kmpc_get_hardware_num_threads_in_block); - ExternalizationRAII WarpSize(OMPInfoCache, OMPRTL___kmpc_get_warp_size); - registerAAs(IsModulePass); ChangeStatus Changed = A.run(); @@ -3296,27 +3257,6 @@ struct AAKernelInfoFunction : AAKernelInfo { return Val; }; - Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB = - [&](const IRPosition &IRP, const AbstractAttribute *AA, - bool &UsedAssumedInformation) -> std::optional { - // IRP represents the "RequiresFullRuntime" argument of an - // __kmpc_target_init or __kmpc_target_deinit call. We will answer this - // one with the internal state of the SPMDCompatibilityTracker, so if - // generic then true, if SPMD then false. - if (!SPMDCompatibilityTracker.isValidState()) - return nullptr; - if (!SPMDCompatibilityTracker.isAtFixpoint()) { - if (AA) - A.recordDependence(*this, *AA, DepClassTy::OPTIONAL); - UsedAssumedInformation = true; - } else { - UsedAssumedInformation = false; - } - auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(), - !SPMDCompatibilityTracker.isAssumed()); - return Val; - }; - constexpr const int InitModeArgNo = 1; constexpr const int DeinitModeArgNo = 1; constexpr const int InitUseStateMachineArgNo = 2; @@ -3338,6 +3278,84 @@ struct AAKernelInfoFunction : AAKernelInfo { // This is a generic region but SPMDization is disabled so stop tracking. else if (DisableOpenMPOptSPMDization) SPMDCompatibilityTracker.indicatePessimisticFixpoint(); + + // Register virtual uses of functions we might need to preserve. + auto RegisterVirtualUse = [&](RuntimeFunction RFKind, + Attributor::VirtualUseCallbackTy &CB) { + if (!OMPInfoCache.RFIs[RFKind].Declaration) + return; + A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB); + }; + + // Add a dependence to ensure updates if the state changes. + auto AddDependence = [](Attributor &A, const AAKernelInfo *KI, + const AbstractAttribute *QueryingAA) { + if (QueryingAA) { + A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL); + } + return true; + }; + + Attributor::VirtualUseCallbackTy CustomStateMachineUseCB = + [&](Attributor &A, const AbstractAttribute *QueryingAA) { + // Whenever we create a custom state machine we will insert calls to + // __kmpc_get_hardware_num_threads_in_block, + // __kmpc_get_warp_size, + // __kmpc_barrier_simple_generic, + // __kmpc_kernel_parallel, and + // __kmpc_kernel_end_parallel. + // Not needed if we are on track for SPMDzation. + if (SPMDCompatibilityTracker.isValidState()) + return AddDependence(A, this, QueryingAA); + // Not needed if we can't rewrite due to an invalid state. + if (!ReachedKnownParallelRegions.isValidState()) + return AddDependence(A, this, QueryingAA); + return false; + }; + + // Not needed if we are pre-runtime merge. + if (!KernelInitCB->getCalledFunction()->isDeclaration()) { + RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block, + CustomStateMachineUseCB); + RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB); + RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic, + CustomStateMachineUseCB); + RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel, + CustomStateMachineUseCB); + RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel, + CustomStateMachineUseCB); + } + + // If we do not perform SPMDzation we do not need the virtual uses below. + if (SPMDCompatibilityTracker.isAtFixpoint()) + return; + + Attributor::VirtualUseCallbackTy HWThreadIdUseCB = + [&](Attributor &A, const AbstractAttribute *QueryingAA) { + // Whenever we perform SPMDzation we will insert + // __kmpc_get_hardware_thread_id_in_block calls. + if (!SPMDCompatibilityTracker.isValidState()) + return AddDependence(A, this, QueryingAA); + return false; + }; + RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block, + HWThreadIdUseCB); + + Attributor::VirtualUseCallbackTy SPMDBarrierUseCB = + [&](Attributor &A, const AbstractAttribute *QueryingAA) { + // Whenever we perform SPMDzation with guarding we will insert + // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is + // nothing to guard, or there are no parallel regions, we don't need + // the calls. + if (!SPMDCompatibilityTracker.isValidState()) + return AddDependence(A, this, QueryingAA); + if (SPMDCompatibilityTracker.empty()) + return AddDependence(A, this, QueryingAA); + if (!mayContainParallelRegion()) + return AddDependence(A, this, QueryingAA); + return false; + }; + RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB); } /// Sanitize the string \p S such that it is a suitable global symbol name. diff --git a/llvm/test/Transforms/OpenMP/get_hardware_num_threads_in_block_fold.ll b/llvm/test/Transforms/OpenMP/get_hardware_num_threads_in_block_fold.ll index 5376a88..8254630 100644 --- a/llvm/test/Transforms/OpenMP/get_hardware_num_threads_in_block_fold.ll +++ b/llvm/test/Transforms/OpenMP/get_hardware_num_threads_in_block_fold.ll @@ -12,6 +12,9 @@ target triple = "nvptx64" ; CHECK: @[[G:[a-zA-Z0-9_$"\\.-]+]] = external global i32 ; CHECK: @[[KERNEL1_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 3 ; CHECK: @[[KERNEL2_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 3 +; CHECK: @[[KERNEL0_NESTED_PARALLELISM:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0 +; CHECK: @[[KERNEL1_NESTED_PARALLELISM:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0 +; CHECK: @[[KERNEL2_NESTED_PARALLELISM:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0 ; CHECK: @[[GLOB0:[0-9]+]] = private unnamed_addr constant [23 x i8] c" ; CHECK: @[[GLOB1:[0-9]+]] = private unnamed_addr constant [[STRUCT_IDENT_T:%.*]] { i32 0, i32 2, i32 0, i32 22, ptr @[[GLOB0]] }, align 8 ;. @@ -191,11 +194,6 @@ entry: } define internal i32 @__kmpc_get_hardware_num_threads_in_block() { -; CHECK-LABEL: define {{[^@]+}}@__kmpc_get_hardware_num_threads_in_block -; CHECK-SAME: () #[[ATTR1]] { -; CHECK-NEXT: [[RET:%.*]] = call i32 @__kmpc_get_hardware_num_threads_in_block_dummy() -; CHECK-NEXT: ret i32 [[RET]] -; %ret = call i32 @__kmpc_get_hardware_num_threads_in_block_dummy() ret i32 %ret } -- 2.7.4