SLPVectorizer: refactor the code that places extracts. Place the code that decides...
authorNadav Rotem <nrotem@apple.com>
Thu, 11 Jul 2013 04:54:05 +0000 (04:54 +0000)
committerNadav Rotem <nrotem@apple.com>
Thu, 11 Jul 2013 04:54:05 +0000 (04:54 +0000)
llvm-svn: 186058

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

index af6a07b..e25befb 100644 (file)
@@ -246,6 +246,7 @@ public:
     VectorizableTree.clear();
     ScalarToTreeEntry.clear();
     MustGather.clear();
+    ExternalUses.clear();
     MemBarrierIgnoreList.clear();
   }
 
@@ -365,6 +366,23 @@ private:
   /// A list of scalars that we found that we need to keep as scalars.
   ValueSet MustGather;
 
+  /// This POD struct describes one external user in the vectorized tree.
+  struct ExternalUser {
+    ExternalUser (Value *S, llvm::User *U, int L) :
+      Scalar(S), User(U), Lane(L){};
+    // Which scalar in our function.
+    Value *Scalar;
+    // Which user that uses the scalar.
+    llvm::User *User;
+    // Which lane does the scalar belong to.
+    int Lane;
+  };
+  typedef SmallVector<ExternalUser, 16> UserList;
+
+  /// A list of values that need to extracted out of the tree.
+  /// This list holds pairs of (Internal Scalar : External User).
+  UserList ExternalUses;
+
   /// A list of instructions to ignore while sinking
   /// memory instructions. This map must be reset between runs of getCost.
   ValueSet MemBarrierIgnoreList;
@@ -392,6 +410,43 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots) {
   if (!getSameType(Roots))
     return;
   buildTree_rec(Roots, 0);
+
+  // Collect the values that we need to extract from the tree.
+  for (int EIdx = 0, EE = VectorizableTree.size(); EIdx < EE; ++EIdx) {
+    TreeEntry *Entry = &VectorizableTree[EIdx];
+
+    // For each lane:
+    for (int Lane = 0, LE = Entry->Scalars.size(); Lane != LE; ++Lane) {
+      Value *Scalar = Entry->Scalars[Lane];
+
+      // No need to handle users of gathered values.
+      if (Entry->NeedToGather)
+        continue;
+
+      for (Value::use_iterator User = Scalar->use_begin(),
+           UE = Scalar->use_end(); User != UE; ++User) {
+        DEBUG(dbgs() << "SLP: Checking user:" << **User << ".\n");
+
+        bool Gathered = MustGather.count(*User);
+
+        // Skip in-tree scalars that become vectors.
+        if (ScalarToTreeEntry.count(*User) && !Gathered) {
+          DEBUG(dbgs() << "SLP: \tInternal user will be removed:" <<
+                **User << ".\n");
+          int Idx = ScalarToTreeEntry[*User]; (void) Idx;
+          assert(!VectorizableTree[Idx].NeedToGather && "Bad state");
+          continue;
+        }
+
+        if (!isa<Instruction>(*User))
+          continue;
+
+        DEBUG(dbgs() << "SLP: Need to extract:" << **User << " from lane " <<
+              Lane << " from " << *Scalar << ".\n");
+        ExternalUses.push_back(ExternalUser(Scalar, *User, Lane));
+      }
+    }
+  }
 }
 
 
@@ -843,14 +898,32 @@ int BoUpSLP::getTreeCost() {
   DEBUG(dbgs() << "SLP: Calculating cost for tree of size " <<
         VectorizableTree.size() << ".\n");
 
+  if (!VectorizableTree.size()) {
+    assert(!ExternalUses.size() && "We should not have any external users");
+    return 0;
+  }
+
+  unsigned BundleWidth = VectorizableTree[0].Scalars.size();
+
   for (unsigned i = 0, e = VectorizableTree.size(); i != e; ++i) {
     int C = getEntryCost(&VectorizableTree[i]);
     DEBUG(dbgs() << "SLP: Adding cost " << C << " for bundle that starts with "
           << *VectorizableTree[i].Scalars[0] << " .\n");
     Cost += C;
   }
-  DEBUG(dbgs() << "SLP: Total Cost " << Cost << ".\n");
-  return  Cost;
+
+  int ExtractCost = 0;
+  for (UserList::iterator I = ExternalUses.begin(), E = ExternalUses.end();
+       I != E; ++I) {
+
+    VectorType *VecTy = VectorType::get(I->Scalar->getType(), BundleWidth);
+    ExtractCost += TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy,
+                                           I->Lane);
+  }
+
+
+  DEBUG(dbgs() << "SLP: Total Cost " << Cost + ExtractCost<< ".\n");
+  return  Cost + ExtractCost;
 }
 
 int BoUpSLP::getGatherCost(Type *Ty) {
@@ -1006,8 +1079,26 @@ Value *BoUpSLP::Gather(ArrayRef<Value *> VL, VectorType *Ty) {
   // Generate the 'InsertElement' instruction.
   for (unsigned i = 0; i < Ty->getNumElements(); ++i) {
     Vec = Builder.CreateInsertElement(Vec, VL[i], Builder.getInt32(i));
-    if (Instruction *I = dyn_cast<Instruction>(Vec))
-      GatherSeq.insert(I);
+    if (Instruction *Insrt = dyn_cast<Instruction>(Vec)) {
+      GatherSeq.insert(Insrt);
+
+      // Add to our 'need-to-extract' list.
+      if (ScalarToTreeEntry.count(VL[i])) {
+        int Idx = ScalarToTreeEntry[VL[i]];
+        TreeEntry *E = &VectorizableTree[Idx];
+        // Find which lane we need to extract.
+        int FoundLane = -1;
+        for (unsigned Lane = 0, LE = VL.size(); Lane != LE; ++Lane) {
+          // Is this the lane of the scalar that we are looking for ?
+          if (E->Scalars[Lane] == VL[i]) {
+            FoundLane = Lane;
+            break;
+          }
+        }
+        assert(FoundLane >= 0 && "Could not find the correct lane");
+        ExternalUses.push_back(ExternalUser(VL[i], Insrt, FoundLane));
+      }
+    }
   }
 
   return Vec;
@@ -1222,6 +1313,42 @@ void BoUpSLP::vectorizeTree() {
   Builder.SetInsertPoint(F->getEntryBlock().begin());
   vectorizeTree(&VectorizableTree[0]);
 
+  DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size() << " values .\n");
+
+  // Extract all of the elements with the external uses.
+  for (UserList::iterator it = ExternalUses.begin(), e = ExternalUses.end();
+       it != e; ++it) {
+    Value *Scalar = it->Scalar;
+    llvm::User *User = it->User;
+    if (std::find(Scalar->use_begin(), Scalar->use_end(), User) ==
+        Scalar->use_end())
+      continue;
+    assert(ScalarToTreeEntry.count(Scalar) && "Invalid scalar");
+
+    int Idx = ScalarToTreeEntry[Scalar];
+    TreeEntry *E = &VectorizableTree[Idx];
+    assert(!E->NeedToGather && "Extracting from a gather list");
+
+    Value *Vec = E->VectorizedValue;
+    assert(Vec && "Can't find vectorizable value");
+
+    // Generate extracts for out-of-tree users.
+    // Find the insertion point for the extractelement lane.
+    Instruction *Loc = 0;
+    if (PHINode *PN = dyn_cast<PHINode>(Vec)) {
+      Loc = PN->getParent()->getFirstInsertionPt();
+    } else if (Instruction *Iv = dyn_cast<Instruction>(Vec)){
+      Loc = ++((BasicBlock::iterator)*Iv);
+    } else {
+      Loc = F->getEntryBlock().begin();
+    }
+
+    Builder.SetInsertPoint(Loc);
+    Value *Ex = Builder.CreateExtractElement(Vec, Builder.getInt32(it->Lane));
+    User->replaceUsesOfWith(Scalar, Ex);
+    DEBUG(dbgs() << "SLP: Replaced:" << *User << ".\n");
+  }
+
   // For each vectorized value:
   for (int EIdx = 0, EE = VectorizableTree.size(); EIdx < EE; ++EIdx) {
     TreeEntry *Entry = &VectorizableTree[EIdx];
@@ -1237,43 +1364,6 @@ void BoUpSLP::vectorizeTree() {
       Value *Vec = Entry->VectorizedValue;
       assert(Vec && "Can't find vectorizable value");
 
-      SmallVector<User*, 16> Users(Scalar->use_begin(), Scalar->use_end());
-
-      for (SmallVector<User*, 16>::iterator User = Users.begin(),
-           UE = Users.end(); User != UE; ++User) {
-        DEBUG(dbgs() << "SLP: \tupdating user  " << **User << ".\n");
-
-        bool Gathered = MustGather.count(*User);
-
-        // Skip in-tree scalars that become vectors.
-        if (ScalarToTreeEntry.count(*User) && !Gathered) {
-          DEBUG(dbgs() << "SLP: \tUser will be removed soon:" <<
-                **User << ".\n");
-          int Idx = ScalarToTreeEntry[*User]; (void) Idx;
-          assert(!VectorizableTree[Idx].NeedToGather && "bad state ?");
-          continue;
-        }
-
-        if (!isa<Instruction>(*User))
-          continue;
-
-        // Generate extracts for out-of-tree users.
-        // Find the insertion point for the extractelement lane.
-        Instruction *Loc = 0;
-        if (PHINode *PN = dyn_cast<PHINode>(Vec)) {
-          Loc = PN->getParent()->getFirstInsertionPt();
-        } else if (Instruction *Iv = dyn_cast<Instruction>(Vec)){
-          Loc = ++((BasicBlock::iterator)*Iv);
-        } else {
-          Loc = F->getEntryBlock().begin();
-        }
-
-        Builder.SetInsertPoint(Loc);
-        Value *Ex = Builder.CreateExtractElement(Vec, Builder.getInt32(Lane));
-        (*User)->replaceUsesOfWith(Scalar, Ex);
-        DEBUG(dbgs() << "SLP: \tupdated user:" << **User << ".\n");
-      }
-
       Type *Ty = Scalar->getType();
       if (!Ty->isVoidTy()) {
         for (Value::use_iterator User = Scalar->use_begin(), UE = Scalar->use_end();