[Uniformity] Propagate divergence only along divergent outputs.
authorSameer Sahasrabuddhe <sameer.sahasrabuddhe@amd.com>
Wed, 17 May 2023 02:17:43 +0000 (07:47 +0530)
committerSameer Sahasrabuddhe <sameer.sahasrabuddhe@amd.com>
Wed, 17 May 2023 02:17:43 +0000 (07:47 +0530)
When an instruction is determined to be divergent, not all its outputs are
divergent. The users of only divergent outputs should now be examined for
divergence.

Also, replaced a repeating pattern of "if new divergent instruction, then add to
worklist" by combining it into a single function. This does not cause any change
in functionality.

Reviewed By: foad, arsenm

Differential Revision: https://reviews.llvm.org/D150636

llvm/include/llvm/ADT/GenericUniformityImpl.h
llvm/lib/Analysis/UniformityAnalysis.cpp
llvm/lib/CodeGen/MachineUniformityAnalysis.cpp
llvm/test/Analysis/UniformityAnalysis/AMDGPU/MIR/always-uniform-gmir.mir

index 75a33e1..71935d1 100644 (file)
@@ -355,10 +355,15 @@ public:
   /// \brief Mark \p UniVal as a value that is always uniform.
   void addUniformOverride(const InstructionT &Instr);
 
-  /// \brief Mark \p DivVal as a value that is always divergent.
+  /// \brief Examine \p I for divergent outputs and add to the worklist.
+  void markDivergent(const InstructionT &I);
+
+  /// \brief Mark \p DivVal as a divergent value.
   /// \returns Whether the tracked divergence state of \p DivVal changed.
-  bool markDivergent(const InstructionT &I);
   bool markDivergent(ConstValueRefT DivVal);
+
+  /// \brief Mark outputs of \p Instr as divergent.
+  /// \returns Whether the tracked divergence state of any output has changed.
   bool markDefsDivergent(const InstructionT &Instr);
 
   /// \brief Propagate divergence to all instructions in the region.
@@ -774,21 +779,23 @@ auto llvm::GenericSyncDependenceAnalysis<ContextT>::getJoinBlocks(
 }
 
 template <typename ContextT>
-bool GenericUniformityAnalysisImpl<ContextT>::markDivergent(
+void GenericUniformityAnalysisImpl<ContextT>::markDivergent(
     const InstructionT &I) {
+  if (isAlwaysUniform(I))
+    return;
+  bool Marked = false;
   if (I.isTerminator()) {
-    if (DivergentTermBlocks.insert(I.getParent()).second) {
+    Marked = DivergentTermBlocks.insert(I.getParent()).second;
+    if (Marked) {
       LLVM_DEBUG(dbgs() << "marked divergent term block: "
                         << Context.print(I.getParent()) << "\n");
-      return true;
     }
-    return false;
+  } else {
+    Marked = markDefsDivergent(I);
   }
 
-  if (isAlwaysUniform(I))
-    return false;
-
-  return markDefsDivergent(I);
+  if (Marked)
+    Worklist.push_back(&I);
 }
 
 template <typename ContextT>
@@ -828,8 +835,7 @@ void GenericUniformityAnalysisImpl<ContextT>::analyzeCycleExitDivergence(
   for (auto *Exit : Exits) {
     for (auto &Phi : Exit->phis()) {
       if (usesValueFromCycle(Phi, DefCycle)) {
-        if (markDivergent(Phi))
-          Worklist.push_back(&Phi);
+        markDivergent(Phi);
       }
     }
   }
@@ -889,8 +895,7 @@ void GenericUniformityAnalysisImpl<ContextT>::taintAndPushAllDefs(
     if (I.isTerminator())
       break;
 
-    if (markDivergent(I))
-      Worklist.push_back(&I);
+    markDivergent(I);
   }
 }
 
@@ -910,8 +915,7 @@ void GenericUniformityAnalysisImpl<ContextT>::taintAndPushPhiNodes(
     // https://reviews.llvm.org/D19013
     if (ContextT::isConstantOrUndefValuePhi(Phi))
       continue;
-    if (markDivergent(Phi))
-      Worklist.push_back(&Phi);
+    markDivergent(Phi);
   }
 }
 
index fad88bb..60d6bb8 100644 (file)
@@ -27,7 +27,7 @@ bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
 template <>
 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
     const Instruction &Instr) {
-  return markDivergent(&Instr);
+  return markDivergent(cast<Value>(&Instr));
 }
 
 template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
@@ -49,9 +49,7 @@ void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
     const Value *V) {
   for (const auto *User : V->users()) {
     if (const auto *UserInstr = dyn_cast<const Instruction>(User)) {
-      if (markDivergent(*UserInstr)) {
-        Worklist.push_back(UserInstr);
-      }
+      markDivergent(*UserInstr);
     }
   }
 }
@@ -88,8 +86,7 @@ void llvm::GenericUniformityAnalysisImpl<
     auto *UserInstr = cast<Instruction>(User);
     if (DefCycle.contains(UserInstr->getParent()))
       continue;
-    if (markDivergent(*UserInstr))
-      Worklist.push_back(UserInstr);
+    markDivergent(*UserInstr);
   }
 }
 
index 693c64e..cc8cdaf 100644 (file)
@@ -62,8 +62,7 @@ void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
       }
 
       if (uniformity == InstructionUniformity::NeverUniform) {
-        if (markDivergent(instr))
-          Worklist.push_back(&instr);
+        markDivergent(instr);
       }
     }
   }
@@ -72,10 +71,10 @@ void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
 template <>
 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
     Register Reg) {
+  assert(isDivergent(Reg));
   const auto &RegInfo = F.getRegInfo();
   for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
-    if (markDivergent(UserInstr))
-      Worklist.push_back(&UserInstr);
+    markDivergent(UserInstr);
   }
 }
 
@@ -86,8 +85,11 @@ void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
   if (Instr.isTerminator())
     return;
   for (const MachineOperand &op : Instr.operands()) {
-    if (op.isReg() && op.isDef() && op.getReg().isVirtual())
-      pushUsers(op.getReg());
+    if (!op.isReg() || !op.isDef())
+      continue;
+    auto Reg = op.getReg();
+    if (isDivergent(Reg))
+      pushUsers(Reg);
   }
 }
 
@@ -128,8 +130,7 @@ void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::
     for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
       if (DefCycle.contains(UserInstr.getParent()))
         continue;
-      if (markDivergent(UserInstr))
-        Worklist.push_back(&UserInstr);
+      markDivergent(UserInstr);
     }
   }
 }
index bae9717..6a0b5bb 100644 (file)
@@ -97,7 +97,6 @@ body:             |
 
 ...
 
-# FIXME :: BELOW INLINE ASM SHOULD BE DIVERGENT
 ---
 name:            asm_mixed_sgpr_vgpr
 registers:
@@ -116,7 +115,9 @@ body:             |
     ; CHECK-LABEL: MachineUniformityInfo for function: asm_mixed_sgpr_vgpr
     ; CHECK: DIVERGENT: %0:
     ; CHECK: DIVERGENT: %3:
+    ; CHECK-NOT: DIVERGENT: %1:
     ; CHECK: DIVERGENT: %2:
+    ; CHECK-NOT: DIVERGENT: %4:
     ; CHECK: DIVERGENT: %5:
     %0:_(s32) = COPY $vgpr0
     %6:_(p1) = G_IMPLICIT_DEF