[RISCV] Move allWUsers from RISCVInstrInfo to RISCVOptWInstrs.
authorCraig Topper <craig.topper@sifive.com>
Wed, 29 Mar 2023 22:09:51 +0000 (15:09 -0700)
committerCraig Topper <craig.topper@sifive.com>
Wed, 29 Mar 2023 22:13:09 +0000 (15:13 -0700)
It was only in RISCVInstrInfo because it was used by 2 passes, but those
passes have been merged in D147173.

Reviewed By: asb

Differential Revision: https://reviews.llvm.org/D147174

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
llvm/lib/Target/RISCV/RISCVInstrInfo.h
llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp

index 523895c..2ad5b81 100644 (file)
@@ -2614,226 +2614,6 @@ void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF,
   }
 }
 
-// Checks if all users only demand the lower \p OrigBits of the original
-// instruction's result.
-// TODO: handle multiple interdependent transformations
-bool RISCVInstrInfo::hasAllNBitUsers(const MachineInstr &OrigMI,
-                                     const MachineRegisterInfo &MRI,
-                                     unsigned OrigBits) const {
-
-  SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited;
-  SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist;
-
-  Worklist.push_back(std::make_pair(&OrigMI, OrigBits));
-
-  while (!Worklist.empty()) {
-    auto P = Worklist.pop_back_val();
-    const MachineInstr *MI = P.first;
-    unsigned Bits = P.second;
-
-    if (!Visited.insert(P).second)
-      continue;
-
-    // Only handle instructions with one def.
-    if (MI->getNumExplicitDefs() != 1)
-      return false;
-
-    for (auto &UserOp : MRI.use_operands(MI->getOperand(0).getReg())) {
-      const MachineInstr *UserMI = UserOp.getParent();
-      unsigned OpIdx = UserOp.getOperandNo();
-
-      switch (UserMI->getOpcode()) {
-      default:
-        return false;
-
-      case RISCV::ADDIW:
-      case RISCV::ADDW:
-      case RISCV::DIVUW:
-      case RISCV::DIVW:
-      case RISCV::MULW:
-      case RISCV::REMUW:
-      case RISCV::REMW:
-      case RISCV::SLLIW:
-      case RISCV::SLLW:
-      case RISCV::SRAIW:
-      case RISCV::SRAW:
-      case RISCV::SRLIW:
-      case RISCV::SRLW:
-      case RISCV::SUBW:
-      case RISCV::ROLW:
-      case RISCV::RORW:
-      case RISCV::RORIW:
-      case RISCV::CLZW:
-      case RISCV::CTZW:
-      case RISCV::CPOPW:
-      case RISCV::SLLI_UW:
-      case RISCV::FMV_W_X:
-      case RISCV::FCVT_H_W:
-      case RISCV::FCVT_H_WU:
-      case RISCV::FCVT_S_W:
-      case RISCV::FCVT_S_WU:
-      case RISCV::FCVT_D_W:
-      case RISCV::FCVT_D_WU:
-        if (Bits >= 32)
-          break;
-        return false;
-      case RISCV::SEXT_B:
-      case RISCV::PACKH:
-        if (Bits >= 8)
-          break;
-        return false;
-      case RISCV::SEXT_H:
-      case RISCV::FMV_H_X:
-      case RISCV::ZEXT_H_RV32:
-      case RISCV::ZEXT_H_RV64:
-      case RISCV::PACKW:
-        if (Bits >= 16)
-          break;
-        return false;
-
-      case RISCV::PACK:
-        if (Bits >= (STI.getXLen() / 2))
-          break;
-        return false;
-
-      case RISCV::SRLI: {
-        // If we are shifting right by less than Bits, and users don't demand
-        // any bits that were shifted into [Bits-1:0], then we can consider this
-        // as an N-Bit user.
-        unsigned ShAmt = UserMI->getOperand(2).getImm();
-        if (Bits > ShAmt) {
-          Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt));
-          break;
-        }
-        return false;
-      }
-
-      // these overwrite higher input bits, otherwise the lower word of output
-      // depends only on the lower word of input. So check their uses read W.
-      case RISCV::SLLI:
-        if (Bits >= (STI.getXLen() - UserMI->getOperand(2).getImm()))
-          break;
-        Worklist.push_back(std::make_pair(UserMI, Bits));
-        break;
-      case RISCV::ANDI: {
-        uint64_t Imm = UserMI->getOperand(2).getImm();
-        if (Bits >= (unsigned)llvm::bit_width(Imm))
-          break;
-        Worklist.push_back(std::make_pair(UserMI, Bits));
-        break;
-      }
-      case RISCV::ORI: {
-        uint64_t Imm = UserMI->getOperand(2).getImm();
-        if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm))
-          break;
-        Worklist.push_back(std::make_pair(UserMI, Bits));
-        break;
-      }
-
-      case RISCV::SLL:
-      case RISCV::BSET:
-      case RISCV::BCLR:
-      case RISCV::BINV:
-        // Operand 2 is the shift amount which uses log2(xlen) bits.
-        if (OpIdx == 2) {
-          if (Bits >= Log2_32(STI.getXLen()))
-            break;
-          return false;
-        }
-        Worklist.push_back(std::make_pair(UserMI, Bits));
-        break;
-
-      case RISCV::SRA:
-      case RISCV::SRL:
-      case RISCV::ROL:
-      case RISCV::ROR:
-        // Operand 2 is the shift amount which uses 6 bits.
-        if (OpIdx == 2 && Bits >= Log2_32(STI.getXLen()))
-          break;
-        return false;
-
-      case RISCV::ADD_UW:
-      case RISCV::SH1ADD_UW:
-      case RISCV::SH2ADD_UW:
-      case RISCV::SH3ADD_UW:
-        // Operand 1 is implicitly zero extended.
-        if (OpIdx == 1 && Bits >= 32)
-          break;
-        Worklist.push_back(std::make_pair(UserMI, Bits));
-        break;
-
-      case RISCV::BEXTI:
-        if (UserMI->getOperand(2).getImm() >= Bits)
-          return false;
-        break;
-
-      case RISCV::SB:
-        // The first argument is the value to store.
-        if (OpIdx == 0 && Bits >= 8)
-          break;
-        return false;
-      case RISCV::SH:
-        // The first argument is the value to store.
-        if (OpIdx == 0 && Bits >= 16)
-          break;
-        return false;
-      case RISCV::SW:
-        // The first argument is the value to store.
-        if (OpIdx == 0 && Bits >= 32)
-          break;
-        return false;
-
-      // For these, lower word of output in these operations, depends only on
-      // the lower word of input. So, we check all uses only read lower word.
-      case RISCV::COPY:
-      case RISCV::PHI:
-
-      case RISCV::ADD:
-      case RISCV::ADDI:
-      case RISCV::AND:
-      case RISCV::MUL:
-      case RISCV::OR:
-      case RISCV::SUB:
-      case RISCV::XOR:
-      case RISCV::XORI:
-
-      case RISCV::ANDN:
-      case RISCV::BREV8:
-      case RISCV::CLMUL:
-      case RISCV::ORC_B:
-      case RISCV::ORN:
-      case RISCV::SH1ADD:
-      case RISCV::SH2ADD:
-      case RISCV::SH3ADD:
-      case RISCV::XNOR:
-      case RISCV::BSETI:
-      case RISCV::BCLRI:
-      case RISCV::BINVI:
-        Worklist.push_back(std::make_pair(UserMI, Bits));
-        break;
-
-      case RISCV::PseudoCCMOVGPR:
-        // Either operand 4 or operand 5 is returned by this instruction. If
-        // only the lower word of the result is used, then only the lower word
-        // of operand 4 and 5 is used.
-        if (OpIdx != 4 && OpIdx != 5)
-          return false;
-        Worklist.push_back(std::make_pair(UserMI, Bits));
-        break;
-
-      case RISCV::VT_MASKC:
-      case RISCV::VT_MASKCN:
-        if (OpIdx != 1)
-          return false;
-        Worklist.push_back(std::make_pair(UserMI, Bits));
-        break;
-      }
-    }
-  }
-
-  return true;
-}
-
 // Returns true if this is the sext.w pattern, addiw rd, rs1, 0.
 bool RISCV::isSEXT_W(const MachineInstr &MI) {
   return MI.getOpcode() == RISCV::ADDIW && MI.getOperand(1).isReg() &&
index 64e0bc0..01f112a 100644 (file)
@@ -227,17 +227,6 @@ public:
 
   std::optional<unsigned> getInverseOpcode(unsigned Opcode) const override;
 
-  // Returns true if all uses of OrigMI only depend on the lower \p NBits bits
-  // of its output.
-  bool hasAllNBitUsers(const MachineInstr &MI, const MachineRegisterInfo &MRI,
-                       unsigned NBits) const;
-  // Returns true if all uses of OrigMI only depend on the lower word of its
-  // output, so we can transform OrigMI to the corresponding W-version.
-  bool hasAllWUsers(const MachineInstr &MI,
-                    const MachineRegisterInfo &MRI) const {
-    return hasAllNBitUsers(MI, MRI, 32);
-  }
-
 protected:
   const RISCVSubtarget &STI;
 };
index 40fe5f9..7014755 100644 (file)
@@ -54,9 +54,9 @@ public:
 
   bool runOnMachineFunction(MachineFunction &MF) override;
   bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII,
-                         MachineRegisterInfo &MRI);
+                         const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
   bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII,
-                      MachineRegisterInfo &MRI);
+                      const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesCFG();
@@ -76,6 +76,231 @@ FunctionPass *llvm::createRISCVOptWInstrsPass() {
   return new RISCVOptWInstrs();
 }
 
+// Checks if all users only demand the lower \p OrigBits of the original
+// instruction's result.
+// TODO: handle multiple interdependent transformations
+static bool hasAllNBitUsers(const MachineInstr &OrigMI,
+                            const RISCVSubtarget &ST,
+                            const MachineRegisterInfo &MRI, unsigned OrigBits) {
+
+  SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited;
+  SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist;
+
+  Worklist.push_back(std::make_pair(&OrigMI, OrigBits));
+
+  while (!Worklist.empty()) {
+    auto P = Worklist.pop_back_val();
+    const MachineInstr *MI = P.first;
+    unsigned Bits = P.second;
+
+    if (!Visited.insert(P).second)
+      continue;
+
+    // Only handle instructions with one def.
+    if (MI->getNumExplicitDefs() != 1)
+      return false;
+
+    for (auto &UserOp : MRI.use_operands(MI->getOperand(0).getReg())) {
+      const MachineInstr *UserMI = UserOp.getParent();
+      unsigned OpIdx = UserOp.getOperandNo();
+
+      switch (UserMI->getOpcode()) {
+      default:
+        return false;
+
+      case RISCV::ADDIW:
+      case RISCV::ADDW:
+      case RISCV::DIVUW:
+      case RISCV::DIVW:
+      case RISCV::MULW:
+      case RISCV::REMUW:
+      case RISCV::REMW:
+      case RISCV::SLLIW:
+      case RISCV::SLLW:
+      case RISCV::SRAIW:
+      case RISCV::SRAW:
+      case RISCV::SRLIW:
+      case RISCV::SRLW:
+      case RISCV::SUBW:
+      case RISCV::ROLW:
+      case RISCV::RORW:
+      case RISCV::RORIW:
+      case RISCV::CLZW:
+      case RISCV::CTZW:
+      case RISCV::CPOPW:
+      case RISCV::SLLI_UW:
+      case RISCV::FMV_W_X:
+      case RISCV::FCVT_H_W:
+      case RISCV::FCVT_H_WU:
+      case RISCV::FCVT_S_W:
+      case RISCV::FCVT_S_WU:
+      case RISCV::FCVT_D_W:
+      case RISCV::FCVT_D_WU:
+        if (Bits >= 32)
+          break;
+        return false;
+      case RISCV::SEXT_B:
+      case RISCV::PACKH:
+        if (Bits >= 8)
+          break;
+        return false;
+      case RISCV::SEXT_H:
+      case RISCV::FMV_H_X:
+      case RISCV::ZEXT_H_RV32:
+      case RISCV::ZEXT_H_RV64:
+      case RISCV::PACKW:
+        if (Bits >= 16)
+          break;
+        return false;
+
+      case RISCV::PACK:
+        if (Bits >= (ST.getXLen() / 2))
+          break;
+        return false;
+
+      case RISCV::SRLI: {
+        // If we are shifting right by less than Bits, and users don't demand
+        // any bits that were shifted into [Bits-1:0], then we can consider this
+        // as an N-Bit user.
+        unsigned ShAmt = UserMI->getOperand(2).getImm();
+        if (Bits > ShAmt) {
+          Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt));
+          break;
+        }
+        return false;
+      }
+
+      // these overwrite higher input bits, otherwise the lower word of output
+      // depends only on the lower word of input. So check their uses read W.
+      case RISCV::SLLI:
+        if (Bits >= (ST.getXLen() - UserMI->getOperand(2).getImm()))
+          break;
+        Worklist.push_back(std::make_pair(UserMI, Bits));
+        break;
+      case RISCV::ANDI: {
+        uint64_t Imm = UserMI->getOperand(2).getImm();
+        if (Bits >= (unsigned)llvm::bit_width(Imm))
+          break;
+        Worklist.push_back(std::make_pair(UserMI, Bits));
+        break;
+      }
+      case RISCV::ORI: {
+        uint64_t Imm = UserMI->getOperand(2).getImm();
+        if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm))
+          break;
+        Worklist.push_back(std::make_pair(UserMI, Bits));
+        break;
+      }
+
+      case RISCV::SLL:
+      case RISCV::BSET:
+      case RISCV::BCLR:
+      case RISCV::BINV:
+        // Operand 2 is the shift amount which uses log2(xlen) bits.
+        if (OpIdx == 2) {
+          if (Bits >= Log2_32(ST.getXLen()))
+            break;
+          return false;
+        }
+        Worklist.push_back(std::make_pair(UserMI, Bits));
+        break;
+
+      case RISCV::SRA:
+      case RISCV::SRL:
+      case RISCV::ROL:
+      case RISCV::ROR:
+        // Operand 2 is the shift amount which uses 6 bits.
+        if (OpIdx == 2 && Bits >= Log2_32(ST.getXLen()))
+          break;
+        return false;
+
+      case RISCV::ADD_UW:
+      case RISCV::SH1ADD_UW:
+      case RISCV::SH2ADD_UW:
+      case RISCV::SH3ADD_UW:
+        // Operand 1 is implicitly zero extended.
+        if (OpIdx == 1 && Bits >= 32)
+          break;
+        Worklist.push_back(std::make_pair(UserMI, Bits));
+        break;
+
+      case RISCV::BEXTI:
+        if (UserMI->getOperand(2).getImm() >= Bits)
+          return false;
+        break;
+
+      case RISCV::SB:
+        // The first argument is the value to store.
+        if (OpIdx == 0 && Bits >= 8)
+          break;
+        return false;
+      case RISCV::SH:
+        // The first argument is the value to store.
+        if (OpIdx == 0 && Bits >= 16)
+          break;
+        return false;
+      case RISCV::SW:
+        // The first argument is the value to store.
+        if (OpIdx == 0 && Bits >= 32)
+          break;
+        return false;
+
+      // For these, lower word of output in these operations, depends only on
+      // the lower word of input. So, we check all uses only read lower word.
+      case RISCV::COPY:
+      case RISCV::PHI:
+
+      case RISCV::ADD:
+      case RISCV::ADDI:
+      case RISCV::AND:
+      case RISCV::MUL:
+      case RISCV::OR:
+      case RISCV::SUB:
+      case RISCV::XOR:
+      case RISCV::XORI:
+
+      case RISCV::ANDN:
+      case RISCV::BREV8:
+      case RISCV::CLMUL:
+      case RISCV::ORC_B:
+      case RISCV::ORN:
+      case RISCV::SH1ADD:
+      case RISCV::SH2ADD:
+      case RISCV::SH3ADD:
+      case RISCV::XNOR:
+      case RISCV::BSETI:
+      case RISCV::BCLRI:
+      case RISCV::BINVI:
+        Worklist.push_back(std::make_pair(UserMI, Bits));
+        break;
+
+      case RISCV::PseudoCCMOVGPR:
+        // Either operand 4 or operand 5 is returned by this instruction. If
+        // only the lower word of the result is used, then only the lower word
+        // of operand 4 and 5 is used.
+        if (OpIdx != 4 && OpIdx != 5)
+          return false;
+        Worklist.push_back(std::make_pair(UserMI, Bits));
+        break;
+
+      case RISCV::VT_MASKC:
+      case RISCV::VT_MASKCN:
+        if (OpIdx != 1)
+          return false;
+        Worklist.push_back(std::make_pair(UserMI, Bits));
+        break;
+      }
+    }
+  }
+
+  return true;
+}
+
+static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST,
+                         const MachineRegisterInfo &MRI) {
+  return hasAllNBitUsers(OrigMI, ST, MRI, 32);
+}
+
 // This function returns true if the machine instruction always outputs a value
 // where bits 63:32 match bit 31.
 static bool isSignExtendingOpW(const MachineInstr &MI,
@@ -110,8 +335,8 @@ static bool isSignExtendingOpW(const MachineInstr &MI,
   return false;
 }
 
-static bool isSignExtendedW(Register SrcReg, const MachineRegisterInfo &MRI,
-                            const RISCVInstrInfo &TII,
+static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST,
+                            const MachineRegisterInfo &MRI,
                             SmallPtrSetImpl<MachineInstr *> &FixableDef) {
 
   SmallPtrSet<const MachineInstr *, 4> Visited;
@@ -300,7 +525,7 @@ static bool isSignExtendedW(Register SrcReg, const MachineRegisterInfo &MRI,
     case RISCV::LWU:
     case RISCV::MUL:
     case RISCV::SUB:
-      if (TII.hasAllWUsers(*MI, MRI)) {
+      if (hasAllWUsers(*MI, ST, MRI)) {
         FixableDef.insert(MI);
         break;
       }
@@ -335,6 +560,7 @@ static unsigned getWOp(unsigned Opcode) {
 
 bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
                                         const RISCVInstrInfo &TII,
+                                        const RISCVSubtarget &ST,
                                         MachineRegisterInfo &MRI) {
   if (DisableSExtWRemoval)
     return false;
@@ -355,8 +581,8 @@ bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
       // If all users only use the lower bits, this sext.w is redundant.
       // Or if all definitions reaching MI sign-extend their output,
       // then sext.w is redundant.
-      if (!TII.hasAllWUsers(*MI, MRI) &&
-          !isSignExtendedW(SrcReg, MRI, TII, FixableDefs))
+      if (!hasAllWUsers(*MI, ST, MRI) &&
+          !isSignExtendedW(SrcReg, ST, MRI, FixableDefs))
         continue;
 
       Register DstReg = MI->getOperand(0).getReg();
@@ -388,6 +614,7 @@ bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
 
 bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF,
                                      const RISCVInstrInfo &TII,
+                                     const RISCVSubtarget &ST,
                                      MachineRegisterInfo &MRI) {
   if (DisableStripWSuffix)
     return false;
@@ -406,7 +633,7 @@ bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF,
       case RISCV::SLLIW: Opc = RISCV::SLLI; break;
       }
 
-      if (TII.hasAllWUsers(MI, MRI)) {
+      if (hasAllWUsers(MI, ST, MRI)) {
         MI.setDesc(TII.get(Opc));
         MadeChange = true;
       }
@@ -428,8 +655,8 @@ bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) {
     return false;
 
   bool MadeChange = false;
-  MadeChange |= removeSExtWInstrs(MF, TII, MRI);
-  MadeChange |= stripWSuffixes(MF, TII, MRI);
+  MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI);
+  MadeChange |= stripWSuffixes(MF, TII, ST, MRI);
 
   return MadeChange;
 }