From 05ae04c396519cca9ef50d3b9cafb0cd9c87d1d7 Mon Sep 17 00:00:00 2001 From: Simon Moll Date: Wed, 30 Sep 2020 17:10:44 +0200 Subject: [PATCH] [DA][SDA] SyncDependenceAnalysis re-write This patch achieves two things: 1. It breaks up the `join_blocks` interface between the SDA to the DA to return two separate sets for divergent loops exits and divergent, disjoint path joins. 2. It updates the SDA algorithm to run in O(n) time and improves the precision on divergent loop exits. This fixes `https://bugs.llvm.org/show_bug.cgi?id=46372` (by virtue of the improved `join_blocks` interface) and revealed an imprecise expected result in the `Analysis/DivergenceAnalysis/AMDGPU/hidden_loopdiverge.ll` test. Reviewed By: sameerds Differential Revision: https://reviews.llvm.org/D84413 --- llvm/include/llvm/Analysis/DivergenceAnalysis.h | 83 ++-- .../include/llvm/Analysis/SyncDependenceAnalysis.h | 42 +- llvm/lib/Analysis/DivergenceAnalysis.cpp | 332 ++++++--------- llvm/lib/Analysis/SyncDependenceAnalysis.cpp | 462 ++++++++++++--------- .../AMDGPU/hidden_loopdiverge.ll | 3 +- .../AMDGPU/trivial-join-at-loop-exit.ll | 3 - 6 files changed, 455 insertions(+), 470 deletions(-) diff --git a/llvm/include/llvm/Analysis/DivergenceAnalysis.h b/llvm/include/llvm/Analysis/DivergenceAnalysis.h index a2da97b..8a32bfb 100644 --- a/llvm/include/llvm/Analysis/DivergenceAnalysis.h +++ b/llvm/include/llvm/Analysis/DivergenceAnalysis.h @@ -59,8 +59,10 @@ public: /// \brief Mark \p UniVal as a value that is always uniform. void addUniformOverride(const Value &UniVal); - /// \brief Mark \p DivVal as a value that is always divergent. - void markDivergent(const Value &DivVal); + /// \brief Mark \p DivVal as a value that is always divergent. Will not do so + /// if `isAlwaysUniform(DivVal)`. + /// \returns Whether the tracked divergence state of \p DivVal changed. + bool markDivergent(const Value &DivVal); /// \brief Propagate divergence to all instructions in the region. /// Divergence is seeded by calls to \p markDivergent. @@ -76,45 +78,38 @@ public: /// \brief Whether \p Val is divergent at its definition. bool isDivergent(const Value &Val) const; - /// \brief Whether \p U is divergent. Uses of a uniform value can be divergent. + /// \brief Whether \p U is divergent. Uses of a uniform value can be + /// divergent. bool isDivergentUse(const Use &U) const; void print(raw_ostream &OS, const Module *) const; private: - bool updateTerminator(const Instruction &Term) const; - bool updatePHINode(const PHINode &Phi) const; - - /// \brief Computes whether \p Inst is divergent based on the - /// divergence of its operands. - /// - /// \returns Whether \p Inst is divergent. - /// - /// This should only be called for non-phi, non-terminator instructions. - bool updateNormalInstruction(const Instruction &Inst) const; - - /// \brief Mark users of live-out users as divergent. - /// - /// \param LoopHeader the header of the divergent loop. - /// - /// Marks all users of live-out values of the loop headed by \p LoopHeader - /// as divergent and puts them on the worklist. - void taintLoopLiveOuts(const BasicBlock &LoopHeader); - - /// \brief Push all users of \p Val (in the region) to the worklist + /// \brief Mark \p Term as divergent and push all Instructions that become + /// divergent as a result on the worklist. + void analyzeControlDivergence(const Instruction &Term); + /// \brief Mark all phi nodes in \p JoinBlock as divergent and push them on + /// the worklist. + void taintAndPushPhiNodes(const BasicBlock &JoinBlock); + + /// \brief Identify all Instructions that become divergent because \p DivExit + /// is a divergent loop exit of \p DivLoop. Mark those instructions as + /// divergent and push them on the worklist. + void propagateLoopExitDivergence(const BasicBlock &DivExit, + const Loop &DivLoop); + + /// \brief Internal implementation function for propagateLoopExitDivergence. + void analyzeLoopExitDivergence(const BasicBlock &DivExit, + const Loop &OuterDivLoop); + + /// \brief Mark all instruction as divergent that use a value defined in \p + /// OuterDivLoop. Push their users on the worklist. + void analyzeTemporalDivergence(const Instruction &I, + const Loop &OuterDivLoop); + + /// \brief Push all users of \p Val (in the region) to the worklist. void pushUsers(const Value &I); - /// \brief Push all phi nodes in @block to the worklist - void pushPHINodes(const BasicBlock &Block); - - /// \brief Mark \p Block as join divergent - /// - /// A block is join divergent if two threads may reach it from different - /// incoming blocks at the same time. - void markBlockJoinDivergent(const BasicBlock &Block) { - DivergentJoinBlocks.insert(&Block); - } - /// \brief Whether \p Val is divergent when read in \p ObservingBlock. bool isTemporalDivergent(const BasicBlock &ObservingBlock, const Value &Val) const; @@ -126,24 +121,6 @@ private: return DivergentJoinBlocks.find(&Block) != DivergentJoinBlocks.end(); } - /// \brief Propagate control-induced divergence to users (phi nodes and - /// instructions). - // - // \param JoinBlock is a divergent loop exit or join point of two disjoint - // paths. - // \returns Whether \p JoinBlock is a divergent loop exit of \p TermLoop. - bool propagateJoinDivergence(const BasicBlock &JoinBlock, - const Loop *TermLoop); - - /// \brief Propagate induced value divergence due to control divergence in \p - /// Term. - void propagateBranchDivergence(const Instruction &Term); - - /// \brief Propagate divergent caused by a divergent loop exit. - /// - /// \param ExitingLoop is a divergent loop. - void propagateLoopDivergence(const Loop &ExitingLoop); - private: const Function &F; // If regionLoop != nullptr, analysis is only performed within \p RegionLoop. @@ -166,7 +143,7 @@ private: DenseSet UniformOverrides; // Blocks with joining divergent control from different predecessors. - DenseSet DivergentJoinBlocks; + DenseSet DivergentJoinBlocks; // FIXME Deprecated // Detected/marked divergent values. DenseSet DivergentValues; diff --git a/llvm/include/llvm/Analysis/SyncDependenceAnalysis.h b/llvm/include/llvm/Analysis/SyncDependenceAnalysis.h index 2f07b31..9838d62 100644 --- a/llvm/include/llvm/Analysis/SyncDependenceAnalysis.h +++ b/llvm/include/llvm/Analysis/SyncDependenceAnalysis.h @@ -21,6 +21,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/LoopInfo.h" #include +#include namespace llvm { @@ -30,6 +31,26 @@ class Loop; class PostDominatorTree; using ConstBlockSet = SmallPtrSet; +struct ControlDivergenceDesc { + // Join points of divergent disjoint paths. + ConstBlockSet JoinDivBlocks; + // Divergent loop exits + ConstBlockSet LoopDivBlocks; +}; + +struct ModifiedPO { + std::vector LoopPO; + std::unordered_map POIndex; + void appendBlock(const BasicBlock &BB) { + POIndex[&BB] = LoopPO.size(); + LoopPO.push_back(&BB); + } + unsigned getIndexOf(const BasicBlock &BB) const { + return POIndex.find(&BB)->second; + } + unsigned size() const { return LoopPO.size(); } + const BasicBlock *getBlockAt(unsigned Idx) const { return LoopPO[Idx]; } +}; /// \brief Relates points of divergent control to join points in /// reducible CFGs. @@ -51,28 +72,19 @@ public: /// header. Those exit blocks are added to the returned set. /// If L is the parent loop of \p Term and an exit of L is in the returned /// set then L is a divergent loop. - const ConstBlockSet &join_blocks(const Instruction &Term); - - /// \brief Computes divergent join points and loop exits (in the surrounding - /// loop) caused by the divergent loop exits of\p Loop. - /// - /// The set of blocks which are reachable by disjoint paths from the - /// loop exits of \p Loop. - /// This treats the loop as a single node in \p Loop's parent loop. - /// The returned set has the same properties as for join_blocks(TermInst&). - const ConstBlockSet &join_blocks(const Loop &Loop); + const ControlDivergenceDesc &getJoinBlocks(const Instruction &Term); private: - static ConstBlockSet EmptyBlockSet; + static ControlDivergenceDesc EmptyDivergenceDesc; + + ModifiedPO LoopPO; - ReversePostOrderTraversal FuncRPOT; const DominatorTree &DT; const PostDominatorTree &PDT; const LoopInfo &LI; - std::map> CachedLoopExitJoins; - std::map> - CachedBranchJoins; + std::map> + CachedControlDivDescs; }; } // namespace llvm diff --git a/llvm/lib/Analysis/DivergenceAnalysis.cpp b/llvm/lib/Analysis/DivergenceAnalysis.cpp index 343406c..d01a0b9 100644 --- a/llvm/lib/Analysis/DivergenceAnalysis.cpp +++ b/llvm/lib/Analysis/DivergenceAnalysis.cpp @@ -1,4 +1,4 @@ -//===- DivergenceAnalysis.cpp --------- Divergence Analysis Implementation -==// +//===---- DivergenceAnalysis.cpp --- Divergence Analysis Implementation ----==// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -97,42 +97,18 @@ DivergenceAnalysis::DivergenceAnalysis( : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA), IsLCSSAForm(IsLCSSAForm) {} -void DivergenceAnalysis::markDivergent(const Value &DivVal) { +bool DivergenceAnalysis::markDivergent(const Value &DivVal) { + if (isAlwaysUniform(DivVal)) + return false; assert(isa(DivVal) || isa(DivVal)); assert(!isAlwaysUniform(DivVal) && "cannot be a divergent"); - DivergentValues.insert(&DivVal); + return DivergentValues.insert(&DivVal).second; } void DivergenceAnalysis::addUniformOverride(const Value &UniVal) { UniformOverrides.insert(&UniVal); } -bool DivergenceAnalysis::updateTerminator(const Instruction &Term) const { - if (Term.getNumSuccessors() <= 1) - return false; - if (auto *BranchTerm = dyn_cast(&Term)) { - assert(BranchTerm->isConditional()); - return isDivergent(*BranchTerm->getCondition()); - } - if (auto *SwitchTerm = dyn_cast(&Term)) { - return isDivergent(*SwitchTerm->getCondition()); - } - if (isa(Term)) { - return false; // ignore abnormal executions through landingpad - } - - llvm_unreachable("unexpected terminator"); -} - -bool DivergenceAnalysis::updateNormalInstruction(const Instruction &I) const { - // TODO function calls with side effects, etc - for (const auto &Op : I.operands()) { - if (isDivergent(*Op)) - return true; - } - return false; -} - bool DivergenceAnalysis::isTemporalDivergent(const BasicBlock &ObservingBlock, const Value &Val) const { const auto *Inst = dyn_cast(&Val); @@ -150,32 +126,6 @@ bool DivergenceAnalysis::isTemporalDivergent(const BasicBlock &ObservingBlock, return false; } -bool DivergenceAnalysis::updatePHINode(const PHINode &Phi) const { - // joining divergent disjoint path in Phi parent block - if (!Phi.hasConstantOrUndefValue() && isJoinDivergent(*Phi.getParent())) { - return true; - } - - // An incoming value could be divergent by itself. - // Otherwise, an incoming value could be uniform within the loop - // that carries its definition but it may appear divergent - // from outside the loop. This happens when divergent loop exits - // drop definitions of that uniform value in different iterations. - // - // for (int i = 0; i < n; ++i) { // 'i' is uniform inside the loop - // if (i % thread_id == 0) break; // divergent loop exit - // } - // int divI = i; // divI is divergent - for (size_t i = 0; i < Phi.getNumIncomingValues(); ++i) { - const auto *InVal = Phi.getIncomingValue(i); - if (isDivergent(*Phi.getIncomingValue(i)) || - isTemporalDivergent(*Phi.getParent(), *InVal)) { - return true; - } - } - return false; -} - bool DivergenceAnalysis::inRegion(const Instruction &I) const { return I.getParent() && inRegion(*I.getParent()); } @@ -184,35 +134,82 @@ bool DivergenceAnalysis::inRegion(const BasicBlock &BB) const { return (!RegionLoop && BB.getParent() == &F) || RegionLoop->contains(&BB); } -static bool usesLiveOut(const Instruction &I, const Loop *DivLoop) { - for (auto &Op : I.operands()) { - auto *OpInst = dyn_cast(&Op); +void DivergenceAnalysis::pushUsers(const Value &V) { + const auto *I = dyn_cast(&V); + + if (I && I->isTerminator()) { + analyzeControlDivergence(*I); + return; + } + + for (const auto *User : V.users()) { + const auto *UserInst = dyn_cast(User); + if (!UserInst) + continue; + + // only compute divergent inside loop + if (!inRegion(*UserInst)) + continue; + + // All users of divergent values are immediate divergent + if (markDivergent(*UserInst)) + Worklist.push_back(UserInst); + } +} + +static const Instruction *getIfCarriedInstruction(const Use &U, + const Loop &DivLoop) { + const auto *I = dyn_cast(&U); + if (!I) + return nullptr; + if (!DivLoop.contains(I)) + return nullptr; + return I; +} + +void DivergenceAnalysis::analyzeTemporalDivergence(const Instruction &I, + const Loop &OuterDivLoop) { + if (isAlwaysUniform(I)) + return; + if (isDivergent(I)) + return; + + LLVM_DEBUG(dbgs() << "Analyze temporal divergence: " << I.getName() << "\n"); + assert((isa(I) || !IsLCSSAForm) && + "In LCSSA form all users of loop-exiting defs are Phi nodes."); + for (const Use &Op : I.operands()) { + const auto *OpInst = getIfCarriedInstruction(Op, OuterDivLoop); if (!OpInst) continue; - if (DivLoop->contains(OpInst->getParent())) - return true; + if (markDivergent(I)) + pushUsers(I); + return; } - return false; } // marks all users of loop-carried values of the loop headed by LoopHeader as // divergent -void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) { - auto *DivLoop = LI.getLoopFor(&LoopHeader); - assert(DivLoop && "loopHeader is not actually part of a loop"); +void DivergenceAnalysis::analyzeLoopExitDivergence(const BasicBlock &DivExit, + const Loop &OuterDivLoop) { + // All users are in immediate exit blocks + if (IsLCSSAForm) { + for (const auto &Phi : DivExit.phis()) { + analyzeTemporalDivergence(Phi, OuterDivLoop); + } + return; + } - SmallVector TaintStack; - DivLoop->getExitBlocks(TaintStack); + // For non-LCSSA we have to follow all live out edges wherever they may lead. + const BasicBlock &LoopHeader = *OuterDivLoop.getHeader(); + SmallVector TaintStack; + TaintStack.push_back(&DivExit); // Otherwise potential users of loop-carried values could be anywhere in the // dominance region of DivLoop (including its fringes for phi nodes) DenseSet Visited; - for (auto *Block : TaintStack) { - Visited.insert(Block); - } - Visited.insert(&LoopHeader); + Visited.insert(&DivExit); - while (!TaintStack.empty()) { + do { auto *UserBlock = TaintStack.back(); TaintStack.pop_back(); @@ -220,33 +217,21 @@ void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) { if (!inRegion(*UserBlock)) continue; - assert(!DivLoop->contains(UserBlock) && + assert(!OuterDivLoop.contains(UserBlock) && "irreducible control flow detected"); // phi nodes at the fringes of the dominance region if (!DT.dominates(&LoopHeader, UserBlock)) { // all PHI nodes of UserBlock become divergent for (auto &Phi : UserBlock->phis()) { - Worklist.push_back(&Phi); + analyzeTemporalDivergence(Phi, OuterDivLoop); } continue; } - // taint outside users of values carried by DivLoop + // Taint outside users of values carried by OuterDivLoop. for (auto &I : *UserBlock) { - if (isAlwaysUniform(I)) - continue; - if (isDivergent(I)) - continue; - if (!usesLiveOut(I, DivLoop)) - continue; - - markDivergent(I); - if (I.isTerminator()) { - propagateBranchDivergence(I); - } else { - pushUsers(I); - } + analyzeTemporalDivergence(I, OuterDivLoop); } // visit all blocks in the dominance region @@ -256,56 +241,57 @@ void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) { } TaintStack.push_back(SuccBlock); } - } + } while (!TaintStack.empty()); } -void DivergenceAnalysis::pushPHINodes(const BasicBlock &Block) { - for (const auto &Phi : Block.phis()) { - if (isDivergent(Phi)) - continue; - Worklist.push_back(&Phi); +void DivergenceAnalysis::propagateLoopExitDivergence(const BasicBlock &DivExit, + const Loop &InnerDivLoop) { + LLVM_DEBUG(dbgs() << "\tpropLoopExitDiv " << DivExit.getName() << "\n"); + + // Find outer-most loop that does not contain \p DivExit + const Loop *DivLoop = &InnerDivLoop; + const Loop *OuterDivLoop = DivLoop; + const Loop *ExitLevelLoop = LI.getLoopFor(&DivExit); + const unsigned LoopExitDepth = + ExitLevelLoop ? ExitLevelLoop->getLoopDepth() : 0; + while (DivLoop && DivLoop->getLoopDepth() > LoopExitDepth) { + DivergentLoops.insert(DivLoop); // all crossed loops are divergent + OuterDivLoop = DivLoop; + DivLoop = DivLoop->getParentLoop(); } -} - -void DivergenceAnalysis::pushUsers(const Value &V) { - for (const auto *User : V.users()) { - const auto *UserInst = dyn_cast(User); - if (!UserInst) - continue; - - if (isDivergent(*UserInst)) - continue; + LLVM_DEBUG(dbgs() << "\tOuter-most left loop: " << OuterDivLoop->getName() + << "\n"); - // only compute divergent inside loop - if (!inRegion(*UserInst)) - continue; - Worklist.push_back(UserInst); - } + analyzeLoopExitDivergence(DivExit, *OuterDivLoop); } -bool DivergenceAnalysis::propagateJoinDivergence(const BasicBlock &JoinBlock, - const Loop *BranchLoop) { - LLVM_DEBUG(dbgs() << "\tpropJoinDiv " << JoinBlock.getName() << "\n"); +// this is a divergent join point - mark all phi nodes as divergent and push +// them onto the stack. +void DivergenceAnalysis::taintAndPushPhiNodes(const BasicBlock &JoinBlock) { + LLVM_DEBUG(dbgs() << "taintAndPushPhiNodes in " << JoinBlock.getName() + << "\n"); // ignore divergence outside the region if (!inRegion(JoinBlock)) { - return false; + return; } // push non-divergent phi nodes in JoinBlock to the worklist - pushPHINodes(JoinBlock); - - // disjoint-paths divergent at JoinBlock - markBlockJoinDivergent(JoinBlock); - - // JoinBlock is a divergent loop exit - return BranchLoop && !BranchLoop->contains(&JoinBlock); + for (const auto &Phi : JoinBlock.phis()) { + if (isDivergent(Phi)) + continue; + // FIXME Theoretically ,the 'undef' value could be replaced by any other + // value causing spurious divergence. + if (Phi.hasConstantOrUndefValue()) + continue; + if (markDivergent(Phi)) + Worklist.push_back(&Phi); + } } -void DivergenceAnalysis::propagateBranchDivergence(const Instruction &Term) { - LLVM_DEBUG(dbgs() << "propBranchDiv " << Term.getParent()->getName() << "\n"); - - markDivergent(Term); +void DivergenceAnalysis::analyzeControlDivergence(const Instruction &Term) { + LLVM_DEBUG(dbgs() << "analyzeControlDiv " << Term.getParent()->getName() + << "\n"); // Don't propagate divergence from unreachable blocks. if (!DT.isReachableFromEntry(Term.getParent())) @@ -313,104 +299,36 @@ void DivergenceAnalysis::propagateBranchDivergence(const Instruction &Term) { const auto *BranchLoop = LI.getLoopFor(Term.getParent()); - // whether there is a divergent loop exit from BranchLoop (if any) - bool IsBranchLoopDivergent = false; + const auto &DivDesc = SDA.getJoinBlocks(Term); - // iterate over all blocks reachable by disjoint from Term within the loop - // also iterates over loop exits that become divergent due to Term. - for (const auto *JoinBlock : SDA.join_blocks(Term)) { - IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop); + // Iterate over all blocks now reachable by a disjoint path join + for (const auto *JoinBlock : DivDesc.JoinDivBlocks) { + taintAndPushPhiNodes(*JoinBlock); } - // Branch loop is a divergent loop due to the divergent branch in Term - if (IsBranchLoopDivergent) { - assert(BranchLoop); - if (!DivergentLoops.insert(BranchLoop).second) { - return; - } - propagateLoopDivergence(*BranchLoop); - } -} - -void DivergenceAnalysis::propagateLoopDivergence(const Loop &ExitingLoop) { - LLVM_DEBUG(dbgs() << "propLoopDiv " << ExitingLoop.getName() << "\n"); - - // don't propagate beyond region - if (!inRegion(*ExitingLoop.getHeader())) - return; - - const auto *BranchLoop = ExitingLoop.getParentLoop(); - - // Uses of loop-carried values could occur anywhere - // within the dominance region of the definition. All loop-carried - // definitions are dominated by the loop header (reducible control). - // Thus all users have to be in the dominance region of the loop header, - // except PHI nodes that can also live at the fringes of the dom region - // (incoming defining value). - if (!IsLCSSAForm) - taintLoopLiveOuts(*ExitingLoop.getHeader()); - - // whether there is a divergent loop exit from BranchLoop (if any) - bool IsBranchLoopDivergent = false; - - // iterate over all blocks reachable by disjoint paths from exits of - // ExitingLoop also iterates over loop exits (of BranchLoop) that in turn - // become divergent. - for (const auto *JoinBlock : SDA.join_blocks(ExitingLoop)) { - IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop); - } - - // Branch loop is a divergent due to divergent loop exit in ExitingLoop - if (IsBranchLoopDivergent) { - assert(BranchLoop); - if (!DivergentLoops.insert(BranchLoop).second) { - return; - } - propagateLoopDivergence(*BranchLoop); + assert(DivDesc.LoopDivBlocks.empty() || BranchLoop); + for (const auto *DivExitBlock : DivDesc.LoopDivBlocks) { + propagateLoopExitDivergence(*DivExitBlock, *BranchLoop); } } void DivergenceAnalysis::compute() { - for (auto *DivVal : DivergentValues) { + // Initialize worklist. + auto DivValuesCopy = DivergentValues; + for (const auto *DivVal : DivValuesCopy) { + assert(isDivergent(*DivVal) && "Worklist invariant violated!"); pushUsers(*DivVal); } - // propagate divergence + // All values on the Worklist are divergent. + // Their users may not have been updated yed. while (!Worklist.empty()) { const Instruction &I = *Worklist.back(); Worklist.pop_back(); - // maintain uniformity of overrides - if (isAlwaysUniform(I)) - continue; - - bool WasDivergent = isDivergent(I); - if (WasDivergent) - continue; - - // propagate divergence caused by terminator - if (I.isTerminator()) { - if (updateTerminator(I)) { - // propagate control divergence to affected instructions - propagateBranchDivergence(I); - continue; - } - } - - // update divergence of I due to divergent operands - bool DivergentUpd = false; - const auto *Phi = dyn_cast(&I); - if (Phi) { - DivergentUpd = updatePHINode(*Phi); - } else { - DivergentUpd = updateNormalInstruction(I); - } - // propagate value divergence to users - if (DivergentUpd) { - markDivergent(I); - pushUsers(I); - } + assert(isDivergent(I) && "Worklist invariant violated!"); + pushUsers(I); } } @@ -444,7 +362,7 @@ GPUDivergenceAnalysis::GPUDivergenceAnalysis(Function &F, const PostDominatorTree &PDT, const LoopInfo &LI, const TargetTransformInfo &TTI) - : SDA(DT, PDT, LI), DA(F, nullptr, DT, LI, SDA, false) { + : SDA(DT, PDT, LI), DA(F, nullptr, DT, LI, SDA, /* LCSSA */ false) { for (auto &I : instructions(F)) { if (TTI.isSourceOfDivergence(&I)) { DA.markDivergent(I); diff --git a/llvm/lib/Analysis/SyncDependenceAnalysis.cpp b/llvm/lib/Analysis/SyncDependenceAnalysis.cpp index 36bef70..0771bb5 100644 --- a/llvm/lib/Analysis/SyncDependenceAnalysis.cpp +++ b/llvm/lib/Analysis/SyncDependenceAnalysis.cpp @@ -1,4 +1,4 @@ -//==- SyncDependenceAnalysis.cpp - Divergent Branch Dependence Calculation -==// +//===--- SyncDependenceAnalysis.cpp - Compute Control Divergence Effects --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -107,271 +107,353 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include #include #include #define DEBUG_TYPE "sync-dependence" +// The SDA algorithm operates on a modified CFG - we modify the edges leaving +// loop headers as follows: +// +// * We remove all edges leaving all loop headers. +// * We add additional edges from the loop headers to their exit blocks. +// +// The modification is virtual, that is whenever we visit a loop header we +// pretend it had different successors. +namespace { +using namespace llvm; + +// Custom Post-Order Traveral +// +// We cannot use the vanilla (R)PO computation of LLVM because: +// * We (virtually) modify the CFG. +// * We want a loop-compact block enumeration, that is the numbers assigned by +// the traveral to the blocks of a loop are an interval. +using POCB = std::function; +using VisitedSet = std::set; +using BlockStack = std::vector; + +// forward +static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, + VisitedSet &Finalized); + +// for a nested region (top-level loop or nested loop) +static void computeStackPO(BlockStack &Stack, const LoopInfo &LI, Loop *Loop, + POCB CallBack, VisitedSet &Finalized) { + const auto *LoopHeader = Loop ? Loop->getHeader() : nullptr; + while (!Stack.empty()) { + const auto *NextBB = Stack.back(); + + auto *NestedLoop = LI.getLoopFor(NextBB); + bool IsNestedLoop = NestedLoop != Loop; + + // Treat the loop as a node + if (IsNestedLoop) { + SmallVector NestedExits; + NestedLoop->getUniqueExitBlocks(NestedExits); + bool PushedNodes = false; + for (const auto *NestedExitBB : NestedExits) { + if (NestedExitBB == LoopHeader) + continue; + if (Loop && !Loop->contains(NestedExitBB)) + continue; + if (Finalized.count(NestedExitBB)) + continue; + PushedNodes = true; + Stack.push_back(NestedExitBB); + } + if (!PushedNodes) { + // All loop exits finalized -> finish this node + Stack.pop_back(); + computeLoopPO(LI, *NestedLoop, CallBack, Finalized); + } + continue; + } + + // DAG-style + bool PushedNodes = false; + for (const auto *SuccBB : successors(NextBB)) { + if (SuccBB == LoopHeader) + continue; + if (Loop && !Loop->contains(SuccBB)) + continue; + if (Finalized.count(SuccBB)) + continue; + PushedNodes = true; + Stack.push_back(SuccBB); + } + if (!PushedNodes) { + // Never push nodes twice + Stack.pop_back(); + if (!Finalized.insert(NextBB).second) + continue; + CallBack(*NextBB); + } + } +} + +static void computeTopLevelPO(Function &F, const LoopInfo &LI, POCB CallBack) { + VisitedSet Finalized; + BlockStack Stack; + Stack.reserve(24); // FIXME made-up number + Stack.push_back(&F.getEntryBlock()); + computeStackPO(Stack, LI, nullptr, CallBack, Finalized); +} + +static void computeLoopPO(const LoopInfo &LI, Loop &Loop, POCB CallBack, + VisitedSet &Finalized) { + /// Call CallBack on all loop blocks. + std::vector Stack; + const auto *LoopHeader = Loop.getHeader(); + + // Visit the header last + Finalized.insert(LoopHeader); + CallBack(*LoopHeader); + + // Initialize with immediate successors + for (const auto *BB : successors(LoopHeader)) { + if (!Loop.contains(BB)) + continue; + if (BB == LoopHeader) + continue; + Stack.push_back(BB); + } + + // Compute PO inside region + computeStackPO(Stack, LI, &Loop, CallBack, Finalized); +} + +} // namespace + namespace llvm { -ConstBlockSet SyncDependenceAnalysis::EmptyBlockSet; +ControlDivergenceDesc SyncDependenceAnalysis::EmptyDivergenceDesc; SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT, const PostDominatorTree &PDT, const LoopInfo &LI) - : FuncRPOT(DT.getRoot()->getParent()), DT(DT), PDT(PDT), LI(LI) {} + : DT(DT), PDT(PDT), LI(LI) { + computeTopLevelPO(*DT.getRoot()->getParent(), LI, + [&](const BasicBlock &BB) { LoopPO.appendBlock(BB); }); +} SyncDependenceAnalysis::~SyncDependenceAnalysis() {} -using FunctionRPOT = ReversePostOrderTraversal; - // divergence propagator for reducible CFGs struct DivergencePropagator { - const FunctionRPOT &FuncRPOT; + const ModifiedPO &LoopPOT; const DominatorTree &DT; const PostDominatorTree &PDT; const LoopInfo &LI; - - // identified join points - std::unique_ptr JoinBlocks; - - // reached loop exits (by a path disjoint to a path to the loop header) - SmallPtrSet ReachedLoopExits; - - // if DefMap[B] == C then C is the dominating definition at block B - // if DefMap[B] ~ undef then we haven't seen B yet - // if DefMap[B] == B then B is a join point of disjoint paths from X or B is - // an immediate successor of X (initial value). - using DefiningBlockMap = std::map; - DefiningBlockMap DefMap; - - // all blocks with pending visits - std::unordered_set PendingUpdates; - - DivergencePropagator(const FunctionRPOT &FuncRPOT, const DominatorTree &DT, - const PostDominatorTree &PDT, const LoopInfo &LI) - : FuncRPOT(FuncRPOT), DT(DT), PDT(PDT), LI(LI), - JoinBlocks(new ConstBlockSet) {} - - // set the definition at @block and mark @block as pending for a visit - void addPending(const BasicBlock &Block, const BasicBlock &DefBlock) { - bool WasAdded = DefMap.emplace(&Block, &DefBlock).second; - if (WasAdded) - PendingUpdates.insert(&Block); - } + const BasicBlock &DivTermBlock; + + // * if BlockLabels[IndexOf(B)] == C then C is the dominating definition at + // block B + // * if BlockLabels[IndexOf(B)] ~ undef then we haven't seen B yet + // * if BlockLabels[IndexOf(B)] == B then B is a join point of disjoint paths + // from X or B is an immediate successor of X (initial value). + using BlockLabelVec = std::vector; + BlockLabelVec BlockLabels; + // divergent join and loop exit descriptor. + std::unique_ptr DivDesc; + + DivergencePropagator(const ModifiedPO &LoopPOT, const DominatorTree &DT, + const PostDominatorTree &PDT, const LoopInfo &LI, + const BasicBlock &DivTermBlock) + : LoopPOT(LoopPOT), DT(DT), PDT(PDT), LI(LI), DivTermBlock(DivTermBlock), + BlockLabels(LoopPOT.size(), nullptr), + DivDesc(new ControlDivergenceDesc) {} void printDefs(raw_ostream &Out) { - Out << "Propagator::DefMap {\n"; - for (const auto *Block : FuncRPOT) { - auto It = DefMap.find(Block); - Out << Block->getName() << " : "; - if (It == DefMap.end()) { - Out << "\n"; + Out << "Propagator::BlockLabels {\n"; + for (int BlockIdx = (int)BlockLabels.size() - 1; BlockIdx > 0; --BlockIdx) { + const auto *Label = BlockLabels[BlockIdx]; + Out << LoopPOT.getBlockAt(BlockIdx)->getName().str() << "(" << BlockIdx + << ") : "; + if (!Label) { + Out << "\n"; } else { - const auto *DefBlock = It->second; - Out << (DefBlock ? DefBlock->getName() : "") << "\n"; + Out << Label->getName() << "\n"; } } Out << "}\n"; } - // process @succBlock with reaching definition @defBlock - // the original divergent branch was in @parentLoop (if any) - void visitSuccessor(const BasicBlock &SuccBlock, const Loop *ParentLoop, - const BasicBlock &DefBlock) { + // Push a definition (\p PushedLabel) to \p SuccBlock and return whether this + // causes a divergent join. + bool computeJoin(const BasicBlock &SuccBlock, const BasicBlock &PushedLabel) { + auto SuccIdx = LoopPOT.getIndexOf(SuccBlock); - // @succBlock is a loop exit - if (ParentLoop && !ParentLoop->contains(&SuccBlock)) { - DefMap.emplace(&SuccBlock, &DefBlock); - ReachedLoopExits.insert(&SuccBlock); - return; + // unset or same reaching label + const auto *OldLabel = BlockLabels[SuccIdx]; + if (!OldLabel || (OldLabel == &PushedLabel)) { + BlockLabels[SuccIdx] = &PushedLabel; + return false; } - // first reaching def? - auto ItLastDef = DefMap.find(&SuccBlock); - if (ItLastDef == DefMap.end()) { - addPending(SuccBlock, DefBlock); - return; - } + // Update the definition + BlockLabels[SuccIdx] = &SuccBlock; + return true; + } - // a join of at least two definitions - if (ItLastDef->second != &DefBlock) { - // do we know this join already? - if (!JoinBlocks->insert(&SuccBlock).second) - return; + // visiting a virtual loop exit edge from the loop header --> temporal + // divergence on join + bool visitLoopExitEdge(const BasicBlock &ExitBlock, + const BasicBlock &DefBlock, bool FromParentLoop) { + // Pushing from a non-parent loop cannot cause temporal divergence. + if (!FromParentLoop) + return visitEdge(ExitBlock, DefBlock); - // update the definition - addPending(SuccBlock, SuccBlock); - } + if (!computeJoin(ExitBlock, DefBlock)) + return false; + + // Identified a divergent loop exit + DivDesc->LoopDivBlocks.insert(&ExitBlock); + LLVM_DEBUG(dbgs() << "\tDivergent loop exit: " << ExitBlock.getName() + << "\n"); + return true; } - // find all blocks reachable by two disjoint paths from @rootTerm. - // This method works for both divergent terminators and loops with - // divergent exits. - // @rootBlock is either the block containing the branch or the header of the - // divergent loop. - // @nodeSuccessors is the set of successors of the node (Loop or Terminator) - // headed by @rootBlock. - // @parentLoop is the parent loop of the Loop or the loop that contains the - // Terminator. - template - std::unique_ptr - computeJoinPoints(const BasicBlock &RootBlock, - SuccessorIterable NodeSuccessors, const Loop *ParentLoop) { - assert(JoinBlocks); - - LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints. Parent loop: " - << (ParentLoop ? ParentLoop->getName() : "") + // process \p SuccBlock with reaching definition \p DefBlock + bool visitEdge(const BasicBlock &SuccBlock, const BasicBlock &DefBlock) { + if (!computeJoin(SuccBlock, DefBlock)) + return false; + + // Divergent, disjoint paths join. + DivDesc->JoinDivBlocks.insert(&SuccBlock); + LLVM_DEBUG(dbgs() << "\tDivergent join: " << SuccBlock.getName()); + return true; + } + + std::unique_ptr computeJoinPoints() { + assert(DivDesc); + + LLVM_DEBUG(dbgs() << "SDA:computeJoinPoints: " << DivTermBlock.getName() << "\n"); + const auto *DivBlockLoop = LI.getLoopFor(&DivTermBlock); + + // Early stopping criterion + int FloorIdx = LoopPOT.size() - 1; + const BasicBlock *FloorLabel = nullptr; + // bootstrap with branch targets - for (const auto *SuccBlock : NodeSuccessors) { - DefMap.emplace(SuccBlock, SuccBlock); + int BlockIdx = 0; - if (ParentLoop && !ParentLoop->contains(SuccBlock)) { - // immediate loop exit from node. - ReachedLoopExits.insert(SuccBlock); - } else { - // regular successor - PendingUpdates.insert(SuccBlock); - } - } + for (const auto *SuccBlock : successors(&DivTermBlock)) { + auto SuccIdx = LoopPOT.getIndexOf(*SuccBlock); + BlockLabels[SuccIdx] = SuccBlock; - LLVM_DEBUG(dbgs() << "SDA: rpo order:\n"; for (const auto *RpoBlock - : FuncRPOT) { - dbgs() << "- " << RpoBlock->getName() << "\n"; - }); + // Find the successor with the highest index to start with + BlockIdx = std::max(BlockIdx, SuccIdx); + FloorIdx = std::min(FloorIdx, SuccIdx); - auto ItBeginRPO = FuncRPOT.begin(); - auto ItEndRPO = FuncRPOT.end(); + // Identify immediate divergent loop exits + if (!DivBlockLoop) + continue; - // skip until term (TODO RPOT won't let us start at @term directly) - for (; *ItBeginRPO != &RootBlock; ++ItBeginRPO) { - assert(ItBeginRPO != ItEndRPO && "Unable to find RootBlock"); + const auto *BlockLoop = LI.getLoopFor(SuccBlock); + if (BlockLoop && DivBlockLoop->contains(BlockLoop)) + continue; + DivDesc->LoopDivBlocks.insert(SuccBlock); + LLVM_DEBUG(dbgs() << "\tImmediate divergent loop exit: " + << SuccBlock->getName() << "\n"); } // propagate definitions at the immediate successors of the node in RPO - auto ItBlockRPO = ItBeginRPO; - while ((++ItBlockRPO != ItEndRPO) && !PendingUpdates.empty()) { - const auto *Block = *ItBlockRPO; - LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n"); + for (; BlockIdx >= FloorIdx; --BlockIdx) { + LLVM_DEBUG(dbgs() << "Before next visit:\n"; printDefs(dbgs())); - // skip Block if not pending update - auto ItPending = PendingUpdates.find(Block); - if (ItPending == PendingUpdates.end()) + // Any label available here + const auto *Label = BlockLabels[BlockIdx]; + if (!Label) continue; - PendingUpdates.erase(ItPending); - // propagate definition at Block to its successors - auto ItDef = DefMap.find(Block); - const auto *DefBlock = ItDef->second; - assert(DefBlock); + // Ok. Get the block + const auto *Block = LoopPOT.getBlockAt(BlockIdx); + LLVM_DEBUG(dbgs() << "SDA::joins. visiting " << Block->getName() << "\n"); auto *BlockLoop = LI.getLoopFor(Block); - if (ParentLoop && - (ParentLoop != BlockLoop && ParentLoop->contains(BlockLoop))) { - // if the successor is the header of a nested loop pretend its a - // single node with the loop's exits as successors + bool IsLoopHeader = BlockLoop && BlockLoop->getHeader() == Block; + bool CausedJoin = false; + int LoweredFloorIdx = FloorIdx; + if (IsLoopHeader) { + // Disconnect from immediate successors and propagate directly to loop + // exits. SmallVector BlockLoopExits; BlockLoop->getExitBlocks(BlockLoopExits); + + bool IsParentLoop = BlockLoop->contains(&DivTermBlock); for (const auto *BlockLoopExit : BlockLoopExits) { - visitSuccessor(*BlockLoopExit, ParentLoop, *DefBlock); + CausedJoin |= visitLoopExitEdge(*BlockLoopExit, *Label, IsParentLoop); + LoweredFloorIdx = std::min(LoweredFloorIdx, + LoopPOT.getIndexOf(*BlockLoopExit)); } - } else { - // the successors are either on the same loop level or loop exits + // Acyclic successor case for (const auto *SuccBlock : successors(Block)) { - visitSuccessor(*SuccBlock, ParentLoop, *DefBlock); + CausedJoin |= visitEdge(*SuccBlock, *Label); + LoweredFloorIdx = + std::min(LoweredFloorIdx, LoopPOT.getIndexOf(*SuccBlock)); } } - } - LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs())); - - // We need to know the definition at the parent loop header to decide - // whether the definition at the header is different from the definition at - // the loop exits, which would indicate a divergent loop exits. - // - // A // loop header - // | - // B // nested loop header - // | - // C -> X (exit from B loop) -..-> (A latch) - // | - // D -> back to B (B latch) - // | - // proper exit from both loops - // - // analyze reached loop exits - if (!ReachedLoopExits.empty()) { - const BasicBlock *ParentLoopHeader = - ParentLoop ? ParentLoop->getHeader() : nullptr; - - assert(ParentLoop); - auto ItHeaderDef = DefMap.find(ParentLoopHeader); - const auto *HeaderDefBlock = - (ItHeaderDef == DefMap.end()) ? nullptr : ItHeaderDef->second; - - LLVM_DEBUG(printDefs(dbgs())); - assert(HeaderDefBlock && "no definition at header of carrying loop"); - - for (const auto *ExitBlock : ReachedLoopExits) { - auto ItExitDef = DefMap.find(ExitBlock); - assert((ItExitDef != DefMap.end()) && - "no reaching def at reachable loop exit"); - if (ItExitDef->second != HeaderDefBlock) { - JoinBlocks->insert(ExitBlock); - } + // Floor update + if (CausedJoin) { + // 1. Different labels pushed to successors + FloorIdx = LoweredFloorIdx; + } else if (FloorLabel != Label) { + // 2. No join caused BUT we pushed a label that is different than the + // last pushed label + FloorIdx = LoweredFloorIdx; + FloorLabel = Label; } } - return std::move(JoinBlocks); + LLVM_DEBUG(dbgs() << "SDA::joins. After propagation:\n"; printDefs(dbgs())); + + return std::move(DivDesc); } }; -const ConstBlockSet &SyncDependenceAnalysis::join_blocks(const Loop &Loop) { - using LoopExitVec = SmallVector; - LoopExitVec LoopExits; - Loop.getExitBlocks(LoopExits); - if (LoopExits.size() < 1) { - return EmptyBlockSet; +static void printBlockSet(ConstBlockSet &Blocks, raw_ostream &Out) { + Out << "["; + bool First = true; + for (const auto *BB : Blocks) { + if (!First) + Out << ", "; + First = false; + Out << BB->getName(); } - - // already available in cache? - auto ItCached = CachedLoopExitJoins.find(&Loop); - if (ItCached != CachedLoopExitJoins.end()) { - return *ItCached->second; - } - - // compute all join points - DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI}; - auto JoinBlocks = Propagator.computeJoinPoints( - *Loop.getHeader(), LoopExits, Loop.getParentLoop()); - - auto ItInserted = CachedLoopExitJoins.emplace(&Loop, std::move(JoinBlocks)); - assert(ItInserted.second); - return *ItInserted.first->second; + Out << "]"; } -const ConstBlockSet & -SyncDependenceAnalysis::join_blocks(const Instruction &Term) { +const ControlDivergenceDesc & +SyncDependenceAnalysis::getJoinBlocks(const Instruction &Term) { // trivial case - if (Term.getNumSuccessors() < 1) { - return EmptyBlockSet; + if (Term.getNumSuccessors() <= 1) { + return EmptyDivergenceDesc; } // already available in cache? - auto ItCached = CachedBranchJoins.find(&Term); - if (ItCached != CachedBranchJoins.end()) + auto ItCached = CachedControlDivDescs.find(&Term); + if (ItCached != CachedControlDivDescs.end()) return *ItCached->second; // compute all join points - DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI}; + // Special handling of divergent loop exits is not needed for LCSSA const auto &TermBlock = *Term.getParent(); - auto JoinBlocks = Propagator.computeJoinPoints( - TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock)); + DivergencePropagator Propagator(LoopPO, DT, PDT, LI, TermBlock); + auto DivDesc = Propagator.computeJoinPoints(); + + LLVM_DEBUG(dbgs() << "Result (" << Term.getParent()->getName() << "):\n"; + dbgs() << "JoinDivBlocks: "; + printBlockSet(DivDesc->JoinDivBlocks, dbgs()); + dbgs() << "\nLoopDivBlocks: "; + printBlockSet(DivDesc->LoopDivBlocks, dbgs()); dbgs() << "\n";); - auto ItInserted = CachedBranchJoins.emplace(&Term, std::move(JoinBlocks)); + auto ItInserted = CachedControlDivDescs.emplace(&Term, std::move(DivDesc)); assert(ItInserted.second); return *ItInserted.first->second; } diff --git a/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/hidden_loopdiverge.ll b/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/hidden_loopdiverge.ll index 12e2b0f..774e995 100644 --- a/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/hidden_loopdiverge.ll +++ b/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/hidden_loopdiverge.ll @@ -119,9 +119,8 @@ L: br i1 %uni.cond, label %D, label %G X: - %div.merge.x = phi i32 [ %a, %entry ], [ %uni.merge.h, %B ] ; temporal divergent phi + %uni.merge.x = phi i32 [ %a, %entry ], [ %uni.merge.h, %B ] br i1 %uni.cond, label %Y, label %exit -; CHECK: DIVERGENT: %div.merge.x = Y: %div.merge.y = phi i32 [ 42, %X ], [ %b, %C ] diff --git a/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/trivial-join-at-loop-exit.ll b/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/trivial-join-at-loop-exit.ll index 8ad848a..b872dd8 100644 --- a/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/trivial-join-at-loop-exit.ll +++ b/llvm/test/Analysis/DivergenceAnalysis/AMDGPU/trivial-join-at-loop-exit.ll @@ -1,7 +1,4 @@ ; RUN: opt -mtriple amdgcn-unknown-amdhsa -analyze -divergence -use-gpu-divergence-analysis %s | FileCheck %s -; XFAIL: * - -; https://bugs.llvm.org/show_bug.cgi?id=46372 ; CHECK: bb2: ; CHECK-NOT: DIVERGENT: %Guard.bb2 = phi i1 [ true, %bb1 ], [ false, %bb0 ] -- 2.7.4