[InstCombine] reassociate splatted vector ops
authorSanjay Patel <spatel@rotateright.com>
Mon, 3 Feb 2020 13:55:43 +0000 (08:55 -0500)
committerSanjay Patel <spatel@rotateright.com>
Mon, 3 Feb 2020 14:08:36 +0000 (09:08 -0500)
bo (splat X), (bo Y, OtherOp) --> bo (splat (bo X, Y)), OtherOp

This patch depends on the splat analysis enhancement in D73549.
See the test with comment:
; Negative test - mismatched splat elements
...as the motivation for that first patch.

The motivating case for reassociating splatted ops is shown in PR42174:
https://bugs.llvm.org/show_bug.cgi?id=42174

In that example, a slight change in order-of-associative math results
in a big difference in IR and codegen. This patch gets all of the
unnecessary shuffles out of the way, but doesn't address the potential
scalarization (see D50992 or D73480 for that).

Differential Revision: https://reviews.llvm.org/D73703

llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
llvm/test/Transforms/InstCombine/vec_shuffle.ll
llvm/test/Transforms/LoopVectorize/induction.ll

index bded1d7..f11fa27 100644 (file)
@@ -60,6 +60,7 @@
 #include "llvm/Analysis/TargetFolder.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
 #include "llvm/IR/Constant.h"
@@ -1683,6 +1684,54 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) {
     }
   }
 
+  // Try to reassociate to sink a splat shuffle after a binary operation.
+  if (Inst.isAssociative() && Inst.isCommutative()) {
+    // Canonicalize shuffle operand as LHS.
+    if (auto *ShufR = dyn_cast<ShuffleVectorInst>(RHS))
+      std::swap(LHS, RHS);
+
+    Value *X;
+    Constant *MaskC;
+    const APInt *SplatIndex;
+    BinaryOperator *BO;
+    if (!match(LHS, m_OneUse(m_ShuffleVector(m_Value(X), m_Undef(),
+                                             m_Constant(MaskC)))) ||
+        !match(MaskC, m_APIntAllowUndef(SplatIndex)) ||
+        X->getType() != Inst.getType() || !match(RHS, m_OneUse(m_BinOp(BO))) ||
+        BO->getOpcode() != Opcode)
+      return nullptr;
+
+    Value *Y, *OtherOp;
+    if (isSplatValue(BO->getOperand(0), SplatIndex->getZExtValue())) {
+      Y = BO->getOperand(0);
+      OtherOp = BO->getOperand(1);
+    } else if (isSplatValue(BO->getOperand(1), SplatIndex->getZExtValue())) {
+      Y = BO->getOperand(1);
+      OtherOp = BO->getOperand(0);
+    } else {
+      return nullptr;
+    }
+
+    // X and Y are splatted values, so perform the binary operation on those
+    // values followed by a splat followed by the 2nd binary operation:
+    // bo (splat X), (bo Y, OtherOp) --> bo (splat (bo X, Y)), OtherOp
+    Value *NewBO = Builder.CreateBinOp(Opcode, X, Y);
+    UndefValue *Undef = UndefValue::get(Inst.getType());
+    Constant *NewMask = ConstantInt::get(MaskC->getType(), *SplatIndex);
+    Value *NewSplat = Builder.CreateShuffleVector(NewBO, Undef, NewMask);
+    Instruction *R = BinaryOperator::Create(Opcode, NewSplat, OtherOp);
+
+    // Intersect FMF on both new binops. Other (poison-generating) flags are
+    // dropped to be safe.
+    if (isa<FPMathOperator>(R)) {
+      R->copyFastMathFlags(&Inst);
+      R->andIRFlags(BO);
+    }
+    if (auto *NewInstBO = dyn_cast<BinaryOperator>(NewBO))
+      NewInstBO->copyIRFlags(R);
+    return R;
+  }
+
   return nullptr;
 }
 
index abbf1df..e806fac 100644 (file)
@@ -1457,9 +1457,9 @@ define <4 x float> @insert_subvector_crash_invalid_mask_elt(<2 x float> %x, <4 x
 
 define <4 x i32> @splat_assoc_add(<4 x i32> %x, <4 x i32> %y) {
 ; CHECK-LABEL: @splat_assoc_add(
-; CHECK-NEXT:    [[SPLATX:%.*]] = shufflevector <4 x i32> [[X:%.*]], <4 x i32> undef, <4 x i32> zeroinitializer
-; CHECK-NEXT:    [[A:%.*]] = add <4 x i32> [[Y:%.*]], <i32 317426, i32 317426, i32 317426, i32 317426>
-; CHECK-NEXT:    [[R:%.*]] = add <4 x i32> [[SPLATX]], [[A]]
+; CHECK-NEXT:    [[TMP1:%.*]] = add <4 x i32> [[X:%.*]], <i32 317426, i32 undef, i32 undef, i32 undef>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> undef, <4 x i32> zeroinitializer
+; CHECK-NEXT:    [[R:%.*]] = add <4 x i32> [[TMP2]], [[Y:%.*]]
 ; CHECK-NEXT:    ret <4 x i32> [[R]]
 ;
   %splatx = shufflevector <4 x i32> %x, <4 x i32> undef, <4 x i32> zeroinitializer
@@ -1468,11 +1468,13 @@ define <4 x i32> @splat_assoc_add(<4 x i32> %x, <4 x i32> %y) {
   ret <4 x i32> %r
 }
 
+; Non-zero splat index; commute operands; FMF intersect
+
 define <2 x float> @splat_assoc_fmul(<2 x float> %x, <2 x float> %y) {
 ; CHECK-LABEL: @splat_assoc_fmul(
-; CHECK-NEXT:    [[SPLATX:%.*]] = shufflevector <2 x float> [[X:%.*]], <2 x float> undef, <2 x i32> <i32 1, i32 1>
-; CHECK-NEXT:    [[A:%.*]] = fmul reassoc nsz <2 x float> [[Y:%.*]], <float 3.000000e+00, float 3.000000e+00>
-; CHECK-NEXT:    [[R:%.*]] = fmul reassoc nnan nsz <2 x float> [[A]], [[SPLATX]]
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc nsz <2 x float> [[X:%.*]], <float undef, float 3.000000e+00>
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <2 x float> [[TMP1]], <2 x float> undef, <2 x i32> <i32 1, i32 1>
+; CHECK-NEXT:    [[R:%.*]] = fmul reassoc nsz <2 x float> [[TMP2]], [[Y:%.*]]
 ; CHECK-NEXT:    ret <2 x float> [[R]]
 ;
   %splatx = shufflevector <2 x float> %x, <2 x float> undef, <2 x i32> <i32 1, i32 1>
@@ -1481,12 +1483,13 @@ define <2 x float> @splat_assoc_fmul(<2 x float> %x, <2 x float> %y) {
   ret <2 x float> %r
 }
 
+; Two splat shuffles; drop poison-generating flags
+
 define <3 x i8> @splat_assoc_mul(<3 x i8> %x, <3 x i8> %y, <3 x i8> %z) {
 ; CHECK-LABEL: @splat_assoc_mul(
-; CHECK-NEXT:    [[SPLATX:%.*]] = shufflevector <3 x i8> [[X:%.*]], <3 x i8> undef, <3 x i32> <i32 2, i32 2, i32 2>
-; CHECK-NEXT:    [[SPLATZ:%.*]] = shufflevector <3 x i8> [[Z:%.*]], <3 x i8> undef, <3 x i32> <i32 2, i32 2, i32 2>
-; CHECK-NEXT:    [[A:%.*]] = mul nsw <3 x i8> [[SPLATZ]], [[Y:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = mul <3 x i8> [[A]], [[SPLATX]]
+; CHECK-NEXT:    [[TMP1:%.*]] = mul <3 x i8> [[X:%.*]], [[Z:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <3 x i8> [[TMP1]], <3 x i8> undef, <3 x i32> <i32 2, i32 2, i32 2>
+; CHECK-NEXT:    [[R:%.*]] = mul <3 x i8> [[TMP2]], [[Y:%.*]]
 ; CHECK-NEXT:    ret <3 x i8> [[R]]
 ;
   %splatx = shufflevector <3 x i8> %x, <3 x i8> undef, <3 x i32> <i32 2, i32 2, i32 2>
@@ -1496,7 +1499,7 @@ define <3 x i8> @splat_assoc_mul(<3 x i8> %x, <3 x i8> %y, <3 x i8> %z) {
   ret <3 x i8> %r
 }
 
-; Mismatched splat elements
+; Negative test - mismatched splat elements
 
 define <3 x i8> @splat_assoc_or(<3 x i8> %x, <3 x i8> %y, <3 x i8> %z) {
 ; CHECK-LABEL: @splat_assoc_or(
@@ -1513,7 +1516,7 @@ define <3 x i8> @splat_assoc_or(<3 x i8> %x, <3 x i8> %y, <3 x i8> %z) {
   ret <3 x i8> %r
 }
 
-; Not associative
+; Negative test - not associative
 
 define <2 x float> @splat_assoc_fdiv(<2 x float> %x, <2 x float> %y) {
 ; CHECK-LABEL: @splat_assoc_fdiv(
@@ -1528,7 +1531,7 @@ define <2 x float> @splat_assoc_fdiv(<2 x float> %x, <2 x float> %y) {
   ret <2 x float> %r
 }
 
-; Extra use
+; Negative test - extra use
 
 define <2 x float> @splat_assoc_fadd(<2 x float> %x, <2 x float> %y) {
 ; CHECK-LABEL: @splat_assoc_fadd(
@@ -1545,7 +1548,7 @@ define <2 x float> @splat_assoc_fadd(<2 x float> %x, <2 x float> %y) {
   ret <2 x float> %r
 }
 
-; Narrowing splat
+; Negative test - narrowing splat
 
 define <3 x i32> @splat_assoc_and(<4 x i32> %x, <3 x i32> %y) {
 ; CHECK-LABEL: @splat_assoc_and(
@@ -1560,7 +1563,7 @@ define <3 x i32> @splat_assoc_and(<4 x i32> %x, <3 x i32> %y) {
   ret <3 x i32> %r
 }
 
-; Widening splat
+; Negative test - widening splat
 
 define <5 x i32> @splat_assoc_xor(<4 x i32> %x, <5 x i32> %y) {
 ; CHECK-LABEL: @splat_assoc_xor(
@@ -1575,7 +1578,7 @@ define <5 x i32> @splat_assoc_xor(<4 x i32> %x, <5 x i32> %y) {
   ret <5 x i32> %r
 }
 
-; Opcode mismatch
+; Negative test - opcode mismatch
 
 define <4 x i32> @splat_assoc_add_mul(<4 x i32> %x, <4 x i32> %y) {
 ; CHECK-LABEL: @splat_assoc_add_mul(
index 6bcf03f..e093ee4 100644 (file)
@@ -427,7 +427,7 @@ for.end:
 ; UNROLL:   %[[i1:.+]] = or i64 %index, 1
 ; UNROLL:   %[[i2:.+]] = or i64 %index, 2
 ; UNROLL:   %[[i3:.+]] = or i64 %index, 3
-; UNROLL:   %step.add3 = add <2 x i32> %vec.ind2, <i32 2, i32 2>
+; UNROLL:   %[[add:.+]]= add <2 x i32> %[[splat:.+]], <i32 2, i32 undef>
 ; UNROLL:   getelementptr inbounds %pair.i16, %pair.i16* %p, i64 %index, i32 1
 ; UNROLL:   getelementptr inbounds %pair.i16, %pair.i16* %p, i64 %[[i1]], i32 1
 ; UNROLL:   getelementptr inbounds %pair.i16, %pair.i16* %p, i64 %[[i2]], i32 1