From e78fb556c5520161fb5943b665da3ca98f3ae53d Mon Sep 17 00:00:00 2001 From: Sanjay Patel Date: Mon, 3 Feb 2020 08:55:43 -0500 Subject: [PATCH] [InstCombine] reassociate splatted vector ops bo (splat X), (bo Y, OtherOp) --> bo (splat (bo X, Y)), OtherOp This patch depends on the splat analysis enhancement in D73549. See the test with comment: ; Negative test - mismatched splat elements ...as the motivation for that first patch. The motivating case for reassociating splatted ops is shown in PR42174: https://bugs.llvm.org/show_bug.cgi?id=42174 In that example, a slight change in order-of-associative math results in a big difference in IR and codegen. This patch gets all of the unnecessary shuffles out of the way, but doesn't address the potential scalarization (see D50992 or D73480 for that). Differential Revision: https://reviews.llvm.org/D73703 --- .../InstCombine/InstructionCombining.cpp | 49 ++++++++++++++++++++++ llvm/test/Transforms/InstCombine/vec_shuffle.ll | 35 +++++++++------- llvm/test/Transforms/LoopVectorize/induction.ll | 2 +- 3 files changed, 69 insertions(+), 17 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index bded1d7..f11fa27 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -60,6 +60,7 @@ #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constant.h" @@ -1683,6 +1684,54 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { } } + // Try to reassociate to sink a splat shuffle after a binary operation. + if (Inst.isAssociative() && Inst.isCommutative()) { + // Canonicalize shuffle operand as LHS. + if (auto *ShufR = dyn_cast(RHS)) + std::swap(LHS, RHS); + + Value *X; + Constant *MaskC; + const APInt *SplatIndex; + BinaryOperator *BO; + if (!match(LHS, m_OneUse(m_ShuffleVector(m_Value(X), m_Undef(), + m_Constant(MaskC)))) || + !match(MaskC, m_APIntAllowUndef(SplatIndex)) || + X->getType() != Inst.getType() || !match(RHS, m_OneUse(m_BinOp(BO))) || + BO->getOpcode() != Opcode) + return nullptr; + + Value *Y, *OtherOp; + if (isSplatValue(BO->getOperand(0), SplatIndex->getZExtValue())) { + Y = BO->getOperand(0); + OtherOp = BO->getOperand(1); + } else if (isSplatValue(BO->getOperand(1), SplatIndex->getZExtValue())) { + Y = BO->getOperand(1); + OtherOp = BO->getOperand(0); + } else { + return nullptr; + } + + // X and Y are splatted values, so perform the binary operation on those + // values followed by a splat followed by the 2nd binary operation: + // bo (splat X), (bo Y, OtherOp) --> bo (splat (bo X, Y)), OtherOp + Value *NewBO = Builder.CreateBinOp(Opcode, X, Y); + UndefValue *Undef = UndefValue::get(Inst.getType()); + Constant *NewMask = ConstantInt::get(MaskC->getType(), *SplatIndex); + Value *NewSplat = Builder.CreateShuffleVector(NewBO, Undef, NewMask); + Instruction *R = BinaryOperator::Create(Opcode, NewSplat, OtherOp); + + // Intersect FMF on both new binops. Other (poison-generating) flags are + // dropped to be safe. + if (isa(R)) { + R->copyFastMathFlags(&Inst); + R->andIRFlags(BO); + } + if (auto *NewInstBO = dyn_cast(NewBO)) + NewInstBO->copyIRFlags(R); + return R; + } + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/vec_shuffle.ll b/llvm/test/Transforms/InstCombine/vec_shuffle.ll index abbf1df..e806fac 100644 --- a/llvm/test/Transforms/InstCombine/vec_shuffle.ll +++ b/llvm/test/Transforms/InstCombine/vec_shuffle.ll @@ -1457,9 +1457,9 @@ define <4 x float> @insert_subvector_crash_invalid_mask_elt(<2 x float> %x, <4 x define <4 x i32> @splat_assoc_add(<4 x i32> %x, <4 x i32> %y) { ; CHECK-LABEL: @splat_assoc_add( -; CHECK-NEXT: [[SPLATX:%.*]] = shufflevector <4 x i32> [[X:%.*]], <4 x i32> undef, <4 x i32> zeroinitializer -; CHECK-NEXT: [[A:%.*]] = add <4 x i32> [[Y:%.*]], -; CHECK-NEXT: [[R:%.*]] = add <4 x i32> [[SPLATX]], [[A]] +; CHECK-NEXT: [[TMP1:%.*]] = add <4 x i32> [[X:%.*]], +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> undef, <4 x i32> zeroinitializer +; CHECK-NEXT: [[R:%.*]] = add <4 x i32> [[TMP2]], [[Y:%.*]] ; CHECK-NEXT: ret <4 x i32> [[R]] ; %splatx = shufflevector <4 x i32> %x, <4 x i32> undef, <4 x i32> zeroinitializer @@ -1468,11 +1468,13 @@ define <4 x i32> @splat_assoc_add(<4 x i32> %x, <4 x i32> %y) { ret <4 x i32> %r } +; Non-zero splat index; commute operands; FMF intersect + define <2 x float> @splat_assoc_fmul(<2 x float> %x, <2 x float> %y) { ; CHECK-LABEL: @splat_assoc_fmul( -; CHECK-NEXT: [[SPLATX:%.*]] = shufflevector <2 x float> [[X:%.*]], <2 x float> undef, <2 x i32> -; CHECK-NEXT: [[A:%.*]] = fmul reassoc nsz <2 x float> [[Y:%.*]], -; CHECK-NEXT: [[R:%.*]] = fmul reassoc nnan nsz <2 x float> [[A]], [[SPLATX]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul reassoc nsz <2 x float> [[X:%.*]], +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x float> [[TMP1]], <2 x float> undef, <2 x i32> +; CHECK-NEXT: [[R:%.*]] = fmul reassoc nsz <2 x float> [[TMP2]], [[Y:%.*]] ; CHECK-NEXT: ret <2 x float> [[R]] ; %splatx = shufflevector <2 x float> %x, <2 x float> undef, <2 x i32> @@ -1481,12 +1483,13 @@ define <2 x float> @splat_assoc_fmul(<2 x float> %x, <2 x float> %y) { ret <2 x float> %r } +; Two splat shuffles; drop poison-generating flags + define <3 x i8> @splat_assoc_mul(<3 x i8> %x, <3 x i8> %y, <3 x i8> %z) { ; CHECK-LABEL: @splat_assoc_mul( -; CHECK-NEXT: [[SPLATX:%.*]] = shufflevector <3 x i8> [[X:%.*]], <3 x i8> undef, <3 x i32> -; CHECK-NEXT: [[SPLATZ:%.*]] = shufflevector <3 x i8> [[Z:%.*]], <3 x i8> undef, <3 x i32> -; CHECK-NEXT: [[A:%.*]] = mul nsw <3 x i8> [[SPLATZ]], [[Y:%.*]] -; CHECK-NEXT: [[R:%.*]] = mul <3 x i8> [[A]], [[SPLATX]] +; CHECK-NEXT: [[TMP1:%.*]] = mul <3 x i8> [[X:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <3 x i8> [[TMP1]], <3 x i8> undef, <3 x i32> +; CHECK-NEXT: [[R:%.*]] = mul <3 x i8> [[TMP2]], [[Y:%.*]] ; CHECK-NEXT: ret <3 x i8> [[R]] ; %splatx = shufflevector <3 x i8> %x, <3 x i8> undef, <3 x i32> @@ -1496,7 +1499,7 @@ define <3 x i8> @splat_assoc_mul(<3 x i8> %x, <3 x i8> %y, <3 x i8> %z) { ret <3 x i8> %r } -; Mismatched splat elements +; Negative test - mismatched splat elements define <3 x i8> @splat_assoc_or(<3 x i8> %x, <3 x i8> %y, <3 x i8> %z) { ; CHECK-LABEL: @splat_assoc_or( @@ -1513,7 +1516,7 @@ define <3 x i8> @splat_assoc_or(<3 x i8> %x, <3 x i8> %y, <3 x i8> %z) { ret <3 x i8> %r } -; Not associative +; Negative test - not associative define <2 x float> @splat_assoc_fdiv(<2 x float> %x, <2 x float> %y) { ; CHECK-LABEL: @splat_assoc_fdiv( @@ -1528,7 +1531,7 @@ define <2 x float> @splat_assoc_fdiv(<2 x float> %x, <2 x float> %y) { ret <2 x float> %r } -; Extra use +; Negative test - extra use define <2 x float> @splat_assoc_fadd(<2 x float> %x, <2 x float> %y) { ; CHECK-LABEL: @splat_assoc_fadd( @@ -1545,7 +1548,7 @@ define <2 x float> @splat_assoc_fadd(<2 x float> %x, <2 x float> %y) { ret <2 x float> %r } -; Narrowing splat +; Negative test - narrowing splat define <3 x i32> @splat_assoc_and(<4 x i32> %x, <3 x i32> %y) { ; CHECK-LABEL: @splat_assoc_and( @@ -1560,7 +1563,7 @@ define <3 x i32> @splat_assoc_and(<4 x i32> %x, <3 x i32> %y) { ret <3 x i32> %r } -; Widening splat +; Negative test - widening splat define <5 x i32> @splat_assoc_xor(<4 x i32> %x, <5 x i32> %y) { ; CHECK-LABEL: @splat_assoc_xor( @@ -1575,7 +1578,7 @@ define <5 x i32> @splat_assoc_xor(<4 x i32> %x, <5 x i32> %y) { ret <5 x i32> %r } -; Opcode mismatch +; Negative test - opcode mismatch define <4 x i32> @splat_assoc_add_mul(<4 x i32> %x, <4 x i32> %y) { ; CHECK-LABEL: @splat_assoc_add_mul( diff --git a/llvm/test/Transforms/LoopVectorize/induction.ll b/llvm/test/Transforms/LoopVectorize/induction.ll index 6bcf03f..e093ee4 100644 --- a/llvm/test/Transforms/LoopVectorize/induction.ll +++ b/llvm/test/Transforms/LoopVectorize/induction.ll @@ -427,7 +427,7 @@ for.end: ; UNROLL: %[[i1:.+]] = or i64 %index, 1 ; UNROLL: %[[i2:.+]] = or i64 %index, 2 ; UNROLL: %[[i3:.+]] = or i64 %index, 3 -; UNROLL: %step.add3 = add <2 x i32> %vec.ind2, +; UNROLL: %[[add:.+]]= add <2 x i32> %[[splat:.+]], ; UNROLL: getelementptr inbounds %pair.i16, %pair.i16* %p, i64 %index, i32 1 ; UNROLL: getelementptr inbounds %pair.i16, %pair.i16* %p, i64 %[[i1]], i32 1 ; UNROLL: getelementptr inbounds %pair.i16, %pair.i16* %p, i64 %[[i2]], i32 1 -- 2.7.4