[SLP] make checks for cmp+select min/max more explicit
authorSanjay Patel <spatel@rotateright.com>
Fri, 9 Jul 2021 15:56:26 +0000 (11:56 -0400)
committerSanjay Patel <spatel@rotateright.com>
Fri, 9 Jul 2021 16:43:43 +0000 (12:43 -0400)
This is NFC-intended currently (so no test diffs). The motivation
is to eventually allow matching for poison-safe logical-and and
logical-or (these are in the form of a select-of-bools).
( https://llvm.org/PR41312 )

Those patterns will not have all of the same constraints as min/max
in the form of cmp+sel. We may also end up removing the cmp+sel
min/max matching entirely (if we canonicalize to intrinsics), so
this will make that step easier.

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

index 2ec3215..7c49a95 100644 (file)
@@ -7288,6 +7288,11 @@ class HorizontalReduction {
   /// The type of reduction operation.
   RecurKind RdxKind;
 
+  static bool isCmpSelMinMax(Instruction *I) {
+    return match(I, m_Select(m_Cmp(), m_Value(), m_Value())) &&
+           RecurrenceDescriptor::isMinMaxRecurrenceKind(getRdxKind(I));
+  }
+
   /// Checks if instruction is associative and can be vectorized.
   static bool isVectorizable(RecurKind Kind, Instruction *I) {
     if (Kind == RecurKind::None)
@@ -7503,19 +7508,19 @@ class HorizontalReduction {
 
   /// Get the index of the first operand.
   static unsigned getFirstOperandIndex(Instruction *I) {
-    return isa<SelectInst>(I) ? 1 : 0;
+    return isCmpSelMinMax(I) ? 1 : 0;
   }
 
   /// Total number of operands in the reduction operation.
   static unsigned getNumberOfOperands(Instruction *I) {
-    return isa<SelectInst>(I) ? 3 : 2;
+    return isCmpSelMinMax(I) ? 3 : 2;
   }
 
   /// Checks if the instruction is in basic block \p BB.
   /// For a min/max reduction check that both compare and select are in \p BB.
-  static bool hasSameParent(Instruction *I, BasicBlock *BB, bool IsRedOp) {
-    auto *Sel = dyn_cast<SelectInst>(I);
-    if (IsRedOp && Sel) {
+  static bool hasSameParent(Instruction *I, BasicBlock *BB) {
+    if (isCmpSelMinMax(I)) {
+      auto *Sel = cast<SelectInst>(I);
       auto *Cmp = cast<Instruction>(Sel->getCondition());
       return Sel->getParent() == BB && Cmp->getParent() == BB;
     }
@@ -7523,10 +7528,10 @@ class HorizontalReduction {
   }
 
   /// Expected number of uses for reduction operations/reduced values.
-  static bool hasRequiredNumberOfUses(bool MatchCmpSel, Instruction *I) {
-    // SelectInst must be used twice while the condition op must have single
-    // use only.
-    if (MatchCmpSel) {
+  static bool hasRequiredNumberOfUses(bool IsCmpSelMinMax, Instruction *I) {
+    if (IsCmpSelMinMax) {
+      // SelectInst must be used twice while the condition op must have single
+      // use only.
       if (auto *Sel = dyn_cast<SelectInst>(I))
         return Sel->hasNUses(2) && Sel->getCondition()->hasOneUse();
       return I->hasNUses(2);
@@ -7538,7 +7543,7 @@ class HorizontalReduction {
 
   /// Initializes the list of reduction operations.
   void initReductionOps(Instruction *I) {
-    if (isa<SelectInst>(I))
+    if (isCmpSelMinMax(I))
       ReductionOps.assign(2, ReductionOpsType());
     else
       ReductionOps.assign(1, ReductionOpsType());
@@ -7546,9 +7551,9 @@ class HorizontalReduction {
 
   /// Add all reduction operations for the reduction instruction \p I.
   void addReductionOps(Instruction *I) {
-    if (auto *Sel = dyn_cast<SelectInst>(I)) {
-      ReductionOps[0].emplace_back(Sel->getCondition());
-      ReductionOps[1].emplace_back(Sel);
+    if (isCmpSelMinMax(I)) {
+      ReductionOps[0].emplace_back(cast<SelectInst>(I)->getCondition());
+      ReductionOps[1].emplace_back(I);
     } else {
       ReductionOps[0].emplace_back(I);
     }
@@ -7675,8 +7680,8 @@ public:
       // ultimate reduction.
       const bool IsRdxInst = EdgeRdxKind == RdxKind;
       if (EdgeInst != Phi && EdgeInst != Inst &&
-          hasSameParent(EdgeInst, Inst->getParent(), IsRdxInst) &&
-          hasRequiredNumberOfUses(isa<SelectInst>(Inst), EdgeInst) &&
+          hasSameParent(EdgeInst, Inst->getParent()) &&
+          hasRequiredNumberOfUses(isCmpSelMinMax(Inst), EdgeInst) &&
           (!LeafOpcode || LeafOpcode == EdgeInst->getOpcode() || IsRdxInst)) {
         if (IsRdxInst) {
           // We need to be able to reassociate the reduction operations.
@@ -7837,7 +7842,7 @@ public:
       // Emit a reduction. If the root is a select (min/max idiom), the insert
       // point is the compare condition of that select.
       Instruction *RdxRootInst = cast<Instruction>(ReductionRoot);
-      if (isa<SelectInst>(RdxRootInst))
+      if (isCmpSelMinMax(RdxRootInst))
         Builder.SetInsertPoint(getCmpForMinMaxReduction(RdxRootInst));
       else
         Builder.SetInsertPoint(RdxRootInst);