From 4830fa18aac6950d479a413c995c38fff56ac42c Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Sat, 22 Oct 2022 14:25:17 -0700 Subject: [PATCH] [RISCV] Make sure we always call tryShrinkShlLogicImm for ISD:AND during isel. There was an early out that prevented us from calling this for (and (sext_inreg (shl X, C1), i32), C2). --- llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp | 313 ++++++++++++++-------------- llvm/test/CodeGen/RISCV/narrow-shl-cst.ll | 3 +- llvm/test/CodeGen/RISCV/shift-and.ll | 3 +- 3 files changed, 160 insertions(+), 159 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index 1ff1ea1..d62098d 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -858,178 +858,181 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) { SDValue N0 = Node->getOperand(0); bool LeftShift = N0.getOpcode() == ISD::SHL; - if (!LeftShift && N0.getOpcode() != ISD::SRL) - break; + if (LeftShift || N0.getOpcode() == ISD::SRL) { + auto *C = dyn_cast(N0.getOperand(1)); + if (!C) + break; + unsigned C2 = C->getZExtValue(); + unsigned XLen = Subtarget->getXLen(); + assert((C2 > 0 && C2 < XLen) && "Unexpected shift amount!"); - auto *C = dyn_cast(N0.getOperand(1)); - if (!C) - break; - unsigned C2 = C->getZExtValue(); - unsigned XLen = Subtarget->getXLen(); - assert((C2 > 0 && C2 < XLen) && "Unexpected shift amount!"); + uint64_t C1 = N1C->getZExtValue(); - uint64_t C1 = N1C->getZExtValue(); + // Keep track of whether this is a c.andi. If we can't use c.andi, the + // shift pair might offer more compression opportunities. + // TODO: We could check for C extension here, but we don't have many lit + // tests with the C extension enabled so not checking gets better + // coverage. + // TODO: What if ANDI faster than shift? + bool IsCANDI = isInt<6>(N1C->getSExtValue()); - // Keep track of whether this is a c.andi. If we can't use c.andi, the - // shift pair might offer more compression opportunities. - // TODO: We could check for C extension here, but we don't have many lit - // tests with the C extension enabled so not checking gets better coverage. - // TODO: What if ANDI faster than shift? - bool IsCANDI = isInt<6>(N1C->getSExtValue()); - - // Clear irrelevant bits in the mask. - if (LeftShift) - C1 &= maskTrailingZeros(C2); - else - C1 &= maskTrailingOnes(XLen - C2); - - // Some transforms should only be done if the shift has a single use or - // the AND would become (srli (slli X, 32), 32) - bool OneUseOrZExtW = N0.hasOneUse() || C1 == UINT64_C(0xFFFFFFFF); - - SDValue X = N0.getOperand(0); - - // Turn (and (srl x, c2) c1) -> (srli (slli x, c3-c2), c3) if c1 is a mask - // with c3 leading zeros. - if (!LeftShift && isMask_64(C1)) { - unsigned Leading = XLen - (64 - countLeadingZeros(C1)); - if (C2 < Leading) { - // If the number of leading zeros is C2+32 this can be SRLIW. - if (C2 + 32 == Leading) { - SDNode *SRLIW = CurDAG->getMachineNode( - RISCV::SRLIW, DL, VT, X, CurDAG->getTargetConstant(C2, DL, VT)); - ReplaceNode(Node, SRLIW); - return; - } + // Clear irrelevant bits in the mask. + if (LeftShift) + C1 &= maskTrailingZeros(C2); + else + C1 &= maskTrailingOnes(XLen - C2); + + // Some transforms should only be done if the shift has a single use or + // the AND would become (srli (slli X, 32), 32) + bool OneUseOrZExtW = N0.hasOneUse() || C1 == UINT64_C(0xFFFFFFFF); + + SDValue X = N0.getOperand(0); + + // Turn (and (srl x, c2) c1) -> (srli (slli x, c3-c2), c3) if c1 is a mask + // with c3 leading zeros. + if (!LeftShift && isMask_64(C1)) { + unsigned Leading = XLen - (64 - countLeadingZeros(C1)); + if (C2 < Leading) { + // If the number of leading zeros is C2+32 this can be SRLIW. + if (C2 + 32 == Leading) { + SDNode *SRLIW = CurDAG->getMachineNode( + RISCV::SRLIW, DL, VT, X, CurDAG->getTargetConstant(C2, DL, VT)); + ReplaceNode(Node, SRLIW); + return; + } - // (and (srl (sexti32 Y), c2), c1) -> (srliw (sraiw Y, 31), c3 - 32) if - // c1 is a mask with c3 leading zeros and c2 >= 32 and c3-c2==1. - // - // This pattern occurs when (i32 (srl (sra 31), c3 - 32)) is type - // legalized and goes through DAG combine. - if (C2 >= 32 && (Leading - C2) == 1 && N0.hasOneUse() && - X.getOpcode() == ISD::SIGN_EXTEND_INREG && - cast(X.getOperand(1))->getVT() == MVT::i32) { - SDNode *SRAIW = - CurDAG->getMachineNode(RISCV::SRAIW, DL, VT, X.getOperand(0), - CurDAG->getTargetConstant(31, DL, VT)); - SDNode *SRLIW = CurDAG->getMachineNode( - RISCV::SRLIW, DL, VT, SDValue(SRAIW, 0), - CurDAG->getTargetConstant(Leading - 32, DL, VT)); - ReplaceNode(Node, SRLIW); - return; + // (and (srl (sexti32 Y), c2), c1) -> (srliw (sraiw Y, 31), c3 - 32) + // if c1 is a mask with c3 leading zeros and c2 >= 32 and c3-c2==1. + // + // This pattern occurs when (i32 (srl (sra 31), c3 - 32)) is type + // legalized and goes through DAG combine. + if (C2 >= 32 && (Leading - C2) == 1 && N0.hasOneUse() && + X.getOpcode() == ISD::SIGN_EXTEND_INREG && + cast(X.getOperand(1))->getVT() == MVT::i32) { + SDNode *SRAIW = + CurDAG->getMachineNode(RISCV::SRAIW, DL, VT, X.getOperand(0), + CurDAG->getTargetConstant(31, DL, VT)); + SDNode *SRLIW = CurDAG->getMachineNode( + RISCV::SRLIW, DL, VT, SDValue(SRAIW, 0), + CurDAG->getTargetConstant(Leading - 32, DL, VT)); + ReplaceNode(Node, SRLIW); + return; + } + + // (srli (slli x, c3-c2), c3). + // Skip if we could use (zext.w (sraiw X, C2)). + bool Skip = Subtarget->hasStdExtZba() && Leading == 32 && + X.getOpcode() == ISD::SIGN_EXTEND_INREG && + cast(X.getOperand(1))->getVT() == MVT::i32; + // Also Skip if we can use bexti. + Skip |= Subtarget->hasStdExtZbs() && Leading == XLen - 1; + if (OneUseOrZExtW && !Skip) { + SDNode *SLLI = CurDAG->getMachineNode( + RISCV::SLLI, DL, VT, X, + CurDAG->getTargetConstant(Leading - C2, DL, VT)); + SDNode *SRLI = CurDAG->getMachineNode( + RISCV::SRLI, DL, VT, SDValue(SLLI, 0), + CurDAG->getTargetConstant(Leading, DL, VT)); + ReplaceNode(Node, SRLI); + return; + } } + } - // (srli (slli x, c3-c2), c3). - // Skip if we could use (zext.w (sraiw X, C2)). - bool Skip = Subtarget->hasStdExtZba() && Leading == 32 && - X.getOpcode() == ISD::SIGN_EXTEND_INREG && - cast(X.getOperand(1))->getVT() == MVT::i32; - // Also Skip if we can use bexti. - Skip |= Subtarget->hasStdExtZbs() && Leading == XLen - 1; - if (OneUseOrZExtW && !Skip) { - SDNode *SLLI = CurDAG->getMachineNode( - RISCV::SLLI, DL, VT, X, - CurDAG->getTargetConstant(Leading - C2, DL, VT)); - SDNode *SRLI = CurDAG->getMachineNode( - RISCV::SRLI, DL, VT, SDValue(SLLI, 0), - CurDAG->getTargetConstant(Leading, DL, VT)); - ReplaceNode(Node, SRLI); - return; + // Turn (and (shl x, c2), c1) -> (srli (slli c2+c3), c3) if c1 is a mask + // shifted by c2 bits with c3 leading zeros. + if (LeftShift && isShiftedMask_64(C1)) { + unsigned Leading = XLen - (64 - countLeadingZeros(C1)); + + if (C2 + Leading < XLen && + C1 == (maskTrailingOnes(XLen - (C2 + Leading)) << C2)) { + // Use slli.uw when possible. + if ((XLen - (C2 + Leading)) == 32 && Subtarget->hasStdExtZba()) { + SDNode *SLLI_UW = + CurDAG->getMachineNode(RISCV::SLLI_UW, DL, VT, X, + CurDAG->getTargetConstant(C2, DL, VT)); + ReplaceNode(Node, SLLI_UW); + return; + } + + // (srli (slli c2+c3), c3) + if (OneUseOrZExtW && !IsCANDI) { + SDNode *SLLI = CurDAG->getMachineNode( + RISCV::SLLI, DL, VT, X, + CurDAG->getTargetConstant(C2 + Leading, DL, VT)); + SDNode *SRLI = CurDAG->getMachineNode( + RISCV::SRLI, DL, VT, SDValue(SLLI, 0), + CurDAG->getTargetConstant(Leading, DL, VT)); + ReplaceNode(Node, SRLI); + return; + } } } - } - // Turn (and (shl x, c2), c1) -> (srli (slli c2+c3), c3) if c1 is a mask - // shifted by c2 bits with c3 leading zeros. - if (LeftShift && isShiftedMask_64(C1)) { - unsigned Leading = XLen - (64 - countLeadingZeros(C1)); - - if (C2 + Leading < XLen && - C1 == (maskTrailingOnes(XLen - (C2 + Leading)) << C2)) { - // Use slli.uw when possible. - if ((XLen - (C2 + Leading)) == 32 && Subtarget->hasStdExtZba()) { - SDNode *SLLI_UW = CurDAG->getMachineNode( - RISCV::SLLI_UW, DL, VT, X, CurDAG->getTargetConstant(C2, DL, VT)); - ReplaceNode(Node, SLLI_UW); + // Turn (and (shr x, c2), c1) -> (slli (srli x, c2+c3), c3) if c1 is a + // shifted mask with c2 leading zeros and c3 trailing zeros. + if (!LeftShift && isShiftedMask_64(C1)) { + unsigned Leading = XLen - (64 - countLeadingZeros(C1)); + unsigned Trailing = countTrailingZeros(C1); + if (Leading == C2 && C2 + Trailing < XLen && OneUseOrZExtW && + !IsCANDI) { + unsigned SrliOpc = RISCV::SRLI; + // If the input is zexti32 we should use SRLIW. + if (X.getOpcode() == ISD::AND && + isa(X.getOperand(1)) && + X.getConstantOperandVal(1) == UINT64_C(0xFFFFFFFF)) { + SrliOpc = RISCV::SRLIW; + X = X.getOperand(0); + } + SDNode *SRLI = CurDAG->getMachineNode( + SrliOpc, DL, VT, X, + CurDAG->getTargetConstant(C2 + Trailing, DL, VT)); + SDNode *SLLI = CurDAG->getMachineNode( + RISCV::SLLI, DL, VT, SDValue(SRLI, 0), + CurDAG->getTargetConstant(Trailing, DL, VT)); + ReplaceNode(Node, SLLI); return; } - - // (srli (slli c2+c3), c3) - if (OneUseOrZExtW && !IsCANDI) { + // If the leading zero count is C2+32, we can use SRLIW instead of SRLI. + if (Leading > 32 && (Leading - 32) == C2 && C2 + Trailing < 32 && + OneUseOrZExtW && !IsCANDI) { + SDNode *SRLIW = CurDAG->getMachineNode( + RISCV::SRLIW, DL, VT, X, + CurDAG->getTargetConstant(C2 + Trailing, DL, VT)); SDNode *SLLI = CurDAG->getMachineNode( - RISCV::SLLI, DL, VT, X, - CurDAG->getTargetConstant(C2 + Leading, DL, VT)); - SDNode *SRLI = CurDAG->getMachineNode( - RISCV::SRLI, DL, VT, SDValue(SLLI, 0), - CurDAG->getTargetConstant(Leading, DL, VT)); - ReplaceNode(Node, SRLI); + RISCV::SLLI, DL, VT, SDValue(SRLIW, 0), + CurDAG->getTargetConstant(Trailing, DL, VT)); + ReplaceNode(Node, SLLI); return; } } - } - // Turn (and (shr x, c2), c1) -> (slli (srli x, c2+c3), c3) if c1 is a - // shifted mask with c2 leading zeros and c3 trailing zeros. - if (!LeftShift && isShiftedMask_64(C1)) { - unsigned Leading = XLen - (64 - countLeadingZeros(C1)); - unsigned Trailing = countTrailingZeros(C1); - if (Leading == C2 && C2 + Trailing < XLen && OneUseOrZExtW && !IsCANDI) { - unsigned SrliOpc = RISCV::SRLI; - // If the input is zexti32 we should use SRLIW. - if (X.getOpcode() == ISD::AND && isa(X.getOperand(1)) && - X.getConstantOperandVal(1) == UINT64_C(0xFFFFFFFF)) { - SrliOpc = RISCV::SRLIW; - X = X.getOperand(0); + // Turn (and (shl x, c2), c1) -> (slli (srli x, c3-c2), c3) if c1 is a + // shifted mask with no leading zeros and c3 trailing zeros. + if (LeftShift && isShiftedMask_64(C1)) { + unsigned Leading = XLen - (64 - countLeadingZeros(C1)); + unsigned Trailing = countTrailingZeros(C1); + if (Leading == 0 && C2 < Trailing && OneUseOrZExtW && !IsCANDI) { + SDNode *SRLI = CurDAG->getMachineNode( + RISCV::SRLI, DL, VT, X, + CurDAG->getTargetConstant(Trailing - C2, DL, VT)); + SDNode *SLLI = CurDAG->getMachineNode( + RISCV::SLLI, DL, VT, SDValue(SRLI, 0), + CurDAG->getTargetConstant(Trailing, DL, VT)); + ReplaceNode(Node, SLLI); + return; + } + // If we have (32-C2) leading zeros, we can use SRLIW instead of SRLI. + if (C2 < Trailing && Leading + C2 == 32 && OneUseOrZExtW && !IsCANDI) { + SDNode *SRLIW = CurDAG->getMachineNode( + RISCV::SRLIW, DL, VT, X, + CurDAG->getTargetConstant(Trailing - C2, DL, VT)); + SDNode *SLLI = CurDAG->getMachineNode( + RISCV::SLLI, DL, VT, SDValue(SRLIW, 0), + CurDAG->getTargetConstant(Trailing, DL, VT)); + ReplaceNode(Node, SLLI); + return; } - SDNode *SRLI = CurDAG->getMachineNode( - SrliOpc, DL, VT, X, - CurDAG->getTargetConstant(C2 + Trailing, DL, VT)); - SDNode *SLLI = - CurDAG->getMachineNode(RISCV::SLLI, DL, VT, SDValue(SRLI, 0), - CurDAG->getTargetConstant(Trailing, DL, VT)); - ReplaceNode(Node, SLLI); - return; - } - // If the leading zero count is C2+32, we can use SRLIW instead of SRLI. - if (Leading > 32 && (Leading - 32) == C2 && C2 + Trailing < 32 && - OneUseOrZExtW && !IsCANDI) { - SDNode *SRLIW = CurDAG->getMachineNode( - RISCV::SRLIW, DL, VT, X, - CurDAG->getTargetConstant(C2 + Trailing, DL, VT)); - SDNode *SLLI = - CurDAG->getMachineNode(RISCV::SLLI, DL, VT, SDValue(SRLIW, 0), - CurDAG->getTargetConstant(Trailing, DL, VT)); - ReplaceNode(Node, SLLI); - return; - } - } - - // Turn (and (shl x, c2), c1) -> (slli (srli x, c3-c2), c3) if c1 is a - // shifted mask with no leading zeros and c3 trailing zeros. - if (LeftShift && isShiftedMask_64(C1)) { - unsigned Leading = XLen - (64 - countLeadingZeros(C1)); - unsigned Trailing = countTrailingZeros(C1); - if (Leading == 0 && C2 < Trailing && OneUseOrZExtW && !IsCANDI) { - SDNode *SRLI = CurDAG->getMachineNode( - RISCV::SRLI, DL, VT, X, - CurDAG->getTargetConstant(Trailing - C2, DL, VT)); - SDNode *SLLI = - CurDAG->getMachineNode(RISCV::SLLI, DL, VT, SDValue(SRLI, 0), - CurDAG->getTargetConstant(Trailing, DL, VT)); - ReplaceNode(Node, SLLI); - return; - } - // If we have (32-C2) leading zeros, we can use SRLIW instead of SRLI. - if (C2 < Trailing && Leading + C2 == 32 && OneUseOrZExtW && !IsCANDI) { - SDNode *SRLIW = CurDAG->getMachineNode( - RISCV::SRLIW, DL, VT, X, - CurDAG->getTargetConstant(Trailing - C2, DL, VT)); - SDNode *SLLI = - CurDAG->getMachineNode(RISCV::SLLI, DL, VT, SDValue(SRLIW, 0), - CurDAG->getTargetConstant(Trailing, DL, VT)); - ReplaceNode(Node, SLLI); - return; } } diff --git a/llvm/test/CodeGen/RISCV/narrow-shl-cst.ll b/llvm/test/CodeGen/RISCV/narrow-shl-cst.ll index bd99b1c..7434f62 100644 --- a/llvm/test/CodeGen/RISCV/narrow-shl-cst.ll +++ b/llvm/test/CodeGen/RISCV/narrow-shl-cst.ll @@ -189,9 +189,8 @@ define signext i32 @test11(i32 signext %x) nounwind { ; ; RV64-LABEL: test11: ; RV64: # %bb.0: +; RV64-NEXT: andi a0, a0, -241 ; RV64-NEXT: slliw a0, a0, 17 -; RV64-NEXT: lui a1, 1040864 -; RV64-NEXT: and a0, a0, a1 ; RV64-NEXT: ret %or = shl i32 %x, 17 %shl = and i32 %or, -31588352 diff --git a/llvm/test/CodeGen/RISCV/shift-and.ll b/llvm/test/CodeGen/RISCV/shift-and.ll index 288c2cd..525ef62 100644 --- a/llvm/test/CodeGen/RISCV/shift-and.ll +++ b/llvm/test/CodeGen/RISCV/shift-and.ll @@ -92,9 +92,8 @@ define i32 @test5(i32 %x) { ; ; RV64I-LABEL: test5: ; RV64I: # %bb.0: +; RV64I-NEXT: andi a0, a0, -1024 ; RV64I-NEXT: slliw a0, a0, 6 -; RV64I-NEXT: lui a1, 1048560 -; RV64I-NEXT: and a0, a0, a1 ; RV64I-NEXT: ret %a = shl i32 %x, 6 %b = and i32 %a, -65536 -- 2.7.4