[RISCV] Relax another one use restriction in performSRACombine.
authorCraig Topper <craig.topper@sifive.com>
Thu, 4 Aug 2022 21:29:14 +0000 (14:29 -0700)
committerCraig Topper <craig.topper@sifive.com>
Thu, 4 Aug 2022 21:32:31 +0000 (14:32 -0700)
When folding (sra (add (shl X, 32), C1), 32 - C) -> (shl (sext_inreg (add X, C1), i32), C)
it's possible that the add is used by multiple sras. We should
allow the combine if all the SRAs will eventually be updated.

After transforming all of the sras, the shls will share a single
(sext_inreg (add X, C1), i32).

This pattern occurs if an sra with 32 is used as index in multiple
GEPs with different scales. The shl from the GEPs will be combined
with the sra before we get a chance to match the sra pattern.

llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/test/CodeGen/RISCV/rv64i-shift-sext.ll

index 227f2a1..fca42d7 100644 (file)
@@ -8874,8 +8874,6 @@ static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
   // We might have an ADD or SUB between the SRA and SHL.
   bool IsAdd = N0.getOpcode() == ISD::ADD;
   if ((IsAdd || N0.getOpcode() == ISD::SUB)) {
-    if (!N0.hasOneUse())
-      return SDValue();
     // Other operand needs to be a constant we can modify.
     AddC = dyn_cast<ConstantSDNode>(N0.getOperand(IsAdd ? 1 : 0));
     if (!AddC)
@@ -8885,6 +8883,16 @@ static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
     if (AddC->getAPIntValue().countTrailingZeros() < 32)
       return SDValue();
 
+    // All users should be a shift by constant less than or equal to 32. This
+    // ensures we'll do this optimization for each of them to produce an
+    // add/sub+sext_inreg they can all share.
+    for (SDNode *U : N0->uses()) {
+      if (U->getOpcode() != ISD::SRA ||
+          !isa<ConstantSDNode>(U->getOperand(1)) ||
+          cast<ConstantSDNode>(U->getOperand(1))->getZExtValue() > 32)
+        return SDValue();
+    }
+
     Shl = N0.getOperand(IsAdd ? 0 : 1);
   } else {
     // Not an ADD or SUB.
index ad1df83..55620af 100644 (file)
@@ -196,3 +196,27 @@ define i8 @test13(i8* %0, i64 %1) {
   %12 = add i8 %7, %11
   ret i8 %12
 }
+
+define signext i32 @test14(i8* %0, i32* %1, i64 %2) {
+; RV64I-LABEL: test14:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    li a3, 1
+; RV64I-NEXT:    subw a2, a3, a2
+; RV64I-NEXT:    add a0, a0, a2
+; RV64I-NEXT:    lbu a0, 0(a0)
+; RV64I-NEXT:    slli a2, a2, 2
+; RV64I-NEXT:    add a1, a1, a2
+; RV64I-NEXT:    lw a1, 0(a1)
+; RV64I-NEXT:    addw a0, a0, a1
+; RV64I-NEXT:    ret
+  %4 = mul i64 %2, -4294967296
+  %5 = add i64 %4, 4294967296 ; 1 << 32
+  %6 = ashr exact i64 %5, 32
+  %7 = getelementptr inbounds i8, i8* %0, i64 %6
+  %8 = load i8, i8* %7, align 4
+  %9 = zext i8 %8 to i32
+  %10 = getelementptr inbounds i32, i32* %1, i64 %6
+  %11 = load i32, i32* %10, align 4
+  %12 = add i32 %9, %11
+  ret i32 %12
+}