[InstCombine] reverse canonicalization of add --> or to allow more shuffle folding
authorSanjay Patel <spatel@rotateright.com>
Mon, 2 Jul 2018 17:42:29 +0000 (17:42 +0000)
committerSanjay Patel <spatel@rotateright.com>
Mon, 2 Jul 2018 17:42:29 +0000 (17:42 +0000)
This extends D48485 to allow another pair of binops (add/or) to be combined either
with or without a leading shuffle:
or X, C --> add X, C (when X and C have no common bits set)

Here, we need value tracking to determine that the 'or' can be reversed into an 'add',
and we've added general infrastructure to allow extending to other opcodes or moving
to where other passes could use that functionality.

Differential Revision: https://reviews.llvm.org/D48662

llvm-svn: 336128

llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
llvm/test/Transforms/InstCombine/shuffle_select.ll

index dd579a6..b06889f 100644 (file)
@@ -1140,9 +1140,52 @@ static bool isShuffleExtractingFromLHS(ShuffleVectorInst &SVI,
   return true;
 }
 
+/// These are the ingredients in an alternate form binary operator as described
+/// below.
+struct BinopElts {
+  BinaryOperator::BinaryOps Opcode;
+  Value *Op0;
+  Value *Op1;
+  BinopElts(BinaryOperator::BinaryOps Opc = (BinaryOperator::BinaryOps)0,
+            Value *V0 = nullptr, Value *V1 = nullptr) :
+      Opcode(Opc), Op0(V0), Op1(V1) {}
+  operator bool() const { return Opcode != 0; }
+};
+
+/// Binops may be transformed into binops with different opcodes and operands.
+/// Reverse the usual canonicalization to enable folds with the non-canonical
+/// form of the binop. If a transform is possible, return the elements of the
+/// new binop. If not, return invalid elements.
+static BinopElts getAlternateBinop(BinaryOperator *BO, const DataLayout &DL) {
+  Value *BO0 = BO->getOperand(0), *BO1 = BO->getOperand(1);
+  Type *Ty = BO->getType();
+  switch (BO->getOpcode()) {
+    case Instruction::Shl: {
+      // shl X, C --> mul X, (1 << C)
+      Constant *C;
+      if (match(BO1, m_Constant(C))) {
+        Constant *ShlOne = ConstantExpr::getShl(ConstantInt::get(Ty, 1), C);
+        return { Instruction::Mul, BO0, ShlOne };
+      }
+      break;
+    }
+    case Instruction::Or: {
+      // or X, C --> add X, C (when X and C have no common bits set)
+      const APInt *C;
+      if (match(BO1, m_APInt(C)) && MaskedValueIsZero(BO0, *C, DL))
+        return { Instruction::Add, BO0, BO1 };
+      break;
+    }
+    default:
+      break;
+  }
+  return {};
+}
+
+/// Try to fold shuffles that are the equivalent of a vector select.
 static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf,
-                                      InstCombiner::BuilderTy &Builder) {
-  // Folds under here require the equivalent of a vector select.
+                                      InstCombiner::BuilderTy &Builder,
+                                      const DataLayout &DL) {
   if (!Shuf.isSelect())
     return nullptr;
 
@@ -1168,19 +1211,19 @@ static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf,
   BinaryOperator::BinaryOps Opc1 = B1->getOpcode();
   bool DropNSW = false;
   if (ConstantsAreOp1 && Opc0 != Opc1) {
-    // If we have multiply and shift-left-by-constant, convert the shift:
-    // shl X, C --> mul X, 1 << C
     // TODO: We drop "nsw" if shift is converted into multiply because it may
     // not be correct when the shift amount is BitWidth - 1. We could examine
     // each vector element to determine if it is safe to keep that flag.
-    if (Opc0 == Instruction::Mul && Opc1 == Instruction::Shl) {
-      C1 = ConstantExpr::getShl(ConstantInt::get(C1->getType(), 1), C1);
-      Opc1 = Instruction::Mul;
-      DropNSW = true;
-    } else if (Opc0 == Instruction::Shl && Opc1 == Instruction::Mul) {
-      C0 = ConstantExpr::getShl(ConstantInt::get(C0->getType(), 1), C0);
-      Opc0 = Instruction::Mul;
+    if (Opc0 == Instruction::Shl || Opc1 == Instruction::Shl)
       DropNSW = true;
+    if (BinopElts AltB0 = getAlternateBinop(B0, DL)) {
+      assert(isa<Constant>(AltB0.Op1) && "Expecting constant with alt binop");
+      Opc0 = AltB0.Opcode;
+      C0 = cast<Constant>(AltB0.Op1);
+    } else if (BinopElts AltB1 = getAlternateBinop(B1, DL)) {
+      assert(isa<Constant>(AltB1.Op1) && "Expecting constant with alt binop");
+      Opc1 = AltB1.Opcode;
+      C1 = cast<Constant>(AltB1.Op1);
     }
   }
 
@@ -1249,7 +1292,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
           LHS, RHS, SVI.getMask(), SVI.getType(), SQ.getWithInstruction(&SVI)))
     return replaceInstUsesWith(SVI, V);
 
-  if (Instruction *I = foldSelectShuffle(SVI, Builder))
+  if (Instruction *I = foldSelectShuffle(SVI, Builder, DL))
     return I;
 
   bool MadeChange = false;
index 241d5f2..3ec5630 100644 (file)
@@ -800,9 +800,7 @@ define <4 x i32> @shl_mul_2_vars(<4 x i32> %v0, <4 x i32> %v1) {
 define <4 x i32> @add_or(<4 x i32> %v) {
 ; CHECK-LABEL: @add_or(
 ; CHECK-NEXT:    [[V0:%.*]] = shl <4 x i32> [[V:%.*]], <i32 5, i32 5, i32 5, i32 5>
-; CHECK-NEXT:    [[T1:%.*]] = add <4 x i32> [[V0]], <i32 undef, i32 undef, i32 65536, i32 65537>
-; CHECK-NEXT:    [[T2:%.*]] = or <4 x i32> [[V0]], <i32 31, i32 31, i32 undef, i32 undef>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
+; CHECK-NEXT:    [[T3:%.*]] = add <4 x i32> [[V0]], <i32 31, i32 31, i32 65536, i32 65537>
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
   %v0 = shl <4 x i32> %v, <i32 5, i32 5, i32 5, i32 5>                   ; clear the bottom bits
@@ -817,9 +815,7 @@ define <4 x i32> @add_or(<4 x i32> %v) {
 define <4 x i8> @or_add(<4 x i8> %v) {
 ; CHECK-LABEL: @or_add(
 ; CHECK-NEXT:    [[V0:%.*]] = lshr <4 x i8> [[V:%.*]], <i8 3, i8 3, i8 3, i8 3>
-; CHECK-NEXT:    [[T1:%.*]] = or <4 x i8> [[V0]], <i8 undef, i8 undef, i8 -64, i8 -64>
-; CHECK-NEXT:    [[T2:%.*]] = add nuw nsw <4 x i8> [[V0]], <i8 1, i8 2, i8 undef, i8 undef>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i8> [[T1]], <4 x i8> [[T2]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
+; CHECK-NEXT:    [[T3:%.*]] = add nsw <4 x i8> [[V0]], <i8 1, i8 2, i8 -64, i8 -64>
 ; CHECK-NEXT:    ret <4 x i8> [[T3]]
 ;
   %v0 = lshr <4 x i8> %v, <i8 3, i8 3, i8 3, i8 3>          ; clear the top bits
@@ -829,6 +825,8 @@ define <4 x i8> @or_add(<4 x i8> %v) {
   ret <4 x i8> %t3
 }
 
+; Negative test: not all 'or' insts can be converted to 'add'.
+
 define <4 x i8> @or_add_not_enough_masking(<4 x i8> %v) {
 ; CHECK-LABEL: @or_add_not_enough_masking(
 ; CHECK-NEXT:    [[V0:%.*]] = lshr <4 x i8> [[V:%.*]], <i8 1, i8 1, i8 1, i8 1>
@@ -849,9 +847,8 @@ define <4 x i8> @or_add_not_enough_masking(<4 x i8> %v) {
 define <4 x i32> @add_or_2_vars(<4 x i32> %v, <4 x i32> %v1) {
 ; CHECK-LABEL: @add_or_2_vars(
 ; CHECK-NEXT:    [[V0:%.*]] = shl <4 x i32> [[V:%.*]], <i32 5, i32 5, i32 5, i32 5>
-; CHECK-NEXT:    [[T1:%.*]] = add <4 x i32> [[V1:%.*]], <i32 undef, i32 undef, i32 65536, i32 65537>
-; CHECK-NEXT:    [[T2:%.*]] = or <4 x i32> [[V0]], <i32 31, i32 31, i32 undef, i32 undef>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[V1:%.*]], <4 x i32> [[V0]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
+; CHECK-NEXT:    [[T3:%.*]] = add <4 x i32> [[TMP1]], <i32 31, i32 31, i32 65536, i32 65537>
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
   %v0 = shl <4 x i32> %v, <i32 5, i32 5, i32 5, i32 5>                   ; clear the bottom bits
@@ -864,9 +861,8 @@ define <4 x i32> @add_or_2_vars(<4 x i32> %v, <4 x i32> %v1) {
 define <4 x i8> @or_add_2_vars(<4 x i8> %v, <4 x i8> %v1) {
 ; CHECK-LABEL: @or_add_2_vars(
 ; CHECK-NEXT:    [[V0:%.*]] = lshr <4 x i8> [[V:%.*]], <i8 3, i8 3, i8 3, i8 3>
-; CHECK-NEXT:    [[T1:%.*]] = or <4 x i8> [[V0]], <i8 undef, i8 undef, i8 -64, i8 -64>
-; CHECK-NEXT:    [[T2:%.*]] = add nuw nsw <4 x i8> [[V1:%.*]], <i8 1, i8 2, i8 undef, i8 undef>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i8> [[T1]], <4 x i8> [[T2]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i8> [[V0]], <4 x i8> [[V1:%.*]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
+; CHECK-NEXT:    [[T3:%.*]] = add <4 x i8> [[TMP1]], <i8 1, i8 2, i8 -64, i8 -64>
 ; CHECK-NEXT:    ret <4 x i8> [[T3]]
 ;
   %v0 = lshr <4 x i8> %v, <i8 3, i8 3, i8 3, i8 3>          ; clear the top bits