[AMDGPU] Simplify loops in SIInsertWaitcnts::generateWaitcntInstBefore
authorJay Foad <jay.foad@amd.com>
Tue, 28 Apr 2020 07:54:19 +0000 (08:54 +0100)
committerJay Foad <jay.foad@amd.com>
Thu, 30 Apr 2020 07:53:12 +0000 (08:53 +0100)
The loops over use operands and def operands were mostly identical.
Combine them, and likewise for load memoperands and store memoperands.
NFC.

llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp

index 66e2c40..5fe0a82 100644 (file)
@@ -948,7 +948,19 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(
       // emitted.
       // If the source operand was defined by a load, add the s_waitcnt
       // instruction.
+      //
+      // Two cases are handled for destination operands:
+      // 1) If the destination operand was defined by a load, add the s_waitcnt
+      // instruction to guarantee the right WAW order.
+      // 2) If a destination operand that was used by a recent export/store ins,
+      // add s_waitcnt on exp_cnt to guarantee the WAR order.
       for (const MachineMemOperand *Memop : MI.memoperands()) {
+        const Value *Ptr = Memop->getValue();
+        if (Memop->isStore() && SLoadAddresses.count(Ptr)) {
+          addWait(Wait, LGKM_CNT, 0);
+          if (PDT->dominates(MI.getParent(), SLoadAddresses.find(Ptr)->second))
+            SLoadAddresses.erase(Ptr);
+        }
         unsigned AS = Memop->getAddrSpace();
         if (AS != AMDGPUAS::LOCAL_ADDRESS)
           continue;
@@ -956,70 +968,32 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(
         // VM_CNT is only relevant to vgpr or LDS.
         ScoreBrackets.determineWait(
             VM_CNT, ScoreBrackets.getRegScore(RegNo, VM_CNT), Wait);
-      }
-
-      for (unsigned I = 0, E = MI.getNumOperands(); I != E; ++I) {
-        const MachineOperand &Op = MI.getOperand(I);
-        if (!Op.isReg())
-          continue;
-        RegInterval Interval =
-            ScoreBrackets.getRegInterval(&MI, TII, MRI, TRI, I);
-        for (signed RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
-          if (TRI->isVGPR(*MRI, Op.getReg())) {
-            // VM_CNT is only relevant to vgpr or LDS.
-            ScoreBrackets.determineWait(
-                VM_CNT, ScoreBrackets.getRegScore(RegNo, VM_CNT), Wait);
-          }
-          ScoreBrackets.determineWait(
-              LGKM_CNT, ScoreBrackets.getRegScore(RegNo, LGKM_CNT), Wait);
-        }
-      }
-      // End of for loop that looks at all source operands to decide vm_wait_cnt
-      // and lgk_wait_cnt.
-
-      // Two cases are handled for destination operands:
-      // 1) If the destination operand was defined by a load, add the s_waitcnt
-      // instruction to guarantee the right WAW order.
-      // 2) If a destination operand that was used by a recent export/store ins,
-      // add s_waitcnt on exp_cnt to guarantee the WAR order.
-      if (MI.mayStore()) {
-        // FIXME: Should not be relying on memoperands.
-        for (const MachineMemOperand *Memop : MI.memoperands()) {
-          const Value *Ptr = Memop->getValue();
-          if (SLoadAddresses.count(Ptr)) {
-            addWait(Wait, LGKM_CNT, 0);
-            if (PDT->dominates(MI.getParent(),
-                               SLoadAddresses.find(Ptr)->second))
-              SLoadAddresses.erase(Ptr);
-          }
-          unsigned AS = Memop->getAddrSpace();
-          if (AS != AMDGPUAS::LOCAL_ADDRESS)
-            continue;
-          unsigned RegNo = SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS;
-          ScoreBrackets.determineWait(
-              VM_CNT, ScoreBrackets.getRegScore(RegNo, VM_CNT), Wait);
+        if (Memop->isStore()) {
           ScoreBrackets.determineWait(
               EXP_CNT, ScoreBrackets.getRegScore(RegNo, EXP_CNT), Wait);
         }
       }
 
+      // Loop over use and def operands.
       for (unsigned I = 0, E = MI.getNumOperands(); I != E; ++I) {
-        MachineOperand &Def = MI.getOperand(I);
-        if (!Def.isReg() || !Def.isDef())
+        MachineOperand &Op = MI.getOperand(I);
+        if (!Op.isReg())
           continue;
         RegInterval Interval =
             ScoreBrackets.getRegInterval(&MI, TII, MRI, TRI, I);
         for (signed RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
-          if (TRI->isVGPR(*MRI, Def.getReg())) {
+          if (TRI->isVGPR(*MRI, Op.getReg())) {
             ScoreBrackets.determineWait(
                 VM_CNT, ScoreBrackets.getRegScore(RegNo, VM_CNT), Wait);
-            ScoreBrackets.determineWait(
-                EXP_CNT, ScoreBrackets.getRegScore(RegNo, EXP_CNT), Wait);
+            if (Op.isDef()) {
+              ScoreBrackets.determineWait(
+                  EXP_CNT, ScoreBrackets.getRegScore(RegNo, EXP_CNT), Wait);
+            }
           }
           ScoreBrackets.determineWait(
               LGKM_CNT, ScoreBrackets.getRegScore(RegNo, LGKM_CNT), Wait);
         }
-      } // End of for loop that looks at all dest operands.
+      }
     }
   }