Generalize the analysis manager framework to work on any operation at any nesting.
authorRiver Riddle <riverriddle@google.com>
Wed, 28 Aug 2019 22:10:37 +0000 (15:10 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 28 Aug 2019 22:11:17 +0000 (15:11 -0700)
The pass manager is moving towards being able to run on operations at arbitrary nesting. An operation may have both parent and child operations, and the AnalysisManager must be able to handle this generalization. The AnalysisManager class now contains generic 'getCachedParentAnalysis' and 'getChildAnalysis/getCachedChildAnalysis' functions to query analyses on parent/child operations. This removes the hard coded nesting relationship between Module/Function.

PiperOrigin-RevId: 266003636

mlir/g3doc/WritingAPass.md
mlir/include/mlir/Pass/AnalysisManager.h
mlir/include/mlir/Pass/Pass.h
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassDetail.h
mlir/unittests/Pass/AnalysisManagerTest.cpp

index 47e57df..61b2226 100644 (file)
@@ -119,12 +119,12 @@ analyses are not passes but free standing classes that are computed lazily
 on-demand and cached to avoid unnecessary recomputation. An analysis in MLIR
 must adhere to the following:
 
-*   Provide a valid constructor for a given IR unit.
-*   Must not modify the given IR unit.
+*   Provide a valid constructor taking an `Operation*`.
+*   Must not modify the given operation.
 
 Each of the base Pass classes provide utilities for querying and preserving
-analyses for the current IR being processed. Using the example passes defined
-above, let's see some examples:
+analyses for the current operation being processed. Using the example passes
+defined above, let's see some examples:
 
 ### Querying Analyses
 
@@ -140,16 +140,10 @@ above, let's see some examples:
     *   `getCachedFunctionAnalysis<>`
 
 ```c++
-/// An interesting function analysis.
-struct MyFunctionAnalysis {
-  // Compute this analysis with the provided function.
-  MyFunctionAnalysis(FuncOp function);
-};
-
-/// An interesting module analysis.
-struct MyModuleAnalysis {
-  // Compute this analysis with the provided module.
-  MyModuleAnalysis(ModuleOp module);
+/// An interesting analysis.
+struct MyAnalysis {
+  // Compute this analysis with the provided operation.
+  MyFunctionAnalysis(Operation *op);
 };
 
 void MyFunctionPass::runOnFunction() {
index ae98831..163ecf6 100644 (file)
@@ -18,7 +18,6 @@
 #ifndef MLIR_PASS_ANALYSISMANAGER_H
 #define MLIR_PASS_ANALYSISMANAGER_H
 
-#include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
 #include "mlir/Pass/PassInstrumentation.h"
 #include "mlir/Support/LLVM.h"
@@ -91,9 +90,9 @@ template <typename AnalysisT> struct AnalysisModel : public AnalysisConcept {
   AnalysisT analysis;
 };
 
-/// This class represents a cache of analyses for a single IR unit. All
+/// This class represents a cache of analyses for a single operation. All
 /// computation, caching, and invalidation of analyses takes place here.
-template <typename IRUnitT> class AnalysisMap {
+class AnalysisMap {
   /// A mapping between an analysis id and an existing analysis instance.
   using ConceptMap =
       llvm::DenseMap<const AnalysisID *, std::unique_ptr<AnalysisConcept>>;
@@ -107,7 +106,7 @@ template <typename IRUnitT> class AnalysisMap {
   }
 
 public:
-  explicit AnalysisMap(IRUnitT ir) : ir(ir) {}
+  explicit AnalysisMap(Operation *ir) : ir(ir) {}
 
   /// Get an analysis for the current IR unit, computing it if necessary.
   template <typename AnalysisT> AnalysisT &getAnalysis(PassInstrumentor *pi) {
@@ -140,9 +139,8 @@ public:
     return {static_cast<AnalysisModel<AnalysisT> &>(*res->second).analysis};
   }
 
-  /// Returns the IR unit that this analysis map represents.
-  IRUnitT getIRUnit() { return ir; }
-  const IRUnitT getIRUnit() const { return ir; }
+  /// Returns the operation that this analysis map represents.
+  Operation *getOperation() const { return ir; }
 
   /// Clear any held analyses.
   void clear() { analyses.clear(); }
@@ -159,10 +157,27 @@ public:
   }
 
 private:
-  IRUnitT ir;
+  Operation *ir;
   ConceptMap analyses;
 };
 
+/// An analysis map that contains a map for the current operation, and a set of
+/// maps for any child operations.
+struct NestedAnalysisMap {
+  NestedAnalysisMap(Operation *op) : analyses(op) {}
+
+  /// Get the operation for this analysis map.
+  Operation *getOperation() const { return analyses.getOperation(); }
+
+  /// Invalidate any non preserved analyses.
+  void invalidate(const detail::PreservedAnalyses &pa);
+
+  /// The cached analyses for nested operations.
+  llvm::DenseMap<Operation *, std::unique_ptr<NestedAnalysisMap>> childAnalyses;
+
+  /// The analyses for the owning module.
+  detail::AnalysisMap analyses;
+};
 } // namespace detail
 
 //===----------------------------------------------------------------------===//
@@ -170,124 +185,119 @@ private:
 //===----------------------------------------------------------------------===//
 class ModuleAnalysisManager;
 
-/// An analysis manager for a specific function instance. This class can only be
-/// constructed from a ModuleAnalysisManager instance.
-class FunctionAnalysisManager {
+/// This class represents an analysis manager for a particular operation
+/// instance. It is used to manage and cache analyses on the operation as well
+/// as those for child operations, via nested AnalysisManager instances
+/// accessible via 'slice'. This class is intended to be passed around by value,
+/// and cannot be constructed directly.
+class AnalysisManager {
+  using ParentPointerT = llvm::PointerUnion<const ModuleAnalysisManager *,
+                                            const AnalysisManager *>;
+
 public:
-  // Query for a cached analysis on the parent Module. The analysis may not
-  // exist and if it does it may be stale.
+  // Query for a cached analysis on the given parent operation. The analysis may
+  // not exist and if it does it may be out-of-date.
   template <typename AnalysisT>
   llvm::Optional<std::reference_wrapper<AnalysisT>>
-  getCachedModuleAnalysis() const;
+  getCachedParentAnalysis(Operation *parentOp) const {
+    ParentPointerT curParent = parent;
+    while (auto *parentAM = curParent.dyn_cast<const AnalysisManager *>()) {
+      if (parentAM->impl->getOperation() == parentOp)
+        return parentAM->getCachedAnalysis<AnalysisT>();
+      curParent = parentAM->parent;
+    }
+    return None;
+  }
 
-  // Query for the given analysis for the current function.
+  // Query for the given analysis for the current operation.
   template <typename AnalysisT> AnalysisT &getAnalysis() {
-    return impl->getAnalysis<AnalysisT>(getPassInstrumentor());
+    return impl->analyses.getAnalysis<AnalysisT>(getPassInstrumentor());
   }
 
-  // Query for a cached entry of the given analysis on the current function.
+  // Query for a cached entry of the given analysis on the current operation.
   template <typename AnalysisT>
   llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() const {
-    return impl->getCachedAnalysis<AnalysisT>();
+    return impl->analyses.getCachedAnalysis<AnalysisT>();
   }
 
-  /// Invalidate any non preserved analyses,
-  void invalidate(const detail::PreservedAnalyses &pa) {
-    // If all analyses were preserved, then there is nothing to do here.
-    if (pa.isAll())
-      return;
-    impl->invalidate(pa);
+  /// Query for a analysis of a child operation, constructing it if necessary.
+  template <typename AnalysisT> AnalysisT &getChildAnalysis(Operation *op) {
+    return slice(op).template getAnalysis<AnalysisT>();
+  }
+
+  /// Query for a cached analysis of a child operation, or return null.
+  template <typename AnalysisT>
+  llvm::Optional<std::reference_wrapper<AnalysisT>>
+  getCachedChildAnalysis(Operation *op) const {
+    assert(op->getParentOp() == impl->getOperation());
+    auto it = impl->childAnalyses.find(op);
+    if (it == impl->childAnalyses.end())
+      return llvm::None;
+    return it->second->analyses.getCachedAnalysis<AnalysisT>();
   }
 
+  /// Get an analysis manager for the given child operation.
+  AnalysisManager slice(Operation *op);
+
+  /// Invalidate any non preserved analyses,
+  void invalidate(const detail::PreservedAnalyses &pa) { impl->invalidate(pa); }
+
   /// Clear any held analyses.
-  void clear() { impl->clear(); }
+  void clear() {
+    impl->analyses.clear();
+    impl->childAnalyses.clear();
+  }
 
-  /// Returns a pass instrumentation object for the current function. This value
-  /// may be null.
+  /// Returns a pass instrumentation object for the current operation. This
+  /// value may be null.
   PassInstrumentor *getPassInstrumentor() const;
 
 private:
-  FunctionAnalysisManager(const ModuleAnalysisManager *parent,
-                          detail::AnalysisMap<FuncOp> *impl)
+  AnalysisManager(const AnalysisManager *parent,
+                  detail::NestedAnalysisMap *impl)
+      : parent(parent), impl(impl) {}
+  AnalysisManager(const ModuleAnalysisManager *parent,
+                  detail::NestedAnalysisMap *impl)
       : parent(parent), impl(impl) {}
 
-  /// A reference to the parent analysis manager.
-  const ModuleAnalysisManager *parent;
+  /// A reference to the parent analysis manager, or the top-level module
+  /// analysis manager.
+  llvm::PointerUnion<const ModuleAnalysisManager *, const AnalysisManager *>
+      parent;
 
-  /// A reference to the impl analysis map within the owning analysis manager.
-  detail::AnalysisMap<FuncOp> *impl;
+  /// A reference to the impl analysis map within the parent analysis manager.
+  detail::NestedAnalysisMap *impl;
 
   /// Allow access to the constructor.
   friend class ModuleAnalysisManager;
 };
 
-/// An analysis manager for a specific module instance.
+/// An analysis manager class specifically for the top-level module operation.
+/// This class contains the memory allocations for all nested analysis managers,
+/// and provides an anchor point. This is necessary because AnalysisManager is
+/// designed to be a thin wrapper around an existing analysis map instance.
 class ModuleAnalysisManager {
 public:
   ModuleAnalysisManager(ModuleOp module, PassInstrumentor *passInstrumentor)
-      : moduleAnalyses(module), passInstrumentor(passInstrumentor) {}
+      : analyses(module), passInstrumentor(passInstrumentor) {}
   ModuleAnalysisManager(const ModuleAnalysisManager &) = delete;
   ModuleAnalysisManager &operator=(const ModuleAnalysisManager &) = delete;
 
-  /// Query for the analysis of a function. The analysis is computed if it does
-  /// not exist.
-  template <typename AnalysisT>
-  AnalysisT &getFunctionAnalysis(FuncOp function) {
-    return slice(function).getAnalysis<AnalysisT>();
-  }
-
-  /// Query for a cached analysis of a child function, or return null.
-  template <typename AnalysisT>
-  llvm::Optional<std::reference_wrapper<AnalysisT>>
-  getCachedFunctionAnalysis(FuncOp function) const {
-    auto it = functionAnalyses.find(function);
-    if (it == functionAnalyses.end())
-      return llvm::None;
-    return it->second->getCachedAnalysis<AnalysisT>();
-  }
-
-  /// Query for the analysis for the module. The analysis is computed if it does
-  /// not exist.
-  template <typename AnalysisT> AnalysisT &getAnalysis() {
-    return moduleAnalyses.getAnalysis<AnalysisT>(getPassInstrumentor());
-  }
-
-  /// Query for a cached analysis for the module, or return null.
-  template <typename AnalysisT>
-  llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedAnalysis() const {
-    return moduleAnalyses.getCachedAnalysis<AnalysisT>();
-  }
-
-  /// Create an analysis slice for the given child function.
-  FunctionAnalysisManager slice(FuncOp function);
-
-  /// Invalidate any non preserved analyses.
-  void invalidate(const detail::PreservedAnalyses &pa);
-
   /// Returns a pass instrumentation object for the current module. This value
   /// may be null.
   PassInstrumentor *getPassInstrumentor() const { return passInstrumentor; }
 
-private:
-  /// The cached analyses for functions within the current module.
-  llvm::DenseMap<FuncOp, std::unique_ptr<detail::AnalysisMap<FuncOp>>>
-      functionAnalyses;
+  /// Returns an analysis manager for the current top-level module.
+  operator AnalysisManager() { return AnalysisManager(this, &analyses); }
 
+private:
   /// The analyses for the owning module.
-  detail::AnalysisMap<ModuleOp> moduleAnalyses;
+  detail::NestedAnalysisMap analyses;
 
   /// An optional instrumentation object.
   PassInstrumentor *passInstrumentor;
 };
 
-// Query for a cached analysis on the parent Module. The analysis may not exist
-// and if it does it may be stale.
-template <typename AnalysisT>
-llvm::Optional<std::reference_wrapper<AnalysisT>>
-FunctionAnalysisManager::getCachedModuleAnalysis() const {
-  return parent->getCachedAnalysis<AnalysisT>();
-}
-
 } // end namespace mlir
 
 #endif // MLIR_PASS_ANALYSISMANAGER_H
index 3a3444a..360eaaf 100644 (file)
@@ -18,6 +18,7 @@
 #ifndef MLIR_PASS_PASS_H
 #define MLIR_PASS_PASS_H
 
+#include "mlir/IR/Function.h"
 #include "mlir/Pass/AnalysisManager.h"
 #include "mlir/Pass/PassRegistry.h"
 #include "mlir/Support/LogicalResult.h"
@@ -68,9 +69,8 @@ class ModulePassExecutor;
 
 /// The state for a single execution of a pass. This provides a unified
 /// interface for accessing and initializing necessary state for pass execution.
-template <typename IRUnitT, typename AnalysisManagerT>
-struct PassExecutionState {
-  PassExecutionState(IRUnitT ir, AnalysisManagerT &analysisManager)
+template <typename IRUnitT> struct PassExecutionState {
+  PassExecutionState(IRUnitT ir, AnalysisManager analysisManager)
       : irAndPassFailed(ir, false), analysisManager(analysisManager) {}
 
   /// The current IR unit being transformed and a bool for if the pass signaled
@@ -78,7 +78,7 @@ struct PassExecutionState {
   llvm::PointerIntPair<IRUnitT, 1, bool> irAndPassFailed;
 
   /// The analysis manager for the IR unit.
-  AnalysisManagerT &analysisManager;
+  AnalysisManager analysisManager;
 
   /// The set of preserved analyses for the current execution.
   detail::PreservedAnalyses preservedAnalyses;
@@ -89,8 +89,7 @@ struct PassExecutionState {
 /// not inherit from this class directly, and instead should use the CRTP
 /// FunctionPass class.
 class FunctionPassBase : public Pass {
-  using PassStateT =
-      detail::PassExecutionState<FuncOp, FunctionAnalysisManager>;
+  using PassStateT = detail::PassExecutionState<FuncOp>;
 
 public:
   static bool classof(const Pass *pass) {
@@ -119,14 +118,14 @@ protected:
   }
 
   /// Returns the current analysis manager.
-  FunctionAnalysisManager &getAnalysisManager() {
+  AnalysisManager getAnalysisManager() {
     return getPassState().analysisManager;
   }
 
 private:
   /// Forwarding function to execute this pass.
   LLVM_NODISCARD
-  LogicalResult run(FuncOp fn, FunctionAnalysisManager &fam);
+  LogicalResult run(FuncOp fn, AnalysisManager am);
 
   /// The current execution state for the pass.
   llvm::Optional<PassStateT> passState;
@@ -138,8 +137,7 @@ private:
 /// Pass to transform a module. Derived passes should not inherit from this
 /// class directly, and instead should use the CRTP ModulePass class.
 class ModulePassBase : public Pass {
-  using PassStateT =
-      detail::PassExecutionState<ModuleOp, ModuleAnalysisManager>;
+  using PassStateT = detail::PassExecutionState<ModuleOp>;
 
 public:
   static bool classof(const Pass *pass) {
@@ -165,14 +163,14 @@ protected:
   }
 
   /// Returns the current analysis manager.
-  ModuleAnalysisManager &getAnalysisManager() {
+  AnalysisManager getAnalysisManager() {
     return getPassState().analysisManager;
   }
 
 private:
   /// Forwarding function to execute this pass.
   LLVM_NODISCARD
-  LogicalResult run(ModuleOp module, ModuleAnalysisManager &mam);
+  LogicalResult run(ModuleOp module, AnalysisManager am);
 
   /// The current execution state for the pass.
   llvm::Optional<PassStateT> passState;
@@ -255,7 +253,8 @@ struct FunctionPass : public detail::PassModel<FuncOp, T, FunctionPassBase> {
   template <typename AnalysisT>
   llvm::Optional<std::reference_wrapper<AnalysisT>> getCachedModuleAnalysis() {
     return this->getAnalysisManager()
-        .template getCachedModuleAnalysis<AnalysisT>();
+        .template getCachedParentAnalysis<AnalysisT>(
+            this->getFunction().getParentOp());
   }
 
   /// A clone method to create a copy of this pass.
@@ -272,8 +271,7 @@ template <typename T>
 struct ModulePass : public detail::PassModel<ModuleOp, T, ModulePassBase> {
   /// Returns the analysis for a child function.
   template <typename AnalysisT> AnalysisT &getFunctionAnalysis(FuncOp f) {
-    return this->getAnalysisManager().template getFunctionAnalysis<AnalysisT>(
-        f);
+    return this->getAnalysisManager().template getChildAnalysis<AnalysisT>(f);
   }
 
   /// Returns an existing analysis for a child function if it exists.
@@ -281,7 +279,7 @@ struct ModulePass : public detail::PassModel<ModuleOp, T, ModulePassBase> {
   llvm::Optional<std::reference_wrapper<AnalysisT>>
   getCachedFunctionAnalysis(FuncOp f) {
     return this->getAnalysisManager()
-        .template getCachedFunctionAnalysis<AnalysisT>(f);
+        .template getCachedChildAnalysis<AnalysisT>(f);
   }
 };
 } // end namespace mlir
index 13f2738..0892aa0 100644 (file)
@@ -42,12 +42,12 @@ using namespace mlir::detail;
 void Pass::anchor() {}
 
 /// Forwarding function to execute this pass.
-LogicalResult FunctionPassBase::run(FuncOp fn, FunctionAnalysisManager &fam) {
+LogicalResult FunctionPassBase::run(FuncOp fn, AnalysisManager am) {
   // Initialize the pass state.
-  passState.emplace(fn, fam);
+  passState.emplace(fn, am);
 
   // Instrument before the pass has run.
-  auto pi = fam.getPassInstrumentor();
+  auto pi = am.getPassInstrumentor();
   if (pi)
     pi->runBeforePass(this, fn);
 
@@ -55,7 +55,7 @@ LogicalResult FunctionPassBase::run(FuncOp fn, FunctionAnalysisManager &fam) {
   runOnFunction();
 
   // Invalidate any non preserved analyses.
-  fam.invalidate(passState->preservedAnalyses);
+  am.invalidate(passState->preservedAnalyses);
 
   // Instrument after the pass has run.
   bool passFailed = passState->irAndPassFailed.getInt();
@@ -71,12 +71,12 @@ LogicalResult FunctionPassBase::run(FuncOp fn, FunctionAnalysisManager &fam) {
 }
 
 /// Forwarding function to execute this pass.
-LogicalResult ModulePassBase::run(ModuleOp module, ModuleAnalysisManager &mam) {
+LogicalResult ModulePassBase::run(ModuleOp module, AnalysisManager am) {
   // Initialize the pass state.
-  passState.emplace(module, mam);
+  passState.emplace(module, am);
 
   // Instrument before the pass has run.
-  auto pi = mam.getPassInstrumentor();
+  auto pi = am.getPassInstrumentor();
   if (pi)
     pi->runBeforePass(this, module);
 
@@ -84,7 +84,7 @@ LogicalResult ModulePassBase::run(ModuleOp module, ModuleAnalysisManager &mam) {
   runOnModule();
 
   // Invalidate any non preserved analyses.
-  mam.invalidate(passState->preservedAnalyses);
+  am.invalidate(passState->preservedAnalyses);
 
   // Instrument after the pass has run.
   bool passFailed = passState->irAndPassFailed.getInt();
@@ -111,20 +111,20 @@ FunctionPassExecutor::FunctionPassExecutor(const FunctionPassExecutor &rhs)
 
 /// Run all of the passes in this manager over the current function.
 LogicalResult detail::FunctionPassExecutor::run(FuncOp function,
-                                                FunctionAnalysisManager &fam) {
+                                                AnalysisManager am) {
   // Run each of the held passes.
   for (auto &pass : passes)
-    if (failed(pass->run(function, fam)))
+    if (failed(pass->run(function, am)))
       return failure();
   return success();
 }
 
 /// Run all of the passes in this manager over the current module.
 LogicalResult detail::ModulePassExecutor::run(ModuleOp module,
-                                              ModuleAnalysisManager &mam) {
+                                              AnalysisManager am) {
   // Run each of the held passes.
   for (auto &pass : passes)
-    if (failed(pass->run(module, mam)))
+    if (failed(pass->run(module, am)))
       return failure();
   return success();
 }
@@ -136,44 +136,44 @@ LogicalResult detail::ModulePassExecutor::run(ModuleOp module,
 /// Utility to run the given function and analysis manager on a provided
 /// function pass executor.
 static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe, FuncOp func,
-                                         FunctionAnalysisManager &fam) {
+                                         AnalysisManager am) {
   // Run the function pipeline over the provided function.
-  auto result = fpe.run(func, fam);
+  auto result = fpe.run(func, am);
 
   // Clear out any computed function analyses. These analyses won't be used
   // any more in this pipeline, and this helps reduce the current working set
   // of memory. If preserving these analyses becomes important in the future
   // we can re-evalutate this.
-  fam.clear();
+  am.clear();
   return result;
 }
 
 /// Run the held function pipeline over all non-external functions within the
 /// module.
 void ModuleToFunctionPassAdaptor::runOnModule() {
-  ModuleAnalysisManager &mam = getAnalysisManager();
+  AnalysisManager am = getAnalysisManager();
   for (auto func : getModule().getOps<FuncOp>()) {
     // Skip external functions.
     if (func.isExternal())
       continue;
 
     // Run the held function pipeline over the current function.
-    auto fam = mam.slice(func);
-    if (failed(runFunctionPipeline(fpe, func, fam)))
+    auto childAM = am.slice(func);
+    if (failed(runFunctionPipeline(fpe, func, childAM)))
       return signalPassFailure();
 
-    // Clear out any computed function analyses. These analyses won't be used
+    // Clear out any computed child analyses. These analyses won't be used
     // any more in this pipeline, and this helps reduce the current working set
     // of memory. If preserving these analyses becomes important in the future
     // we can re-evalutate this.
-    fam.clear();
+    am.clear();
   }
 }
 
 // Run the held function pipeline synchronously across the functions within
 // the module.
 void ModuleToFunctionPassAdaptorParallel::runOnModule() {
-  ModuleAnalysisManager &mam = getAnalysisManager();
+  AnalysisManager am = getAnalysisManager();
 
   // Create the async executors if they haven't been created, or if the main
   // function pipeline has changed.
@@ -183,10 +183,10 @@ void ModuleToFunctionPassAdaptorParallel::runOnModule() {
   // Run a prepass over the module to collect the functions to execute a over.
   // This ensures that an analysis manager exists for each function, as well as
   // providing a queue of functions to execute over.
-  std::vector<std::pair<FuncOp, FunctionAnalysisManager>> funcAMPairs;
+  std::vector<std::pair<FuncOp, AnalysisManager>> funcAMPairs;
   for (auto func : getModule().getOps<FuncOp>())
     if (!func.isExternal())
-      funcAMPairs.emplace_back(func, mam.slice(func));
+      funcAMPairs.emplace_back(func, am.slice(func));
 
   // A parallel diagnostic handler that provides deterministic diagnostic
   // ordering.
@@ -253,8 +253,8 @@ PassManager::~PassManager() {}
 
 /// Run the passes within this manager on the provided module.
 LogicalResult PassManager::run(ModuleOp module) {
-  ModuleAnalysisManager mam(module, instrumentor.get());
-  return mpe->run(module, mam);
+  ModuleAnalysisManager am(module, instrumentor.get());
+  return mpe->run(module, am);
 }
 
 /// Disable support for multi-threading within the pass manager.
@@ -329,42 +329,53 @@ void PassManager::addInstrumentation(PassInstrumentation *pi) {
 // AnalysisManager
 //===----------------------------------------------------------------------===//
 
-/// Returns a pass instrumentation object for the current function.
-PassInstrumentor *FunctionAnalysisManager::getPassInstrumentor() const {
-  return parent->getPassInstrumentor();
+/// Returns a pass instrumentation object for the current operation.
+PassInstrumentor *AnalysisManager::getPassInstrumentor() const {
+  ParentPointerT curParent = parent;
+  while (auto *parentAM = curParent.dyn_cast<const AnalysisManager *>())
+    curParent = parentAM->parent;
+  return curParent.get<const ModuleAnalysisManager *>()->getPassInstrumentor();
 }
 
-/// Create an analysis slice for the given child function.
-FunctionAnalysisManager ModuleAnalysisManager::slice(FuncOp func) {
-  assert(func.getOperation()->getParentOp() == moduleAnalyses.getIRUnit() &&
-         "function has a different parent module");
-  auto it = functionAnalyses.find(func);
-  if (it == functionAnalyses.end()) {
-    it =
-        functionAnalyses.try_emplace(func, new AnalysisMap<FuncOp>(func)).first;
-  }
+/// Get an analysis manager for the given child operation.
+AnalysisManager AnalysisManager::slice(Operation *op) {
+  assert(op->getParentOp() == impl->getOperation() &&
+         "'op' has a different parent operation");
+  auto it = impl->childAnalyses.find(op);
+  if (it == impl->childAnalyses.end())
+    it = impl->childAnalyses
+             .try_emplace(op, std::make_unique<NestedAnalysisMap>(op))
+             .first;
   return {this, it->second.get()};
 }
 
 /// Invalidate any non preserved analyses.
-void ModuleAnalysisManager::invalidate(const detail::PreservedAnalyses &pa) {
+void detail::NestedAnalysisMap::invalidate(
+    const detail::PreservedAnalyses &pa) {
   // If all analyses were preserved, then there is nothing to do here.
   if (pa.isAll())
     return;
 
-  // Invalidate the module analyses directly.
-  moduleAnalyses.invalidate(pa);
+  // Invalidate the analyses for the current operation directly.
+  analyses.invalidate(pa);
 
-  // If no analyses were preserved, then just simply clear out the function
+  // If no analyses were preserved, then just simply clear out the child
   // analysis results.
   if (pa.isNone()) {
-    functionAnalyses.clear();
+    childAnalyses.clear();
     return;
   }
 
-  // Otherwise, invalidate each function analyses.
-  for (auto &analysisPair : functionAnalyses)
-    analysisPair.second->invalidate(pa);
+  // Otherwise, invalidate each child analysis map.
+  SmallVector<NestedAnalysisMap *, 8> mapsToInvalidate(1, this);
+  while (!mapsToInvalidate.empty()) {
+    auto *map = mapsToInvalidate.pop_back_val();
+    for (auto &analysisPair : map->childAnalyses) {
+      analysisPair.second->invalidate(pa);
+      if (!analysisPair.second->childAnalyses.empty())
+        mapsToInvalidate.push_back(analysisPair.second.get());
+    }
+  }
 }
 
 //===----------------------------------------------------------------------===//
index bb482a2..aa60cfb 100644 (file)
@@ -62,7 +62,7 @@ public:
   FunctionPassExecutor(const FunctionPassExecutor &rhs);
 
   /// Run the executor on the given function.
-  LogicalResult run(FuncOp function, FunctionAnalysisManager &fam);
+  LogicalResult run(FuncOp function, AnalysisManager am);
 
   /// Add a pass to the current executor. This takes ownership over the provided
   /// pass pointer.
@@ -92,7 +92,7 @@ public:
   ModulePassExecutor &operator=(const ModulePassExecutor &) = delete;
 
   /// Run the executor on the given module.
-  LogicalResult run(ModuleOp module, ModuleAnalysisManager &mam);
+  LogicalResult run(ModuleOp module, AnalysisManager am);
 
   /// Add a pass to the current executor. This takes ownership over the provided
   /// pass pointer.
index 4c7da8e..d55c47d 100644 (file)
@@ -17,6 +17,7 @@
 
 #include "mlir/Pass/AnalysisManager.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
 #include "gtest/gtest.h"
 
 using namespace mlir;
@@ -25,12 +26,10 @@ using namespace mlir::detail;
 namespace {
 /// Minimal class definitions for two analyses.
 struct MyAnalysis {
-  MyAnalysis(FuncOp) {}
-  MyAnalysis(ModuleOp) {}
+  MyAnalysis(Operation *) {}
 };
 struct OtherAnalysis {
-  OtherAnalysis(FuncOp) {}
-  OtherAnalysis(ModuleOp) {}
+  OtherAnalysis(Operation *) {}
 };
 
 TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
@@ -39,18 +38,19 @@ TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
   // Test fine grain invalidation of the module analysis manager.
   OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
   ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
+  AnalysisManager am = mam;
 
   // Query two different analyses, but only preserve one before invalidating.
-  mam.getAnalysis<MyAnalysis>();
-  mam.getAnalysis<OtherAnalysis>();
+  am.getAnalysis<MyAnalysis>();
+  am.getAnalysis<OtherAnalysis>();
 
   detail::PreservedAnalyses pa;
   pa.preserve<MyAnalysis>();
-  mam.invalidate(pa);
+  am.invalidate(pa);
 
   // Check that only MyAnalysis is preserved.
-  EXPECT_TRUE(mam.getCachedAnalysis<MyAnalysis>().hasValue());
-  EXPECT_FALSE(mam.getCachedAnalysis<OtherAnalysis>().hasValue());
+  EXPECT_TRUE(am.getCachedAnalysis<MyAnalysis>().hasValue());
+  EXPECT_FALSE(am.getCachedAnalysis<OtherAnalysis>().hasValue());
 }
 
 TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
@@ -66,7 +66,8 @@ TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
 
   // Test fine grain invalidation of the function analysis manager.
   ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
-  FunctionAnalysisManager fam = mam.slice(func1);
+  AnalysisManager am = mam;
+  AnalysisManager fam = am.slice(func1);
 
   // Query two different analyses, but only preserve one before invalidating.
   fam.getAnalysis<MyAnalysis>();
@@ -95,18 +96,22 @@ TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) {
   // Test fine grain invalidation of a function analysis from within a module
   // analysis manager.
   ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
+  AnalysisManager am = mam;
+
+  // Check that the analysis cache is initially empty.
+  EXPECT_FALSE(am.getCachedChildAnalysis<MyAnalysis>(func1).hasValue());
 
   // Query two different analyses, but only preserve one before invalidating.
-  mam.getFunctionAnalysis<MyAnalysis>(func1);
-  mam.getFunctionAnalysis<OtherAnalysis>(func1);
+  am.getChildAnalysis<MyAnalysis>(func1);
+  am.getChildAnalysis<OtherAnalysis>(func1);
 
   detail::PreservedAnalyses pa;
   pa.preserve<MyAnalysis>();
-  mam.invalidate(pa);
+  am.invalidate(pa);
 
   // Check that only MyAnalysis is preserved.
-  EXPECT_TRUE(mam.getCachedFunctionAnalysis<MyAnalysis>(func1).hasValue());
-  EXPECT_FALSE(mam.getCachedFunctionAnalysis<OtherAnalysis>(func1).hasValue());
+  EXPECT_TRUE(am.getCachedChildAnalysis<MyAnalysis>(func1).hasValue());
+  EXPECT_FALSE(am.getCachedChildAnalysis<OtherAnalysis>(func1).hasValue());
 }
 
 } // end namespace