From 58eac856ccc0d7d0f6760b07c6be1537d970cda3 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 14 Feb 2023 18:36:07 -0500 Subject: [PATCH] [LICM] Ensure LICM can hoist invariant.group Invariant.group's are not sufficiently handled by LICM. Specifically, if a given invariant.group loaded pointer is not overwritten between the start of a loop, and its use in the load, it can be hoisted. The invariant.group (on an already invariant pointer operand) ensures the result is the same. If it is not overwritten between the start of the loop and the load, it is therefore legal to hoist. Reviewed By: nikic Differential Revision: https://reviews.llvm.org/D144053 --- llvm/lib/Transforms/Scalar/LICM.cpp | 26 +++++++++-- llvm/test/Transforms/LICM/invariant.group.ll | 69 ++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 5 deletions(-) create mode 100644 llvm/test/Transforms/LICM/invariant.group.ll diff --git a/llvm/lib/Transforms/Scalar/LICM.cpp b/llvm/lib/Transforms/Scalar/LICM.cpp index be7a625..8a301de 100644 --- a/llvm/lib/Transforms/Scalar/LICM.cpp +++ b/llvm/lib/Transforms/Scalar/LICM.cpp @@ -163,7 +163,8 @@ static bool isSafeToExecuteUnconditionally( AssumptionCache *AC, bool AllowSpeculation); static bool pointerInvalidatedByLoop(MemorySSA *MSSA, MemoryUse *MU, Loop *CurLoop, Instruction &I, - SinkAndHoistLICMFlags &Flags); + SinkAndHoistLICMFlags &Flags, + bool InvariantGroup); static bool pointerInvalidatedByBlock(BasicBlock &BB, MemorySSA &MSSA, MemoryUse &MU); static Instruction *cloneInstructionInExitBlock( @@ -1176,8 +1177,12 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, if (isLoadInvariantInLoop(LI, DT, CurLoop)) return true; + auto MU = cast(MSSA->getMemoryAccess(LI)); + + bool InvariantGroup = LI->hasMetadata(LLVMContext::MD_invariant_group); + bool Invalidated = pointerInvalidatedByLoop( - MSSA, cast(MSSA->getMemoryAccess(LI)), CurLoop, I, Flags); + MSSA, MU, CurLoop, I, Flags, InvariantGroup); // Check loop-invariant address because this may also be a sinkable load // whose address is not necessarily loop-invariant. if (ORE && Invalidated && CurLoop->isLoopInvariant(LI->getPointerOperand())) @@ -1228,7 +1233,7 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, if (Op->getType()->isPointerTy() && pointerInvalidatedByLoop( MSSA, cast(MSSA->getMemoryAccess(CI)), CurLoop, I, - Flags)) + Flags, /*InvariantGroup=*/false)) return false; return true; } @@ -2330,7 +2335,8 @@ collectPromotionCandidates(MemorySSA *MSSA, AliasAnalysis *AA, Loop *L) { static bool pointerInvalidatedByLoop(MemorySSA *MSSA, MemoryUse *MU, Loop *CurLoop, Instruction &I, - SinkAndHoistLICMFlags &Flags) { + SinkAndHoistLICMFlags &Flags, + bool InvariantGroup) { // For hoisting, use the walker to determine safety if (!Flags.getIsSink()) { MemoryAccess *Source; @@ -2341,8 +2347,18 @@ static bool pointerInvalidatedByLoop(MemorySSA *MSSA, MemoryUse *MU, Source = MSSA->getSkipSelfWalker()->getClobberingMemoryAccess(MU); Flags.incrementClobberingCalls(); } + // If hoisting an invariant group, we only need to check that there + // is no store to the loaded pointer between the start of the loop, + // and the load (since all values must be the same). + + // This can be checked in two conditions: + // 1) if the memoryaccess is outside the loop + // 2) the earliest access is at the loop header, + // if the memory loaded is the phi node + return !MSSA->isLiveOnEntryDef(Source) && - CurLoop->contains(Source->getBlock()); + CurLoop->contains(Source->getBlock()) && + !(InvariantGroup && Source->getBlock() == CurLoop->getHeader() && isa(Source)); } // For sinking, we'd need to check all Defs below this use. The getClobbering diff --git a/llvm/test/Transforms/LICM/invariant.group.ll b/llvm/test/Transforms/LICM/invariant.group.ll new file mode 100644 index 0000000..c2a2d21 --- /dev/null +++ b/llvm/test/Transforms/LICM/invariant.group.ll @@ -0,0 +1,69 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -passes=licm < %s -S | FileCheck %s + +define void @test(ptr %arg, ptr %arg1) { +; CHECK-LABEL: @test( +; CHECK-NEXT: bb2: +; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[ARG1:%.*]], align 4, !invariant.group !0 +; CHECK-NEXT: br label [[BB5:%.*]] +; CHECK: bb5: +; CHECK-NEXT: [[TMP6:%.*]] = phi i64 [ 0, [[BB2:%.*]] ], [ [[TMP10:%.*]], [[BB5]] ] +; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i32, ptr [[ARG:%.*]], i64 [[TMP6]] +; CHECK-NEXT: store i32 [[TMP3]], ptr [[TMP7]], align 8 +; CHECK-NEXT: [[TMP10]] = add nuw nsw i64 [[TMP6]], 1 +; CHECK-NEXT: [[TMP11:%.*]] = icmp eq i64 [[TMP10]], 200 +; CHECK-NEXT: br i1 [[TMP11]], label [[BB12:%.*]], label [[BB5]] +; CHECK: bb12: +; CHECK-NEXT: ret void +; +bb2: ; preds = %bb + br label %bb5 + +bb5: ; preds = %bb5, %bb2 + %tmp6 = phi i64 [ 0, %bb2 ], [ %tmp10, %bb5 ] + %tmp3 = load i32, ptr %arg1, align 4, !invariant.group !0 + %tmp7 = getelementptr inbounds i32, ptr %arg, i64 %tmp6 + store i32 %tmp3, ptr %tmp7, align 8 + %tmp10 = add nuw nsw i64 %tmp6, 1 + %tmp11 = icmp eq i64 %tmp10, 200 + br i1 %tmp11, label %bb12, label %bb5 + +bb12: ; preds = %bb5, %bb + ret void +} + + +define void @test_fail(ptr %arg, ptr %arg1) { +; CHECK-LABEL: @test_fail( +; CHECK-NEXT: bb2: +; CHECK-NEXT: br label [[BB5:%.*]] +; CHECK: bb5: +; CHECK-NEXT: [[TMP6:%.*]] = phi i64 [ 0, [[BB2:%.*]] ], [ [[TMP10:%.*]], [[BB5]] ] +; CHECK-NEXT: store i32 3, ptr [[ARG1:%.*]], align 4 +; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr [[ARG1]], align 4, !invariant.group !0 +; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds i32, ptr [[ARG:%.*]], i64 [[TMP6]] +; CHECK-NEXT: store i32 [[TMP3]], ptr [[TMP7]], align 8 +; CHECK-NEXT: [[TMP10]] = add nuw nsw i64 [[TMP6]], 1 +; CHECK-NEXT: [[TMP11:%.*]] = icmp eq i64 [[TMP10]], 200 +; CHECK-NEXT: br i1 [[TMP11]], label [[BB12:%.*]], label [[BB5]] +; CHECK: bb12: +; CHECK-NEXT: ret void +; +bb2: ; preds = %bb + br label %bb5 + +bb5: ; preds = %bb5, %bb2 + %tmp6 = phi i64 [ 0, %bb2 ], [ %tmp10, %bb5 ] + store i32 3, ptr %arg1 + %tmp3 = load i32, ptr %arg1, align 4, !invariant.group !0 + %tmp7 = getelementptr inbounds i32, ptr %arg, i64 %tmp6 + store i32 %tmp3, ptr %tmp7, align 8 + %tmp10 = add nuw nsw i64 %tmp6, 1 + %tmp11 = icmp eq i64 %tmp10, 200 + br i1 %tmp11, label %bb12, label %bb5 + +bb12: ; preds = %bb5, %bb + ret void +} + +!0 = !{} -- 2.7.4