[X86] Rewrite how X86PartialReduction finds candidates to consider optimizing.
authorCraig Topper <craig.topper@gmail.com>
Sun, 31 May 2020 19:39:14 +0000 (12:39 -0700)
committerCraig Topper <craig.topper@gmail.com>
Sun, 31 May 2020 19:53:01 +0000 (12:53 -0700)
Previously we walked the users of any vector binop looking for
more binops with the same opcode or phis that eventually ended up
in a reduction. While this is simple it also means visiting the
same nodes many times since we'll do a forward walk for each
BinaryOperator in the chain. It was also far more general than what
we have tests for or expect to see.

This patch replaces the algorithm with a new method that starts at
extract elements looking for a horizontal reduction. Once we find
a reduction we walk through backwards through phis and adds to
collect leaves that we can consider for rewriting.

We only consider single use adds and phis. Except for a special
case if the Add is used by a phi that forms a loop back to the
Add. Including other single use Adds to support unrolled loops.

Ultimately, I want to narrow the Adds, Phis, and final reduction
based on the partial reduction we're doing. I still haven't
figured out exactly what that looks like yet. But restricting
the types of graphs we expect to handle seemed like a good first
step. As does having all the leaves and the reduction at once.

Differential Revision: https://reviews.llvm.org/D79971

llvm/lib/Target/X86/X86PartialReduction.cpp
llvm/test/CodeGen/X86/madd.ll
llvm/test/CodeGen/X86/sad.ll

index 16108bd..65caeab 100644 (file)
@@ -49,11 +49,8 @@ public:
   }
 
 private:
-  bool tryMAddPattern(BinaryOperator *BO);
-  bool tryMAddReplacement(Value *Op, BinaryOperator *Add);
-
-  bool trySADPattern(BinaryOperator *BO);
-  bool trySADReplacement(Value *Op, BinaryOperator *Add);
+  bool tryMAddReplacement(Instruction *Op);
+  bool trySADReplacement(Instruction *Op);
 };
 }
 
@@ -66,139 +63,24 @@ char X86PartialReduction::ID = 0;
 INITIALIZE_PASS(X86PartialReduction, DEBUG_TYPE,
                 "X86 Partial Reduction", false, false)
 
-static bool isVectorReductionOp(const BinaryOperator &BO) {
-  if (!BO.getType()->isVectorTy())
+bool X86PartialReduction::tryMAddReplacement(Instruction *Op) {
+  if (!ST->hasSSE2())
     return false;
 
-  unsigned Opcode = BO.getOpcode();
-
-  switch (Opcode) {
-  case Instruction::Add:
-  case Instruction::Mul:
-  case Instruction::And:
-  case Instruction::Or:
-  case Instruction::Xor:
-    break;
-  case Instruction::FAdd:
-  case Instruction::FMul:
-    if (auto *FPOp = dyn_cast<FPMathOperator>(&BO))
-      if (FPOp->getFastMathFlags().isFast())
-        break;
-    LLVM_FALLTHROUGH;
-  default:
+  // Need at least 8 elements.
+  if (cast<VectorType>(Op->getType())->getNumElements() < 8)
     return false;
-  }
 
-  unsigned ElemNum = cast<VectorType>(BO.getType())->getNumElements();
-  // Ensure the reduction size is a power of 2.
-  if (!isPowerOf2_32(ElemNum))
+  // Element type should be i32.
+  if (!cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(32))
     return false;
 
-  unsigned ElemNumToReduce = ElemNum;
-
-  // Do DFS search on the def-use chain from the given instruction. We only
-  // allow four kinds of operations during the search until we reach the
-  // instruction that extracts the first element from the vector:
-  //
-  //   1. The reduction operation of the same opcode as the given instruction.
-  //
-  //   2. PHI node.
-  //
-  //   3. ShuffleVector instruction together with a reduction operation that
-  //      does a partial reduction.
-  //
-  //   4. ExtractElement that extracts the first element from the vector, and we
-  //      stop searching the def-use chain here.
-  //
-  // 3 & 4 above perform a reduction on all elements of the vector. We push defs
-  // from 1-3 to the stack to continue the DFS. The given instruction is not
-  // a reduction operation if we meet any other instructions other than those
-  // listed above.
-
-  SmallVector<const User *, 16> UsersToVisit{&BO};
-  SmallPtrSet<const User *, 16> Visited;
-  bool ReduxExtracted = false;
-
-  while (!UsersToVisit.empty()) {
-    auto User = UsersToVisit.back();
-    UsersToVisit.pop_back();
-    if (!Visited.insert(User).second)
-      continue;
-
-    for (const auto *U : User->users()) {
-      auto *Inst = dyn_cast<Instruction>(U);
-      if (!Inst)
-        return false;
-
-      if (Inst->getOpcode() == Opcode || isa<PHINode>(U)) {
-        if (auto *FPOp = dyn_cast<FPMathOperator>(Inst))
-          if (!isa<PHINode>(FPOp) && !FPOp->getFastMathFlags().isFast())
-            return false;
-        UsersToVisit.push_back(U);
-      } else if (auto *ShufInst = dyn_cast<ShuffleVectorInst>(U)) {
-        // Detect the following pattern: A ShuffleVector instruction together
-        // with a reduction that do partial reduction on the first and second
-        // ElemNumToReduce / 2 elements, and store the result in
-        // ElemNumToReduce / 2 elements in another vector.
-
-        unsigned ResultElements = ShufInst->getType()->getNumElements();
-        if (ResultElements < ElemNum)
-          return false;
-
-        if (ElemNumToReduce == 1)
-          return false;
-        if (!isa<UndefValue>(U->getOperand(1)))
-          return false;
-        for (unsigned i = 0; i < ElemNumToReduce / 2; ++i)
-          if (ShufInst->getMaskValue(i) != int(i + ElemNumToReduce / 2))
-            return false;
-        for (unsigned i = ElemNumToReduce / 2; i < ElemNum; ++i)
-          if (ShufInst->getMaskValue(i) != -1)
-            return false;
-
-        // There is only one user of this ShuffleVector instruction, which
-        // must be a reduction operation.
-        if (!U->hasOneUse())
-          return false;
-
-        auto *U2 = dyn_cast<BinaryOperator>(*U->user_begin());
-        if (!U2 || U2->getOpcode() != Opcode)
-          return false;
-
-        // Check operands of the reduction operation.
-        if ((U2->getOperand(0) == U->getOperand(0) && U2->getOperand(1) == U) ||
-            (U2->getOperand(1) == U->getOperand(0) && U2->getOperand(0) == U)) {
-          UsersToVisit.push_back(U2);
-          ElemNumToReduce /= 2;
-        } else
-          return false;
-      } else if (isa<ExtractElementInst>(U)) {
-        // At this moment we should have reduced all elements in the vector.
-        if (ElemNumToReduce != 1)
-          return false;
-
-        auto *Val = dyn_cast<ConstantInt>(U->getOperand(1));
-        if (!Val || !Val->isZero())
-          return false;
-
-        ReduxExtracted = true;
-      } else
-        return false;
-    }
-  }
-  return ReduxExtracted;
-}
-
-bool X86PartialReduction::tryMAddReplacement(Value *Op, BinaryOperator *Add) {
-  BasicBlock *BB = Add->getParent();
-
-  auto *BO = dyn_cast<BinaryOperator>(Op);
-  if (!BO || BO->getOpcode() != Instruction::Mul || !BO->hasOneUse() ||
-      BO->getParent() != BB)
+  auto *Mul = dyn_cast<BinaryOperator>(Op);
+  if (!Mul || Mul->getOpcode() != Instruction::Mul)
     return false;
 
-  Value *LHS = BO->getOperand(0);
-  Value *RHS = BO->getOperand(1);
+  Value *LHS = Mul->getOperand(0);
+  Value *RHS = Mul->getOperand(1);
 
   // LHS and RHS should be only used once or if they are the same then only
   // used twice. Only check this when SSE4.1 is enabled and we have zext/sext
@@ -219,7 +101,7 @@ bool X86PartialReduction::tryMAddReplacement(Value *Op, BinaryOperator *Add) {
   auto CanShrinkOp = [&](Value *Op) {
     auto IsFreeTruncation = [&](Value *Op) {
       if (auto *Cast = dyn_cast<CastInst>(Op)) {
-        if (Cast->getParent() == BB &&
+        if (Cast->getParent() == Mul->getParent() &&
             (Cast->getOpcode() == Instruction::SExt ||
              Cast->getOpcode() == Instruction::ZExt) &&
             Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 16)
@@ -232,16 +114,16 @@ bool X86PartialReduction::tryMAddReplacement(Value *Op, BinaryOperator *Add) {
     // If the operation can be freely truncated and has enough sign bits we
     // can shrink.
     if (IsFreeTruncation(Op) &&
-        ComputeNumSignBits(Op, *DL, 0, nullptr, BO) > 16)
+        ComputeNumSignBits(Op, *DL, 0, nullptr, Mul) > 16)
       return true;
 
     // SelectionDAG has limited support for truncating through an add or sub if
     // the inputs are freely truncatable.
     if (auto *BO = dyn_cast<BinaryOperator>(Op)) {
-      if (BO->getParent() == BB &&
+      if (BO->getParent() == Mul->getParent() &&
           IsFreeTruncation(BO->getOperand(0)) &&
           IsFreeTruncation(BO->getOperand(1)) &&
-          ComputeNumSignBits(Op, *DL, 0, nullptr, BO) > 16)
+          ComputeNumSignBits(Op, *DL, 0, nullptr, Mul) > 16)
         return true;
     }
 
@@ -252,7 +134,7 @@ bool X86PartialReduction::tryMAddReplacement(Value *Op, BinaryOperator *Add) {
   if (!CanShrinkOp(LHS) && !CanShrinkOp(RHS))
     return false;
 
-  IRBuilder<> Builder(Add);
+  IRBuilder<> Builder(Mul);
 
   auto *MulTy = cast<VectorType>(Op->getType());
   unsigned NumElts = MulTy->getNumElements();
@@ -266,8 +148,11 @@ bool X86PartialReduction::tryMAddReplacement(Value *Op, BinaryOperator *Add) {
     EvenMask[i] = i * 2;
     OddMask[i] = i * 2 + 1;
   }
-  Value *EvenElts = Builder.CreateShuffleVector(BO, BO, EvenMask);
-  Value *OddElts = Builder.CreateShuffleVector(BO, BO, OddMask);
+  // Creating a new mul so the replaceAllUsesWith below doesn't replace the
+  // uses in the shuffles we're creating.
+  Value *NewMul = Builder.CreateMul(Mul->getOperand(0), Mul->getOperand(1));
+  Value *EvenElts = Builder.CreateShuffleVector(NewMul, NewMul, EvenMask);
+  Value *OddElts = Builder.CreateShuffleVector(NewMul, NewMul, OddMask);
   Value *MAdd = Builder.CreateAdd(EvenElts, OddElts);
 
   // Concatenate zeroes to extend back to the original type.
@@ -276,34 +161,21 @@ bool X86PartialReduction::tryMAddReplacement(Value *Op, BinaryOperator *Add) {
   Value *Zero = Constant::getNullValue(MAdd->getType());
   Value *Concat = Builder.CreateShuffleVector(MAdd, Zero, ConcatMask);
 
-  // Replaces the use of mul in the original Add with the pmaddwd and zeroes.
-  Add->replaceUsesOfWith(BO, Concat);
-  Add->setHasNoSignedWrap(false);
-  Add->setHasNoUnsignedWrap(false);
+  Mul->replaceAllUsesWith(Concat);
+  Mul->eraseFromParent();
 
   return true;
 }
 
-// Try to replace operans of this add with pmaddwd patterns.
-bool X86PartialReduction::tryMAddPattern(BinaryOperator *BO) {
+bool X86PartialReduction::trySADReplacement(Instruction *Op) {
   if (!ST->hasSSE2())
     return false;
 
-  // Need at least 8 elements.
-  if (cast<VectorType>(BO->getType())->getNumElements() < 8)
-    return false;
-
-  // Element type should be i32.
-  if (!cast<VectorType>(BO->getType())->getElementType()->isIntegerTy(32))
+  // TODO: There's nothing special about i32, any integer type above i16 should
+  // work just as well.
+  if (!cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(32))
     return false;
 
-  bool Changed = false;
-  Changed |= tryMAddReplacement(BO->getOperand(0), BO);
-  Changed |= tryMAddReplacement(BO->getOperand(1), BO);
-  return Changed;
-}
-
-bool X86PartialReduction::trySADReplacement(Value *Op, BinaryOperator *Add) {
   // Operand should be a select.
   auto *SI = dyn_cast<SelectInst>(Op);
   if (!SI)
@@ -337,7 +209,7 @@ bool X86PartialReduction::trySADReplacement(Value *Op, BinaryOperator *Add) {
   if (!Op0 || !Op1)
     return false;
 
-  IRBuilder<> Builder(Add);
+  IRBuilder<> Builder(SI);
 
   auto *OpTy = cast<VectorType>(Op->getType());
   unsigned NumElts = OpTy->getNumElements();
@@ -355,7 +227,7 @@ bool X86PartialReduction::trySADReplacement(Value *Op, BinaryOperator *Add) {
     IntrinsicNumElts = 16;
   }
 
-  Function *PSADBWFn = Intrinsic::getDeclaration(Add->getModule(), IID);
+  Function *PSADBWFn = Intrinsic::getDeclaration(SI->getModule(), IID);
 
   if (NumElts < 16) {
     // Pad input with zeroes.
@@ -419,27 +291,155 @@ bool X86PartialReduction::trySADReplacement(Value *Op, BinaryOperator *Add) {
     Ops[0] = Builder.CreateShuffleVector(Ops[0], Zero, ConcatMask);
   }
 
-  // Replaces the uses of Op in Add with the new sequence.
-  Add->replaceUsesOfWith(Op, Ops[0]);
-  Add->setHasNoSignedWrap(false);
-  Add->setHasNoUnsignedWrap(false);
+  SI->replaceAllUsesWith(Ops[0]);
+  SI->eraseFromParent();
 
   return true;
 }
 
-bool X86PartialReduction::trySADPattern(BinaryOperator *BO) {
-  if (!ST->hasSSE2())
-    return false;
+// Walk backwards from the ExtractElementInst and determine if it is the end of
+// a horizontal reduction. Return the input to the reduction if we find one.
+static Value *matchAddReduction(const ExtractElementInst &EE) {
+  // Make sure we're extracting index 0.
+  auto *Index = dyn_cast<ConstantInt>(EE.getIndexOperand());
+  if (!Index || !Index->isNullValue())
+    return nullptr;
 
-  // TODO: There's nothing special about i32, any integer type above i16 should
-  // work just as well.
-  if (!cast<VectorType>(BO->getType())->getElementType()->isIntegerTy(32))
+  const auto *BO = dyn_cast<BinaryOperator>(EE.getVectorOperand());
+  if (!BO || BO->getOpcode() != Instruction::Add || !BO->hasOneUse())
+    return nullptr;
+
+  unsigned NumElems = cast<VectorType>(BO->getType())->getNumElements();
+  // Ensure the reduction size is a power of 2.
+  if (!isPowerOf2_32(NumElems))
+    return nullptr;
+
+  const Value *Op = BO;
+  unsigned Stages = Log2_32(NumElems);
+  for (unsigned i = 0; i != Stages; ++i) {
+    const auto *BO = dyn_cast<BinaryOperator>(Op);
+    if (!BO || BO->getOpcode() != Instruction::Add)
+      return nullptr;
+
+    // If this isn't the first add, then it should only have 2 users, the
+    // shuffle and another add which we checked in the previous iteration.
+    if (i != 0 && !BO->hasNUses(2))
+      return nullptr;
+
+    Value *LHS = BO->getOperand(0);
+    Value *RHS = BO->getOperand(1);
+
+    auto *Shuffle = dyn_cast<ShuffleVectorInst>(LHS);
+    if (Shuffle) {
+      Op = RHS;
+    } else {
+      Shuffle = dyn_cast<ShuffleVectorInst>(RHS);
+      Op = LHS;
+    }
+
+    // The first operand of the shuffle should be the same as the other operand
+    // of the bin op.
+    if (!Shuffle || Shuffle->getOperand(0) != Op)
+      return nullptr;
+
+    // Verify the shuffle has the expected (at this stage of the pyramid) mask.
+    unsigned MaskEnd = 1 << i;
+    for (unsigned Index = 0; Index < MaskEnd; ++Index)
+      if (Shuffle->getMaskValue(Index) != (int)(MaskEnd + Index))
+        return nullptr;
+  }
+
+  return const_cast<Value *>(Op);
+}
+
+// See if this BO is reachable from this Phi by walking forward through single
+// use BinaryOperators with the same opcode. If we get back then we know we've
+// found a loop and it is safe to step through this Add to find more leaves.
+static bool isReachableFromPHI(PHINode *Phi, BinaryOperator *BO) {
+  // The PHI itself should only have one use.
+  if (!Phi->hasOneUse())
     return false;
 
-  bool Changed = false;
-  Changed |= trySADReplacement(BO->getOperand(0), BO);
-  Changed |= trySADReplacement(BO->getOperand(1), BO);
-  return Changed;
+  Instruction *U = cast<Instruction>(*Phi->user_begin());
+  if (U == BO)
+    return true;
+
+  while (U->hasOneUse() && U->getOpcode() == BO->getOpcode())
+    U = cast<Instruction>(*U->user_begin());
+
+  return U == BO;
+}
+
+// Collect all the leaves of the tree of adds that feeds into the horizontal
+// reduction. Root is the Value that is used by the horizontal reduction.
+// We look through single use phis, single use adds, or adds that are used by
+// a phi that forms a loop with the add.
+static void collectLeaves(Value *Root, SmallVectorImpl<Instruction *> &Leaves) {
+  SmallPtrSet<Value *, 8> Visited;
+  SmallVector<Value *, 8> Worklist;
+  Worklist.push_back(Root);
+
+  while (!Worklist.empty()) {
+    Value *V = Worklist.pop_back_val();
+     if (!Visited.insert(V).second)
+       continue;
+
+    if (auto *PN = dyn_cast<PHINode>(V)) {
+      // PHI node should have single use unless it is the root node, then it
+      // has 2 uses.
+      if (!PN->hasNUses(PN == Root ? 2 : 1))
+        break;
+
+      // Push incoming values to the worklist.
+      for (Value *InV : PN->incoming_values())
+        Worklist.push_back(InV);
+
+      continue;
+    }
+
+    if (auto *BO = dyn_cast<BinaryOperator>(V)) {
+      if (BO->getOpcode() == Instruction::Add) {
+        // Simple case. Single use, just push its operands to the worklist.
+        if (BO->hasNUses(BO == Root ? 2 : 1)) {
+          for (Value *Op : BO->operands())
+            Worklist.push_back(Op);
+          continue;
+        }
+
+        // If there is additional use, make sure it is an unvisited phi that
+        // gets us back to this node.
+        if (BO->hasNUses(BO == Root ? 3 : 2)) {
+          PHINode *PN = nullptr;
+          for (auto *U : Root->users())
+            if (auto *P = dyn_cast<PHINode>(U))
+              if (!Visited.count(P))
+                PN = P;
+
+          // If we didn't find a 2-input PHI then this isn't a case we can
+          // handle.
+          if (!PN || PN->getNumIncomingValues() != 2)
+            continue;
+
+          // Walk forward from this phi to see if it reaches back to this add.
+          if (!isReachableFromPHI(PN, BO))
+            continue;
+
+          // The phi forms a loop with this Add, push its operands.
+          for (Value *Op : BO->operands())
+            Worklist.push_back(Op);
+        }
+      }
+    }
+
+    // Not an add or phi, make it a leaf.
+    if (auto *I = dyn_cast<Instruction>(V)) {
+      if (!V->hasNUses(I == Root ? 2 : 1))
+        continue;
+
+      // Add this as a leaf.
+      Leaves.push_back(I);
+    }
+  }
 }
 
 bool X86PartialReduction::runOnFunction(Function &F) {
@@ -458,22 +458,29 @@ bool X86PartialReduction::runOnFunction(Function &F) {
   bool MadeChange = false;
   for (auto &BB : F) {
     for (auto &I : BB) {
-      auto *BO = dyn_cast<BinaryOperator>(&I);
-      if (!BO)
+      auto *EE = dyn_cast<ExtractElementInst>(&I);
+      if (!EE)
         continue;
 
-      if (!isVectorReductionOp(*BO))
+      // First find a reduction tree.
+      // FIXME: Do we need to handle other opcodes than Add?
+      Value *Root = matchAddReduction(*EE);
+      if (!Root)
         continue;
 
-      if (BO->getOpcode() == Instruction::Add) {
-        if (tryMAddPattern(BO)) {
+      SmallVector<Instruction *, 8> Leaves;
+      collectLeaves(Root, Leaves);
+
+      for (Instruction *I : Leaves) {
+        if (tryMAddReplacement(I)) {
           MadeChange = true;
           continue;
         }
-        if (trySADPattern(BO)) {
+
+        // Don't do SAD matching on the root node. SelectionDAG already
+        // has support for that and currently generates better code.
+        if (I != Root && trySADReplacement(I))
           MadeChange = true;
-          continue;
-        }
       }
     }
   }
index d6d04d9..6109bd2 100644 (file)
@@ -2657,9 +2657,9 @@ define i32 @madd_double_reduction(<8 x i16>* %arg, <8 x i16>* %arg1, <8 x i16>*
 ; AVX-LABEL: madd_double_reduction:
 ; AVX:       # %bb.0:
 ; AVX-NEXT:    vmovdqu (%rdi), %xmm0
+; AVX-NEXT:    vpmaddwd (%rsi), %xmm0, %xmm0
 ; AVX-NEXT:    vmovdqu (%rdx), %xmm1
 ; AVX-NEXT:    vpmaddwd (%rcx), %xmm1, %xmm1
-; AVX-NEXT:    vpmaddwd (%rsi), %xmm0, %xmm0
 ; AVX-NEXT:    vpaddd %xmm0, %xmm1, %xmm0
 ; AVX-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
 ; AVX-NEXT:    vpaddd %xmm1, %xmm0, %xmm0
@@ -2720,9 +2720,9 @@ define i32 @madd_quad_reduction(<8 x i16>* %arg, <8 x i16>* %arg1, <8 x i16>* %a
 ; AVX-NEXT:    movq {{[0-9]+}}(%rsp), %r10
 ; AVX-NEXT:    movq {{[0-9]+}}(%rsp), %rax
 ; AVX-NEXT:    vmovdqu (%rdi), %xmm0
+; AVX-NEXT:    vpmaddwd (%rsi), %xmm0, %xmm0
 ; AVX-NEXT:    vmovdqu (%rdx), %xmm1
 ; AVX-NEXT:    vpmaddwd (%rcx), %xmm1, %xmm1
-; AVX-NEXT:    vpmaddwd (%rsi), %xmm0, %xmm0
 ; AVX-NEXT:    vmovdqu (%r8), %xmm2
 ; AVX-NEXT:    vpmaddwd (%r9), %xmm2, %xmm2
 ; AVX-NEXT:    vpaddd %xmm2, %xmm0, %xmm0
index 006dd3d..f55a580 100644 (file)
@@ -1061,9 +1061,9 @@ define i32 @sad_double_reduction(<16 x i8>* %arg, <16 x i8>* %arg1, <16 x i8>* %
 ; AVX-LABEL: sad_double_reduction:
 ; AVX:       # %bb.0: # %bb
 ; AVX-NEXT:    vmovdqu (%rdi), %xmm0
+; AVX-NEXT:    vpsadbw (%rsi), %xmm0, %xmm0
 ; AVX-NEXT:    vmovdqu (%rdx), %xmm1
 ; AVX-NEXT:    vpsadbw (%rcx), %xmm1, %xmm1
-; AVX-NEXT:    vpsadbw (%rsi), %xmm0, %xmm0
 ; AVX-NEXT:    vpaddd %xmm0, %xmm1, %xmm0
 ; AVX-NEXT:    vpshufd {{.*#+}} xmm1 = xmm0[2,3,0,1]
 ; AVX-NEXT:    vpaddd %xmm1, %xmm0, %xmm0