[AArch64][SVE] Add API for conversion between SVE predicate pattern and element numbe...
authorJun Ma <JunMa@linux.alibaba.com>
Wed, 25 Aug 2021 09:25:39 +0000 (17:25 +0800)
committerJun Ma <JunMa@linux.alibaba.com>
Fri, 27 Aug 2021 12:03:48 +0000 (20:03 +0800)
This patch solely moves convert operation between SVE predicate pattern
and element number into two small functions. It's pre-commit patch for optimize
pture with known sve register width.

Differential Revision: https://reviews.llvm.org/D108705

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h

index 3ca8fcf..f01b14a 100644 (file)
@@ -18022,38 +18022,9 @@ static SDValue getPredicateForFixedLengthVector(SelectionDAG &DAG, SDLoc &DL,
          DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
          "Expected legal fixed length vector!");
 
-  int PgPattern;
-  switch (VT.getVectorNumElements()) {
-  default:
-    llvm_unreachable("unexpected element count for SVE predicate");
-  case 1:
-    PgPattern = AArch64SVEPredPattern::vl1;
-    break;
-  case 2:
-    PgPattern = AArch64SVEPredPattern::vl2;
-    break;
-  case 4:
-    PgPattern = AArch64SVEPredPattern::vl4;
-    break;
-  case 8:
-    PgPattern = AArch64SVEPredPattern::vl8;
-    break;
-  case 16:
-    PgPattern = AArch64SVEPredPattern::vl16;
-    break;
-  case 32:
-    PgPattern = AArch64SVEPredPattern::vl32;
-    break;
-  case 64:
-    PgPattern = AArch64SVEPredPattern::vl64;
-    break;
-  case 128:
-    PgPattern = AArch64SVEPredPattern::vl128;
-    break;
-  case 256:
-    PgPattern = AArch64SVEPredPattern::vl256;
-    break;
-  }
+  unsigned PgPattern =
+      getSVEPredPatternFromNumElements(VT.getVectorNumElements());
+  assert(PgPattern && "Unexpected element count for SVE predicate");
 
   // TODO: For vectors that are exactly getMaxSVEVectorSizeInBits big, we can
   // use AArch64SVEPredPattern::all, which can enable the use of unpredicated
index 367c980..ebd44b4 100644 (file)
@@ -593,39 +593,11 @@ static Optional<Instruction *> instCombineSVELast(InstCombiner &IC,
       cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue();
 
   // Can the intrinsic's predicate be converted to a known constant index?
-  unsigned Idx;
-  switch (PTruePattern) {
-  default:
+  unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern);
+  if (!MinNumElts)
     return None;
-  case AArch64SVEPredPattern::vl1:
-    Idx = 0;
-    break;
-  case AArch64SVEPredPattern::vl2:
-    Idx = 1;
-    break;
-  case AArch64SVEPredPattern::vl3:
-    Idx = 2;
-    break;
-  case AArch64SVEPredPattern::vl4:
-    Idx = 3;
-    break;
-  case AArch64SVEPredPattern::vl5:
-    Idx = 4;
-    break;
-  case AArch64SVEPredPattern::vl6:
-    Idx = 5;
-    break;
-  case AArch64SVEPredPattern::vl7:
-    Idx = 6;
-    break;
-  case AArch64SVEPredPattern::vl8:
-    Idx = 7;
-    break;
-  case AArch64SVEPredPattern::vl16:
-    Idx = 15;
-    break;
-  }
 
+  unsigned Idx = MinNumElts - 1;
   // Increment the index if extracting the element after the last active
   // predicate element.
   if (IsAfter)
@@ -678,26 +650,9 @@ instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
     return IC.replaceInstUsesWith(II, VScale);
   }
 
-  unsigned MinNumElts = 0;
-  switch (Pattern) {
-  default:
-    return None;
-  case AArch64SVEPredPattern::vl1:
-  case AArch64SVEPredPattern::vl2:
-  case AArch64SVEPredPattern::vl3:
-  case AArch64SVEPredPattern::vl4:
-  case AArch64SVEPredPattern::vl5:
-  case AArch64SVEPredPattern::vl6:
-  case AArch64SVEPredPattern::vl7:
-  case AArch64SVEPredPattern::vl8:
-    MinNumElts = Pattern;
-    break;
-  case AArch64SVEPredPattern::vl16:
-    MinNumElts = 16;
-    break;
-  }
+  unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern);
 
-  return NumElts >= MinNumElts
+  return MinNumElts && NumElts >= MinNumElts
              ? Optional<Instruction *>(IC.replaceInstUsesWith(
                    II, ConstantInt::get(II.getType(), MinNumElts)))
              : None;
index d168c2a..5555c4b 100644 (file)
@@ -454,6 +454,60 @@ namespace AArch64SVEPredPattern {
 #include "AArch64GenSystemOperands.inc"
 }
 
+/// Return the number of active elements for VL1 to VL256 predicate pattern,
+/// zero for all other patterns.
+inline unsigned getNumElementsFromSVEPredPattern(unsigned Pattern) {
+  switch (Pattern) {
+  default:
+    return 0;
+  case AArch64SVEPredPattern::vl1:
+  case AArch64SVEPredPattern::vl2:
+  case AArch64SVEPredPattern::vl3:
+  case AArch64SVEPredPattern::vl4:
+  case AArch64SVEPredPattern::vl5:
+  case AArch64SVEPredPattern::vl6:
+  case AArch64SVEPredPattern::vl7:
+  case AArch64SVEPredPattern::vl8:
+    return Pattern;
+  case AArch64SVEPredPattern::vl16:
+    return 16;
+  case AArch64SVEPredPattern::vl32:
+    return 32;
+  case AArch64SVEPredPattern::vl64:
+    return 64;
+  case AArch64SVEPredPattern::vl128:
+    return 128;
+  case AArch64SVEPredPattern::vl256:
+    return 256;
+  }
+}
+
+/// Return specific VL predicate pattern based on the number of elements.
+inline unsigned getSVEPredPatternFromNumElements(unsigned MinNumElts) {
+  switch (MinNumElts) {
+  default:
+    llvm_unreachable("unexpected element count for SVE predicate");
+  case 1:
+    return AArch64SVEPredPattern::vl1;
+  case 2:
+    return AArch64SVEPredPattern::vl2;
+  case 4:
+    return AArch64SVEPredPattern::vl4;
+  case 8:
+    return AArch64SVEPredPattern::vl8;
+  case 16:
+    return AArch64SVEPredPattern::vl16;
+  case 32:
+    return AArch64SVEPredPattern::vl32;
+  case 64:
+    return AArch64SVEPredPattern::vl64;
+  case 128:
+    return AArch64SVEPredPattern::vl128;
+  case 256:
+    return AArch64SVEPredPattern::vl256;
+  }
+}
+
 namespace AArch64ExactFPImm {
   struct ExactFPImm {
     const char *Name;