[Analysis] reduce code for matching min/max; NFC
authorSanjay Patel <spatel@rotateright.com>
Thu, 31 Dec 2020 21:45:33 +0000 (16:45 -0500)
committerSanjay Patel <spatel@rotateright.com>
Thu, 31 Dec 2020 22:19:37 +0000 (17:19 -0500)
This might also make it easier to adapt if we want
to match min/max intrinsics rather than cmp+sel idioms.

The 'const' part is to potentially avoid confusion
in calling code. There's some surprising and possibly
wrong behavior related to matching min/max reductions
differently than other reductions.

llvm/include/llvm/Analysis/IVDescriptors.h
llvm/lib/Analysis/IVDescriptors.cpp

index e736adf..30216e2 100644 (file)
@@ -96,15 +96,15 @@ public:
         : IsRecurrence(true), PatternLastInst(I), MinMaxKind(K),
           UnsafeAlgebraInst(UAI) {}
 
-    bool isRecurrence() { return IsRecurrence; }
+    bool isRecurrence() const { return IsRecurrence; }
 
-    bool hasUnsafeAlgebra() { return UnsafeAlgebraInst != nullptr; }
+    bool hasUnsafeAlgebra() const { return UnsafeAlgebraInst != nullptr; }
 
-    Instruction *getUnsafeAlgebraInst() { return UnsafeAlgebraInst; }
+    Instruction *getUnsafeAlgebraInst() const { return UnsafeAlgebraInst; }
 
-    MinMaxRecurrenceKind getMinMaxKind() { return MinMaxKind; }
+    MinMaxRecurrenceKind getMinMaxKind() const { return MinMaxKind; }
 
-    Instruction *getPatternInst() { return PatternLastInst; }
+    Instruction *getPatternInst() const { return PatternLastInst; }
 
   private:
     // Is this instruction a recurrence candidate.
@@ -134,10 +134,11 @@ public:
   /// Returns true if all uses of the instruction I is within the Set.
   static bool areAllUsesIn(Instruction *I, SmallPtrSetImpl<Instruction *> &Set);
 
-  /// Returns a struct describing if the instruction if the instruction is a
+  /// Returns a struct describing if the instruction is a
   /// Select(ICmp(X, Y), X, Y) instruction pattern corresponding to a min(X, Y)
-  /// or max(X, Y).
-  static InstDesc isMinMaxSelectCmpPattern(Instruction *I, InstDesc &Prev);
+  /// or max(X, Y). \p Prev is specifies the description of an already processed
+  /// select instruction, so its corresponding cmp can be matched to it.
+  static InstDesc isMinMaxSelectCmpPattern(Instruction *I, const InstDesc &Prev);
 
   /// Returns a struct describing if the instruction is a
   /// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
index d975651..eac6f3c 100644 (file)
@@ -456,53 +456,42 @@ bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind,
   return true;
 }
 
-/// Returns true if the instruction is a Select(ICmp(X, Y), X, Y) instruction
-/// pattern corresponding to a min(X, Y) or max(X, Y).
 RecurrenceDescriptor::InstDesc
-RecurrenceDescriptor::isMinMaxSelectCmpPattern(Instruction *I, InstDesc &Prev) {
-
-  assert((isa<ICmpInst>(I) || isa<FCmpInst>(I) || isa<SelectInst>(I)) &&
-         "Expect a select instruction");
-  Instruction *Cmp = nullptr;
-  SelectInst *Select = nullptr;
+RecurrenceDescriptor::isMinMaxSelectCmpPattern(Instruction *I,
+                                               const InstDesc &Prev) {
+  assert((isa<CmpInst>(I) || isa<SelectInst>(I)) &&
+         "Expected a cmp or select instruction");
 
   // We must handle the select(cmp()) as a single instruction. Advance to the
   // select.
-  if ((Cmp = dyn_cast<ICmpInst>(I)) || (Cmp = dyn_cast<FCmpInst>(I))) {
-    if (!Cmp->hasOneUse() || !(Select = dyn_cast<SelectInst>(*I->user_begin())))
-      return InstDesc(false, I);
-    return InstDesc(Select, Prev.getMinMaxKind());
+  CmpInst::Predicate Pred;
+  if (match(I, m_OneUse(m_Cmp(Pred, m_Value(), m_Value())))) {
+    if (auto *Select = dyn_cast<SelectInst>(*I->user_begin()))
+      return InstDesc(Select, Prev.getMinMaxKind());
   }
 
-  // Only handle single use cases for now.
-  if (!(Select = dyn_cast<SelectInst>(I)))
+  // Only match select with single use cmp condition.
+  if (!match(I, m_Select(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), m_Value(),
+                         m_Value())))
     return InstDesc(false, I);
-  if (!(Cmp = dyn_cast<ICmpInst>(I->getOperand(0))) &&
-      !(Cmp = dyn_cast<FCmpInst>(I->getOperand(0))))
-    return InstDesc(false, I);
-  if (!Cmp->hasOneUse())
-    return InstDesc(false, I);
-
-  Value *CmpLeft;
-  Value *CmpRight;
 
   // Look for a min/max pattern.
-  if (m_UMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select))
-    return InstDesc(Select, MRK_UIntMin);
-  else if (m_UMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select))
-    return InstDesc(Select, MRK_UIntMax);
-  else if (m_SMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select))
-    return InstDesc(Select, MRK_SIntMax);
-  else if (m_SMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select))
-    return InstDesc(Select, MRK_SIntMin);
-  else if (m_OrdFMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select))
-    return InstDesc(Select, MRK_FloatMin);
-  else if (m_OrdFMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select))
-    return InstDesc(Select, MRK_FloatMax);
-  else if (m_UnordFMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select))
-    return InstDesc(Select, MRK_FloatMin);
-  else if (m_UnordFMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select))
-    return InstDesc(Select, MRK_FloatMax);
+  if (match(I, m_UMin(m_Value(), m_Value())))
+    return InstDesc(I, MRK_UIntMin);
+  if (match(I, m_UMax(m_Value(), m_Value())))
+    return InstDesc(I, MRK_UIntMax);
+  if (match(I, m_SMax(m_Value(), m_Value())))
+    return InstDesc(I, MRK_SIntMax);
+  if (match(I, m_SMin(m_Value(), m_Value())))
+    return InstDesc(I, MRK_SIntMin);
+  if (match(I, m_OrdFMin(m_Value(), m_Value())))
+    return InstDesc(I, MRK_FloatMin);
+  if (match(I, m_OrdFMax(m_Value(), m_Value())))
+    return InstDesc(I, MRK_FloatMax);
+  if (match(I, m_UnordFMin(m_Value(), m_Value())))
+    return InstDesc(I, MRK_FloatMin);
+  if (match(I, m_UnordFMax(m_Value(), m_Value())))
+    return InstDesc(I, MRK_FloatMax);
 
   return InstDesc(false, I);
 }