From 973685fc784a937d39293be972e95c2c4ec4c97e Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Mon, 29 Jun 2020 11:24:26 +0100 Subject: [PATCH] [TargetLowering] Add DemandedElts arg to ShrinkDemandedConstant Pre-commit for D82257, this adds a DemandedElts arg to ShrinkDemandedConstant/targetShrinkDemandedConstant which will allow future patches to (optionally) add vector support. --- llvm/include/llvm/CodeGen/TargetLowering.h | 11 ++++++-- llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 36 ++++++++++++++++-------- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 7 +++-- llvm/lib/Target/AArch64/AArch64ISelLowering.h | 3 +- llvm/lib/Target/ARM/ARMISelLowering.cpp | 9 +++--- llvm/lib/Target/ARM/ARMISelLowering.h | 4 +-- llvm/lib/Target/X86/X86ISelLowering.cpp | 7 +++-- llvm/lib/Target/X86/X86ISelLowering.h | 3 +- 8 files changed, 52 insertions(+), 28 deletions(-) diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index f920d30..bbc507e 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -3254,13 +3254,20 @@ public: /// constant integer. If so, check to see if there are any bits set in the /// constant that are not demanded. If so, shrink the constant and return /// true. - bool ShrinkDemandedConstant(SDValue Op, const APInt &Demanded, + bool ShrinkDemandedConstant(SDValue Op, const APInt &DemandedBits, + const APInt &DemandedElts, + TargetLoweringOpt &TLO) const; + + /// Helper wrapper around ShrinkDemandedConstant, demanding all elements. + bool ShrinkDemandedConstant(SDValue Op, const APInt &DemandedBits, TargetLoweringOpt &TLO) const; // Target hook to do target-specific const optimization, which is called by // ShrinkDemandedConstant. This function should return true if the target // doesn't want ShrinkDemandedConstant to further optimize the constant. - virtual bool targetShrinkDemandedConstant(SDValue Op, const APInt &Demanded, + virtual bool targetShrinkDemandedConstant(SDValue Op, + const APInt &DemandedBits, + const APInt &DemandedElts, TargetLoweringOpt &TLO) const { return false; } diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index fab77fb..f6e34e8 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -483,13 +483,15 @@ TargetLowering::isOffsetFoldingLegal(const GlobalAddressSDNode *GA) const { /// If the specified instruction has a constant integer operand and there are /// bits set in that constant that are not demanded, then clear those bits and /// return true. -bool TargetLowering::ShrinkDemandedConstant(SDValue Op, const APInt &Demanded, +bool TargetLowering::ShrinkDemandedConstant(SDValue Op, + const APInt &DemandedBits, + const APInt &DemandedElts, TargetLoweringOpt &TLO) const { SDLoc DL(Op); unsigned Opcode = Op.getOpcode(); // Do target-specific constant optimization. - if (targetShrinkDemandedConstant(Op, Demanded, TLO)) + if (targetShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO)) return TLO.New.getNode(); // FIXME: ISD::SELECT, ISD::SELECT_CC @@ -505,12 +507,12 @@ bool TargetLowering::ShrinkDemandedConstant(SDValue Op, const APInt &Demanded, // If this is a 'not' op, don't touch it because that's a canonical form. const APInt &C = Op1C->getAPIntValue(); - if (Opcode == ISD::XOR && Demanded.isSubsetOf(C)) + if (Opcode == ISD::XOR && DemandedBits.isSubsetOf(C)) return false; - if (!C.isSubsetOf(Demanded)) { + if (!C.isSubsetOf(DemandedBits)) { EVT VT = Op.getValueType(); - SDValue NewC = TLO.DAG.getConstant(Demanded & C, DL, VT); + SDValue NewC = TLO.DAG.getConstant(DemandedBits & C, DL, VT); SDValue NewOp = TLO.DAG.getNode(Opcode, DL, VT, Op.getOperand(0), NewC); return TLO.CombineTo(Op, NewOp); } @@ -522,6 +524,16 @@ bool TargetLowering::ShrinkDemandedConstant(SDValue Op, const APInt &Demanded, return false; } +bool TargetLowering::ShrinkDemandedConstant(SDValue Op, + const APInt &DemandedBits, + TargetLoweringOpt &TLO) const { + EVT VT = Op.getValueType(); + APInt DemandedElts = VT.isVector() + ? APInt::getAllOnesValue(VT.getVectorNumElements()) + : APInt(1, 1); + return ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO); +} + /// Convert x+y to (VT)((SmallVT)x+(SmallVT)y) if the casts are free. /// This uses isZExtFree and ZERO_EXTEND for the widening cast, but it could be /// generalized for targets with other types of implicit widening casts. @@ -1173,7 +1185,8 @@ bool TargetLowering::SimplifyDemandedBits( // If any of the set bits in the RHS are known zero on the LHS, shrink // the constant. - if (ShrinkDemandedConstant(Op, ~LHSKnown.Zero & DemandedBits, TLO)) + if (ShrinkDemandedConstant(Op, ~LHSKnown.Zero & DemandedBits, + DemandedElts, TLO)) return true; // Bitwise-not (xor X, -1) is a special case: we don't usually shrink its @@ -1221,7 +1234,8 @@ bool TargetLowering::SimplifyDemandedBits( if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero)) return TLO.CombineTo(Op, TLO.DAG.getConstant(0, dl, VT)); // If the RHS is a constant, see if we can simplify it. - if (ShrinkDemandedConstant(Op, ~Known2.Zero & DemandedBits, TLO)) + if (ShrinkDemandedConstant(Op, ~Known2.Zero & DemandedBits, DemandedElts, + TLO)) return true; // If the operation can be done in a smaller type, do so. if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) @@ -1264,7 +1278,7 @@ bool TargetLowering::SimplifyDemandedBits( if (DemandedBits.isSubsetOf(Known.One | Known2.Zero)) return TLO.CombineTo(Op, Op1); // If the RHS is a constant, see if we can simplify it. - if (ShrinkDemandedConstant(Op, DemandedBits, TLO)) + if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO)) return true; // If the operation can be done in a smaller type, do so. if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO)) @@ -1338,7 +1352,7 @@ bool TargetLowering::SimplifyDemandedBits( return TLO.CombineTo(Op, New); } // If we can't turn this into a 'not', try to shrink the constant. - if (ShrinkDemandedConstant(Op, DemandedBits, TLO)) + if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO)) return true; } } @@ -1357,7 +1371,7 @@ bool TargetLowering::SimplifyDemandedBits( assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); // If the operands are constants, see if we can simplify them. - if (ShrinkDemandedConstant(Op, DemandedBits, TLO)) + if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO)) return true; // Only known if known in both the LHS and RHS. @@ -1375,7 +1389,7 @@ bool TargetLowering::SimplifyDemandedBits( assert(!Known2.hasConflict() && "Bits known to be one AND zero?"); // If the operands are constants, see if we can simplify them. - if (ShrinkDemandedConstant(Op, DemandedBits, TLO)) + if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO)) return true; // Only known if known in both the LHS and RHS. diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 13863bf..71df3c6 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1162,7 +1162,8 @@ static bool optimizeLogicalImm(SDValue Op, unsigned Size, uint64_t Imm, } bool AArch64TargetLowering::targetShrinkDemandedConstant( - SDValue Op, const APInt &Demanded, TargetLoweringOpt &TLO) const { + SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts, + TargetLoweringOpt &TLO) const { // Delay this optimization to as late as possible. if (!TLO.LegalOps) return false; @@ -1179,7 +1180,7 @@ bool AArch64TargetLowering::targetShrinkDemandedConstant( "i32 or i64 is expected after legalization."); // Exit early if we demand all bits. - if (Demanded.countPopulation() == Size) + if (DemandedBits.countPopulation() == Size) return false; unsigned NewOpc; @@ -1200,7 +1201,7 @@ bool AArch64TargetLowering::targetShrinkDemandedConstant( if (!C) return false; uint64_t Imm = C->getZExtValue(); - return optimizeLogicalImm(Op, Size, Imm, Demanded, TLO, NewOpc); + return optimizeLogicalImm(Op, Size, Imm, DemandedBits, TLO, NewOpc); } /// computeKnownBitsForTargetNode - Determine which of the bits specified in diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 63a9f65..6f7079c 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -403,7 +403,8 @@ public: return MVT::getIntegerVT(64); } - bool targetShrinkDemandedConstant(SDValue Op, const APInt &Demanded, + bool targetShrinkDemandedConstant(SDValue Op, const APInt &DemandedBits, + const APInt &DemandedElts, TargetLoweringOpt &TLO) const override; MVT getScalarShiftAmountTy(const DataLayout &DL, EVT) const override; diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index 98b28d1..3b1b704 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -17028,10 +17028,9 @@ void ARMTargetLowering::computeKnownBitsForTargetNode(const SDValue Op, } } -bool -ARMTargetLowering::targetShrinkDemandedConstant(SDValue Op, - const APInt &DemandedAPInt, - TargetLoweringOpt &TLO) const { +bool ARMTargetLowering::targetShrinkDemandedConstant( + SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts, + TargetLoweringOpt &TLO) const { // Delay optimization, so we don't have to deal with illegal types, or block // optimizations. if (!TLO.LegalOps) @@ -17056,7 +17055,7 @@ ARMTargetLowering::targetShrinkDemandedConstant(SDValue Op, unsigned Mask = C->getZExtValue(); - unsigned Demanded = DemandedAPInt.getZExtValue(); + unsigned Demanded = DemandedBits.getZExtValue(); unsigned ShrunkMask = Mask & Demanded; unsigned ExpandedMask = Mask | ~Demanded; diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h index 38a4da3..8b1f418 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.h +++ b/llvm/lib/Target/ARM/ARMISelLowering.h @@ -453,10 +453,10 @@ class VectorType; const SelectionDAG &DAG, unsigned Depth) const override; - bool targetShrinkDemandedConstant(SDValue Op, const APInt &Demanded, + bool targetShrinkDemandedConstant(SDValue Op, const APInt &DemandedBits, + const APInt &DemandedElts, TargetLoweringOpt &TLO) const override; - bool ExpandInlineAsm(CallInst *CI) const override; ConstraintType getConstraintType(StringRef Constraint) const override; diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index d943d75..2e570e0 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -33218,7 +33218,8 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI, bool X86TargetLowering::targetShrinkDemandedConstant(SDValue Op, - const APInt &Demanded, + const APInt &DemandedBits, + const APInt &DemandedElts, TargetLoweringOpt &TLO) const { // Only optimize Ands to prevent shrinking a constant that could be // matched by movzx. @@ -33241,7 +33242,7 @@ X86TargetLowering::targetShrinkDemandedConstant(SDValue Op, const APInt &Mask = C->getAPIntValue(); // Clear all non-demanded bits initially. - APInt ShrunkMask = Mask & Demanded; + APInt ShrunkMask = Mask & DemandedBits; // Find the width of the shrunk mask. unsigned Width = ShrunkMask.getActiveBits(); @@ -33265,7 +33266,7 @@ X86TargetLowering::targetShrinkDemandedConstant(SDValue Op, // Make sure the new mask can be represented by a combination of mask bits // and non-demanded bits. - if (!ZeroExtendMask.isSubsetOf(Mask | ~Demanded)) + if (!ZeroExtendMask.isSubsetOf(Mask | ~DemandedBits)) return false; // Replace the constant with the zero extend mask. diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h index 560da44..ad76c55 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.h +++ b/llvm/lib/Target/X86/X86ISelLowering.h @@ -1036,7 +1036,8 @@ namespace llvm { EVT getSetCCResultType(const DataLayout &DL, LLVMContext &Context, EVT VT) const override; - bool targetShrinkDemandedConstant(SDValue Op, const APInt &Demanded, + bool targetShrinkDemandedConstant(SDValue Op, const APInt &DemandedBits, + const APInt &DemandedElts, TargetLoweringOpt &TLO) const override; /// Determine which of the bits specified in Mask are known to be either -- 2.7.4