SLPVectorizer: support slp-vectorization of PHINodes between basic blocks
authorNadav Rotem <nrotem@apple.com>
Tue, 25 Jun 2013 23:04:09 +0000 (23:04 +0000)
committerNadav Rotem <nrotem@apple.com>
Tue, 25 Jun 2013 23:04:09 +0000 (23:04 +0000)
llvm-svn: 184888

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
llvm/test/Transforms/SLPVectorizer/X86/phi.ll [new file with mode: 0644]

index 88bcd90..9c8244b 100644 (file)
@@ -239,6 +239,10 @@ public:
   /// NOTICE: The vectorization methods also use this set.
   ValueSet MustGather;
 
+  /// Contains PHINodes that are being processed. We use this data structure
+  /// to stop cycles in the graph.
+  ValueSet VisitedPHIs;
+
   /// Contains a list of values that are used outside the current tree. This
   /// set must be reset between runs.
   SetVector<Value *> MultiUserVals;
@@ -457,13 +461,31 @@ void FuncSLP::getTreeUses_rec(ArrayRef<Value *> VL, unsigned Depth) {
 
   // Mark instructions with multiple users.
   for (unsigned i = 0, e = VL.size(); i < e; ++i) {
+    if (PHINode *PN = dyn_cast<PHINode>(VL[i])) {
+      unsigned NumUses = 0;
+      // Check that PHINodes have only one external (non-self) use.
+      for (Value::use_iterator U = VL[i]->use_begin(), UE = VL[i]->use_end();
+           U != UE; ++U) {
+        // Don't count self uses.
+        if (*U == PN)
+          continue;
+        NumUses++;
+      }
+      if (NumUses > 1) {
+        DEBUG(dbgs() << "SLP: Adding PHI to MultiUserVals "
+              "because it has " << NumUses << " users:" << *PN << " \n");
+        MultiUserVals.insert(PN);
+      }
+      continue;
+    }
+
     Instruction *I = dyn_cast<Instruction>(VL[i]);
     // Remember to check if all of the users of this instruction are vectorized
     // within our tree. At depth zero we have no local users, only external
     // users that we don't care about.
     if (Depth && I && I->getNumUses() > 1) {
       DEBUG(dbgs() << "SLP: Adding to MultiUserVals "
-                      "because it has multiple users:" << *I << " \n");
+            "because it has " << I->getNumUses() << " users:" << *I << " \n");
       MultiUserVals.insert(I);
     }
   }
@@ -483,6 +505,24 @@ void FuncSLP::getTreeUses_rec(ArrayRef<Value *> VL, unsigned Depth) {
     return MustGather.insert(VL.begin(), VL.end());
 
   switch (Opcode) {
+  case Instruction::PHI: {
+    PHINode *PH = dyn_cast<PHINode>(VL0);
+
+    // Stop self cycles.
+    if (VisitedPHIs.count(PH))
+        return;
+
+    VisitedPHIs.insert(PH);
+    for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) {
+      ValueList Operands;
+      // Prepare the operand vector.
+      for (unsigned j = 0; j < VL.size(); ++j)
+        Operands.push_back(cast<PHINode>(VL[j])->getIncomingValue(i));
+
+      getTreeUses_rec(Operands, Depth + 1);
+    }
+    return;
+  }
   case Instruction::ExtractElement: {
     VectorType *VecTy = VectorType::get(VL[0]->getType(), VL.size());
     // No need to follow ExtractElements that are going to be optimized away.
@@ -640,6 +680,35 @@ int FuncSLP::getTreeCost_rec(ArrayRef<Value *> VL, unsigned Depth) {
 
   Instruction *VL0 = cast<Instruction>(VL[0]);
   switch (Opcode) {
+  case Instruction::PHI: {
+    PHINode *PH = dyn_cast<PHINode>(VL0);
+
+    // Stop self cycles.
+    if (VisitedPHIs.count(PH))
+        return 0;
+
+    VisitedPHIs.insert(PH);
+    int TotalCost = 0;
+    // Calculate the cost of all of the operands.
+    for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) {      
+      ValueList Operands;
+      // Prepare the operand vector.
+      for (unsigned j = 0; j < VL.size(); ++j)
+        Operands.push_back(cast<PHINode>(VL[j])->getIncomingValue(i));
+
+      int Cost = getTreeCost_rec(Operands, Depth + 1);
+      if (Cost == MAX_COST)
+        return MAX_COST;
+      TotalCost += TotalCost;
+    }
+
+    if (TotalCost > GatherCost) {
+      MustGather.insert(VL.begin(), VL.end());
+      return GatherCost;
+    }
+
+    return TotalCost;
+  }
   case Instruction::ExtractElement: {
     if (CanReuseExtract(VL, VL.size(), VecTy))
       return 0;
@@ -806,6 +875,7 @@ int FuncSLP::getTreeCost(ArrayRef<Value *> VL) {
   LaneMap.clear();
   MultiUserVals.clear();
   MustGather.clear();
+  VisitedPHIs.clear();
 
   if (!getSameBlock(VL))
     return MAX_COST;
@@ -990,6 +1060,30 @@ Value *FuncSLP::vectorizeTree_rec(ArrayRef<Value *> VL) {
   assert(Opcode == getSameOpcode(VL) && "Invalid opcode");
 
   switch (Opcode) {
+  case Instruction::PHI: {
+    PHINode *PH = dyn_cast<PHINode>(VL0);
+    Builder.SetInsertPoint(PH->getParent()->getFirstInsertionPt());
+    PHINode *NewPhi = Builder.CreatePHI(VecTy, PH->getNumIncomingValues());
+    VectorizedValues[VL0] = NewPhi;
+
+    for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) {
+      ValueList Operands;
+      BasicBlock *IBB = PH->getIncomingBlock(i);
+
+      // Prepare the operand vector.
+      for (unsigned j = 0; j < VL.size(); ++j)
+        Operands.push_back(cast<PHINode>(VL[j])->getIncomingValueForBlock(IBB));
+
+      Builder.SetInsertPoint(IBB->getTerminator());
+      Value *Vec = vectorizeTree_rec(Operands);
+      NewPhi->addIncoming(Vec, IBB);
+    }
+
+    assert(NewPhi->getNumIncomingValues() == PH->getNumIncomingValues() &&
+           "Invalid number of incoming values");
+    return NewPhi;
+  }
+
   case Instruction::ExtractElement: {
     if (CanReuseExtract(VL, VL.size(), VecTy))
       return VL0->getOperand(0);
@@ -1150,6 +1244,7 @@ Value *FuncSLP::vectorizeTree(ArrayRef<Value *> VL) {
     BlocksNumbers[it].forget();
   // Clear the state.
   MustGather.clear();
+  VisitedPHIs.clear();
   VectorizedValues.clear();
   MemBarrierIgnoreList.clear();
   return V;
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/phi.ll b/llvm/test/Transforms/SLPVectorizer/X86/phi.ll
new file mode 100644 (file)
index 0000000..af0b480
--- /dev/null
@@ -0,0 +1,46 @@
+; RUN: opt < %s -basicaa -slp-vectorizer -dce -S -mtriple=i386-apple-macosx10.8.0 -mcpu=corei7-avx | FileCheck %s
+
+target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:128:128-n8:16:32-S128"
+target triple = "i386-apple-macosx10.9.0"
+
+;int foo(double *A, int k) {
+;  double A0;
+;  double A1;
+;  if (k) {
+;    A0 = 3;
+;    A1 = 5;
+;  } else {
+;    A0 = A[10];
+;    A1 = A[11];
+;  }
+;  A[0] = A0;
+;  A[1] = A1;
+;}
+
+
+;CHECK: i32 @foo
+;CHECK: load <2 x double>
+;CHECK: phi <2 x double>
+;CHECK: store <2 x double>
+;CHECK: ret i32 undef
+define i32 @foo(double* nocapture %A, i32 %k) {
+entry:
+  %tobool = icmp eq i32 %k, 0
+  br i1 %tobool, label %if.else, label %if.end
+
+if.else:                                          ; preds = %entry
+  %arrayidx = getelementptr inbounds double* %A, i64 10
+  %0 = load double* %arrayidx, align 8
+  %arrayidx1 = getelementptr inbounds double* %A, i64 11
+  %1 = load double* %arrayidx1, align 8
+  br label %if.end
+
+if.end:                                           ; preds = %entry, %if.else
+  %A0.0 = phi double [ %0, %if.else ], [ 3.000000e+00, %entry ]
+  %A1.0 = phi double [ %1, %if.else ], [ 5.000000e+00, %entry ]
+  store double %A0.0, double* %A, align 8
+  %arrayidx3 = getelementptr inbounds double* %A, i64 1
+  store double %A1.0, double* %arrayidx3, align 8
+  ret i32 undef
+}
+