From: Sanjay Patel Date: Mon, 3 Feb 2020 13:55:43 +0000 (-0500) Subject: [InstCombine] reassociate splatted vector ops X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=e78fb556c5520161fb5943b665da3ca98f3ae53d;p=platform%2Fupstream%2Fllvm.git [InstCombine] reassociate splatted vector ops bo (splat X), (bo Y, OtherOp) --> bo (splat (bo X, Y)), OtherOp This patch depends on the splat analysis enhancement in D73549. See the test with comment: ; Negative test - mismatched splat elements ...as the motivation for that first patch. The motivating case for reassociating splatted ops is shown in PR42174: https://bugs.llvm.org/show_bug.cgi?id=42174 In that example, a slight change in order-of-associative math results in a big difference in IR and codegen. This patch gets all of the unnecessary shuffles out of the way, but doesn't address the potential scalarization (see D50992 or D73480 for that). Differential Revision: https://reviews.llvm.org/D73703 --- diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index bded1d7..f11fa27 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -60,6 +60,7 @@ #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constant.h" @@ -1683,6 +1684,54 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { } } + // Try to reassociate to sink a splat shuffle after a binary operation. + if (Inst.isAssociative() && Inst.isCommutative()) { + // Canonicalize shuffle operand as LHS. + if (auto *ShufR = dyn_cast(RHS)) + std::swap(LHS, RHS); + + Value *X; + Constant *MaskC; + const APInt *SplatIndex; + BinaryOperator *BO; + if (!match(LHS, m_OneUse(m_ShuffleVector(m_Value(X), m_Undef(), + m_Constant(MaskC)))) || + !match(MaskC, m_APIntAllowUndef(SplatIndex)) || + X->getType() != Inst.getType() || !match(RHS, m_OneUse(m_BinOp(BO))) || + BO->getOpcode() != Opcode) + return nullptr; + + Value *Y, *OtherOp; + if (isSplatValue(BO->getOperand(0), SplatIndex->getZExtValue())) { + Y = BO->getOperand(0); + OtherOp = BO->getOperand(1); + } else if (isSplatValue(BO->getOperand(1), SplatIndex->getZExtValue())) { + Y = BO->getOperand(1); + OtherOp = BO->getOperand(0); + } else { + return nullptr; + } + + // X and Y are splatted values, so perform the binary operation on those + // values followed by a splat followed by the 2nd binary operation: + // bo (splat X), (bo Y, OtherOp) --> bo (splat (bo X, Y)), OtherOp + Value *NewBO = Builder.CreateBinOp(Opcode, X, Y); + UndefValue *Undef = UndefValue::get(Inst.getType()); + Constant *NewMask = ConstantInt::get(MaskC->getType(), *SplatIndex); + Value *NewSplat = Builder.CreateShuffleVector(NewBO, Undef, NewMask); + Instruction *R = BinaryOperator::Create(Opcode, NewSplat, OtherOp); + + // Intersect FMF on both new binops. Other (poison-generating) flags are + // dropped to be safe. + if (isa(R)) { + R->copyFastMathFlags(&Inst); + R->andIRFlags(BO); + } + if (auto *NewInstBO = dyn_cast(NewBO)) + NewInstBO->copyIRFlags(R); + return R; + } + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/vec_shuffle.ll b/llvm/test/Transforms/InstCombine/vec_shuffle.ll index abbf1df..e806fac 100644 --- a/llvm/test/Transforms/InstCombine/vec_shuffle.ll +++ b/llvm/test/Transforms/InstCombine/vec_shuffle.ll @@ -1457,9 +1457,9 @@ define <4 x float> @insert_subvector_crash_invalid_mask_elt(<2 x float> %x, <4 x define <4 x i32> @splat_assoc_add(<4 x i32> %x, <4 x i32> %y) { ; CHECK-LABEL: @splat_assoc_add( -; CHECK-NEXT: [[SPLATX:%.*]] = shufflevector <4 x i32> [[X:%.*]], <4 x i32> undef, <4 x i32> zeroinitializer -; CHECK-NEXT: [[A:%.*]] = add <4 x i32> [[Y:%.*]], -; CHECK-NEXT: [[R:%.*]] = add <4 x i32> [[SPLATX]], [[A]] +; CHECK-NEXT: [[TMP1:%.*]] = add <4 x i32> [[X:%.*]], +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> undef, <4 x i32> zeroinitializer +; CHECK-NEXT: [[R:%.*]] = add <4 x i32> [[TMP2]], [[Y:%.*]] ; CHECK-NEXT: ret <4 x i32> [[R]] ; %splatx = shufflevector <4 x i32> %x, <4 x i32> undef, <4 x i32> zeroinitializer @@ -1468,11 +1468,13 @@ define <4 x i32> @splat_assoc_add(<4 x i32> %x, <4 x i32> %y) { ret <4 x i32> %r } +; Non-zero splat index; commute operands; FMF intersect + define <2 x float> @splat_assoc_fmul(<2 x float> %x, <2 x float> %y) { ; CHECK-LABEL: @splat_assoc_fmul( -; CHECK-NEXT: [[SPLATX:%.*]] = shufflevector <2 x float> [[X:%.*]], <2 x float> undef, <2 x i32> -; CHECK-NEXT: [[A:%.*]] = fmul reassoc nsz <2 x float> [[Y:%.*]], -; CHECK-NEXT: [[R:%.*]] = fmul reassoc nnan nsz <2 x float> [[A]], [[SPLATX]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul reassoc nsz <2 x float> [[X:%.*]], +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x float> [[TMP1]], <2 x float> undef, <2 x i32> +; CHECK-NEXT: [[R:%.*]] = fmul reassoc nsz <2 x float> [[TMP2]], [[Y:%.*]] ; CHECK-NEXT: ret <2 x float> [[R]] ; %splatx = shufflevector <2 x float> %x, <2 x float> undef, <2 x i32> @@ -1481,12 +1483,13 @@ define <2 x float> @splat_assoc_fmul(<2 x float> %x, <2 x float> %y) { ret <2 x float> %r } +; Two splat shuffles; drop poison-generating flags + define <3 x i8> @splat_assoc_mul(<3 x i8> %x, <3 x i8> %y, <3 x i8> %z) { ; CHECK-LABEL: @splat_assoc_mul( -; CHECK-NEXT: [[SPLATX:%.*]] = shufflevector <3 x i8> [[X:%.*]], <3 x i8> undef, <3 x i32> -; CHECK-NEXT: [[SPLATZ:%.*]] = shufflevector <3 x i8> [[Z:%.*]], <3 x i8> undef, <3 x i32> -; CHECK-NEXT: [[A:%.*]] = mul nsw <3 x i8> [[SPLATZ]], [[Y:%.*]] -; CHECK-NEXT: [[R:%.*]] = mul <3 x i8> [[A]], [[SPLATX]] +; CHECK-NEXT: [[TMP1:%.*]] = mul <3 x i8> [[X:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <3 x i8> [[TMP1]], <3 x i8> undef, <3 x i32> +; CHECK-NEXT: [[R:%.*]] = mul <3 x i8> [[TMP2]], [[Y:%.*]] ; CHECK-NEXT: ret <3 x i8> [[R]] ; %splatx = shufflevector <3 x i8> %x, <3 x i8> undef, <3 x i32> @@ -1496,7 +1499,7 @@ define <3 x i8> @splat_assoc_mul(<3 x i8> %x, <3 x i8> %y, <3 x i8> %z) { ret <3 x i8> %r } -; Mismatched splat elements +; Negative test - mismatched splat elements define <3 x i8> @splat_assoc_or(<3 x i8> %x, <3 x i8> %y, <3 x i8> %z) { ; CHECK-LABEL: @splat_assoc_or( @@ -1513,7 +1516,7 @@ define <3 x i8> @splat_assoc_or(<3 x i8> %x, <3 x i8> %y, <3 x i8> %z) { ret <3 x i8> %r } -; Not associative +; Negative test - not associative define <2 x float> @splat_assoc_fdiv(<2 x float> %x, <2 x float> %y) { ; CHECK-LABEL: @splat_assoc_fdiv( @@ -1528,7 +1531,7 @@ define <2 x float> @splat_assoc_fdiv(<2 x float> %x, <2 x float> %y) { ret <2 x float> %r } -; Extra use +; Negative test - extra use define <2 x float> @splat_assoc_fadd(<2 x float> %x, <2 x float> %y) { ; CHECK-LABEL: @splat_assoc_fadd( @@ -1545,7 +1548,7 @@ define <2 x float> @splat_assoc_fadd(<2 x float> %x, <2 x float> %y) { ret <2 x float> %r } -; Narrowing splat +; Negative test - narrowing splat define <3 x i32> @splat_assoc_and(<4 x i32> %x, <3 x i32> %y) { ; CHECK-LABEL: @splat_assoc_and( @@ -1560,7 +1563,7 @@ define <3 x i32> @splat_assoc_and(<4 x i32> %x, <3 x i32> %y) { ret <3 x i32> %r } -; Widening splat +; Negative test - widening splat define <5 x i32> @splat_assoc_xor(<4 x i32> %x, <5 x i32> %y) { ; CHECK-LABEL: @splat_assoc_xor( @@ -1575,7 +1578,7 @@ define <5 x i32> @splat_assoc_xor(<4 x i32> %x, <5 x i32> %y) { ret <5 x i32> %r } -; Opcode mismatch +; Negative test - opcode mismatch define <4 x i32> @splat_assoc_add_mul(<4 x i32> %x, <4 x i32> %y) { ; CHECK-LABEL: @splat_assoc_add_mul( diff --git a/llvm/test/Transforms/LoopVectorize/induction.ll b/llvm/test/Transforms/LoopVectorize/induction.ll index 6bcf03f..e093ee4 100644 --- a/llvm/test/Transforms/LoopVectorize/induction.ll +++ b/llvm/test/Transforms/LoopVectorize/induction.ll @@ -427,7 +427,7 @@ for.end: ; UNROLL: %[[i1:.+]] = or i64 %index, 1 ; UNROLL: %[[i2:.+]] = or i64 %index, 2 ; UNROLL: %[[i3:.+]] = or i64 %index, 3 -; UNROLL: %step.add3 = add <2 x i32> %vec.ind2, +; UNROLL: %[[add:.+]]= add <2 x i32> %[[splat:.+]], ; UNROLL: getelementptr inbounds %pair.i16, %pair.i16* %p, i64 %index, i32 1 ; UNROLL: getelementptr inbounds %pair.i16, %pair.i16* %p, i64 %[[i1]], i32 1 ; UNROLL: getelementptr inbounds %pair.i16, %pair.i16* %p, i64 %[[i2]], i32 1