[ConstantFolding] refactor helper for vector reductions; NFC
authorSanjay Patel <spatel@rotateright.com>
Thu, 29 Apr 2021 16:07:51 +0000 (12:07 -0400)
committerSanjay Patel <spatel@rotateright.com>
Thu, 29 Apr 2021 16:09:22 +0000 (12:09 -0400)
We should handle other cases (undef/poison), so reduce
the duplication of repeated switches.

llvm/lib/Analysis/ConstantFolding.cpp

index dc55930..cfc1fc4 100644 (file)
@@ -1702,19 +1702,29 @@ Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double), double V,
   return GetConstantFoldFPValue(V, Ty);
 }
 
-Constant *ConstantFoldVectorReduce(Intrinsic::ID IID, Constant *Op) {
+Constant *constantFoldVectorReduce(Intrinsic::ID IID, Constant *Op) {
   FixedVectorType *VT = dyn_cast<FixedVectorType>(Op->getType());
   if (!VT)
     return nullptr;
-  ConstantInt *CI = dyn_cast<ConstantInt>(Op->getAggregateElement(0U));
-  if (!CI)
+
+  // This isn't strictly necessary, but handle the special/common case of zero:
+  // all integer reductions of a zero input produce zero.
+  if (isa<ConstantAggregateZero>(Op))
+    return ConstantInt::get(VT->getElementType(), 0);
+
+  // TODO: Handle undef and poison.
+  if (!isa<ConstantVector>(Op) && !isa<ConstantDataVector>(Op))
     return nullptr;
-  APInt Acc = CI->getValue();
 
-  for (unsigned I = 1; I < VT->getNumElements(); I++) {
-    if (!(CI = dyn_cast<ConstantInt>(Op->getAggregateElement(I))))
+  auto *EltC = dyn_cast<ConstantInt>(Op->getAggregateElement(0U));
+  if (!EltC)
+    return nullptr;
+
+  APInt Acc = EltC->getValue();
+  for (unsigned I = 1, E = VT->getNumElements(); I != E; I++) {
+    if (!(EltC = dyn_cast<ConstantInt>(Op->getAggregateElement(I))))
       return nullptr;
-    const APInt &X = CI->getValue();
+    const APInt &X = EltC->getValue();
     switch (IID) {
     case Intrinsic::vector_reduce_add:
       Acc = Acc + X;
@@ -2241,20 +2251,20 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
     }
   }
 
-  if (isa<ConstantAggregateZero>(Operands[0])) {
-    switch (IntrinsicID) {
-    default: break;
-    case Intrinsic::vector_reduce_add:
-    case Intrinsic::vector_reduce_mul:
-    case Intrinsic::vector_reduce_and:
-    case Intrinsic::vector_reduce_or:
-    case Intrinsic::vector_reduce_xor:
-    case Intrinsic::vector_reduce_smin:
-    case Intrinsic::vector_reduce_smax:
-    case Intrinsic::vector_reduce_umin:
-    case Intrinsic::vector_reduce_umax:
-      return ConstantInt::get(Ty, 0);
-    }
+  switch (IntrinsicID) {
+  default: break;
+  case Intrinsic::vector_reduce_add:
+  case Intrinsic::vector_reduce_mul:
+  case Intrinsic::vector_reduce_and:
+  case Intrinsic::vector_reduce_or:
+  case Intrinsic::vector_reduce_xor:
+  case Intrinsic::vector_reduce_smin:
+  case Intrinsic::vector_reduce_smax:
+  case Intrinsic::vector_reduce_umin:
+  case Intrinsic::vector_reduce_umax:
+    if (Constant *C = constantFoldVectorReduce(IntrinsicID, Operands[0]))
+      return C;
+    break;
   }
 
   // Support ConstantVector in case we have an Undef in the top.
@@ -2263,18 +2273,6 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
     auto *Op = cast<Constant>(Operands[0]);
     switch (IntrinsicID) {
     default: break;
-    case Intrinsic::vector_reduce_add:
-    case Intrinsic::vector_reduce_mul:
-    case Intrinsic::vector_reduce_and:
-    case Intrinsic::vector_reduce_or:
-    case Intrinsic::vector_reduce_xor:
-    case Intrinsic::vector_reduce_smin:
-    case Intrinsic::vector_reduce_smax:
-    case Intrinsic::vector_reduce_umin:
-    case Intrinsic::vector_reduce_umax:
-      if (Constant *C = ConstantFoldVectorReduce(IntrinsicID, Op))
-        return C;
-      break;
     case Intrinsic::x86_sse_cvtss2si:
     case Intrinsic::x86_sse_cvtss2si64:
     case Intrinsic::x86_sse2_cvtsd2si: