[RISCV] Rework hasAllWUsers in RISCVSExtWRemoval. NFCI
authorCraig Topper <craig.topper@sifive.com>
Wed, 9 Nov 2022 07:40:15 +0000 (23:40 -0800)
committerCraig Topper <craig.topper@sifive.com>
Wed, 9 Nov 2022 19:32:19 +0000 (11:32 -0800)
Instead of storing the uses to check in the worklist, store the
instruction we want to check uses for.

Now we pop and instruction from the worklist, loop over its uses
and check them. If it's something we need to look across, we'll push
it to the worklist.

By doing it this way, we can have access to which operand
of the user is using the instruction. This will allow supporting
store instructions since we'll be able to disambiguate the the value
operand and the pointer operand. We can also improve support for
*add.uw instructions and shift amount uses.

Reviewed By: mohammed-nurulhoque, asb

Differential Revision: https://reviews.llvm.org/D137446

llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp

index 14d7afb..d8791e8 100644 (file)
@@ -58,18 +58,6 @@ FunctionPass *llvm::createRISCVSExtWRemovalPass() {
   return new RISCVSExtWRemoval();
 }
 
-// add uses of MI to the Worklist
-static void addUses(const MachineInstr &MI,
-                    SmallVectorImpl<const MachineInstr *> &Worklist,
-                    MachineRegisterInfo &MRI) {
-  for (auto &UserOp : MRI.reg_operands(MI.getOperand(0).getReg())) {
-    const auto *User = UserOp.getParent();
-    if (User == &MI) // ignore the def, current MI
-      continue;
-    Worklist.push_back(User);
-  }
-}
-
 // returns true if all uses of OrigMI only depend on the lower word of its
 // output, so we can transform OrigMI to the corresponding W-version.
 // TODO: handle multiple interdependent transformations
@@ -78,114 +66,115 @@ static bool hasAllWUsers(const MachineInstr &OrigMI, MachineRegisterInfo &MRI) {
   SmallPtrSet<const MachineInstr *, 4> Visited;
   SmallVector<const MachineInstr *, 4> Worklist;
 
-  Visited.insert(&OrigMI);
-  addUses(OrigMI, Worklist, MRI);
+  Worklist.push_back(&OrigMI);
 
   while (!Worklist.empty()) {
     const MachineInstr *MI = Worklist.pop_back_val();
 
-    if (!Visited.insert(MI).second) {
-      // If we've looped back to OrigMI through a PHI cycle, we can't transform
-      // LD or LWU, because these operations use all 64 bits of input.
-      if (MI == &OrigMI) {
-        unsigned opcode = MI->getOpcode();
-        if (opcode == RISCV::LD || opcode == RISCV::LWU)
-          return false;
-      }
+    if (!Visited.insert(MI).second)
       continue;
-    }
 
-    switch (MI->getOpcode()) {
-    case RISCV::ADDIW:
-    case RISCV::ADDW:
-    case RISCV::DIVUW:
-    case RISCV::DIVW:
-    case RISCV::MULW:
-    case RISCV::REMUW:
-    case RISCV::REMW:
-    case RISCV::SLLIW:
-    case RISCV::SLLW:
-    case RISCV::SRAIW:
-    case RISCV::SRAW:
-    case RISCV::SRLIW:
-    case RISCV::SRLW:
-    case RISCV::SUBW:
-    case RISCV::ROLW:
-    case RISCV::RORW:
-    case RISCV::RORIW:
-    case RISCV::CLZW:
-    case RISCV::CTZW:
-    case RISCV::CPOPW:
-    case RISCV::SLLI_UW:
-    case RISCV::FMV_H_X:
-    case RISCV::FMV_W_X:
-    case RISCV::FCVT_H_W:
-    case RISCV::FCVT_H_WU:
-    case RISCV::FCVT_S_W:
-    case RISCV::FCVT_S_WU:
-    case RISCV::FCVT_D_W:
-    case RISCV::FCVT_D_WU:
-    case RISCV::SEXT_B:
-    case RISCV::SEXT_H:
-    case RISCV::ZEXT_H_RV64:
-      continue;
+    // Only handle instructions with one def.
+    if (MI->getNumExplicitDefs() != 1)
+      return false;
 
-    // these overwrite higher input bits, otherwise the lower word of output
-    // depends only on the lower word of input. So check their uses read W.
-    case RISCV::SLLI:
-      if (MI->getOperand(2).getImm() >= 32)
-        continue;
-      addUses(*MI, Worklist, MRI);
-      continue;
-    case RISCV::ANDI:
-      if (isUInt<11>(MI->getOperand(2).getImm()))
-        continue;
-      addUses(*MI, Worklist, MRI);
-      continue;
-    case RISCV::ORI:
-      if (!isUInt<11>(MI->getOperand(2).getImm()))
-        continue;
-      addUses(*MI, Worklist, MRI);
-      continue;
+    for (auto &UserOp : MRI.use_operands(MI->getOperand(0).getReg())) {
+      const MachineInstr *UserMI = UserOp.getParent();
 
-    case RISCV::BEXTI:
-      if (MI->getOperand(2).getImm() >= 32)
+      switch (UserMI->getOpcode()) {
+      default:
         return false;
-      continue;
 
-    // For these, lower word of output in these operations, depends only on
-    // the lower word of input. So, we check all uses only read lower word.
-    case RISCV::COPY:
-    case RISCV::PHI:
-
-    case RISCV::ADD:
-    case RISCV::ADDI:
-    case RISCV::AND:
-    case RISCV::MUL:
-    case RISCV::OR:
-    case RISCV::SLL:
-    case RISCV::SUB:
-    case RISCV::XOR:
-    case RISCV::XORI:
-
-    case RISCV::ADD_UW:
-    case RISCV::ANDN:
-    case RISCV::CLMUL:
-    case RISCV::ORC_B:
-    case RISCV::ORN:
-    case RISCV::SH1ADD:
-    case RISCV::SH1ADD_UW:
-    case RISCV::SH2ADD:
-    case RISCV::SH2ADD_UW:
-    case RISCV::SH3ADD:
-    case RISCV::SH3ADD_UW:
-    case RISCV::XNOR:
-      addUses(*MI, Worklist, MRI);
-      continue;
-    default:
-      return false;
+      case RISCV::ADDIW:
+      case RISCV::ADDW:
+      case RISCV::DIVUW:
+      case RISCV::DIVW:
+      case RISCV::MULW:
+      case RISCV::REMUW:
+      case RISCV::REMW:
+      case RISCV::SLLIW:
+      case RISCV::SLLW:
+      case RISCV::SRAIW:
+      case RISCV::SRAW:
+      case RISCV::SRLIW:
+      case RISCV::SRLW:
+      case RISCV::SUBW:
+      case RISCV::ROLW:
+      case RISCV::RORW:
+      case RISCV::RORIW:
+      case RISCV::CLZW:
+      case RISCV::CTZW:
+      case RISCV::CPOPW:
+      case RISCV::SLLI_UW:
+      case RISCV::FMV_H_X:
+      case RISCV::FMV_W_X:
+      case RISCV::FCVT_H_W:
+      case RISCV::FCVT_H_WU:
+      case RISCV::FCVT_S_W:
+      case RISCV::FCVT_S_WU:
+      case RISCV::FCVT_D_W:
+      case RISCV::FCVT_D_WU:
+      case RISCV::SEXT_B:
+      case RISCV::SEXT_H:
+      case RISCV::ZEXT_H_RV64:
+        break;
+
+      // these overwrite higher input bits, otherwise the lower word of output
+      // depends only on the lower word of input. So check their uses read W.
+      case RISCV::SLLI:
+        if (UserMI->getOperand(2).getImm() >= 32)
+          break;
+        Worklist.push_back(UserMI);
+        break;
+      case RISCV::ANDI:
+        if (isUInt<11>(UserMI->getOperand(2).getImm()))
+          break;
+        Worklist.push_back(UserMI);
+        break;
+      case RISCV::ORI:
+        if (!isUInt<11>(UserMI->getOperand(2).getImm()))
+          break;
+        Worklist.push_back(UserMI);
+        break;
+
+      case RISCV::BEXTI:
+        if (UserMI->getOperand(2).getImm() >= 32)
+          return false;
+        break;
+
+      // For these, lower word of output in these operations, depends only on
+      // the lower word of input. So, we check all uses only read lower word.
+      case RISCV::COPY:
+      case RISCV::PHI:
+
+      case RISCV::ADD:
+      case RISCV::ADDI:
+      case RISCV::AND:
+      case RISCV::MUL:
+      case RISCV::OR:
+      case RISCV::SLL:
+      case RISCV::SUB:
+      case RISCV::XOR:
+      case RISCV::XORI:
+
+      case RISCV::ADD_UW:
+      case RISCV::ANDN:
+      case RISCV::CLMUL:
+      case RISCV::ORC_B:
+      case RISCV::ORN:
+      case RISCV::SH1ADD:
+      case RISCV::SH1ADD_UW:
+      case RISCV::SH2ADD:
+      case RISCV::SH2ADD_UW:
+      case RISCV::SH3ADD:
+      case RISCV::SH3ADD_UW:
+      case RISCV::XNOR:
+        Worklist.push_back(UserMI);
+        break;
+      }
     }
   }
+
   return true;
 }