[mlir][dataflow] Unify dependency management in AnalysisState.
authorZhixun Tan <phisiart@gmail.com>
Mon, 3 Jul 2023 19:20:18 +0000 (12:20 -0700)
committerJeff Niu <jeff@modular.com>
Mon, 3 Jul 2023 19:20:52 +0000 (12:20 -0700)
In the MLIR dataflow analysis framework, when an `AnalysisState` is updated, it's dependents are enqueued to be visited.

Currently, there are two ways dependents are managed:

* `AnalysisState::dependents` stores a list of dependents. `DataFlowSolver::propagateIfChanged()` reads this list and enqueues them to the worklist.

* `AnalysisState::onUpdate()` allows custom logic to enqueue more to the worklist. This is called by `DataFlowSolver::propagateIfChanged()`.

This cleanup diff consolidates the two into `AnalysisState::onUpdate()`. This way, `DataFlowSolver` does not need to know the detail about `AnalysisState::dependents`, and the logic of dependency management is entirely handled by `AnalysisState`.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D154170

mlir/include/mlir/Analysis/DataFlowFramework.h
mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
mlir/lib/Analysis/DataFlowFramework.cpp

index 9649f91..7b97ea4 100644 (file)
@@ -235,12 +235,6 @@ public:
   /// dependent work items to the back of the queue.
   void propagateIfChanged(AnalysisState *state, ChangeResult changed);
 
-  /// Add a dependency to an analysis state on a child analysis and program
-  /// point. If the state is updated, the child analysis must be invoked on the
-  /// given program point again.
-  void addDependency(AnalysisState *state, DataFlowAnalysis *analysis,
-                     ProgramPoint point);
-
 private:
   /// The solver's work queue. Work items can be inserted to the front of the
   /// queue to be processed greedily, speeding up computations that otherwise
@@ -294,13 +288,30 @@ public:
   /// Print the contents of the analysis state.
   virtual void print(raw_ostream &os) const = 0;
 
+  /// Add a dependency to this analysis state on a program point and an
+  /// analysis. If this state is updated, the analysis will be invoked on the
+  /// given program point again (in onUpdate()).
+  void addDependency(ProgramPoint dependent, DataFlowAnalysis *analysis);
+
 protected:
   /// This function is called by the solver when the analysis state is updated
-  /// to optionally enqueue more work items. For example, if a state tracks
-  /// dependents through the IR (e.g. use-def chains), this function can be
-  /// implemented to push those dependents on the worklist.
-  virtual void onUpdate(DataFlowSolver *solver) const {}
+  /// to enqueue more work items. For example, if a state tracks dependents
+  /// through the IR (e.g. use-def chains), this function can be implemented to
+  /// push those dependents on the worklist.
+  virtual void onUpdate(DataFlowSolver *solver) const {
+    for (const DataFlowSolver::WorkItem &item : dependents)
+      solver->enqueue(item);
+  }
+
+  /// The program point to which the state belongs.
+  ProgramPoint point;
+
+#if LLVM_ENABLE_ABI_BREAKING_CHECKS
+  /// When compiling with debugging, keep a name for the analysis state.
+  StringRef debugName;
+#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
 
+private:
   /// The dependency relations originating from this analysis state. An entry
   /// `state -> (analysis, point)` is created when `analysis` queries `state`
   /// when updating `point`.
@@ -312,14 +323,6 @@ protected:
   /// Store the dependents on the analysis state for efficiency.
   SetVector<DataFlowSolver::WorkItem> dependents;
 
-  /// The program point to which the state belongs.
-  ProgramPoint point;
-
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
-  /// When compiling with debugging, keep a name for the analysis state.
-  StringRef debugName;
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
-
   /// Allow the framework to access the dependents.
   friend class DataFlowSolver;
 };
index d681604..30a2850 100644 (file)
@@ -8,6 +8,7 @@
 
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include <optional>
@@ -31,6 +32,8 @@ void Executable::print(raw_ostream &os) const {
 }
 
 void Executable::onUpdate(DataFlowSolver *solver) const {
+  AnalysisState::onUpdate(solver);
+
   if (auto *block = llvm::dyn_cast_if_present<Block *>(point)) {
     // Re-invoke the analyses on the block itself.
     for (DataFlowAnalysis *analysis : subscribers)
index f5cf866..3f2a69e 100644 (file)
@@ -8,6 +8,7 @@
 
 #include "mlir/Analysis/DataFlow/SparseAnalysis.h"
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
+#include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/Interfaces/CallInterfaces.h"
 
 using namespace mlir;
@@ -18,6 +19,8 @@ using namespace mlir::dataflow;
 //===----------------------------------------------------------------------===//
 
 void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const {
+  AnalysisState::onUpdate(solver);
+
   // Push all users of the value to the queue.
   for (Operation *user : point.get<Value>().getUsers())
     for (DataFlowAnalysis *analysis : useDefSubscribers)
index 47caf26..6f9168c 100644 (file)
@@ -30,6 +30,19 @@ GenericProgramPoint::~GenericProgramPoint() = default;
 
 AnalysisState::~AnalysisState() = default;
 
+void AnalysisState::addDependency(ProgramPoint dependent,
+                                  DataFlowAnalysis *analysis) {
+  auto inserted = dependents.insert({dependent, analysis});
+  (void)inserted;
+  DATAFLOW_DEBUG({
+    if (inserted) {
+      llvm::dbgs() << "Creating dependency between " << debugName << " of "
+                   << point << "\nand " << debugName << " on " << dependent
+                   << "\n";
+    }
+  });
+}
+
 //===----------------------------------------------------------------------===//
 // ProgramPoint
 //===----------------------------------------------------------------------===//
@@ -97,26 +110,10 @@ void DataFlowSolver::propagateIfChanged(AnalysisState *state,
     DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
                                 << " of " << state->point << "\n"
                                 << "Value: " << *state << "\n");
-    for (const WorkItem &item : state->dependents)
-      enqueue(item);
     state->onUpdate(this);
   }
 }
 
-void DataFlowSolver::addDependency(AnalysisState *state,
-                                   DataFlowAnalysis *analysis,
-                                   ProgramPoint point) {
-  auto inserted = state->dependents.insert({point, analysis});
-  (void)inserted;
-  DATAFLOW_DEBUG({
-    if (inserted) {
-      llvm::dbgs() << "Creating dependency between " << state->debugName
-                   << " of " << state->point << "\nand " << analysis->debugName
-                   << " on " << point << "\n";
-    }
-  });
-}
-
 //===----------------------------------------------------------------------===//
 // DataFlowAnalysis
 //===----------------------------------------------------------------------===//
@@ -126,7 +123,7 @@ DataFlowAnalysis::~DataFlowAnalysis() = default;
 DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver &solver) : solver(solver) {}
 
 void DataFlowAnalysis::addDependency(AnalysisState *state, ProgramPoint point) {
-  solver.addDependency(state, this, point);
+  state->addDependency(point, this);
 }
 
 void DataFlowAnalysis::propagateIfChanged(AnalysisState *state,