From 124c93c528758071fccfce68f6b633081a19c226 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Mon, 16 Nov 2020 09:22:42 -0800 Subject: [PATCH] [RISCV] When matching SROIW, check all 64 bits of the OR mask We need to make sure the upper 32 bits are all ones to ensure the result is properly sign extended. Previously we only checked the lower 32 bits of the mask. I've also added a check that the shift amount is less than 32. Without that the original code asserts inside maskLeadingOnes if the SROI check is removed or the SROIW pattern is checked first. I've refactored the code to use early outs to reduce nesting. I've also updated SLOIW matching with the same changes, but I couldn't find a broken test case with the existing code. Differential Revision: https://reviews.llvm.org/D90961 --- llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp | 92 +++++++++++++++-------------- llvm/test/CodeGen/RISCV/rv64Zbb.ll | 11 +++- 2 files changed, 56 insertions(+), 47 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index 54219e9..765775c 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -314,62 +314,66 @@ bool RISCVDAGToDAGISel::SelectSLLIUW(SDValue N, SDValue &RS1, SDValue &Shamt) { // and then we check that VC1, the mask used to fill with ones, is compatible // with VC2, the shamt: // -// VC1 == maskTrailingOnes(VC2) +// VC2 < 32 +// VC1 == maskTrailingOnes(VC2) bool RISCVDAGToDAGISel::SelectSLOIW(SDValue N, SDValue &RS1, SDValue &Shamt) { - if (Subtarget->getXLenVT() == MVT::i64 && - N.getOpcode() == ISD::SIGN_EXTEND_INREG && - cast(N.getOperand(1))->getVT() == MVT::i32) { - if (N.getOperand(0).getOpcode() == ISD::OR) { - SDValue Or = N.getOperand(0); - if (Or.getOperand(0).getOpcode() == ISD::SHL) { - SDValue Shl = Or.getOperand(0); - if (isa(Shl.getOperand(1)) && - isa(Or.getOperand(1))) { - uint32_t VC1 = Or.getConstantOperandVal(1); - uint32_t VC2 = Shl.getConstantOperandVal(1); - if (VC1 == maskTrailingOnes(VC2)) { - RS1 = Shl.getOperand(0); - Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N), - Shl.getOperand(1).getValueType()); - return true; - } - } - } - } - } - return false; + assert(Subtarget->is64Bit() && "SLOIW should only be matched on RV64"); + if (N.getOpcode() != ISD::SIGN_EXTEND_INREG || + cast(N.getOperand(1))->getVT() != MVT::i32) + return false; + + SDValue Or = N.getOperand(0); + + if (Or.getOpcode() != ISD::OR || !isa(Or.getOperand(1))) + return false; + + SDValue Shl = Or.getOperand(0); + if (Shl.getOpcode() != ISD::SHL || !isa(Shl.getOperand(1))) + return false; + + uint64_t VC1 = Or.getConstantOperandVal(1); + uint64_t VC2 = Shl.getConstantOperandVal(1); + + if (VC2 >= 32 || VC1 != maskTrailingOnes(VC2)) + return false; + + RS1 = Shl.getOperand(0); + Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N), + Shl.getOperand(1).getValueType()); + return true; } // Check that it is a SROIW (Shift Right Ones Immediate i32 on RV64). // We first check that it is the right node tree: // -// (OR (SHL RS1, VC2), VC1) +// (OR (SRL RS1, VC2), VC1) // // and then we check that VC1, the mask used to fill with ones, is compatible // with VC2, the shamt: // -// VC1 == maskLeadingOnes(VC2) - +// VC2 < 32 +// VC1 == maskTrailingZeros(32 - VC2) +// bool RISCVDAGToDAGISel::SelectSROIW(SDValue N, SDValue &RS1, SDValue &Shamt) { - if (N.getOpcode() == ISD::OR && Subtarget->getXLenVT() == MVT::i64) { - SDValue Or = N; - if (Or.getOperand(0).getOpcode() == ISD::SRL) { - SDValue Srl = Or.getOperand(0); - if (isa(Srl.getOperand(1)) && - isa(Or.getOperand(1))) { - uint32_t VC1 = Or.getConstantOperandVal(1); - uint32_t VC2 = Srl.getConstantOperandVal(1); - if (VC1 == maskLeadingOnes(VC2)) { - RS1 = Srl.getOperand(0); - Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N), - Srl.getOperand(1).getValueType()); - return true; - } - } - } - } - return false; + assert(Subtarget->is64Bit() && "SROIW should only be matched on RV64"); + if (N.getOpcode() != ISD::OR || !isa(N.getOperand(1))) + return false; + + SDValue Srl = N.getOperand(0); + if (Srl.getOpcode() != ISD::SRL || !isa(Srl.getOperand(1))) + return false; + + uint64_t VC1 = N.getConstantOperandVal(1); + uint64_t VC2 = Srl.getConstantOperandVal(1); + + if (VC2 >= 32 || VC1 != maskTrailingZeros(32 - VC2)) + return false; + + RS1 = Srl.getOperand(0); + Shamt = CurDAG->getTargetConstant(VC2, SDLoc(N), + Srl.getOperand(1).getValueType()); + return true; } // Check that it is a RORIW (i32 Right Rotate Immediate on RV64). diff --git a/llvm/test/CodeGen/RISCV/rv64Zbb.ll b/llvm/test/CodeGen/RISCV/rv64Zbb.ll index a1d0b8a..66985c5 100644 --- a/llvm/test/CodeGen/RISCV/rv64Zbb.ll +++ b/llvm/test/CodeGen/RISCV/rv64Zbb.ll @@ -166,7 +166,6 @@ define signext i32 @sroi_i32(i32 signext %a) nounwind { ; This is similar to the type legalized version of sroiw but the mask is 0 in ; the upper bits instead of 1 so the result is not sign extended. Make sure we ; don't match it to sroiw. -; FIXME: We're matching it to sroiw. define i64 @sroiw_bug(i64 %a) nounwind { ; RV64I-LABEL: sroiw_bug: ; RV64I: # %bb.0: @@ -178,12 +177,18 @@ define i64 @sroiw_bug(i64 %a) nounwind { ; ; RV64IB-LABEL: sroiw_bug: ; RV64IB: # %bb.0: -; RV64IB-NEXT: sroiw a0, a0, 1 +; RV64IB-NEXT: srli a0, a0, 1 +; RV64IB-NEXT: addi a1, zero, 1 +; RV64IB-NEXT: slli a1, a1, 31 +; RV64IB-NEXT: or a0, a0, a1 ; RV64IB-NEXT: ret ; ; RV64IBB-LABEL: sroiw_bug: ; RV64IBB: # %bb.0: -; RV64IBB-NEXT: sroiw a0, a0, 1 +; RV64IBB-NEXT: srli a0, a0, 1 +; RV64IBB-NEXT: addi a1, zero, 1 +; RV64IBB-NEXT: slli a1, a1, 31 +; RV64IBB-NEXT: or a0, a0, a1 ; RV64IBB-NEXT: ret %neg = lshr i64 %a, 1 %neg12 = or i64 %neg, 2147483648 -- 2.7.4