From 269f563a2bcd88eba3fd4c25669713e64ab06f03 Mon Sep 17 00:00:00 2001 From: Evgeniy Brevnov Date: Tue, 19 Oct 2021 15:26:28 +0700 Subject: [PATCH] [NARY-REASSOCIATE] Fix infinite recursion optimizing min\max To guarantee convergence of the algorithm each optimization step should decrease number of instructions when IR is modified. This property is not held in this test case. The problem is that SCEV Expander may do "unexpected" reassociation what results in creation of new min/max chains and introduction of extra instructions. As a result on each step we indefinitely optimize back and forth. The solution is to restrict SCEV Expander to perform uncontrolled reassociations by means of "Unknown" expressions. Reviewed By: nikic Differential Revision: https://reviews.llvm.org/D112060 --- llvm/lib/Transforms/Scalar/NaryReassociate.cpp | 21 +++++++----- llvm/test/Transforms/NaryReassociate/nary-req.ll | 42 ++++++++++++++++++++++-- 2 files changed, 52 insertions(+), 11 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp index ded5caf..ef73cd0 100644 --- a/llvm/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/llvm/lib/Transforms/Scalar/NaryReassociate.cpp @@ -598,21 +598,24 @@ Value *NaryReassociatePass::tryReassociateMinOrMax(Instruction *I, MaxMinT m_MaxMin(m_Value(A), m_Value(B)); for (unsigned int i = 0; i < 2; ++i) { if (!LHS->hasNUsesOrMore(3) && match(LHS, m_MaxMin)) { + Value *C = RHS; const SCEV *AExpr = SE->getSCEV(A), *BExpr = SE->getSCEV(B); - const SCEV *RHSExpr = SE->getSCEV(RHS); + const SCEV *CExpr = SE->getSCEV(C); for (unsigned int j = 0; j < 2; ++j) { if (j == 0) { - if (BExpr == RHSExpr) + if (BExpr == CExpr) continue; - // Transform 'I = (A op B) op RHS' to 'I = (A op RHS) op B' on the + // Transform 'I = (A op B) op C' to 'I = (A op C) op B' on the // first iteration. - std::swap(BExpr, RHSExpr); + std::swap(BExpr, CExpr); + std::swap(B, C); } else { - if (AExpr == RHSExpr) + if (AExpr == CExpr) continue; - // Transform 'I = (A op RHS) op B' 'I = (B op RHS) op A' on the second + // Transform 'I = (A op C) op B' to 'I = (B op C) op A' on the second // iteration. - std::swap(AExpr, RHSExpr); + std::swap(AExpr, CExpr); + std::swap(A, C); } // The optimization is profitable only if LHS can be removed in the end. @@ -635,8 +638,8 @@ Value *NaryReassociatePass::tryReassociateMinOrMax(Instruction *I, LLVM_DEBUG(dbgs() << "NARY: Found common sub-expr: " << *R1MinMax << "\n"); - R1Expr = SE->getUnknown(R1MinMax); - SmallVector Ops2{ RHSExpr, R1Expr }; + SmallVector Ops2{SE->getUnknown(C), + SE->getUnknown(R1MinMax)}; const SCEV *R2Expr = SE->getMinMaxExpr(SCEVType, Ops2); Value *NewMinMax = Expander.expandCodeFor(R2Expr, I->getType(), I); diff --git a/llvm/test/Transforms/NaryReassociate/nary-req.ll b/llvm/test/Transforms/NaryReassociate/nary-req.ll index 5c4efc8..7b0d6b9 100644 --- a/llvm/test/Transforms/NaryReassociate/nary-req.ll +++ b/llvm/test/Transforms/NaryReassociate/nary-req.ll @@ -3,7 +3,7 @@ ; RUN: opt < %s -passes='nary-reassociate' -S | FileCheck %s declare i32 @llvm.smax.i32(i32 %a, i32 %b) -declare i64 @llvm.umin.i64(i64, i64) +declare i64 @llvm.umin.i64(i64, i64) ; This is a negative test. We should not optimize if intermediate result ; has a use outside of optimizable pattern. In other words %smax2 has one @@ -46,7 +46,8 @@ define void @test2(i64 %arg) { ; CHECK-NEXT: [[E4:%.*]] = sub i64 [[ARG]], 0 ; CHECK-NEXT: [[E5:%.*]] = call i64 @llvm.umin.i64(i64 [[E4]], i64 16384) ; CHECK-NEXT: [[E6:%.*]] = icmp ugt i64 [[E5]], 0 -; CHECK-NEXT: [[E10_NARY:%.*]] = call i64 @llvm.umin.i64(i64 [[E5]], i64 [[E]]) +; CHECK-NEXT: [[E7:%.*]] = sub i64 undef, 0 +; CHECK-NEXT: [[E10_NARY:%.*]] = call i64 @llvm.umin.i64(i64 [[E5]], i64 [[E7]]) ; CHECK-NEXT: unreachable ; bb: @@ -64,3 +65,40 @@ bb: unreachable } +; Make sure we don't fall into infinte loop optimizing %sel5. +; The subtle thing is that %sel3 is min/max as well and +; there is "unexpected" reassociation coming from SCEV Expander +; during %sel5 rewrite. That results in a new chain of min/max +; which is matched on the next iteration. +define i32 @nary_infinite_loop_minmax(i32 %d0, i32 %d1, i32 %d2, i32 %d3) { +; CHECK-LABEL: @nary_infinite_loop_minmax( +; CHECK-NEXT: [[CMP0:%.*]] = icmp slt i32 [[D2:%.*]], [[D1:%.*]] +; CHECK-NEXT: [[SEL0:%.*]] = select i1 [[CMP0]], i32 [[D1]], i32 [[D2]] +; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i32 [[D3:%.*]], [[D0:%.*]] +; CHECK-NEXT: [[SEL1:%.*]] = select i1 [[CMP1]], i32 [[D0]], i32 [[D3]] +; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i32 [[SEL1]], [[SEL0]] +; CHECK-NEXT: [[SEL2:%.*]] = select i1 [[CMP2]], i32 [[SEL1]], i32 [[SEL0]] +; CHECK-NEXT: [[CMP3:%.*]] = icmp slt i32 [[D3]], [[D0]] +; CHECK-NEXT: [[SEL3:%.*]] = select i1 [[CMP3]], i32 [[D0]], i32 [[D3]] +; CHECK-NEXT: [[SEL5_NARY:%.*]] = call i32 @llvm.smax.i32(i32 [[SEL0]], i32 [[SEL3]]) +; CHECK-NEXT: ret i32 [[SEL5_NARY]] +; + %cmp0 = icmp slt i32 %d2, %d1 + %sel0 = select i1 %cmp0, i32 %d1, i32 %d2 + + %cmp1 = icmp slt i32 %d3, %d0 + %sel1 = select i1 %cmp1, i32 %d0, i32 %d3 + + %cmp2 = icmp slt i32 %sel1, %sel0 + %sel2 = select i1 %cmp2, i32 %sel1, i32 %sel0 + + %cmp3 = icmp slt i32 %d3, %d0 + %sel3 = select i1 %cmp3, i32 %d0, i32 %d3 + + %cmp4 = icmp slt i32 %sel3, %d2 + %sel4 = select i1 %cmp4, i32 %d2, i32 %sel3 + + %cmp5 = icmp slt i32 %sel4, %d1 + %sel5 = select i1 %cmp5, i32 %d1, i32 %sel4 + ret i32 %sel5 +} -- 2.7.4