From 1bf05fbc987dddebe94de7b33d810d8221eda1a5 Mon Sep 17 00:00:00 2001 From: Chen Zheng Date: Thu, 23 Sep 2021 05:48:46 +0000 Subject: [PATCH] [PowerPC] refactor rewriteLoadStores for reusing; nfc This is split from https://reviews.llvm.org/D108750. Refactor rewriteLoadStores() so that we can reuse the outlined functions. Reviewed By: jsji Differential Revision: https://reviews.llvm.org/D110314 --- llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp | 411 +++++++++++++---------- 1 file changed, 240 insertions(+), 171 deletions(-) diff --git a/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp b/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp index 6916a28..db91779 100644 --- a/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp +++ b/llvm/lib/Target/PowerPC/PPCLoopInstrFormPrep.cpp @@ -125,10 +125,10 @@ STATISTIC(UpdFormChainRewritten, "Num of update form chain rewritten"); namespace { struct BucketElement { - BucketElement(const SCEVConstant *O, Instruction *I) : Offset(O), Instr(I) {} + BucketElement(const SCEV *O, Instruction *I) : Offset(O), Instr(I) {} BucketElement(Instruction *I) : Offset(nullptr), Instr(I) {} - const SCEVConstant *Offset; + const SCEV *Offset; Instruction *Instr; }; @@ -234,6 +234,19 @@ namespace { bool rewriteLoadStores(Loop *L, Bucket &BucketChain, SmallSet &BBChanged, InstrForm Form); + + /// Rewrite for the base load/store of a chain. + std::pair + rewriteForBase(Loop *L, const SCEVAddRecExpr *BasePtrSCEV, + Instruction *BaseMemI, bool CanPreInc, InstrForm Form, + SCEVExpander &SCEVE, SmallPtrSet &DeletedPtrs); + + /// Rewrite for the other load/stores of a chain according to the new \p + /// Base. + Instruction * + rewriteForBucketElement(std::pair Base, + const BucketElement &Element, Value *OffToBase, + SmallPtrSet &DeletedPtrs); }; } // end anonymous namespace @@ -321,6 +334,193 @@ bool PPCLoopInstrFormPrep::runOnFunction(Function &F) { return MadeChange; } +// Rewrite the new base according to BasePtrSCEV. +// bb.loop.preheader: +// %newstart = ... +// bb.loop.body: +// %phinode = phi [ %newstart, %bb.loop.preheader ], [ %add, %bb.loop.body ] +// ... +// %add = getelementptr %phinode, %inc +// +// First returned instruciton is %phinode (or a type cast to %phinode), caller +// needs this value to rewrite other load/stores in the same chain. +// Second returned instruction is %add, caller needs this value to rewrite other +// load/stores in the same chain. +std::pair +PPCLoopInstrFormPrep::rewriteForBase(Loop *L, const SCEVAddRecExpr *BasePtrSCEV, + Instruction *BaseMemI, bool CanPreInc, + InstrForm Form, SCEVExpander &SCEVE, + SmallPtrSet &DeletedPtrs) { + + LLVM_DEBUG(dbgs() << "PIP: Transforming: " << *BasePtrSCEV << "\n"); + + assert(BasePtrSCEV->getLoop() == L && "AddRec for the wrong loop?"); + + Value *BasePtr = getPointerOperandAndType(BaseMemI); + assert(BasePtr && "No pointer operand"); + + Type *I8Ty = Type::getInt8Ty(BaseMemI->getParent()->getContext()); + Type *I8PtrTy = + Type::getInt8PtrTy(BaseMemI->getParent()->getContext(), + BasePtr->getType()->getPointerAddressSpace()); + + bool IsConstantInc = false; + const SCEV *BasePtrIncSCEV = BasePtrSCEV->getStepRecurrence(*SE); + Value *IncNode = getNodeForInc(L, BaseMemI, BasePtrIncSCEV); + + const SCEVConstant *BasePtrIncConstantSCEV = + dyn_cast(BasePtrIncSCEV); + if (BasePtrIncConstantSCEV) + IsConstantInc = true; + + // No valid representation for the increment. + if (!IncNode) { + LLVM_DEBUG(dbgs() << "Loop Increasement can not be represented!\n"); + return std::make_pair(nullptr, nullptr); + } + + const SCEV *BasePtrStartSCEV = nullptr; + if (CanPreInc) { + assert(SE->isLoopInvariant(BasePtrIncSCEV, L) && + "Increment is not loop invariant!\n"); + BasePtrStartSCEV = SE->getMinusSCEV(BasePtrSCEV->getStart(), + IsConstantInc ? BasePtrIncConstantSCEV + : BasePtrIncSCEV); + } else + BasePtrStartSCEV = BasePtrSCEV->getStart(); + + if (alreadyPrepared(L, BaseMemI, BasePtrStartSCEV, BasePtrIncSCEV, Form)) { + LLVM_DEBUG(dbgs() << "Instruction form is already prepared!\n"); + return std::make_pair(nullptr, nullptr); + } + + LLVM_DEBUG(dbgs() << "PIP: New start is: " << *BasePtrStartSCEV << "\n"); + + BasicBlock *Header = L->getHeader(); + unsigned HeaderLoopPredCount = pred_size(Header); + BasicBlock *LoopPredecessor = L->getLoopPredecessor(); + + PHINode *NewPHI = PHINode::Create(I8PtrTy, HeaderLoopPredCount, + getInstrName(BaseMemI, PHINodeNameSuffix), + Header->getFirstNonPHI()); + + Value *BasePtrStart = SCEVE.expandCodeFor(BasePtrStartSCEV, I8PtrTy, + LoopPredecessor->getTerminator()); + + // Note that LoopPredecessor might occur in the predecessor list multiple + // times, and we need to add it the right number of times. + for (auto PI : predecessors(Header)) { + if (PI != LoopPredecessor) + continue; + + NewPHI->addIncoming(BasePtrStart, LoopPredecessor); + } + + Instruction *PtrInc = nullptr; + Instruction *NewBasePtr = nullptr; + if (CanPreInc) { + Instruction *InsPoint = &*Header->getFirstInsertionPt(); + PtrInc = GetElementPtrInst::Create( + I8Ty, NewPHI, IncNode, getInstrName(BaseMemI, GEPNodeIncNameSuffix), + InsPoint); + cast(PtrInc)->setIsInBounds(IsPtrInBounds(BasePtr)); + for (auto PI : predecessors(Header)) { + if (PI == LoopPredecessor) + continue; + + NewPHI->addIncoming(PtrInc, PI); + } + if (PtrInc->getType() != BasePtr->getType()) + NewBasePtr = + new BitCastInst(PtrInc, BasePtr->getType(), + getInstrName(PtrInc, CastNodeNameSuffix), InsPoint); + else + NewBasePtr = PtrInc; + } else { + // Note that LoopPredecessor might occur in the predecessor list multiple + // times, and we need to make sure no more incoming value for them in PHI. + for (auto PI : predecessors(Header)) { + if (PI == LoopPredecessor) + continue; + + // For the latch predecessor, we need to insert a GEP just before the + // terminator to increase the address. + BasicBlock *BB = PI; + Instruction *InsPoint = BB->getTerminator(); + PtrInc = GetElementPtrInst::Create( + I8Ty, NewPHI, IncNode, getInstrName(BaseMemI, GEPNodeIncNameSuffix), + InsPoint); + cast(PtrInc)->setIsInBounds(IsPtrInBounds(BasePtr)); + + NewPHI->addIncoming(PtrInc, PI); + } + PtrInc = NewPHI; + if (NewPHI->getType() != BasePtr->getType()) + NewBasePtr = new BitCastInst(NewPHI, BasePtr->getType(), + getInstrName(NewPHI, CastNodeNameSuffix), + &*Header->getFirstInsertionPt()); + else + NewBasePtr = NewPHI; + } + + BasePtr->replaceAllUsesWith(NewBasePtr); + + DeletedPtrs.insert(BasePtr); + + return std::make_pair(NewBasePtr, PtrInc); +} + +Instruction *PPCLoopInstrFormPrep::rewriteForBucketElement( + std::pair Base, const BucketElement &Element, + Value *OffToBase, SmallPtrSet &DeletedPtrs) { + Instruction *NewBasePtr = Base.first; + Instruction *PtrInc = Base.second; + assert((NewBasePtr && PtrInc) && "base does not exist!\n"); + + Type *I8Ty = Type::getInt8Ty(PtrInc->getParent()->getContext()); + + Value *Ptr = getPointerOperandAndType(Element.Instr); + assert(Ptr && "No pointer operand"); + + Instruction *RealNewPtr; + if (!Element.Offset || + (isa(Element.Offset) && + cast(Element.Offset)->getValue()->isZero())) { + RealNewPtr = NewBasePtr; + } else { + Instruction *PtrIP = dyn_cast(Ptr); + if (PtrIP && isa(NewBasePtr) && + cast(NewBasePtr)->getParent() == PtrIP->getParent()) + PtrIP = nullptr; + else if (PtrIP && isa(PtrIP)) + PtrIP = &*PtrIP->getParent()->getFirstInsertionPt(); + else if (!PtrIP) + PtrIP = Element.Instr; + + assert(OffToBase && "There should be an offset for non base element!\n"); + GetElementPtrInst *NewPtr = GetElementPtrInst::Create( + I8Ty, PtrInc, OffToBase, + getInstrName(Element.Instr, GEPNodeOffNameSuffix), PtrIP); + if (!PtrIP) + NewPtr->insertAfter(cast(PtrInc)); + NewPtr->setIsInBounds(IsPtrInBounds(Ptr)); + RealNewPtr = NewPtr; + } + + Instruction *ReplNewPtr; + if (Ptr->getType() != RealNewPtr->getType()) { + ReplNewPtr = new BitCastInst(RealNewPtr, Ptr->getType(), + getInstrName(Ptr, CastNodeNameSuffix)); + ReplNewPtr->insertAfter(RealNewPtr); + } else + ReplNewPtr = RealNewPtr; + + Ptr->replaceAllUsesWith(ReplNewPtr); + DeletedPtrs.insert(Ptr); + + return ReplNewPtr; +} + void PPCLoopInstrFormPrep::addOneCandidate(Instruction *MemI, const SCEV *LSCEV, SmallVector &Buckets, unsigned MaxCandidateNum) { @@ -390,8 +590,9 @@ bool PPCLoopInstrFormPrep::prepareBaseForDispFormChain(Bucket &BucketChain, if (!BucketChain.Elements[j].Offset) RemainderOffsetInfo[0] = std::make_pair(0, 1); else { - unsigned Remainder = - BucketChain.Elements[j].Offset->getAPInt().urem(Form); + unsigned Remainder = cast(BucketChain.Elements[j].Offset) + ->getAPInt() + .urem(Form); if (RemainderOffsetInfo.find(Remainder) == RemainderOffsetInfo.end()) RemainderOffsetInfo[Remainder] = std::make_pair(j, 1); else @@ -473,7 +674,7 @@ bool PPCLoopInstrFormPrep::prepareBaseForUpdateFormChain(Bucket &BucketChain) { // If our chosen element has no offset from the base pointer, there's // nothing to do. if (!BucketChain.Elements[j].Offset || - BucketChain.Elements[j].Offset->isZero()) + cast(BucketChain.Elements[j].Offset)->isZero()) break; const SCEV *Offset = BucketChain.Elements[j].Offset; @@ -491,157 +692,46 @@ bool PPCLoopInstrFormPrep::prepareBaseForUpdateFormChain(Bucket &BucketChain) { return true; } -bool PPCLoopInstrFormPrep::rewriteLoadStores(Loop *L, Bucket &BucketChain, - SmallSet &BBChanged, - InstrForm Form) { +bool PPCLoopInstrFormPrep::rewriteLoadStores( + Loop *L, Bucket &BucketChain, SmallSet &BBChanged, + InstrForm Form) { bool MadeChange = false; + const SCEVAddRecExpr *BasePtrSCEV = cast(BucketChain.BaseSCEV); if (!BasePtrSCEV->isAffine()) return MadeChange; - LLVM_DEBUG(dbgs() << "PIP: Transforming: " << *BasePtrSCEV << "\n"); - - assert(BasePtrSCEV->getLoop() == L && "AddRec for the wrong loop?"); - - // The instruction corresponding to the Bucket's BaseSCEV must be the first - // in the vector of elements. - Instruction *MemI = BucketChain.Elements.begin()->Instr; - Value *BasePtr = getPointerOperandAndType(MemI); - assert(BasePtr && "No pointer operand"); - - Type *I8Ty = Type::getInt8Ty(MemI->getParent()->getContext()); - Type *I8PtrTy = Type::getInt8PtrTy(MemI->getParent()->getContext(), - BasePtr->getType()->getPointerAddressSpace()); - - if (!SE->isLoopInvariant(BasePtrSCEV->getStart(), L)) + if (!isSafeToExpand(BasePtrSCEV->getStart(), *SE)) return MadeChange; - bool IsConstantInc = false; - const SCEV *BasePtrIncSCEV = BasePtrSCEV->getStepRecurrence(*SE); - Value *IncNode = getNodeForInc(L, MemI, BasePtrIncSCEV); - - const SCEVConstant *BasePtrIncConstantSCEV = - dyn_cast(BasePtrIncSCEV); - if (BasePtrIncConstantSCEV) - IsConstantInc = true; + SmallPtrSet DeletedPtrs; - // No valid representation for the increment. - if (!IncNode) { - LLVM_DEBUG(dbgs() << "Loop Increasement can not be represented!\n"); - return MadeChange; - } + BasicBlock *Header = L->getHeader(); + SCEVExpander SCEVE(*SE, Header->getModule()->getDataLayout(), "pistart"); // For some DS form load/store instructions, it can also be an update form, // if the stride is constant and is a multipler of 4. Use update form if // prefer it. - bool CanPreInc = - (Form == UpdateForm || - ((Form == DSForm) && IsConstantInc && - !BasePtrIncConstantSCEV->getAPInt().urem(4) && PreferUpdateForm)); - const SCEV *BasePtrStartSCEV = nullptr; - if (CanPreInc) { - assert(SE->isLoopInvariant(BasePtrIncSCEV, L) && - "Increment is not loop invariant!\n"); - BasePtrStartSCEV = SE->getMinusSCEV(BasePtrSCEV->getStart(), - IsConstantInc ? BasePtrIncConstantSCEV - : BasePtrIncSCEV); - } else - BasePtrStartSCEV = BasePtrSCEV->getStart(); - - if (!isSafeToExpand(BasePtrStartSCEV, *SE)) - return MadeChange; - - if (alreadyPrepared(L, MemI, BasePtrStartSCEV, BasePtrIncSCEV, Form)) { - LLVM_DEBUG(dbgs() << "Instruction form is already prepared!\n"); + bool CanPreInc = (Form == UpdateForm || + ((Form == DSForm) && + isa(BasePtrSCEV->getStepRecurrence(*SE)) && + !cast(BasePtrSCEV->getStepRecurrence(*SE)) + ->getAPInt() + .urem(4) && + PreferUpdateForm)); + + std::pair Base = + rewriteForBase(L, BasePtrSCEV, BucketChain.Elements.begin()->Instr, + CanPreInc, Form, SCEVE, DeletedPtrs); + + if (!Base.first || !Base.second) return MadeChange; - } - - LLVM_DEBUG(dbgs() << "PIP: New start is: " << *BasePtrStartSCEV << "\n"); - - BasicBlock *Header = L->getHeader(); - unsigned HeaderLoopPredCount = pred_size(Header); - BasicBlock *LoopPredecessor = L->getLoopPredecessor(); - - PHINode *NewPHI = - PHINode::Create(I8PtrTy, HeaderLoopPredCount, - getInstrName(MemI, PHINodeNameSuffix), - Header->getFirstNonPHI()); - - SCEVExpander SCEVE(*SE, Header->getModule()->getDataLayout(), "pistart"); - Value *BasePtrStart = SCEVE.expandCodeFor(BasePtrStartSCEV, I8PtrTy, - LoopPredecessor->getTerminator()); - - // Note that LoopPredecessor might occur in the predecessor list multiple - // times, and we need to add it the right number of times. - for (auto PI : predecessors(Header)) { - if (PI != LoopPredecessor) - continue; - - NewPHI->addIncoming(BasePtrStart, LoopPredecessor); - } - - Instruction *PtrInc = nullptr; - Instruction *NewBasePtr = nullptr; - if (CanPreInc) { - Instruction *InsPoint = &*Header->getFirstInsertionPt(); - PtrInc = GetElementPtrInst::Create(I8Ty, NewPHI, IncNode, - getInstrName(MemI, GEPNodeIncNameSuffix), - InsPoint); - cast(PtrInc)->setIsInBounds(IsPtrInBounds(BasePtr)); - for (auto PI : predecessors(Header)) { - if (PI == LoopPredecessor) - continue; - - NewPHI->addIncoming(PtrInc, PI); - } - if (PtrInc->getType() != BasePtr->getType()) - NewBasePtr = new BitCastInst( - PtrInc, BasePtr->getType(), - getInstrName(PtrInc, CastNodeNameSuffix), InsPoint); - else - NewBasePtr = PtrInc; - } else { - // Note that LoopPredecessor might occur in the predecessor list multiple - // times, and we need to make sure no more incoming value for them in PHI. - for (auto PI : predecessors(Header)) { - if (PI == LoopPredecessor) - continue; - - // For the latch predecessor, we need to insert a GEP just before the - // terminator to increase the address. - BasicBlock *BB = PI; - Instruction *InsPoint = BB->getTerminator(); - PtrInc = GetElementPtrInst::Create( - I8Ty, NewPHI, IncNode, getInstrName(MemI, GEPNodeIncNameSuffix), - InsPoint); - cast(PtrInc)->setIsInBounds(IsPtrInBounds(BasePtr)); - - NewPHI->addIncoming(PtrInc, PI); - } - PtrInc = NewPHI; - if (NewPHI->getType() != BasePtr->getType()) - NewBasePtr = - new BitCastInst(NewPHI, BasePtr->getType(), - getInstrName(NewPHI, CastNodeNameSuffix), - &*Header->getFirstInsertionPt()); - else - NewBasePtr = NewPHI; - } - - // Clear the rewriter cache, because values that are in the rewriter's cache - // can be deleted below, causing the AssertingVH in the cache to trigger. - SCEVE.clear(); - - if (Instruction *IDel = dyn_cast(BasePtr)) - BBChanged.insert(IDel->getParent()); - BasePtr->replaceAllUsesWith(NewBasePtr); - RecursivelyDeleteTriviallyDeadInstructions(BasePtr); // Keep track of the replacement pointer values we've inserted so that we // don't generate more pointer values than necessary. SmallPtrSet NewPtrs; - NewPtrs.insert(NewBasePtr); + NewPtrs.insert(Base.first); for (auto I = std::next(BucketChain.Elements.begin()), IE = BucketChain.Elements.end(); I != IE; ++I) { @@ -650,43 +740,22 @@ bool PPCLoopInstrFormPrep::rewriteLoadStores(Loop *L, Bucket &BucketChain, if (NewPtrs.count(Ptr)) continue; - Instruction *RealNewPtr; - if (!I->Offset || I->Offset->getValue()->isZero()) { - RealNewPtr = NewBasePtr; - } else { - Instruction *PtrIP = dyn_cast(Ptr); - if (PtrIP && isa(NewBasePtr) && - cast(NewBasePtr)->getParent() == PtrIP->getParent()) - PtrIP = nullptr; - else if (PtrIP && isa(PtrIP)) - PtrIP = &*PtrIP->getParent()->getFirstInsertionPt(); - else if (!PtrIP) - PtrIP = I->Instr; - - GetElementPtrInst *NewPtr = GetElementPtrInst::Create( - I8Ty, PtrInc, I->Offset->getValue(), - getInstrName(I->Instr, GEPNodeOffNameSuffix), PtrIP); - if (!PtrIP) - NewPtr->insertAfter(cast(PtrInc)); - NewPtr->setIsInBounds(IsPtrInBounds(Ptr)); - RealNewPtr = NewPtr; - } + Instruction *NewPtr = rewriteForBucketElement( + Base, *I, + I->Offset ? cast(I->Offset)->getValue() : nullptr, + DeletedPtrs); + assert(NewPtr && "wrong rewrite!\n"); + NewPtrs.insert(NewPtr); + } + // Clear the rewriter cache, because values that are in the rewriter's cache + // can be deleted below, causing the AssertingVH in the cache to trigger. + SCEVE.clear(); + + for (auto *Ptr : DeletedPtrs) { if (Instruction *IDel = dyn_cast(Ptr)) BBChanged.insert(IDel->getParent()); - - Instruction *ReplNewPtr; - if (Ptr->getType() != RealNewPtr->getType()) { - ReplNewPtr = new BitCastInst(RealNewPtr, Ptr->getType(), - getInstrName(Ptr, CastNodeNameSuffix)); - ReplNewPtr->insertAfter(RealNewPtr); - } else - ReplNewPtr = RealNewPtr; - - Ptr->replaceAllUsesWith(ReplNewPtr); RecursivelyDeleteTriviallyDeadInstructions(Ptr); - - NewPtrs.insert(RealNewPtr); } MadeChange = true; -- 2.7.4