[InstCombine] Add more complex folds for extractelement + stepvector
authorDavid Sherwood <david.sherwood@arm.com>
Mon, 19 Jul 2021 14:18:05 +0000 (15:18 +0100)
committerDavid Sherwood <david.sherwood@arm.com>
Tue, 10 Aug 2021 08:17:21 +0000 (09:17 +0100)
I have updated cheapToScalarize to also consider the case when
extracting lanes from a stepvector intrinsic. This required removing
the existing 'bool IsConstantExtractIndex' and passing in the actual
index as a Value instead. We do this because we need to know if the
index is <= known minimum number of elements returned by the stepvector
intrinsic. Effectively, when extracting lane X from a stepvector we
know the value returned is also X.

New tests added here:

  Transforms/InstCombine/vscale_extractelement.ll

Differential Revision: https://reviews.llvm.org/D106358

llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
llvm/test/Transforms/InstCombine/vscale_extractelement.ll

index 32b1537..7a4a80d 100644 (file)
@@ -52,20 +52,29 @@ STATISTIC(NumAggregateReconstructionsSimplified,
           "original aggregate");
 
 /// Return true if the value is cheaper to scalarize than it is to leave as a
-/// vector operation. IsConstantExtractIndex indicates whether we are extracting
-/// one known element from a vector constant.
+/// vector operation. If the extract index \p EI is a constant integer then
+/// some operations may be cheap to scalarize.
 ///
 /// FIXME: It's possible to create more instructions than previously existed.
-static bool cheapToScalarize(Value *V, bool IsConstantExtractIndex) {
+static bool cheapToScalarize(Value *V, Value *EI) {
+  ConstantInt *CEI = dyn_cast<ConstantInt>(EI);
+
   // If we can pick a scalar constant value out of a vector, that is free.
   if (auto *C = dyn_cast<Constant>(V))
-    return IsConstantExtractIndex || C->getSplatValue();
+    return CEI || C->getSplatValue();
+
+  if (CEI && match(V, m_Intrinsic<Intrinsic::experimental_stepvector>())) {
+    ElementCount EC = cast<VectorType>(V->getType())->getElementCount();
+    // Index needs to be lower than the minimum size of the vector, because
+    // for scalable vector, the vector size is known at run time.
+    return CEI->getValue().ult(EC.getKnownMinValue());
+  }
 
   // An insertelement to the same constant index as our extract will simplify
   // to the scalar inserted element. An insertelement to a different constant
   // index is irrelevant to our extract.
   if (match(V, m_InsertElt(m_Value(), m_Value(), m_ConstantInt())))
-    return IsConstantExtractIndex;
+    return CEI;
 
   if (match(V, m_OneUse(m_Load(m_Value()))))
     return true;
@@ -75,14 +84,12 @@ static bool cheapToScalarize(Value *V, bool IsConstantExtractIndex) {
 
   Value *V0, *V1;
   if (match(V, m_OneUse(m_BinOp(m_Value(V0), m_Value(V1)))))
-    if (cheapToScalarize(V0, IsConstantExtractIndex) ||
-        cheapToScalarize(V1, IsConstantExtractIndex))
+    if (cheapToScalarize(V0, EI) || cheapToScalarize(V1, EI))
       return true;
 
   CmpInst::Predicate UnusedPred;
   if (match(V, m_OneUse(m_Cmp(UnusedPred, m_Value(V0), m_Value(V1)))))
-    if (cheapToScalarize(V0, IsConstantExtractIndex) ||
-        cheapToScalarize(V1, IsConstantExtractIndex))
+    if (cheapToScalarize(V0, EI) || cheapToScalarize(V1, EI))
       return true;
 
   return false;
@@ -119,7 +126,8 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI,
   // and that it is a binary operation which is cheap to scalarize.
   // otherwise return nullptr.
   if (!PHIUser->hasOneUse() || !(PHIUser->user_back() == PN) ||
-      !(isa<BinaryOperator>(PHIUser)) || !cheapToScalarize(PHIUser, true))
+      !(isa<BinaryOperator>(PHIUser)) ||
+      !cheapToScalarize(PHIUser, EI.getIndexOperand()))
     return nullptr;
 
   // Create a scalar PHI node that will replace the vector PHI node
@@ -415,7 +423,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
   // TODO come up with a n-ary matcher that subsumes both unary and
   // binary matchers.
   UnaryOperator *UO;
-  if (match(SrcVec, m_UnOp(UO)) && cheapToScalarize(SrcVec, IndexC)) {
+  if (match(SrcVec, m_UnOp(UO)) && cheapToScalarize(SrcVec, Index)) {
     // extelt (unop X), Index --> unop (extelt X, Index)
     Value *X = UO->getOperand(0);
     Value *E = Builder.CreateExtractElement(X, Index);
@@ -423,7 +431,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
   }
 
   BinaryOperator *BO;
-  if (match(SrcVec, m_BinOp(BO)) && cheapToScalarize(SrcVec, IndexC)) {
+  if (match(SrcVec, m_BinOp(BO)) && cheapToScalarize(SrcVec, Index)) {
     // extelt (binop X, Y), Index --> binop (extelt X, Index), (extelt Y, Index)
     Value *X = BO->getOperand(0), *Y = BO->getOperand(1);
     Value *E0 = Builder.CreateExtractElement(X, Index);
@@ -434,7 +442,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
   Value *X, *Y;
   CmpInst::Predicate Pred;
   if (match(SrcVec, m_Cmp(Pred, m_Value(X), m_Value(Y))) &&
-      cheapToScalarize(SrcVec, IndexC)) {
+      cheapToScalarize(SrcVec, Index)) {
     // extelt (cmp X, Y), Index --> cmp (extelt X, Index), (extelt Y, Index)
     Value *E0 = Builder.CreateExtractElement(X, Index);
     Value *E1 = Builder.CreateExtractElement(Y, Index);
index 302f845..e946b4e 100644 (file)
@@ -243,6 +243,35 @@ entry:
   ret i8 %1
 }
 
+; Check that we can extract more complex cases where the stepvector is
+; involved in a binary operation prior to the lane being extracted.
+
+define i64 @ext_lane0_from_add_with_stepvec(i64 %i) {
+; CHECK-LABEL: @ext_lane0_from_add_with_stepvec(
+; CHECK-NEXT:    ret i64 [[I:%.*]]
+;
+  %tmp = insertelement <vscale x 2 x i64> poison, i64 %i, i32 0
+  %splatofi = shufflevector <vscale x 2 x i64> %tmp, <vscale x 2 x i64> poison, <vscale x  2 x i32> zeroinitializer
+  %stepvec = call <vscale x 2 x i64> @llvm.experimental.stepvector.nxv2i64()
+  %add = add <vscale x 2 x i64> %splatofi, %stepvec
+  %res = extractelement <vscale x 2 x i64> %add, i32 0
+  ret i64 %res
+}
+
+define i1 @ext_lane1_from_cmp_with_stepvec(i64 %i) {
+; CHECK-LABEL: @ext_lane1_from_cmp_with_stepvec(
+; CHECK-NEXT:    [[RES:%.*]] = icmp eq i64 [[I:%.*]], 1
+; CHECK-NEXT:    ret i1 [[RES]]
+;
+  %tmp = insertelement <vscale x 2 x i64> poison, i64 %i, i32 0
+  %splatofi = shufflevector <vscale x 2 x i64> %tmp, <vscale x 2 x i64> poison, <vscale x  2 x i32> zeroinitializer
+  %stepvec = call <vscale x 2 x i64> @llvm.experimental.stepvector.nxv2i64()
+  %cmp = icmp eq <vscale x 2 x i64> %splatofi, %stepvec
+  %res = extractelement <vscale x 2 x i1> %cmp, i32 1
+  ret i1 %res
+}
+
+declare <vscale x 2 x i64> @llvm.experimental.stepvector.nxv2i64()
 declare <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
 declare <vscale x 4 x i32> @llvm.experimental.stepvector.nxv4i32()
 declare <vscale x 512 x i8> @llvm.experimental.stepvector.nxv512i8()