From 3ea0774b13a538759aa1a68f30130d18ddb0d3f2 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Mon, 30 Mar 2020 10:36:21 -0700 Subject: [PATCH] [ConstantFold][NFC] Compile time optimization for large vectors Optimize the common case of splat vector constant. For large vector going through all elements is expensive. For splatr/broadcast cases we can skip going through all elements. Differential Revision: https://reviews.llvm.org/D76664 --- llvm/include/llvm/IR/Constants.h | 7 +++- llvm/lib/Analysis/ValueTracking.cpp | 8 +++- llvm/lib/IR/ConstantFold.cpp | 44 +++++++++++++++++++++- llvm/lib/IR/Constants.cpp | 10 ++++- llvm/lib/IR/Instructions.cpp | 6 ++- .../InstCombine/InstCombineSimplifyDemanded.cpp | 18 +++++++++ 6 files changed, 87 insertions(+), 6 deletions(-) diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h index d0ea1ad..a345795 100644 --- a/llvm/include/llvm/IR/Constants.h +++ b/llvm/include/llvm/IR/Constants.h @@ -766,7 +766,12 @@ class ConstantDataVector final : public ConstantDataSequential { friend class ConstantDataSequential; explicit ConstantDataVector(Type *ty, const char *Data) - : ConstantDataSequential(ty, ConstantDataVectorVal, Data) {} + : ConstantDataSequential(ty, ConstantDataVectorVal, Data), + IsSplatSet(false) {} + // Cache whether or not the constant is a splat. + mutable bool IsSplatSet : 1; + mutable bool IsSplat : 1; + bool isSplatData() const; public: ConstantDataVector(const ConstantDataVector &) = delete; diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index ab6afa1..0075130 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -174,7 +174,13 @@ static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf, int NumElts = Shuf->getOperand(0)->getType()->getVectorNumElements(); int NumMaskElts = Shuf->getMask()->getType()->getVectorNumElements(); DemandedLHS = DemandedRHS = APInt::getNullValue(NumElts); - + if (DemandedElts.isNullValue()) + return true; + // Simple case of a shuffle with zeroinitializer. + if (isa(Shuf->getMask())) { + DemandedLHS.setBit(0); + return true; + } for (int i = 0; i != NumMaskElts; ++i) { if (!DemandedElts[i]) continue; diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp index 3e2e74c..07292e5 100644 --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -60,6 +60,11 @@ static Constant *BitCastConstantVector(Constant *CV, VectorType *DstTy) { return nullptr; Type *DstEltTy = DstTy->getElementType(); + // Fast path for splatted constants. + if (Constant *Splat = CV->getSplatValue()) { + return ConstantVector::getSplat(DstTy->getVectorElementCount(), + ConstantExpr::getBitCast(Splat, DstEltTy)); + } SmallVector Result; Type *Ty = IntegerType::get(CV->getContext(), 32); @@ -577,9 +582,15 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V, if ((isa(V) || isa(V)) && DestTy->isVectorTy() && DestTy->getVectorNumElements() == V->getType()->getVectorNumElements()) { - SmallVector res; VectorType *DestVecTy = cast(DestTy); Type *DstEltTy = DestVecTy->getElementType(); + // Fast path for splatted constants. + if (Constant *Splat = V->getSplatValue()) { + return ConstantVector::getSplat( + DestTy->getVectorElementCount(), + ConstantExpr::getCast(opc, Splat, DstEltTy)); + } + SmallVector res; Type *Ty = IntegerType::get(V->getContext(), 32); for (unsigned i = 0, e = V->getType()->getVectorNumElements(); i != e; ++i) { Constant *C = @@ -878,6 +889,14 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1, // Don't break the bitcode reader hack. if (isa(Mask)) return nullptr; + // If the mask is all zeros this is a splat, no need to go through all + // elements. + if (isa(Mask) && !MaskEltCount.Scalable) { + Type *Ty = IntegerType::get(V1->getContext(), 32); + Constant *Elt = + ConstantExpr::getExtractElement(V1, ConstantInt::get(Ty, 0)); + return ConstantVector::getSplat(MaskEltCount, Elt); + } // Do not iterate on scalable vector. The num of elements is unknown at // compile-time. VectorType *ValTy = cast(V1->getType()); @@ -993,10 +1012,15 @@ Constant *llvm::ConstantFoldUnaryInstruction(unsigned Opcode, Constant *C) { // compile-time. if (IsScalableVector) return nullptr; + Type *Ty = IntegerType::get(VTy->getContext(), 32); + // Fast path for splatted constants. + if (Constant *Splat = C->getSplatValue()) { + Constant *Elt = ConstantExpr::get(Opcode, Splat); + return ConstantVector::getSplat(VTy->getElementCount(), Elt); + } // Fold each element and create a vector constant from those constants. SmallVector Result; - Type *Ty = IntegerType::get(VTy->getContext(), 32); for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { Constant *ExtractIdx = ConstantInt::get(Ty, i); Constant *Elt = ConstantExpr::getExtractElement(C, ExtractIdx); @@ -1357,6 +1381,16 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1, // compile-time. if (IsScalableVector) return nullptr; + // Fast path for splatted constants. + if (Constant *C2Splat = C2->getSplatValue()) { + if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue()) + return UndefValue::get(VTy); + if (Constant *C1Splat = C1->getSplatValue()) { + return ConstantVector::getSplat( + VTy->getVectorElementCount(), + ConstantExpr::get(Opcode, C1Splat, C2Splat)); + } + } // Fold each element and create a vector constant from those constants. SmallVector Result; @@ -1975,6 +2009,12 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred, // compile-time. if (C1->getType()->getVectorIsScalable()) return nullptr; + // Fast path for splatted constants. + if (Constant *C1Splat = C1->getSplatValue()) + if (Constant *C2Splat = C2->getSplatValue()) + return ConstantVector::getSplat( + C1->getType()->getVectorElementCount(), + ConstantExpr::getCompare(pred, C1Splat, C2Splat)); // If we can constant fold the comparison of each element, constant fold // the whole vector comparison. diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index bde4c07..e001b52 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -2891,7 +2891,7 @@ bool ConstantDataSequential::isCString() const { return Str.drop_back().find(0) == StringRef::npos; } -bool ConstantDataVector::isSplat() const { +bool ConstantDataVector::isSplatData() const { const char *Base = getRawDataValues().data(); // Compare elements 1+ to the 0'th element. @@ -2903,6 +2903,14 @@ bool ConstantDataVector::isSplat() const { return true; } +bool ConstantDataVector::isSplat() const { + if (!IsSplatSet) { + IsSplatSet = true; + IsSplat = isSplatData(); + } + return IsSplat; +} + Constant *ConstantDataVector::getSplatValue() const { // If they're all the same, return the 0th one as a representative. return isSplat() ? getElementAsConstant(0) : nullptr; diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index 0884a24..f6748c8 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -1958,7 +1958,11 @@ void ShuffleVectorInst::getShuffleMask(const Constant *Mask, assert(!Mask->getType()->getVectorElementCount().Scalable && "Length of scalable vectors unknown at compile time"); unsigned NumElts = Mask->getType()->getVectorNumElements(); - + if (isa(Mask)) { + Result.resize(NumElts, 0); + return; + } + Result.reserve(NumElts); if (auto *CDS = dyn_cast(Mask)) { for (unsigned i = 0; i != NumElts; ++i) Result.push_back(CDS->getElementAsInteger(i)); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index be5135e..90b0053 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -1387,6 +1387,24 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, "Expected shuffle operands to have same type"); unsigned OpWidth = Shuffle->getOperand(0)->getType()->getVectorNumElements(); + // Handle trivial case of a splat. Only check the first element of LHS + // operand. + if (isa(Shuffle->getMask()) && + DemandedElts.isAllOnesValue()) { + if (!isa(I->getOperand(1))) { + I->setOperand(1, UndefValue::get(I->getOperand(1)->getType())); + MadeChange = true; + } + APInt LeftDemanded(OpWidth, 1); + APInt LHSUndefElts(OpWidth, 0); + simplifyAndSetOp(I, 0, LeftDemanded, LHSUndefElts); + if (LHSUndefElts[0]) + UndefElts = EltMask; + else + UndefElts.clearAllBits(); + break; + } + APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0); for (unsigned i = 0; i < VWidth; i++) { if (DemandedElts[i]) { -- 2.7.4