[LV] Create RT checks once VF/IC are selected, track scalar cost.
authorFlorian Hahn <flo@fhahn.com>
Fri, 24 Jun 2022 15:42:11 +0000 (17:42 +0200)
committerFlorian Hahn <flo@fhahn.com>
Fri, 24 Jun 2022 15:42:11 +0000 (17:42 +0200)
This patch updates LV to generate runtime after the VF & IC are selected. It
allows deciding whether to vectorize with runtime checks or not based on
their cost compared to the vector loop.

It also updates VectorizationFactor to include the scalar cost.

Reviewed By: lebedev.ri, dmgreen

Differential Revision: https://reviews.llvm.org/D75981

llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h
llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

index 9cbbf96..342727d 100644 (file)
@@ -188,12 +188,16 @@ struct VectorizationFactor {
   /// Cost of the loop with that width.
   InstructionCost Cost;
 
-  VectorizationFactor(ElementCount Width, InstructionCost Cost)
-      : Width(Width), Cost(Cost) {}
+  /// Cost of the scalar loop.
+  InstructionCost ScalarCost;
+
+  VectorizationFactor(ElementCount Width, InstructionCost Cost,
+                      InstructionCost ScalarCost)
+      : Width(Width), Cost(Cost), ScalarCost(ScalarCost) {}
 
   /// Width 1 means no vectorization, cost 0 means uncomputed cost.
   static VectorizationFactor Disabled() {
-    return {ElementCount::getFixed(1), 0};
+    return {ElementCount::getFixed(1), 0, 0};
   }
 
   bool operator==(const VectorizationFactor &rhs) const {
index ccdc17f..5e7a762 100644 (file)
@@ -5298,7 +5298,8 @@ VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor(
   assert(VFCandidates.count(ElementCount::getFixed(1)) &&
          "Expected Scalar VF to be a candidate");
 
-  const VectorizationFactor ScalarCost(ElementCount::getFixed(1), ExpectedCost);
+  const VectorizationFactor ScalarCost(ElementCount::getFixed(1), ExpectedCost,
+                                       ExpectedCost);
   VectorizationFactor ChosenFactor = ScalarCost;
 
   bool ForceVectorization = Hints->getForce() == LoopVectorizeHints::FK_Enabled;
@@ -5316,7 +5317,7 @@ VectorizationFactor LoopVectorizationCostModel::selectVectorizationFactor(
       continue;
 
     VectorizationCostTy C = expectedCost(i, &InvalidCosts);
-    VectorizationFactor Candidate(i, C.first);
+    VectorizationFactor Candidate(i, C.first, ScalarCost.ScalarCost);
 
 #ifndef NDEBUG
     unsigned AssumedMinimumVscale = 1;
@@ -5509,7 +5510,7 @@ LoopVectorizationCostModel::selectEpilogueVectorizationFactor(
     LLVM_DEBUG(dbgs() << "LEV: Epilogue vectorization factor is forced.\n";);
     ElementCount ForcedEC = ElementCount::getFixed(EpilogueVectorizationForceVF);
     if (LVP.hasPlanWithVF(ForcedEC))
-      return {ForcedEC, 0};
+      return {ForcedEC, 0, 0};
     else {
       LLVM_DEBUG(
           dbgs()
@@ -7432,7 +7433,7 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) {
     if (VPlanBuildStressTest)
       return VectorizationFactor::Disabled();
 
-    return {VF, 0 /*Cost*/};
+    return {VF, 0 /*Cost*/, 0 /* ScalarCost */};
   }
 
   LLVM_DEBUG(
@@ -7483,7 +7484,7 @@ LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
       CM.collectInLoopReductions();
       buildVPlansWithVPRecipes(UserVF, UserVF);
       LLVM_DEBUG(printPlans(dbgs()));
-      return {{UserVF, 0}};
+      return {{UserVF, 0, 0}};
     } else
       reportVectorizationInfo("UserVF ignored because of invalid costs.",
                               "InvalidCost", ORE, OrigLoop);
@@ -10411,6 +10412,8 @@ bool LoopVectorizePass::processLoop(Loop *L) {
   VectorizationFactor VF = VectorizationFactor::Disabled();
   unsigned IC = 1;
 
+  GeneratedRTChecks Checks(*PSE.getSE(), DT, LI,
+                           F->getParent()->getDataLayout());
   if (MaybeVF) {
     if (LVP.requiresTooManyRuntimeChecks()) {
       ORE->emit([&]() {
@@ -10427,6 +10430,12 @@ bool LoopVectorizePass::processLoop(Loop *L) {
     VF = *MaybeVF;
     // Select the interleave count.
     IC = CM.selectInterleaveCount(VF.Width, *VF.Cost.getValue());
+
+    unsigned SelectedIC = std::max(IC, UserIC);
+    //  Optimistically generate runtime checks if they are needed. Drop them if
+    //  they turn out to not be profitable.
+    if (VF.Width.isVector() || SelectedIC > 1)
+      Checks.Create(L, *LVL.getLAI(), PSE.getPredicate(), VF.Width, SelectedIC);
   }
 
   // Identify the diagnostic messages that should be produced.
@@ -10514,14 +10523,6 @@ bool LoopVectorizePass::processLoop(Loop *L) {
   bool DisableRuntimeUnroll = false;
   MDNode *OrigLoopID = L->getLoopID();
   {
-    // Optimistically generate runtime checks. Drop them if they turn out to not
-    // be profitable. Limit the scope of Checks, so the cleanup happens
-    // immediately after vector codegeneration is done.
-    GeneratedRTChecks Checks(*PSE.getSE(), DT, LI,
-                             F->getParent()->getDataLayout());
-    if (!VF.Width.isScalar() || IC > 1)
-      Checks.Create(L, *LVL.getLAI(), PSE.getPredicate(), VF.Width, IC);
-
     using namespace ore;
     if (!VectorizeLoop) {
       assert(IC > 1 && "interleave count should not be 1 or 0");