SLPVectorizer: Add support for vectorizing trees that start at compare instructions.
authorNadav Rotem <nrotem@apple.com>
Mon, 15 Apr 2013 04:25:27 +0000 (04:25 +0000)
committerNadav Rotem <nrotem@apple.com>
Mon, 15 Apr 2013 04:25:27 +0000 (04:25 +0000)
llvm-svn: 179504

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll [new file with mode: 0644]

index d94b2b2..ea33801 100644 (file)
@@ -100,7 +100,7 @@ struct SLPVectorizer : public BasicBlockPass {
     return true;
   }
 
-  bool tryToVectorizeCandidate(BinaryOperator *V,  BoUpSLP &R) {
+  bool tryToVectorize(BinaryOperator *V,  BoUpSLP &R) {
     if (!V) return false;
     // Try to vectorize V.
     if (tryToVectorizePair(V->getOperand(0), V->getOperand(1), R))
@@ -142,25 +142,42 @@ struct SLPVectorizer : public BasicBlockPass {
     bool Changed = false;
     for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; ++it) {
       if (isa<DbgInfoIntrinsic>(it)) continue;
-      PHINode *P = dyn_cast<PHINode>(it);
-      if (!P) return Changed;
-      // Check that the PHI is a reduction PHI.
-      if (P->getNumIncomingValues() != 2) return Changed;
-      Value *Rdx = (P->getIncomingBlock(0) == BB ? P->getIncomingValue(0) :
-                   (P->getIncomingBlock(1) == BB ? P->getIncomingValue(1) : 0));
-      // Check if this is a Binary Operator.
-      BinaryOperator *BI = dyn_cast_or_null<BinaryOperator>(Rdx);
-      if (!BI) continue;
-
-      Value *Inst = BI->getOperand(0);
-      if (Inst == P) Inst = BI->getOperand(1);
-      Changed |= tryToVectorizeCandidate(dyn_cast<BinaryOperator>(Inst), R);
+
+      // Try to vectorize reductions that use PHINodes.
+      if (PHINode *P = dyn_cast<PHINode>(it)) {
+        // Check that the PHI is a reduction PHI.
+        if (P->getNumIncomingValues() != 2) return Changed;
+        Value *Rdx = (P->getIncomingBlock(0) == BB ? P->getIncomingValue(0) :
+                     (P->getIncomingBlock(1) == BB ? P->getIncomingValue(1) :
+                      0));
+        // Check if this is a Binary Operator.
+        BinaryOperator *BI = dyn_cast_or_null<BinaryOperator>(Rdx);
+        if (!BI)
+          continue;
+
+        Value *Inst = BI->getOperand(0);
+        if (Inst == P) Inst = BI->getOperand(1);
+        Changed |= tryToVectorize(dyn_cast<BinaryOperator>(Inst), R);
+        continue;
+      }
+
+      // Try to vectorize trees that start at compare instructions.
+      if (CmpInst *CI = dyn_cast<CmpInst>(it)) {
+        if (tryToVectorizePair(CI->getOperand(0), CI->getOperand(1), R)) {
+          Changed |= true;
+          continue;
+        }
+        for (int i = 0; i < 2; ++i)
+          if (BinaryOperator *BI = dyn_cast<BinaryOperator>(CI->getOperand(i)))
+            Changed |= tryToVectorize(BI, R);
+        continue;
+      }
     }
 
     return Changed;
   }
 
-  bool rollStoreChains(BoUpSLP &R) {
+  bool vectorizeStoreChains(BoUpSLP &R) {
     bool Changed = false;
     // Attempt to sort and vectorize each of the store-groups.
     for (StoreListMap::iterator it = StoreRefs.begin(), e = StoreRefs.end();
@@ -192,17 +209,19 @@ struct SLPVectorizer : public BasicBlockPass {
     // he store instructions.
     BoUpSLP R(&BB, SE, DL, TTI, AA);
 
+    // Vectorize trees that end at reductions.
     bool Changed = vectorizeReductions(&BB, R);
 
-    if (!collectStores(&BB, R))
-      return Changed;
+    // Vectorize trees that end at stores.
+    if (collectStores(&BB, R)) {
+      DEBUG(dbgs()<<"SLP: Found stores to vectorize.\n");
+      Changed |= vectorizeStoreChains(R);
+    }
 
-    if (rollStoreChains(R)) {
-      DEBUG(dbgs()<<"SLP: vectorized in \""<<BB.getParent()->getName()<<"\"\n");
+    if (Changed) {
+      DEBUG(dbgs()<<"SLP: vectorized \""<<BB.getParent()->getName()<<"\"\n");
       DEBUG(verifyFunction(*BB.getParent()));
-      Changed |= true;
     }
-
     return Changed;
   }
 
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll b/llvm/test/Transforms/SLPVectorizer/X86/compare-reduce.ll
new file mode 100644 (file)
index 0000000..05f8e61
--- /dev/null
@@ -0,0 +1,53 @@
+; RUN: opt < %s -basicaa -slp-vectorizer -dce -S -mtriple=x86_64-apple-macosx10.8.0 -mcpu=corei7-avx | FileCheck %s
+
+target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64-S128"
+target triple = "x86_64-apple-macosx10.7.0"
+
+@.str = private unnamed_addr constant [6 x i8] c"bingo\00", align 1
+
+;CHECK: @reduce_compare
+;CHECK: load <2 x double>
+;CHECK: fmul <2 x double>
+;CHECK: fmul <2 x double>
+;CHECK: fadd <2 x double>
+;CHECK: extractelement
+;CHECK: extractelement
+;CHECK: ret
+define void @reduce_compare(double* nocapture %A, i32 %n) {
+entry:
+  %conv = sitofp i32 %n to double
+  br label %for.body
+
+for.body:                                         ; preds = %for.inc, %entry
+  %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.inc ]
+  %0 = shl nsw i64 %indvars.iv, 1
+  %arrayidx = getelementptr inbounds double* %A, i64 %0
+  %1 = load double* %arrayidx, align 8
+  %mul1 = fmul double %conv, %1
+  %mul2 = fmul double %mul1, 7.000000e+00
+  %add = fadd double %mul2, 5.000000e+00
+  %2 = or i64 %0, 1
+  %arrayidx6 = getelementptr inbounds double* %A, i64 %2
+  %3 = load double* %arrayidx6, align 8
+  %mul8 = fmul double %conv, %3
+  %mul9 = fmul double %mul8, 4.000000e+00
+  %add10 = fadd double %mul9, 9.000000e+00
+  %cmp11 = fcmp ogt double %add, %add10
+  br i1 %cmp11, label %if.then, label %for.inc
+
+if.then:                                          ; preds = %for.body
+  %call = tail call i32 (i8*, ...)* @printf(i8* getelementptr inbounds ([6 x i8]* @.str, i64 0, i64 0))
+  br label %for.inc
+
+for.inc:                                          ; preds = %for.body, %if.then
+  %indvars.iv.next = add i64 %indvars.iv, 1
+  %lftr.wideiv = trunc i64 %indvars.iv.next to i32
+  %exitcond = icmp eq i32 %lftr.wideiv, 100
+  br i1 %exitcond, label %for.end, label %for.body
+
+for.end:                                          ; preds = %for.inc
+  ret void
+}
+
+declare i32 @printf(i8* nocapture, ...)
+