[SLP][NFC]Fix PR58476: Fix compile time for reductions, NFC.
authorAlexey Bataev <a.bataev@outlook.com>
Thu, 20 Oct 2022 19:54:32 +0000 (12:54 -0700)
committerAlexey Bataev <a.bataev@outlook.com>
Mon, 24 Oct 2022 17:13:24 +0000 (10:13 -0700)
Improve O(N^2) to O(N) in some cases, reduce number of allocations by
reserving memory.
Also, improve analysis of loads reduction values to avoid analysis
of not vectorizable cases.

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

index d60c9f6..583ba28 100644 (file)
@@ -2108,7 +2108,7 @@ public:
   }
   /// Checks if the provided list of reduced values was checked already for
   /// vectorization.
-  bool areAnalyzedReductionVals(ArrayRef<Value *> VL) {
+  bool areAnalyzedReductionVals(ArrayRef<Value *> VL) const {
     return AnalyzedReductionVals.contains(hash_value(VL));
   }
   /// Adds the list of reduced values to list of already checked values for the
@@ -3539,6 +3539,24 @@ namespace {
 enum class LoadsState { Gather, Vectorize, ScatterVectorize };
 } // anonymous namespace
 
+static bool arePointersCompatible(Value *Ptr1, Value *Ptr2,
+                                  bool CompareOpcodes = true) {
+  if (getUnderlyingObject(Ptr1) != getUnderlyingObject(Ptr2))
+    return false;
+  auto *GEP1 = dyn_cast<GetElementPtrInst>(Ptr1);
+  if (!GEP1)
+    return false;
+  auto *GEP2 = dyn_cast<GetElementPtrInst>(Ptr2);
+  if (!GEP2)
+    return false;
+  return GEP1->getNumOperands() == 2 && GEP2->getNumOperands() == 2 &&
+         ((isConstant(GEP1->getOperand(1)) &&
+           isConstant(GEP2->getOperand(1))) ||
+          !CompareOpcodes ||
+          getSameOpcode({GEP1->getOperand(1), GEP2->getOperand(1)})
+              .getOpcode());
+}
+
 /// Checks if the given array of loads can be represented as a vectorized,
 /// scatter or just simple gather.
 static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
@@ -3575,17 +3593,7 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
   // Check the order of pointer operands or that all pointers are the same.
   bool IsSorted = sortPtrAccesses(PointerOps, ScalarTy, DL, SE, Order);
   if (IsSorted || all_of(PointerOps, [&PointerOps](Value *P) {
-        if (getUnderlyingObject(P) != getUnderlyingObject(PointerOps.front()))
-          return false;
-        auto *GEP = dyn_cast<GetElementPtrInst>(P);
-        if (!GEP)
-          return false;
-        auto *GEP0 = cast<GetElementPtrInst>(PointerOps.front());
-        return GEP->getNumOperands() == 2 &&
-               ((isConstant(GEP->getOperand(1)) &&
-                 isConstant(GEP0->getOperand(1))) ||
-                getSameOpcode({GEP->getOperand(1), GEP0->getOperand(1)})
-                    .getOpcode());
+        return arePointersCompatible(P, PointerOps.front());
       })) {
     if (IsSorted) {
       Value *Ptr0;
@@ -4628,11 +4636,11 @@ static std::pair<size_t, size_t> generateKeySubkey(
   hash_code SubKey = hash_value(0);
   // Sort the loads by the distance between the pointers.
   if (auto *LI = dyn_cast<LoadInst>(V)) {
-    Key = hash_combine(hash_value(Instruction::Load), Key);
+    Key = hash_combine(LI->getType(), hash_value(Instruction::Load), Key);
     if (LI->isSimple())
       SubKey = hash_value(LoadsSubkeyGenerator(Key, LI));
     else
-      SubKey = hash_value(LI);
+      Key = SubKey = hash_value(LI);
   } else if (isVectorLikeInstWithConstOps(V)) {
     // Sort extracts by the vector operands.
     if (isa<ExtractElementInst, UndefValue>(V))
@@ -4660,7 +4668,7 @@ static std::pair<size_t, size_t> generateKeySubkey(
       if (isa<CastInst>(I)) {
         std::pair<size_t, size_t> OpVals =
             generateKeySubkey(I->getOperand(0), TLI, LoadsSubkeyGenerator,
-                              /*=AllowAlternate*/ true);
+                              /*AllowAlternate=*/true);
         Key = hash_combine(OpVals.first, Key);
         SubKey = hash_combine(OpVals.first, SubKey);
       }
@@ -4719,7 +4727,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
                                 &UserTreeIdx,
                                 this](const InstructionsState &S) {
     // Check that every instruction appears once in this bundle.
-    DenseMap<Value *, unsigned> UniquePositions;
+    DenseMap<Value *, unsigned> UniquePositions(VL.size());
     for (Value *V : VL) {
       if (isConstant(V)) {
         ReuseShuffleIndicies.emplace_back(
@@ -4877,7 +4885,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
        BB &&
        sortPtrAccesses(VL, UserTreeIdx.UserTE->getMainOp()->getType(), *DL, *SE,
                        SortedIndices));
-  if (allConstant(VL) || isSplat(VL) || !AreAllSameInsts ||
+  if (!AreAllSameInsts || allConstant(VL) || isSplat(VL) ||
       (isa<InsertElementInst, ExtractValueInst, ExtractElementInst>(
            S.OpValue) &&
        !all_of(VL, isVectorLikeInstWithConstOps)) ||
@@ -4951,9 +4959,9 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
 
   // Special processing for sorted pointers for ScatterVectorize node with
   // constant indeces only.
-  if (AreAllSameInsts && !(S.getOpcode() && allSameBlock(VL)) &&
-      UserTreeIdx.UserTE &&
-      UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize) {
+  if (AreAllSameInsts && UserTreeIdx.UserTE &&
+      UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize &&
+      !(S.getOpcode() && allSameBlock(VL))) {
     assert(S.OpValue->getType()->isPointerTy() &&
            count_if(VL, [](Value *V) { return isa<GetElementPtrInst>(V); }) >=
                2 &&
@@ -11104,6 +11112,13 @@ class HorizontalReduction {
     return I->getOperand(getFirstOperandIndex(I) + 1);
   }
 
+  static bool isGoodForReduction(ArrayRef<Value *> Data) {
+    int Sz = Data.size();
+    auto *I = dyn_cast<Instruction>(Data.front());
+    return Sz > 1 || isConstant(Data.front()) ||
+           (I && !isa<LoadInst>(I) && isValidForAlternation(I->getOpcode()));
+  }
+
 public:
   HorizontalReduction() = default;
 
@@ -11199,6 +11214,9 @@ public:
     MapVector<size_t, MapVector<size_t, MapVector<Value *, unsigned>>>
         PossibleReducedVals;
     initReductionOps(Inst);
+    DenseMap<Value *, SmallVector<LoadInst *>> LoadsMap;
+    SmallSet<size_t, 2> LoadKeyUsed;
+    SmallPtrSet<Value *, 4> DoNotReverseVals;
     while (!Worklist.empty()) {
       Instruction *TreeN = Worklist.pop_back_val();
       SmallVector<Value *> Args;
@@ -11220,18 +11238,36 @@ public:
           size_t Key, Idx;
           std::tie(Key, Idx) = generateKeySubkey(
               V, &TLI,
-              [&PossibleReducedVals, &DL, &SE](size_t Key, LoadInst *LI) {
-                auto It = PossibleReducedVals.find(Key);
-                if (It != PossibleReducedVals.end()) {
-                  for (const auto &LoadData : It->second) {
-                    auto *RLI = cast<LoadInst>(LoadData.second.front().first);
-                    if (getPointersDiff(RLI->getType(),
-                                        RLI->getPointerOperand(), LI->getType(),
-                                        LI->getPointerOperand(), DL, SE,
-                                        /*StrictCheck=*/true))
-                      return hash_value(RLI->getPointerOperand());
+              [&](size_t Key, LoadInst *LI) {
+                Value *Ptr = getUnderlyingObject(LI->getPointerOperand());
+                if (LoadKeyUsed.contains(Key)) {
+                  auto LIt = LoadsMap.find(Ptr);
+                  if (LIt != LoadsMap.end()) {
+                    for (LoadInst *RLI: LIt->second) {
+                      if (getPointersDiff(
+                              RLI->getType(), RLI->getPointerOperand(),
+                              LI->getType(), LI->getPointerOperand(), DL, SE,
+                              /*StrictCheck=*/true))
+                        return hash_value(RLI->getPointerOperand());
+                    }
+                    for (LoadInst *RLI : LIt->second) {
+                      if (arePointersCompatible(RLI->getPointerOperand(),
+                                                LI->getPointerOperand())) {
+                        hash_code SubKey = hash_value(RLI->getPointerOperand());
+                        DoNotReverseVals.insert(RLI);
+                        return SubKey;
+                      }
+                    }
+                    if (LIt->second.size() > 2) {
+                      hash_code SubKey =
+                          hash_value(LIt->second.back()->getPointerOperand());
+                      DoNotReverseVals.insert(LIt->second.back());
+                      return SubKey;
+                    }
                   }
                 }
+                LoadKeyUsed.insert(Key);
+                LoadsMap.try_emplace(Ptr).first->second.push_back(LI);
                 return hash_value(LI->getPointerOperand());
               },
               /*AllowAlternate=*/false);
@@ -11245,17 +11281,35 @@ public:
         size_t Key, Idx;
         std::tie(Key, Idx) = generateKeySubkey(
             TreeN, &TLI,
-            [&PossibleReducedVals, &DL, &SE](size_t Key, LoadInst *LI) {
-              auto It = PossibleReducedVals.find(Key);
-              if (It != PossibleReducedVals.end()) {
-                for (const auto &LoadData : It->second) {
-                  auto *RLI = cast<LoadInst>(LoadData.second.front().first);
-                  if (getPointersDiff(RLI->getType(), RLI->getPointerOperand(),
-                                      LI->getType(), LI->getPointerOperand(),
-                                      DL, SE, /*StrictCheck=*/true))
-                    return hash_value(RLI->getPointerOperand());
+            [&](size_t Key, LoadInst *LI) {
+              Value *Ptr = getUnderlyingObject(LI->getPointerOperand());
+              if (LoadKeyUsed.contains(Key)) {
+                auto LIt = LoadsMap.find(Ptr);
+                if (LIt != LoadsMap.end()) {
+                  for (LoadInst *RLI: LIt->second) {
+                    if (getPointersDiff(RLI->getType(),
+                                        RLI->getPointerOperand(), LI->getType(),
+                                        LI->getPointerOperand(), DL, SE,
+                                        /*StrictCheck=*/true))
+                      return hash_value(RLI->getPointerOperand());
+                  }
+                  for (LoadInst *RLI : LIt->second) {
+                    if (arePointersCompatible(RLI->getPointerOperand(),
+                                              LI->getPointerOperand())) {
+                      hash_code SubKey = hash_value(RLI->getPointerOperand());
+                      DoNotReverseVals.insert(RLI);
+                      return SubKey;
+                    }
+                  }
+                  if (LIt->second.size() > 2) {
+                    hash_code SubKey = hash_value(LIt->second.back()->getPointerOperand());
+                    DoNotReverseVals.insert(LIt->second.back());
+                    return SubKey;
+                  }
                 }
               }
+              LoadKeyUsed.insert(Key);
+              LoadsMap.try_emplace(Ptr).first->second.push_back(LI);
               return hash_value(LI->getPointerOperand());
             },
             /*AllowAlternate=*/false);
@@ -11281,9 +11335,27 @@ public:
       stable_sort(PossibleRedValsVect, [](const auto &P1, const auto &P2) {
         return P1.size() > P2.size();
       });
-      ReducedVals.emplace_back();
-      for (ArrayRef<Value *> Data : PossibleRedValsVect)
-        ReducedVals.back().append(Data.rbegin(), Data.rend());
+      int NewIdx = -1;
+      for (ArrayRef<Value *> Data : PossibleRedValsVect) {
+        if (isGoodForReduction(Data) ||
+            (isa<LoadInst>(Data.front()) && NewIdx >= 0 &&
+             isa<LoadInst>(ReducedVals[NewIdx].front()) &&
+             getUnderlyingObject(
+                 cast<LoadInst>(Data.front())->getPointerOperand()) ==
+                 getUnderlyingObject(cast<LoadInst>(ReducedVals[NewIdx].front())
+                                         ->getPointerOperand()))) {
+          if (NewIdx < 0) {
+            NewIdx = ReducedVals.size();
+            ReducedVals.emplace_back();
+          }
+          if (DoNotReverseVals.contains(Data.front()))
+            ReducedVals[NewIdx].append(Data.begin(), Data.end());
+          else
+            ReducedVals[NewIdx].append(Data.rbegin(), Data.rend());
+        } else {
+          ReducedVals.emplace_back().append(Data.rbegin(), Data.rend());
+        }
+      }
     }
     // Sort the reduced values by number of same/alternate opcode and/or pointer
     // operand.
@@ -11301,18 +11373,28 @@ public:
     // If there are a sufficient number of reduction values, reduce
     // to a nearby power-of-2. We can safely generate oversized
     // vectors and rely on the backend to split them to legal sizes.
-    unsigned NumReducedVals = std::accumulate(
-        ReducedVals.begin(), ReducedVals.end(), 0,
-        [](int Num, ArrayRef<Value *> Vals) { return Num + Vals.size(); });
-    if (NumReducedVals < ReductionLimit)
+    size_t NumReducedVals =
+        std::accumulate(ReducedVals.begin(), ReducedVals.end(), 0,
+                        [](size_t Num, ArrayRef<Value *> Vals) {
+                          if (!isGoodForReduction(Vals))
+                            return Num;
+                          return Num + Vals.size();
+                        });
+    if (NumReducedVals < ReductionLimit) {
+      for (ReductionOpsType &RdxOps : ReductionOps)
+        for (Value *RdxOp : RdxOps)
+          V.analyzedReductionRoot(cast<Instruction>(RdxOp));
       return nullptr;
+    }
 
     IRBuilder<> Builder(cast<Instruction>(ReductionRoot));
 
     // Track the reduced values in case if they are replaced by extractelement
     // because of the vectorization.
-    DenseMap<Value *, WeakTrackingVH> TrackedVals;
+    DenseMap<Value *, WeakTrackingVH> TrackedVals(
+        ReducedVals.size() * ReducedVals.front().size() + ExtraArgs.size());
     BoUpSLP::ExtraValueToDebugLocsMap ExternallyUsedValues;
+    ExternallyUsedValues.reserve(ExtraArgs.size() + 1);
     // The same extra argument may be used several times, so log each attempt
     // to use it.
     for (const std::pair<Instruction *, Value *> &Pair : ExtraArgs) {
@@ -11335,7 +11417,8 @@ public:
     // The reduction root is used as the insertion point for new instructions,
     // so set it as externally used to prevent it from being deleted.
     ExternallyUsedValues[ReductionRoot];
-    SmallDenseSet<Value *> IgnoreList;
+    SmallDenseSet<Value *> IgnoreList(ReductionOps.size() *
+                                      ReductionOps.front().size());
     for (ReductionOpsType &RdxOps : ReductionOps)
       for (Value *RdxOp : RdxOps) {
         if (!RdxOp)
@@ -11350,7 +11433,7 @@ public:
       for (Value *V : Candidates)
         TrackedVals.try_emplace(V, V);
 
-    DenseMap<Value *, unsigned> VectorizedVals;
+    DenseMap<Value *, unsigned> VectorizedVals(ReducedVals.size());
     Value *VectorizedTree = nullptr;
     bool CheckForReusedReductionOps = false;
     // Try to vectorize elements based on their type.
@@ -11358,7 +11441,8 @@ public:
       ArrayRef<Value *> OrigReducedVals = ReducedVals[I];
       InstructionsState S = getSameOpcode(OrigReducedVals);
       SmallVector<Value *> Candidates;
-      DenseMap<Value *, Value *> TrackedToOrig;
+      Candidates.reserve(2 * OrigReducedVals.size());
+      DenseMap<Value *, Value *> TrackedToOrig(2 * OrigReducedVals.size());
       for (unsigned Cnt = 0, Sz = OrigReducedVals.size(); Cnt < Sz; ++Cnt) {
         Value *RdxVal = TrackedVals.find(OrigReducedVals[Cnt])->second;
         // Check if the reduction value was not overriden by the extractelement
@@ -11483,18 +11567,14 @@ public:
                    });
         }
         // Number of uses of the candidates in the vector of values.
-        SmallDenseMap<Value *, unsigned> NumUses;
+        SmallDenseMap<Value *, unsigned> NumUses(Candidates.size());
         for (unsigned Cnt = 0; Cnt < Pos; ++Cnt) {
           Value *V = Candidates[Cnt];
-          if (NumUses.count(V) > 0)
-            continue;
-          NumUses[V] = std::count(VL.begin(), VL.end(), V);
+          ++NumUses.try_emplace(V, 0).first->getSecond();
         }
         for (unsigned Cnt = Pos + ReduxWidth; Cnt < NumReducedVals; ++Cnt) {
           Value *V = Candidates[Cnt];
-          if (NumUses.count(V) > 0)
-            continue;
-          NumUses[V] = std::count(VL.begin(), VL.end(), V);
+          ++NumUses.try_emplace(V, 0).first->getSecond();
         }
         // Gather externally used values.
         SmallPtrSet<Value *, 4> Visited;
@@ -11545,9 +11625,8 @@ public:
         }
         InstructionCost Cost = TreeCost + ReductionCost;
         LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost << " for reduction\n");
-        if (!Cost.isValid()) {
+        if (!Cost.isValid())
           return nullptr;
-        }
         if (Cost >= -SLPCostThreshold) {
           V.getORE()->emit([&]() {
             return OptimizationRemarkMissed(