[RISCV] Use a ComplexPattern to merge the PatFrags for removing unneeded masks on...
authorCraig Topper <craig.topper@sifive.com>
Fri, 12 Feb 2021 22:01:28 +0000 (14:01 -0800)
committerCraig Topper <craig.topper@sifive.com>
Fri, 12 Feb 2021 22:03:23 +0000 (14:03 -0800)
Rather than having patterns with and without an AND, use a
ComplexPattern to handle both cases.

Reduces the isel table by about 700 bytes.

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
llvm/lib/Target/RISCV/RISCVInstrInfo.td

index 05620fc..57f037f 100644 (file)
@@ -925,19 +925,35 @@ bool RISCVDAGToDAGISel::SelectRVVBaseAddr(SDValue Addr, SDValue &Base) {
   return true;
 }
 
-// Helper to detect unneeded and instructions on shift amounts. Called
-// from PatFrags in tablegen.
-bool RISCVDAGToDAGISel::isUnneededShiftMask(SDNode *N, unsigned Width) const {
-  assert(N->getOpcode() == ISD::AND && "Unexpected opcode");
-  assert(Width >= 5 && N->getValueSizeInBits(0) >= (1ULL << Width) &&
-         "Unexpected width");
-  const APInt &Val = N->getConstantOperandAPInt(1);
-
-  if (Val.countTrailingOnes() >= Width)
-    return true;
+bool RISCVDAGToDAGISel::selectShiftMask(SDValue N, unsigned ShiftWidth,
+                                        SDValue &ShAmt) {
+  // Shift instructions on RISCV only read the lower 5 or 6 bits of the shift
+  // amount. If there is an AND on the shift amount, we can bypass it if it
+  // doesn't affect any of those bits.
+  if (N.getOpcode() == ISD::AND && isa<ConstantSDNode>(N.getOperand(1))) {
+    const APInt &AndMask = N->getConstantOperandAPInt(1);
+
+    // Since the max shift amount is a power of 2 we can subtract 1 to make a
+    // mask that covers the bits needed to represent all shift amounts.
+    assert(isPowerOf2_32(ShiftWidth) && "Unexpected max shift amount!");
+    APInt ShMask(AndMask.getBitWidth(), ShiftWidth - 1);
+
+    if (ShMask.isSubsetOf(AndMask)) {
+      ShAmt = N.getOperand(0);
+      return true;
+    }
+
+    // SimplifyDemandedBits may have optimized the mask so try restoring any
+    // bits that are known zero.
+    KnownBits Known = CurDAG->computeKnownBits(N->getOperand(0));
+    if (ShMask.isSubsetOf(AndMask | Known.Zero)) {
+      ShAmt = N.getOperand(0);
+      return true;
+    }
+  }
 
-  APInt Mask = Val | CurDAG->computeKnownBits(N->getOperand(0)).Zero;
-  return Mask.countTrailingOnes() >= Width;
+  ShAmt = N;
+  return true;
 }
 
 // Match (srl (and val, mask), imm) where the result would be a
index 1e9dba3..0bf5ada 100644 (file)
@@ -46,7 +46,13 @@ public:
   bool SelectAddrFI(SDValue Addr, SDValue &Base);
   bool SelectRVVBaseAddr(SDValue Addr, SDValue &Base);
 
-  bool isUnneededShiftMask(SDNode *N, unsigned Width) const;
+  bool selectShiftMask(SDValue N, unsigned ShiftWidth, SDValue &ShAmt);
+  bool selectShiftMaskXLen(SDValue N, SDValue &ShAmt) {
+    return selectShiftMask(N, Subtarget->getXLen(), ShAmt);
+  }
+  bool selectShiftMask32(SDValue N, SDValue &ShAmt) {
+    return selectShiftMask(N, 32, ShAmt);
+  }
 
   bool MatchSRLIW(SDNode *N) const;
   bool MatchSLLIUW(SDNode *N) const;
index 20aaa82..707cd04 100644 (file)
@@ -895,21 +895,15 @@ def : PatGprUimmLog2XLen<sra, SRAI>;
 // typically introduced when the legalizer promotes the shift amount and
 // zero-extends it). For RISC-V, the mask is unnecessary as shifts in the base
 // ISA only read the least significant 5 bits (RV32I) or 6 bits (RV64I).
-def shiftMaskXLen : PatFrag<(ops node:$lhs), (and node:$lhs, imm), [{
-  return isUnneededShiftMask(N, Subtarget->is64Bit() ? 6 : 5);
-}]>;
-def shiftMask32 : PatFrag<(ops node:$lhs), (and node:$lhs, imm), [{
-  return isUnneededShiftMask(N, 5);
-}]>;
+def shiftMaskXLen : ComplexPattern<XLenVT, 1, "selectShiftMaskXLen", [], [], 0>;
+def shiftMask32   : ComplexPattern<i64, 1, "selectShiftMask32", [], [], 0>;
 
 class shiftop<SDPatternOperator operator>
-    : PatFrags<(ops node:$val, node:$count),
-               [(operator node:$val, node:$count),
-                (operator node:$val, (shiftMaskXLen node:$count))]>;
+    : PatFrag<(ops node:$val, node:$count),
+              (operator node:$val, (XLenVT (shiftMaskXLen node:$count)))>;
 class shiftopw<SDPatternOperator operator>
-    : PatFrags<(ops node:$val, node:$count),
-               [(operator node:$val, node:$count),
-                (operator node:$val, (shiftMask32 node:$count))]>;
+    : PatFrag<(ops node:$val, node:$count),
+              (operator node:$val, (i64 (shiftMask32 node:$count)))>;
 
 def : PatGprGpr<shiftop<shl>, SLL>;
 def : PatGprGpr<shiftop<srl>, SRL>;