[TargetLowering] Add DemandedElts arg to ShrinkDemandedConstant
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Mon, 29 Jun 2020 10:24:26 +0000 (11:24 +0100)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Mon, 29 Jun 2020 10:46:58 +0000 (11:46 +0100)
Pre-commit for D82257, this adds a DemandedElts arg to ShrinkDemandedConstant/targetShrinkDemandedConstant which will allow future patches to (optionally) add vector support.

llvm/include/llvm/CodeGen/TargetLowering.h
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/lib/Target/ARM/ARMISelLowering.cpp
llvm/lib/Target/ARM/ARMISelLowering.h
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/lib/Target/X86/X86ISelLowering.h

index f920d30..bbc507e 100644 (file)
@@ -3254,13 +3254,20 @@ public:
   /// constant integer.  If so, check to see if there are any bits set in the
   /// constant that are not demanded.  If so, shrink the constant and return
   /// true.
-  bool ShrinkDemandedConstant(SDValue Op, const APInt &Demanded,
+  bool ShrinkDemandedConstant(SDValue Op, const APInt &DemandedBits,
+                              const APInt &DemandedElts,
+                              TargetLoweringOpt &TLO) const;
+
+  /// Helper wrapper around ShrinkDemandedConstant, demanding all elements.
+  bool ShrinkDemandedConstant(SDValue Op, const APInt &DemandedBits,
                               TargetLoweringOpt &TLO) const;
 
   // Target hook to do target-specific const optimization, which is called by
   // ShrinkDemandedConstant. This function should return true if the target
   // doesn't want ShrinkDemandedConstant to further optimize the constant.
-  virtual bool targetShrinkDemandedConstant(SDValue Op, const APInt &Demanded,
+  virtual bool targetShrinkDemandedConstant(SDValue Op,
+                                            const APInt &DemandedBits,
+                                            const APInt &DemandedElts,
                                             TargetLoweringOpt &TLO) const {
     return false;
   }
index fab77fb..f6e34e8 100644 (file)
@@ -483,13 +483,15 @@ TargetLowering::isOffsetFoldingLegal(const GlobalAddressSDNode *GA) const {
 /// If the specified instruction has a constant integer operand and there are
 /// bits set in that constant that are not demanded, then clear those bits and
 /// return true.
-bool TargetLowering::ShrinkDemandedConstant(SDValue Op, const APInt &Demanded,
+bool TargetLowering::ShrinkDemandedConstant(SDValue Op,
+                                            const APInt &DemandedBits,
+                                            const APInt &DemandedElts,
                                             TargetLoweringOpt &TLO) const {
   SDLoc DL(Op);
   unsigned Opcode = Op.getOpcode();
 
   // Do target-specific constant optimization.
-  if (targetShrinkDemandedConstant(Op, Demanded, TLO))
+  if (targetShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
     return TLO.New.getNode();
 
   // FIXME: ISD::SELECT, ISD::SELECT_CC
@@ -505,12 +507,12 @@ bool TargetLowering::ShrinkDemandedConstant(SDValue Op, const APInt &Demanded,
 
     // If this is a 'not' op, don't touch it because that's a canonical form.
     const APInt &C = Op1C->getAPIntValue();
-    if (Opcode == ISD::XOR && Demanded.isSubsetOf(C))
+    if (Opcode == ISD::XOR && DemandedBits.isSubsetOf(C))
       return false;
 
-    if (!C.isSubsetOf(Demanded)) {
+    if (!C.isSubsetOf(DemandedBits)) {
       EVT VT = Op.getValueType();
-      SDValue NewC = TLO.DAG.getConstant(Demanded & C, DL, VT);
+      SDValue NewC = TLO.DAG.getConstant(DemandedBits & C, DL, VT);
       SDValue NewOp = TLO.DAG.getNode(Opcode, DL, VT, Op.getOperand(0), NewC);
       return TLO.CombineTo(Op, NewOp);
     }
@@ -522,6 +524,16 @@ bool TargetLowering::ShrinkDemandedConstant(SDValue Op, const APInt &Demanded,
   return false;
 }
 
+bool TargetLowering::ShrinkDemandedConstant(SDValue Op,
+                                            const APInt &DemandedBits,
+                                            TargetLoweringOpt &TLO) const {
+  EVT VT = Op.getValueType();
+  APInt DemandedElts = VT.isVector()
+                           ? APInt::getAllOnesValue(VT.getVectorNumElements())
+                           : APInt(1, 1);
+  return ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO);
+}
+
 /// Convert x+y to (VT)((SmallVT)x+(SmallVT)y) if the casts are free.
 /// This uses isZExtFree and ZERO_EXTEND for the widening cast, but it could be
 /// generalized for targets with other types of implicit widening casts.
@@ -1173,7 +1185,8 @@ bool TargetLowering::SimplifyDemandedBits(
 
       // If any of the set bits in the RHS are known zero on the LHS, shrink
       // the constant.
-      if (ShrinkDemandedConstant(Op, ~LHSKnown.Zero & DemandedBits, TLO))
+      if (ShrinkDemandedConstant(Op, ~LHSKnown.Zero & DemandedBits,
+                                 DemandedElts, TLO))
         return true;
 
       // Bitwise-not (xor X, -1) is a special case: we don't usually shrink its
@@ -1221,7 +1234,8 @@ bool TargetLowering::SimplifyDemandedBits(
     if (DemandedBits.isSubsetOf(Known.Zero | Known2.Zero))
       return TLO.CombineTo(Op, TLO.DAG.getConstant(0, dl, VT));
     // If the RHS is a constant, see if we can simplify it.
-    if (ShrinkDemandedConstant(Op, ~Known2.Zero & DemandedBits, TLO))
+    if (ShrinkDemandedConstant(Op, ~Known2.Zero & DemandedBits, DemandedElts,
+                               TLO))
       return true;
     // If the operation can be done in a smaller type, do so.
     if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
@@ -1264,7 +1278,7 @@ bool TargetLowering::SimplifyDemandedBits(
     if (DemandedBits.isSubsetOf(Known.One | Known2.Zero))
       return TLO.CombineTo(Op, Op1);
     // If the RHS is a constant, see if we can simplify it.
-    if (ShrinkDemandedConstant(Op, DemandedBits, TLO))
+    if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
       return true;
     // If the operation can be done in a smaller type, do so.
     if (ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
@@ -1338,7 +1352,7 @@ bool TargetLowering::SimplifyDemandedBits(
           return TLO.CombineTo(Op, New);
         }
         // If we can't turn this into a 'not', try to shrink the constant.
-        if (ShrinkDemandedConstant(Op, DemandedBits, TLO))
+        if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
           return true;
       }
     }
@@ -1357,7 +1371,7 @@ bool TargetLowering::SimplifyDemandedBits(
     assert(!Known2.hasConflict() && "Bits known to be one AND zero?");
 
     // If the operands are constants, see if we can simplify them.
-    if (ShrinkDemandedConstant(Op, DemandedBits, TLO))
+    if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
       return true;
 
     // Only known if known in both the LHS and RHS.
@@ -1375,7 +1389,7 @@ bool TargetLowering::SimplifyDemandedBits(
     assert(!Known2.hasConflict() && "Bits known to be one AND zero?");
 
     // If the operands are constants, see if we can simplify them.
-    if (ShrinkDemandedConstant(Op, DemandedBits, TLO))
+    if (ShrinkDemandedConstant(Op, DemandedBits, DemandedElts, TLO))
       return true;
 
     // Only known if known in both the LHS and RHS.
index 13863bf..71df3c6 100644 (file)
@@ -1162,7 +1162,8 @@ static bool optimizeLogicalImm(SDValue Op, unsigned Size, uint64_t Imm,
 }
 
 bool AArch64TargetLowering::targetShrinkDemandedConstant(
-    SDValue Op, const APInt &Demanded, TargetLoweringOpt &TLO) const {
+    SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
+    TargetLoweringOpt &TLO) const {
   // Delay this optimization to as late as possible.
   if (!TLO.LegalOps)
     return false;
@@ -1179,7 +1180,7 @@ bool AArch64TargetLowering::targetShrinkDemandedConstant(
          "i32 or i64 is expected after legalization.");
 
   // Exit early if we demand all bits.
-  if (Demanded.countPopulation() == Size)
+  if (DemandedBits.countPopulation() == Size)
     return false;
 
   unsigned NewOpc;
@@ -1200,7 +1201,7 @@ bool AArch64TargetLowering::targetShrinkDemandedConstant(
   if (!C)
     return false;
   uint64_t Imm = C->getZExtValue();
-  return optimizeLogicalImm(Op, Size, Imm, Demanded, TLO, NewOpc);
+  return optimizeLogicalImm(Op, Size, Imm, DemandedBits, TLO, NewOpc);
 }
 
 /// computeKnownBitsForTargetNode - Determine which of the bits specified in
index 63a9f65..6f7079c 100644 (file)
@@ -403,7 +403,8 @@ public:
     return MVT::getIntegerVT(64);
   }
 
-  bool targetShrinkDemandedConstant(SDValue Op, const APInt &Demanded,
+  bool targetShrinkDemandedConstant(SDValue Op, const APInt &DemandedBits,
+                                    const APInt &DemandedElts,
                                     TargetLoweringOpt &TLO) const override;
 
   MVT getScalarShiftAmountTy(const DataLayout &DL, EVT) const override;
index 98b28d1..3b1b704 100644 (file)
@@ -17028,10 +17028,9 @@ void ARMTargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
   }
 }
 
-bool
-ARMTargetLowering::targetShrinkDemandedConstant(SDValue Op,
-                                                const APInt &DemandedAPInt,
-                                                TargetLoweringOpt &TLO) const {
+bool ARMTargetLowering::targetShrinkDemandedConstant(
+    SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
+    TargetLoweringOpt &TLO) const {
   // Delay optimization, so we don't have to deal with illegal types, or block
   // optimizations.
   if (!TLO.LegalOps)
@@ -17056,7 +17055,7 @@ ARMTargetLowering::targetShrinkDemandedConstant(SDValue Op,
 
   unsigned Mask = C->getZExtValue();
 
-  unsigned Demanded = DemandedAPInt.getZExtValue();
+  unsigned Demanded = DemandedBits.getZExtValue();
   unsigned ShrunkMask = Mask & Demanded;
   unsigned ExpandedMask = Mask | ~Demanded;
 
index 38a4da3..8b1f418 100644 (file)
@@ -453,10 +453,10 @@ class VectorType;
                                        const SelectionDAG &DAG,
                                        unsigned Depth) const override;
 
-    bool targetShrinkDemandedConstant(SDValue Op, const APInt &Demanded,
+    bool targetShrinkDemandedConstant(SDValue Op, const APInt &DemandedBits,
+                                      const APInt &DemandedElts,
                                       TargetLoweringOpt &TLO) const override;
 
-
     bool ExpandInlineAsm(CallInst *CI) const override;
 
     ConstraintType getConstraintType(StringRef Constraint) const override;
index d943d75..2e570e0 100644 (file)
@@ -33218,7 +33218,8 @@ X86TargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
 
 bool
 X86TargetLowering::targetShrinkDemandedConstant(SDValue Op,
-                                                const APInt &Demanded,
+                                                const APInt &DemandedBits,
+                                                const APInt &DemandedElts,
                                                 TargetLoweringOpt &TLO) const {
   // Only optimize Ands to prevent shrinking a constant that could be
   // matched by movzx.
@@ -33241,7 +33242,7 @@ X86TargetLowering::targetShrinkDemandedConstant(SDValue Op,
   const APInt &Mask = C->getAPIntValue();
 
   // Clear all non-demanded bits initially.
-  APInt ShrunkMask = Mask & Demanded;
+  APInt ShrunkMask = Mask & DemandedBits;
 
   // Find the width of the shrunk mask.
   unsigned Width = ShrunkMask.getActiveBits();
@@ -33265,7 +33266,7 @@ X86TargetLowering::targetShrinkDemandedConstant(SDValue Op,
 
   // Make sure the new mask can be represented by a combination of mask bits
   // and non-demanded bits.
-  if (!ZeroExtendMask.isSubsetOf(Mask | ~Demanded))
+  if (!ZeroExtendMask.isSubsetOf(Mask | ~DemandedBits))
     return false;
 
   // Replace the constant with the zero extend mask.
index 560da44..ad76c55 100644 (file)
@@ -1036,7 +1036,8 @@ namespace llvm {
     EVT getSetCCResultType(const DataLayout &DL, LLVMContext &Context,
                            EVT VT) const override;
 
-    bool targetShrinkDemandedConstant(SDValue Op, const APInt &Demanded,
+    bool targetShrinkDemandedConstant(SDValue Op, const APInt &DemandedBits,
+                                      const APInt &DemandedElts,
                                       TargetLoweringOpt &TLO) const override;
 
     /// Determine which of the bits specified in Mask are known to be either