From 228cc2c38bfb9d703c33561dd88ce1b9e16601ce Mon Sep 17 00:00:00 2001 From: Andrew Litteken Date: Sun, 13 Mar 2022 23:44:46 -0500 Subject: [PATCH] [IROutliner] Ensure merged PHINodes respect order and incoming blocks, not just incoming values When matching PHINodes when margining functions the IROutliner only checks that an incoming value exists in phi node in overall function. It doesn't check the length, the order, or that the incoming block also matches. In the given example, we see that both phi nodes have the same incoming values, but from different blocks. The fix is to to enforce stricter a match of the incoming value, and the incoming block as well when matching the created phi nodes. Reviewers: paquette Differential Revision: https://reviews.llvm.org/D121310 --- llvm/include/llvm/Transforms/IPO/IROutliner.h | 9 ++ llvm/lib/Transforms/IPO/IROutliner.cpp | 87 ++++++++++++---- .../IROutliner/different-order-phi-merges.ll | 115 +++++++++++++++++++++ 3 files changed, 193 insertions(+), 18 deletions(-) create mode 100644 llvm/test/Transforms/IROutliner/different-order-phi-merges.ll diff --git a/llvm/include/llvm/Transforms/IPO/IROutliner.h b/llvm/include/llvm/Transforms/IPO/IROutliner.h index 283c98c..07e0059 100644 --- a/llvm/include/llvm/Transforms/IPO/IROutliner.h +++ b/llvm/include/llvm/Transforms/IPO/IROutliner.h @@ -169,6 +169,15 @@ struct OutlinableRegion { /// \return The corresponding Value to \p V if it exists, otherwise nullptr. Value *findCorrespondingValueIn(const OutlinableRegion &Other, Value *V); + /// Find a corresponding BasicBlock for \p BB in similar OutlinableRegion \p Other. + /// + /// \param Other [in] - The OutlinableRegion to find the corresponding + /// BasicBlock in. + /// \param BB [in] - The BasicBlock to look for in the other region. + /// \return The corresponding Value to \p V if it exists, otherwise nullptr. + BasicBlock *findCorrespondingBlockIn(const OutlinableRegion &Other, + BasicBlock *BB); + /// Get the size of the code removed from the region. /// /// \param [in] TTI - The TargetTransformInfo for the parent function. diff --git a/llvm/lib/Transforms/IPO/IROutliner.cpp b/llvm/lib/Transforms/IPO/IROutliner.cpp index 9028573..a7b77ee 100644 --- a/llvm/lib/Transforms/IPO/IROutliner.cpp +++ b/llvm/lib/Transforms/IPO/IROutliner.cpp @@ -190,6 +190,19 @@ Value *OutlinableRegion::findCorrespondingValueIn(const OutlinableRegion &Other, return FoundValueOpt.getValueOr(nullptr); } +BasicBlock * +OutlinableRegion::findCorrespondingBlockIn(const OutlinableRegion &Other, + BasicBlock *BB) { + Instruction *FirstNonPHI = BB->getFirstNonPHI(); + assert(FirstNonPHI && "block is empty?"); + Value *CorrespondingVal = findCorrespondingValueIn(Other, FirstNonPHI); + if (!CorrespondingVal) + return nullptr; + BasicBlock *CorrespondingBlock = + cast(CorrespondingVal)->getParent(); + return CorrespondingBlock; +} + /// Rewrite the BranchInsts in the incoming blocks to \p PHIBlock that are found /// in \p Included to branch to BasicBlock \p Replace if they currently branch /// to the BasicBlock \p Find. This is used to fix up the incoming basic blocks @@ -1530,17 +1543,18 @@ getPassedArgumentAndAdjustArgumentLocation(const Argument *A, /// \param OutputMappings [in] - The mapping of output values from outlined /// region to their original values. /// \param CanonNums [out] - The canonical numbering for the incoming values to -/// \p PN. +/// \p PN paired with their incoming block. /// \param ReplacedWithOutlinedCall - A flag to use the extracted function call /// of \p Region rather than the overall function's call. -static void -findCanonNumsForPHI(PHINode *PN, OutlinableRegion &Region, - const DenseMap &OutputMappings, - DenseSet &CanonNums, - bool ReplacedWithOutlinedCall = true) { +static void findCanonNumsForPHI( + PHINode *PN, OutlinableRegion &Region, + const DenseMap &OutputMappings, + SmallVector> &CanonNums, + bool ReplacedWithOutlinedCall = true) { // Iterate over the incoming values. for (unsigned Idx = 0, EIdx = PN->getNumIncomingValues(); Idx < EIdx; Idx++) { Value *IVal = PN->getIncomingValue(Idx); + BasicBlock *IBlock = PN->getIncomingBlock(Idx); // If we have an argument as incoming value, we need to grab the passed // value from the call itself. if (Argument *A = dyn_cast(IVal)) { @@ -1558,7 +1572,7 @@ findCanonNumsForPHI(PHINode *PN, OutlinableRegion &Region, assert(GVN.hasValue() && "No GVN for incoming value"); Optional CanonNum = Region.Candidate->getCanonicalNum(*GVN); assert(CanonNum.hasValue() && "No Canonical Number for GVN"); - CanonNums.insert(*CanonNum); + CanonNums.push_back(std::make_pair(*CanonNum, IBlock)); } } @@ -1582,7 +1596,11 @@ findOrCreatePHIInBlock(PHINode &PN, OutlinableRegion &Region, DenseSet &UsedPHIs) { OutlinableGroup &Group = *Region.Parent; - DenseSet PNCanonNums; + + // A list of the canonical numbering assigned to each incoming value, paired + // with the incoming block for the PHINode passed into this function. + SmallVector> PNCanonNums; + // We have to use the extracted function since we have merged this region into // the overall function yet. We make sure to reassign the argument numbering // since it is possible that the argument ordering is different between the @@ -1591,7 +1609,12 @@ findOrCreatePHIInBlock(PHINode &PN, OutlinableRegion &Region, /* ReplacedWithOutlinedCall = */ false); OutlinableRegion *FirstRegion = Group.Regions[0]; - DenseSet CurrentCanonNums; + + // A list of the canonical numbering assigned to each incoming value, paired + // with the incoming block for the PHINode that we are currently comparing + // the passed PHINode to. + SmallVector> CurrentCanonNums; + // Find the Canonical Numbering for each PHINode, if it matches, we replace // the uses of the PHINode we are searching for, with the found PHINode. for (PHINode &CurrPN : OverallPhiBlock->phis()) { @@ -1603,9 +1626,41 @@ findOrCreatePHIInBlock(PHINode &PN, OutlinableRegion &Region, CurrentCanonNums.clear(); findCanonNumsForPHI(&CurrPN, *FirstRegion, OutputMappings, CurrentCanonNums, /* ReplacedWithOutlinedCall = */ true); - if (all_of(PNCanonNums, [&CurrentCanonNums](unsigned CanonNum) { - return CurrentCanonNums.contains(CanonNum); - })) { + + // If the list of incoming values is not the same length, then they cannot + // match since there is not an analogue for each incoming value. + if (PNCanonNums.size() != CurrentCanonNums.size()) + continue; + + bool FoundMatch = true; + + // We compare the canonical value for each incoming value in the passed + // in PHINode to one already present in the outlined region. If the + // incoming values do not match, then the PHINodes do not match. + + // We also check to make sure that the incoming block matches as well by + // finding the corresponding incoming block in the combined outlined region + // for the current outlined region. + for (unsigned Idx = 0, Edx = PNCanonNums.size(); Idx < Edx; ++Idx) { + std::pair ToCompareTo = CurrentCanonNums[Idx]; + std::pair ToAdd = PNCanonNums[Idx]; + if (ToCompareTo.first != ToAdd.first) { + FoundMatch = false; + break; + } + + BasicBlock *CorrespondingBlock = + Region.findCorrespondingBlockIn(*FirstRegion, ToAdd.second); + assert(CorrespondingBlock && "Found block is nullptr"); + if (CorrespondingBlock != ToCompareTo.second) { + FoundMatch = false; + break; + } + } + + // If all incoming values and branches matched, then we can merge + // into the found PHINode. + if (FoundMatch) { UsedPHIs.insert(&CurrPN); return &CurrPN; } @@ -1622,12 +1677,8 @@ findOrCreatePHIInBlock(PHINode &PN, OutlinableRegion &Region, // Find corresponding basic block in the overall function for the incoming // block. - Instruction *FirstNonPHI = IncomingBlock->getFirstNonPHI(); - assert(FirstNonPHI && "Incoming block is empty?"); - Value *CorrespondingVal = - Region.findCorrespondingValueIn(*FirstRegion, FirstNonPHI); - assert(CorrespondingVal && "Value is nullptr?"); - BasicBlock *BlockToUse = cast(CorrespondingVal)->getParent(); + BasicBlock *BlockToUse = + Region.findCorrespondingBlockIn(*FirstRegion, IncomingBlock); NewPN->setIncomingBlock(Idx, BlockToUse); // If we have an argument we make sure we replace using the argument from diff --git a/llvm/test/Transforms/IROutliner/different-order-phi-merges.ll b/llvm/test/Transforms/IROutliner/different-order-phi-merges.ll new file mode 100644 index 0000000..7539f83 --- /dev/null +++ b/llvm/test/Transforms/IROutliner/different-order-phi-merges.ll @@ -0,0 +1,115 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --include-generated-funcs +; RUN: opt -S -verify -iroutliner -ir-outlining-no-cost < %s | FileCheck %s + +; Check that differently ordered phi nodes are not matched when merged, instead +; generating two output paths. + +define void @f1() { +bb1: + %0 = add i32 1, 2 + %1 = add i32 3, 4 + %2 = add i32 5, 6 + %3 = add i32 7, 8 + br i1 true, label %bb2, label %bb5 +bb2: + %4 = mul i32 5, 4 + br label %bb5 + +placeholder: + %a = sub i32 5, 4 + ret void + +bb5: + %phinode = phi i32 [%3, %bb1], [%2, %bb2] + ret void +} + +define void @f2() { +bb1: + %0 = add i32 1, 2 + %1 = add i32 3, 4 + %2 = add i32 5, 6 + %3 = add i32 7, 8 + br i1 true, label %bb2, label %bb5 +bb2: + %4 = mul i32 5, 4 + br label %bb5 + +placeholder: + %a = sub i32 5, 4 + ret void + +bb5: + %phinode = phi i32 [%2, %bb1], [%3, %bb2] + ret void +} +; CHECK-LABEL: @f1( +; CHECK-NEXT: bb1: +; CHECK-NEXT: [[PHINODE_CE_LOC:%.*]] = alloca i32, align 4 +; CHECK-NEXT: [[LT_CAST:%.*]] = bitcast i32* [[PHINODE_CE_LOC]] to i8* +; CHECK-NEXT: call void @llvm.lifetime.start.p0i8(i64 -1, i8* [[LT_CAST]]) +; CHECK-NEXT: [[TMP0:%.*]] = call i1 @outlined_ir_func_0(i32* [[PHINODE_CE_LOC]], i32 0) +; CHECK-NEXT: [[PHINODE_CE_RELOAD:%.*]] = load i32, i32* [[PHINODE_CE_LOC]], align 4 +; CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 -1, i8* [[LT_CAST]]) +; CHECK-NEXT: br i1 [[TMP0]], label [[BB5:%.*]], label [[BB1_AFTER_OUTLINE:%.*]] +; CHECK: bb1_after_outline: +; CHECK-NEXT: ret void +; CHECK: bb5: +; CHECK-NEXT: [[PHINODE:%.*]] = phi i32 [ [[PHINODE_CE_RELOAD]], [[BB1:%.*]] ] +; CHECK-NEXT: ret void +; +; +; CHECK-LABEL: @f2( +; CHECK-NEXT: bb1: +; CHECK-NEXT: [[PHINODE_CE_LOC:%.*]] = alloca i32, align 4 +; CHECK-NEXT: [[LT_CAST:%.*]] = bitcast i32* [[PHINODE_CE_LOC]] to i8* +; CHECK-NEXT: call void @llvm.lifetime.start.p0i8(i64 -1, i8* [[LT_CAST]]) +; CHECK-NEXT: [[TMP0:%.*]] = call i1 @outlined_ir_func_0(i32* [[PHINODE_CE_LOC]], i32 1) +; CHECK-NEXT: [[PHINODE_CE_RELOAD:%.*]] = load i32, i32* [[PHINODE_CE_LOC]], align 4 +; CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 -1, i8* [[LT_CAST]]) +; CHECK-NEXT: br i1 [[TMP0]], label [[BB5:%.*]], label [[BB1_AFTER_OUTLINE:%.*]] +; CHECK: bb1_after_outline: +; CHECK-NEXT: ret void +; CHECK: bb5: +; CHECK-NEXT: [[PHINODE:%.*]] = phi i32 [ [[PHINODE_CE_RELOAD]], [[BB1:%.*]] ] +; CHECK-NEXT: ret void +; +; +; CHECK-LABEL: define internal i1 @outlined_ir_func_0( +; CHECK-NEXT: newFuncRoot: +; CHECK-NEXT: br label [[BB1_TO_OUTLINE:%.*]] +; CHECK: bb1_to_outline: +; CHECK-NEXT: [[TMP2:%.*]] = add i32 1, 2 +; CHECK-NEXT: [[TMP3:%.*]] = add i32 3, 4 +; CHECK-NEXT: [[TMP4:%.*]] = add i32 5, 6 +; CHECK-NEXT: [[TMP5:%.*]] = add i32 7, 8 +; CHECK-NEXT: br i1 true, label [[BB2:%.*]], label [[BB5_SPLIT:%.*]] +; CHECK: bb2: +; CHECK-NEXT: [[TMP6:%.*]] = mul i32 5, 4 +; CHECK-NEXT: br label [[BB5_SPLIT]] +; CHECK: placeholder: +; CHECK-NEXT: [[A:%.*]] = sub i32 5, 4 +; CHECK-NEXT: br label [[BB1_AFTER_OUTLINE_EXITSTUB:%.*]] +; CHECK: bb5.split: +; CHECK-NEXT: [[TMP7:%.*]] = phi i32 [ [[TMP4]], [[BB1_TO_OUTLINE]] ], [ [[TMP5]], [[BB2]] ] +; CHECK-NEXT: [[PHINODE_CE:%.*]] = phi i32 [ [[TMP5]], [[BB1_TO_OUTLINE]] ], [ [[TMP4]], [[BB2]] ] +; CHECK-NEXT: br label [[BB5_EXITSTUB:%.*]] +; CHECK: bb5.exitStub: +; CHECK-NEXT: switch i32 [[TMP1:%.*]], label [[FINAL_BLOCK_1:%.*]] [ +; CHECK-NEXT: i32 0, label [[OUTPUT_BLOCK_0_1:%.*]] +; CHECK-NEXT: i32 1, label [[OUTPUT_BLOCK_1_1:%.*]] +; CHECK-NEXT: ] +; CHECK: bb1_after_outline.exitStub: +; CHECK-NEXT: switch i32 [[TMP1]], label [[FINAL_BLOCK_0:%.*]] [ +; CHECK-NEXT: ] +; CHECK: output_block_0_1: +; CHECK-NEXT: store i32 [[PHINODE_CE]], i32* [[TMP0:%.*]], align 4 +; CHECK-NEXT: br label [[FINAL_BLOCK_1]] +; CHECK: output_block_1_1: +; CHECK-NEXT: store i32 [[TMP7]], i32* [[TMP0]], align 4 +; CHECK-NEXT: br label [[FINAL_BLOCK_1]] +; CHECK: final_block_0: +; CHECK-NEXT: ret i1 false +; CHECK: final_block_1: +; CHECK-NEXT: ret i1 true +; -- 2.7.4