Refactor the pass manager to support operations other than FuncOp/ModuleOp.
authorRiver Riddle <riverriddle@google.com>
Tue, 3 Sep 2019 02:24:47 +0000 (19:24 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 3 Sep 2019 02:25:26 +0000 (19:25 -0700)
This change generalizes the structure of the pass manager to allow arbitrary nesting pass managers for other operations, at any level. The only user visible change to existing code is the fact that a PassManager must now provide an MLIRContext on construction. A new class `OpPassManager` has been added that represents a pass manager on a specific operation type. `PassManager` will remain the top-level entry point into the pipeline, with OpPassManagers being nested underneath. OpPassManagers will still be implicitly nested if the operation type on the pass differs from the pass manager. To explicitly build a pipeline, the 'nest' methods on OpPassManager may be used:

// Pass manager for the top-level module.
PassManager pm(ctx);

// Nest a pipeline operating on FuncOp.
OpPassManager &fpm = pm.nest<FuncOp>();
fpm.addPass(...);

// Nest a pipeline under the FuncOp pipeline that operates on spirv::ModuleOp
OpPassManager &spvModulePM = pm.nest<spirv::ModuleOp>();

// Nest a pipeline on FuncOps inside of the spirv::ModuleOp.
OpPassManager &spvFuncPM = spvModulePM.nest<FuncOp>();

To help accomplish this a new general OperationPass is added that operates on opaque Operations. This pass can be inserted in a pass manager of any type to operate on any operation opaquely. An example of this opaque OperationPass is a VerifierPass, that simply runs the verifier opaquely on the current operation.

/// Pass to verify an operation and signal failure if necessary.
class VerifierPass : public OperationPass<VerifierPass> {
  void runOnOperation() override {
    Operation *op = getOperation();
    if (failed(verify(op)))
      signalPassFailure();
    markAllAnalysesPreserved();
  }
};

PiperOrigin-RevId: 266840344

17 files changed:
mlir/bindings/python/pybind.cpp
mlir/examples/Linalg/Linalg1/include/linalg1/Common.h
mlir/examples/Linalg/Linalg3/Example.cpp
mlir/examples/Linalg/Linalg4/Example.cpp
mlir/examples/toy/Ch4/toyc.cpp
mlir/examples/toy/Ch5/toyc.cpp
mlir/g3doc/Tutorials/Toy/Ch-4.md
mlir/include/mlir/Pass/Pass.h
mlir/include/mlir/Pass/PassManager.h
mlir/lib/Pass/IRPrinting.cpp
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassDetail.h
mlir/lib/Pass/PassTiming.cpp
mlir/lib/Support/JitRunner.cpp
mlir/lib/Support/MlirOptMain.cpp
mlir/test/Pass/pass-timing.mlir
mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp

index 4003718..3e4f5db 100644 (file)
@@ -193,7 +193,7 @@ struct PythonMLIRModule {
   PythonAttribute boolAttr(bool value);
 
   void compile() {
-    PassManager manager;
+    PassManager manager(module->getContext());
     manager.addPass(mlir::createCanonicalizerPass());
     manager.addPass(mlir::createCSEPass());
     manager.addPass(mlir::createLowerAffinePass());
index 8bedf51..b24a2af 100644 (file)
@@ -71,8 +71,9 @@ inline mlir::FuncOp makeFunction(mlir::ModuleOp module, llvm::StringRef name,
 }
 
 /// A basic pass manager pre-populated with cleanup passes.
-inline std::unique_ptr<mlir::PassManager> cleanupPassManager() {
-  std::unique_ptr<mlir::PassManager> pm(new mlir::PassManager());
+inline std::unique_ptr<mlir::PassManager>
+cleanupPassManager(mlir::MLIRContext *ctx) {
+  std::unique_ptr<mlir::PassManager> pm(new mlir::PassManager(ctx));
   pm->addPass(mlir::createCanonicalizerPass());
   pm->addPass(mlir::createSimplifyAffineStructuresPass());
   pm->addPass(mlir::createCSEPass());
@@ -92,7 +93,7 @@ inline void cleanupAndPrintFunction(mlir::FuncOp f) {
       printToOuts = false;
     }
   };
-  auto pm = cleanupPassManager();
+  auto pm = cleanupPassManager(f.getContext());
   check(mlir::verify(f.getParentOfType<mlir::ModuleOp>()));
   check(pm->run(f.getParentOfType<mlir::ModuleOp>()));
   if (printToOuts)
index e596dda..97db6d3 100644 (file)
@@ -171,7 +171,7 @@ TEST_FUNC(matmul_as_matvec_as_affine) {
   lowerToFinerGrainedTensorContraction(f);
   composeSliceOps(f);
   lowerToLoops(f);
-  PassManager pm;
+  PassManager pm(&context);
   pm.addPass(createLowerLinalgLoadStorePass());
   if (succeeded(pm.run(f.getParentOfType<mlir::ModuleOp>())))
     cleanupAndPrintFunction(f);
index 405eb68..6f2a23a 100644 (file)
@@ -67,7 +67,7 @@ TEST_FUNC(matmul_tiled_loops) {
   OwningModuleRef module = ModuleOp::create(UnknownLoc::get(&context));
   mlir::FuncOp f = makeFunctionWithAMatmulOp(*module, "matmul_tiled_loops");
   lowerToTiledLoops(f, {8, 9});
-  PassManager pm;
+  PassManager pm(&context);
   pm.addPass(createLowerLinalgLoadStorePass());
   if (succeeded(pm.run(f.getParentOfType<mlir::ModuleOp>())))
     cleanupAndPrintFunction(f);
index 9e7a8a3..1c084e0 100644 (file)
@@ -81,7 +81,7 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
 }
 
 mlir::LogicalResult optimize(mlir::ModuleOp module) {
-  mlir::PassManager pm;
+  mlir::PassManager pm(module.getContext());
   pm.addPass(mlir::createCanonicalizerPass());
   pm.addPass(createShapeInferencePass());
   pm.addPass(mlir::createCanonicalizerPass());
index a21eda7..6600ff6 100644 (file)
@@ -104,7 +104,7 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
 }
 
 mlir::LogicalResult optimize(mlir::ModuleOp module) {
-  mlir::PassManager pm;
+  mlir::PassManager pm(module.getContext());
   pm.addPass(mlir::createCanonicalizerPass());
   pm.addPass(createShapeInferencePass());
   pm.addPass(mlir::createCanonicalizerPass());
@@ -117,7 +117,7 @@ mlir::LogicalResult optimize(mlir::ModuleOp module) {
 }
 
 mlir::LogicalResult lowerDialect(mlir::ModuleOp module, bool OnlyLinalg) {
-  mlir::PassManager pm;
+  mlir::PassManager pm(module.getContext());
   pm.addPass(createEarlyLoweringPass());
   pm.addPass(mlir::createCanonicalizerPass());
   pm.addPass(mlir::createCSEPass());
index 1551e12..c064942 100644 (file)
@@ -121,7 +121,7 @@ update our main file, `toyc.cpp`, to add an optimization pipeline. In MLIR, the
 optimizations are ran through a `PassManager` in a similar way to LLVM:
 
 ```c++
-mlir::PassManager pm;
+mlir::PassManager pm(ctx);
 pm.addPass(mlir::createCanonicalizerPass());
 pm.run(&module);
 ```
index 79ede0e..fc440d5 100644 (file)
@@ -26,9 +26,6 @@
 
 namespace mlir {
 namespace detail {
-class FunctionPassExecutor;
-class ModulePassExecutor;
-
 /// The state for a single execution of a pass. This provides a unified
 /// interface for accessing and initializing necessary state for pass execution.
 struct PassExecutionState {
@@ -68,11 +65,13 @@ public:
   /// Returns the derived pass name.
   virtual StringRef getName() = 0;
 
-  /// Returns the name of the operation that this pass operates on.
-  StringRef getOpName() const { return opName; }
+  /// Returns the name of the operation that this pass operates on, or None if
+  /// this is a generic OperationPass.
+  llvm::Optional<StringRef> getOpName() const { return opName; }
 
 protected:
-  Pass(const PassID *passID, StringRef opName)
+  explicit Pass(const PassID *passID,
+                llvm::Optional<StringRef> opName = llvm::None)
       : passID(passID), opName(opName) {}
 
   /// Returns the current pass state.
@@ -112,15 +111,15 @@ private:
   /// Represents a unique identifier for the pass.
   const PassID *passID;
 
-  /// The name of the operation that this pass operates on.
-  StringRef opName;
+  /// The name of the operation that this pass operates on, or None if this is a
+  /// generic OperationPass.
+  llvm::Optional<StringRef> opName;
 
   /// The current execution state for the pass.
   llvm::Optional<detail::PassExecutionState> passState;
 
   /// Allow access to 'clone' and 'run'.
-  friend detail::FunctionPassExecutor;
-  friend detail::ModulePassExecutor;
+  friend class OpPassManager;
 };
 
 //===----------------------------------------------------------------------===//
@@ -138,7 +137,8 @@ public:
   }
 
 protected:
-  PassModel(StringRef opName) : BasePassT(PassID::getID<PassT>(), opName) {}
+  explicit PassModel(llvm::Optional<StringRef> opName = llvm::None)
+      : BasePassT(PassID::getID<PassT>(), opName) {}
 
   /// Signal that some invariant was broken when running. The IR is allowed to
   /// be in an invalid state.
@@ -227,11 +227,24 @@ public:
   }
 };
 
+/// Pass to transform an operation.
+///
+/// Operation passes must not:
+///   - modify any other operations within the parent region, as other threads
+///     may be manipulating them concurrently.
+///   - modify any state within the parent operation, this includes adding
+///     additional operations.
+///
+/// Derived function passes are expected to provide the following:
+///   - A 'void runOnOperation()' method.
+template <typename T>
+struct OperationPass : public detail::PassModel<T, Pass> {};
+
 /// Pass to transform an operation of a specific type.
 ///
 /// Operation passes must not:
-///   - read or modify any other operations within the parent region, as
-///     other threads may be manipulating them concurrently.
+///   - modify any other operations within the parent region, as other threads
+///     may be manipulating them concurrently.
 ///   - modify any state within the parent operation, this includes adding
 ///     additional operations.
 ///
index 888d903..b240e5b 100644 (file)
@@ -26,21 +26,74 @@ class Any;
 } // end namespace llvm
 
 namespace mlir {
-class FuncOp;
+class AnalysisManager;
+class MLIRContext;
 class ModuleOp;
-template <typename OpT> class OpPassBase;
-using FunctionPassBase = OpPassBase<FuncOp>;
-using ModulePassBase = OpPassBase<ModuleOp>;
-
+class OperationName;
+class Operation;
 class Pass;
 class PassInstrumentation;
 class PassInstrumentor;
 
 namespace detail {
-class PassExecutor;
-class ModulePassExecutor;
+struct OpPassManagerImpl;
 } // end namespace detail
 
+//===----------------------------------------------------------------------===//
+// OpPassManager
+//===----------------------------------------------------------------------===//
+
+/// This class represents a pass manager that runs passes on a specific
+/// operation type. This class is not constructed directly, but nested within
+/// other OpPassManagers or the top-level PassManager.
+class OpPassManager {
+public:
+  OpPassManager(OpPassManager &&) = default;
+  OpPassManager(const OpPassManager &rhs);
+  ~OpPassManager();
+
+  /// Run the held passes over the given operation.
+  LogicalResult run(Operation *op, AnalysisManager am);
+
+  /// Nest a new operation pass manager for the given operation kind under this
+  /// pass manager.
+  OpPassManager &nest(const OperationName &nestedName);
+  OpPassManager &nest(StringRef nestedName);
+  template <typename OpT> OpPassManager &nest() {
+    return nest(OpT::getOperationName());
+  }
+
+  /// Add the given pass to this pass manager. The pass must either be an opaque
+  /// `OperationPass`, or an `OpPass` that operates on operations of the same
+  /// type as this pass manager.
+  void addPass(std::unique_ptr<Pass> pass);
+
+  /// Returns the number of passes held by this manager.
+  size_t size() const;
+
+  /// Return an instance of the context.
+  MLIRContext *getContext() const;
+
+  /// Return the operation name that this pass manager operates on.
+  const OperationName &getOpName() const;
+
+private:
+  OpPassManager(OperationName name, bool disableThreads, bool verifyPasses);
+
+  /// Returns the internal implementation instance.
+  detail::OpPassManagerImpl &getImpl();
+
+  /// A pointer to an internal implementation instance.
+  std::unique_ptr<detail::OpPassManagerImpl> impl;
+
+  /// Allow access to the constructor.
+  friend class PassManager;
+};
+
+//===----------------------------------------------------------------------===//
+// PassManager
+//===----------------------------------------------------------------------===//
+
 /// An enum describing the different display modes for the pass timing
 /// information within the pass manager.
 enum class PassTimingDisplayMode {
@@ -58,7 +111,7 @@ enum class PassTimingDisplayMode {
 class PassManager {
 public:
   // If verifyPasses is true, the verifier is run after each pass.
-  PassManager(bool verifyPasses = true);
+  PassManager(MLIRContext *ctx, bool verifyPasses = true);
   ~PassManager();
 
   /// Run the passes within this manager on the provided module.
@@ -72,18 +125,20 @@ public:
   // Pipeline Building
   //===--------------------------------------------------------------------===//
 
+  /// Allow converting to the impl OpPassManager.
+  operator OpPassManager &() { return opPassManager; }
+
   /// Add an opaque pass pointer to the current manager. This takes ownership
   /// over the provided pass pointer.
   void addPass(std::unique_ptr<Pass> pass);
 
-  /// Add a module pass to the current manager. This takes ownership over the
-  /// provided pass pointer.
-  void addPass(std::unique_ptr<ModulePassBase> pass);
-
-  /// Add a function pass to the current manager. This takes ownership over the
-  /// provided pass pointer. This will automatically create a function pass
-  /// executor if necessary.
-  void addPass(std::unique_ptr<FunctionPassBase> pass);
+  /// Allow nesting other operation pass managers.
+  OpPassManager &nest(const OperationName &nestedName) {
+    return opPassManager.nest(nestedName);
+  }
+  template <typename OpT> OpPassManager &nest() {
+    return opPassManager.nest<OpT>();
+  }
 
   //===--------------------------------------------------------------------===//
   // Instrumentations
@@ -113,21 +168,12 @@ public:
       PassTimingDisplayMode displayMode = PassTimingDisplayMode::Pipeline);
 
 private:
-  /// A stack of nested pass executors on sub-module IR units, e.g. function.
-  llvm::SmallVector<detail::PassExecutor *, 1> nestedExecutorStack;
-
-  /// The top level module pass executor.
-  std::unique_ptr<detail::ModulePassExecutor> mpe;
-
-  /// Flag that specifies if the IR should be verified after each pass has run.
-  bool verifyPasses : 1;
+  /// The top level pass manager instance.
+  OpPassManager opPassManager;
 
   /// Flag that specifies if pass timing is enabled.
   bool passTiming : 1;
 
-  /// Flag that specifies if multi-threading is disabled.
-  bool disableThreads : 1;
-
   /// A manager for pass instrumentations.
   std::unique_ptr<PassInstrumentor> instrumentor;
 };
index bc66197..5a83965 100644 (file)
@@ -60,7 +60,7 @@ private:
 
 /// Returns true if the given pass is hidden from IR printing.
 static bool isHiddenPass(Pass *pass) {
-  return isAdaptorPass(pass) || isVerifierPass(pass);
+  return isAdaptorPass(pass) || isa<VerifierPass>(pass);
 }
 
 static void printIR(Operation *op, bool printModuleScope, raw_ostream &out) {
index e208e20..18c4de4 100644 (file)
@@ -23,6 +23,7 @@
 #include "PassDetail.h"
 #include "mlir/Analysis/Verifier.h"
 #include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dialect.h"
 #include "mlir/IR/Module.h"
 #include "mlir/Pass/PassManager.h"
 #include "llvm/Support/CommandLine.h"
@@ -70,47 +71,155 @@ LogicalResult Pass::run(Operation *op, AnalysisManager am) {
 }
 
 //===----------------------------------------------------------------------===//
-// PassExecutor
+// Verifier Passes
 //===----------------------------------------------------------------------===//
 
-FunctionPassExecutor::FunctionPassExecutor(const FunctionPassExecutor &rhs)
-    : PassExecutor(Kind::FunctionExecutor) {
-  for (auto &pass : rhs.passes)
-    addPass(cast<FunctionPassBase>(pass->clone()));
+void VerifierPass::runOnOperation() {
+  if (failed(verify(getOperation())))
+    signalPassFailure();
+  markAllAnalysesPreserved();
 }
 
-/// Run all of the passes in this manager over the current function.
-LogicalResult detail::FunctionPassExecutor::run(FuncOp function,
-                                                AnalysisManager am) {
-  // Run each of the held passes.
-  for (auto &pass : passes)
-    if (failed(pass->run(function, am)))
-      return failure();
-  return success();
+//===----------------------------------------------------------------------===//
+// OpPassManager
+//===----------------------------------------------------------------------===//
+
+namespace mlir {
+namespace detail {
+struct OpPassManagerImpl {
+  OpPassManagerImpl(OperationName name, bool disableThreads, bool verifyPasses)
+      : name(name), disableThreads(disableThreads), verifyPasses(verifyPasses) {
+  }
+
+  /// Returns the pass manager instance corresponding to the last pass added
+  /// if that pass was a PassAdaptor.
+  OpPassManager *getLastNestedPM() {
+    if (passes.empty())
+      return nullptr;
+    auto lastPassIt = passes.rbegin();
+
+    // If this pass was a verifier, skip it as it is opaque to ordering for
+    // pipeline construction.
+    if (isa<VerifierPass>(*lastPassIt))
+      ++lastPassIt;
+
+    // Get the internal pass manager if this pass is an adaptor.
+    if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(lastPassIt->get()))
+      return &adaptor->getPassManager();
+    if (auto *adaptor = dyn_cast<OpToOpPassAdaptorParallel>(lastPassIt->get()))
+      return &adaptor->getPassManager();
+    return nullptr;
+  }
+
+  /// The name of the operation that passes of this pass manager operate on.
+  OperationName name;
+
+  /// Flag to disable multi-threading of passes.
+  bool disableThreads : 1;
+
+  /// Flag that specifies if the IR should be verified after each pass has run.
+  bool verifyPasses : 1;
+
+  /// The set of passes to run as part of this pass manager.
+  std::vector<std::unique_ptr<Pass>> passes;
+};
+} // end namespace detail
+} // end namespace mlir
+
+OpPassManager::OpPassManager(OperationName name, bool disableThreads,
+                             bool verifyPasses)
+    : impl(new OpPassManagerImpl(name, disableThreads, verifyPasses)) {
+  assert(name.getAbstractOperation() &&
+         "OpPassManager can only operate on registered operations");
+  assert(name.getAbstractOperation()->hasProperty(
+             OperationProperty::IsolatedFromAbove) &&
+         "OpPassManager only supports operating on operations marked as "
+         "'IsolatedFromAbove'");
+}
+OpPassManager::OpPassManager(const OpPassManager &rhs)
+    : impl(new OpPassManagerImpl(rhs.impl->name, rhs.impl->disableThreads,
+                                 rhs.impl->verifyPasses)) {
+  for (auto &pass : rhs.impl->passes)
+    impl->passes.emplace_back(pass->clone());
 }
 
-/// Run all of the passes in this manager over the current module.
-LogicalResult detail::ModulePassExecutor::run(ModuleOp module,
-                                              AnalysisManager am) {
+OpPassManager::~OpPassManager() {}
+
+/// Run all of the passes in this manager over the current operation.
+LogicalResult OpPassManager::run(Operation *op, AnalysisManager am) {
   // Run each of the held passes.
-  for (auto &pass : passes)
-    if (failed(pass->run(module, am)))
+  for (auto &pass : impl->passes)
+    if (failed(pass->run(op, am)))
       return failure();
   return success();
 }
 
+/// Nest a new operation pass manager for the given operation kind under this
+/// pass manager.
+OpPassManager &OpPassManager::nest(const OperationName &nestedName) {
+  // Check to see if an existing nested pass manager already exists.
+  if (auto *nestedPM = impl->getLastNestedPM()) {
+    if (nestedPM->getOpName() == nestedName)
+      return *nestedPM;
+  }
+
+  std::unique_ptr<OpPassManager> nested(
+      new OpPassManager(nestedName, impl->disableThreads, impl->verifyPasses));
+  auto &nestedRef = *nested;
+
+  /// Create an executor adaptor for this pass. If multi-threading is disabled,
+  /// then create a synchronous adaptor.
+  if (impl->disableThreads || !llvm::llvm_is_multithreaded())
+    addPass(std::make_unique<OpToOpPassAdaptor>(std::move(nested)));
+  else
+    addPass(std::make_unique<OpToOpPassAdaptorParallel>(std::move(nested)));
+  return nestedRef;
+}
+OpPassManager &OpPassManager::nest(StringRef nestedName) {
+  return nest(OperationName(nestedName, getContext()));
+}
+
+/// Add the given pass to this pass manager. The pass must either be an opaque
+/// `OperationPass`, or an `OpPass` that operates on operations of the same
+/// type as this pass manager.
+void OpPassManager::addPass(std::unique_ptr<Pass> pass) {
+  // If this pass runs on a different operation than this pass manager, then
+  // implicitly nest a pass manager for this operation.
+  auto passOpName = pass->getOpName();
+  if (passOpName && passOpName != impl->name.getStringRef())
+    return nest(*passOpName).addPass(std::move(pass));
+
+  impl->passes.emplace_back(std::move(pass));
+  if (impl->verifyPasses)
+    impl->passes.emplace_back(std::make_unique<VerifierPass>());
+}
+
+/// Returns the number of passes held by this manager.
+size_t OpPassManager::size() const { return impl->passes.size(); }
+
+/// Returns the internal implementation instance.
+OpPassManagerImpl &OpPassManager::getImpl() { return *impl; }
+
+/// Return an instance of the context.
+MLIRContext *OpPassManager::getContext() const {
+  return impl->name.getAbstractOperation()->dialect.getContext();
+}
+
+/// Return the operation name that this pass manager operates on.
+const OperationName &OpPassManager::getOpName() const { return impl->name; }
+
 //===----------------------------------------------------------------------===//
-// ModuleToFunctionPassAdaptor
+// OpToOpPassAdaptor
 //===----------------------------------------------------------------------===//
 
-/// Utility to run the given function and analysis manager on a provided
-/// function pass executor.
-static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe, FuncOp func,
-                                         AnalysisManager am) {
-  // Run the function pipeline over the provided function.
-  auto result = fpe.run(func, am);
+/// Utility to run the given operation and analysis manager on a provided op
+/// pass manager.
+static LogicalResult runPipeline(OpPassManager &pm, Operation *op,
+                                 AnalysisManager am) {
+  // Run the pipeline over the provided operation.
+  auto result = pm.run(op, am);
 
-  // Clear out any computed function analyses. These analyses won't be used
+  // Clear out any computed operation analyses. These analyses won't be used
   // any more in this pipeline, and this helps reduce the current working set
   // of memory. If preserving these analyses becomes important in the future
   // we can re-evalutate this.
@@ -118,72 +227,82 @@ static LogicalResult runFunctionPipeline(FunctionPassExecutor &fpe, FuncOp func,
   return result;
 }
 
-/// Run the held function pipeline over all non-external functions within the
-/// module.
-void ModuleToFunctionPassAdaptor::runOnModule() {
-  AnalysisManager am = getAnalysisManager();
-  for (auto func : getModule().getOps<FuncOp>()) {
-    // Skip external functions.
-    if (func.isExternal())
-      continue;
-
-    // Run the held function pipeline over the current function.
-    auto childAM = am.slice(func);
-    if (failed(runFunctionPipeline(fpe, func, childAM)))
-      return signalPassFailure();
-
-    // Clear out any computed child analyses. These analyses won't be used
-    // any more in this pipeline, and this helps reduce the current working set
-    // of memory. If preserving these analyses becomes important in the future
-    // we can re-evalutate this.
-    am.clear();
+OpToOpPassAdaptor::OpToOpPassAdaptor(std::unique_ptr<OpPassManager> mgr)
+    : mgr(std::move(mgr)) {}
+OpToOpPassAdaptor::OpToOpPassAdaptor(const OpToOpPassAdaptor &rhs)
+    : mgr(new OpPassManager(*rhs.mgr)) {}
+
+/// Run the held pipeline over all nested operations.
+void OpToOpPassAdaptor::runOnOperation() {
+  auto am = getAnalysisManager();
+  for (auto &region : getOperation()->getRegions()) {
+    for (auto &block : region) {
+      for (auto &op : block) {
+        // Run the held pipeline over the current operation.
+        if (op.getName() == mgr->getOpName() &&
+            failed(runPipeline(*mgr, &op, am.slice(&op))))
+          return signalPassFailure();
+      }
+    }
   }
 }
 
-// Run the held function pipeline synchronously across the functions within
-// the module.
-void ModuleToFunctionPassAdaptorParallel::runOnModule() {
+OpToOpPassAdaptorParallel::OpToOpPassAdaptorParallel(
+    std::unique_ptr<OpPassManager> mgr)
+    : mgr(std::move(mgr)) {}
+OpToOpPassAdaptorParallel::OpToOpPassAdaptorParallel(
+    const OpToOpPassAdaptorParallel &rhs)
+    : mgr(std::make_unique<OpPassManager>(*rhs.mgr)) {}
+
+// Run the held pipeline asynchronously across the functions within the module.
+void OpToOpPassAdaptorParallel::runOnOperation() {
   AnalysisManager am = getAnalysisManager();
 
   // Create the async executors if they haven't been created, or if the main
-  // function pipeline has changed.
-  if (asyncExecutors.empty() || asyncExecutors.front().size() != fpe.size())
-    asyncExecutors = {llvm::hardware_concurrency(), fpe};
-
-  // Run a prepass over the module to collect the functions to execute a over.
-  // This ensures that an analysis manager exists for each function, as well as
-  // providing a queue of functions to execute over.
-  std::vector<std::pair<FuncOp, AnalysisManager>> funcAMPairs;
-  for (auto func : getModule().getOps<FuncOp>())
-    if (!func.isExternal())
-      funcAMPairs.emplace_back(func, am.slice(func));
+  // pipeline has changed.
+  if (asyncExecutors.empty() || asyncExecutors.front().size() != mgr->size())
+    asyncExecutors = {llvm::hardware_concurrency(), *mgr};
+
+  // Run a prepass over the module to collect the operations to execute over.
+  // This ensures that an analysis manager exists for each operation, as well as
+  // providing a queue of operations to execute over.
+  std::vector<std::pair<Operation *, AnalysisManager>> opAMPairs;
+  for (auto &region : getOperation()->getRegions()) {
+    for (auto &block : region) {
+      for (auto &op : block) {
+        // Add this operation iff the name matches the current pass manager.
+        if (op.getName() == mgr->getOpName())
+          opAMPairs.emplace_back(&op, am.slice(&op));
+      }
+    }
+  }
 
   // A parallel diagnostic handler that provides deterministic diagnostic
   // ordering.
   ParallelDiagnosticHandler diagHandler(&getContext());
 
-  // An index for the current function/analysis manager pair.
-  std::atomic<unsigned> funcIt(0);
+  // An index for the current operation/analysis manager pair.
+  std::atomic<unsigned> opIt(0);
 
   // An atomic failure variable for the async executors.
   std::atomic<bool> passFailed(false);
   llvm::parallel::for_each(
       llvm::parallel::par, asyncExecutors.begin(),
       std::next(asyncExecutors.begin(),
-                std::min(asyncExecutors.size(), funcAMPairs.size())),
-      [&](FunctionPassExecutor &executor) {
-        for (auto e = funcAMPairs.size(); !passFailed && funcIt < e;) {
-          // Get the next available function index.
-          unsigned nextID = funcIt++;
+                std::min(asyncExecutors.size(), opAMPairs.size())),
+      [&](OpPassManager &pm) {
+        for (auto e = opAMPairs.size(); !passFailed && opIt < e;) {
+          // Get the next available operation index.
+          unsigned nextID = opIt++;
           if (nextID >= e)
             break;
 
-          // Set the function id for this thread in the diagnostic handler.
+          // Set the order id for this thread in the diagnostic handler.
           diagHandler.setOrderIDForThread(nextID);
 
-          // Run the executor over the current function.
-          auto &it = funcAMPairs[nextID];
-          if (failed(runFunctionPipeline(executor, it.first, it.second))) {
+          // Run the executor over the current operation.
+          auto &it = opAMPairs[nextID];
+          if (failed(runPipeline(pm, it.first, it.second))) {
             passFailed = true;
             break;
           }
@@ -195,90 +314,42 @@ void ModuleToFunctionPassAdaptorParallel::runOnModule() {
     signalPassFailure();
 }
 
-//===----------------------------------------------------------------------===//
-// Verifier Passes
-//===----------------------------------------------------------------------===//
-
-void FunctionVerifierPass::runOnFunction() {
-  if (failed(verify(getFunction())))
-    signalPassFailure();
-  markAllAnalysesPreserved();
-}
-
-void ModuleVerifierPass::runOnModule() {
-  if (failed(verify(getModule())))
-    signalPassFailure();
-  markAllAnalysesPreserved();
+/// Utility function to return the operation name that the given adaptor pass
+/// operates on. Return None if the given pass is not an adaptor pass.
+Optional<StringRef> mlir::detail::getAdaptorPassOpName(Pass *pass) {
+  if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
+    return adaptor->getPassManager().getOpName().getStringRef();
+  if (auto *adaptor = dyn_cast<OpToOpPassAdaptorParallel>(pass))
+    return adaptor->getPassManager().getOpName().getStringRef();
+  return llvm::None;
 }
 
 //===----------------------------------------------------------------------===//
 // PassManager
 //===----------------------------------------------------------------------===//
 
-PassManager::PassManager(bool verifyPasses)
-    : mpe(new ModulePassExecutor()), verifyPasses(verifyPasses),
-      passTiming(false), disableThreads(false) {}
+PassManager::PassManager(MLIRContext *ctx, bool verifyPasses)
+    : opPassManager(OperationName(ModuleOp::getOperationName(), ctx),
+                    /*disableThreads=*/false, verifyPasses),
+      passTiming(false) {}
 
 PassManager::~PassManager() {}
 
 /// Run the passes within this manager on the provided module.
 LogicalResult PassManager::run(ModuleOp module) {
   ModuleAnalysisManager am(module, instrumentor.get());
-  return mpe->run(module, am);
+  return opPassManager.run(module, am);
 }
 
 /// Disable support for multi-threading within the pass manager.
 void PassManager::disableMultithreading(bool disable) {
-  disableThreads = disable;
+  opPassManager.getImpl().disableThreads = disable;
 }
 
 /// Add an opaque pass pointer to the current manager. This takes ownership
 /// over the provided pass pointer.
 void PassManager::addPass(std::unique_ptr<Pass> pass) {
-  if (isa<FunctionPassBase>(pass.get()))
-    return addPass(cast<FunctionPassBase>(std::move(pass)));
-  addPass(cast<ModulePassBase>(std::move(pass)));
-}
-
-/// Add a module pass to the current manager. This takes ownership over the
-/// provided pass pointer.
-void PassManager::addPass(std::unique_ptr<ModulePassBase> pass) {
-  nestedExecutorStack.clear();
-  mpe->addPass(std::move(pass));
-
-  // Add a verifier run if requested.
-  if (verifyPasses)
-    mpe->addPass(std::make_unique<ModuleVerifierPass>());
-}
-
-/// Add a function pass to the current manager. This takes ownership over the
-/// provided pass pointer. This will automatically create a function pass
-/// executor if necessary.
-void PassManager::addPass(std::unique_ptr<FunctionPassBase> pass) {
-  detail::FunctionPassExecutor *fpe;
-  if (nestedExecutorStack.empty()) {
-    /// Create an executor adaptor for this pass.
-    if (disableThreads || !llvm::llvm_is_multithreaded()) {
-      // If multi-threading is disabled, then create a synchronous adaptor.
-      auto adaptor = std::make_unique<ModuleToFunctionPassAdaptor>();
-      fpe = &adaptor->getFunctionExecutor();
-      addPass(std::unique_ptr<ModulePassBase>{adaptor.release()});
-    } else {
-      auto adaptor = std::make_unique<ModuleToFunctionPassAdaptorParallel>();
-      fpe = &adaptor->getFunctionExecutor();
-      addPass(std::unique_ptr<ModulePassBase>{adaptor.release()});
-    }
-
-    /// Add the executor to the stack.
-    nestedExecutorStack.push_back(fpe);
-  } else {
-    fpe = cast<detail::FunctionPassExecutor>(nestedExecutorStack.back());
-  }
-  fpe->addPass(std::move(pass));
-
-  // Add a verifier run if requested.
-  if (verifyPasses)
-    fpe->addPass(std::make_unique<FunctionVerifierPass>());
+  opPassManager.addPass(std::move(pass));
 }
 
 /// Add the provided instrumentation to the pass manager. This takes ownership
index 40e75f4..8188cc1 100644 (file)
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+class OpPassManager;
+
 namespace detail {
 
 //===----------------------------------------------------------------------===//
-// Verifier Passes
+// Verifier Pass
 //===----------------------------------------------------------------------===//
 
-/// Pass to verify a function and signal failure if necessary.
-class FunctionVerifierPass : public FunctionPass<FunctionVerifierPass> {
-  void runOnFunction() override;
-};
-
-/// Pass to verify a module and signal failure if necessary.
-class ModuleVerifierPass : public ModulePass<ModuleVerifierPass> {
-  void runOnModule() override;
+/// Pass to verify an operation and signal failure if necessary.
+class VerifierPass : public OperationPass<VerifierPass> {
+  void runOnOperation() override;
 };
 
 //===----------------------------------------------------------------------===//
-// PassExecutor
+// OpToOpPassAdaptor
 //===----------------------------------------------------------------------===//
 
-/// The abstract base pass executor class.
-class PassExecutor {
-public:
-  enum Kind { FunctionExecutor, ModuleExecutor };
-  explicit PassExecutor(Kind kind) : kind(kind) {}
-
-  /// Get the kind of this executor.
-  Kind getKind() const { return kind; }
-
-private:
-  /// The kind of executor this object is.
-  Kind kind;
-};
-
-/// A pass executor that contains a list of passes over a function.
-class FunctionPassExecutor : public PassExecutor {
-public:
-  FunctionPassExecutor() : PassExecutor(Kind::FunctionExecutor) {}
-  FunctionPassExecutor(FunctionPassExecutor &&) = default;
-  FunctionPassExecutor(const FunctionPassExecutor &rhs);
-
-  /// Run the executor on the given function.
-  LogicalResult run(FuncOp function, AnalysisManager am);
-
-  /// Add a pass to the current executor. This takes ownership over the provided
-  /// pass pointer.
-  void addPass(std::unique_ptr<Pass> pass) {
-    passes.push_back(std::move(pass));
-  }
-
-  /// Returns the number of passes held by this executor.
-  size_t size() const { return passes.size(); }
-
-  static bool classof(const PassExecutor *pe) {
-    return pe->getKind() == Kind::FunctionExecutor;
-  }
-
-private:
-  std::vector<std::unique_ptr<Pass>> passes;
-};
-
-/// A pass executor that contains a list of passes over a module unit.
-class ModulePassExecutor : public PassExecutor {
+/// An adaptor pass used to run operation passes over nested operations
+/// synchronously on a single thread.
+class OpToOpPassAdaptor : public OperationPass<OpToOpPassAdaptor> {
 public:
-  ModulePassExecutor() : PassExecutor(Kind::ModuleExecutor) {}
-  ModulePassExecutor(ModulePassExecutor &&) = default;
-
-  // Don't allow copying.
-  ModulePassExecutor(const ModulePassExecutor &) = delete;
-  ModulePassExecutor &operator=(const ModulePassExecutor &) = delete;
+  OpToOpPassAdaptor(std::unique_ptr<OpPassManager> mgr);
+  OpToOpPassAdaptor(const OpToOpPassAdaptor &rhs);
 
-  /// Run the executor on the given module.
-  LogicalResult run(ModuleOp module, AnalysisManager am);
+  /// Run the held pipeline over all operations.
+  void runOnOperation() override;
 
-  /// Add a pass to the current executor. This takes ownership over the provided
-  /// pass pointer.
-  void addPass(std::unique_ptr<ModulePassBase> pass) {
-    passes.push_back(std::move(pass));
-  }
-
-  static bool classof(const PassExecutor *pe) {
-    return pe->getKind() == Kind::ModuleExecutor;
-  }
+  /// Returns the nested pass manager for this adaptor.
+  OpPassManager &getPassManager() { return *mgr; }
 
 private:
-  /// Set of passes to run on the given module.
-  std::vector<std::unique_ptr<ModulePassBase>> passes;
+  std::unique_ptr<OpPassManager> mgr;
 };
 
-//===----------------------------------------------------------------------===//
-// ModuleToFunctionPassAdaptor
-//===----------------------------------------------------------------------===//
-
-/// An adaptor module pass used to run function passes over all of the
-/// non-external functions of a module synchronously on a single thread.
-class ModuleToFunctionPassAdaptor
-    : public ModulePass<ModuleToFunctionPassAdaptor> {
+/// An adaptor pass used to run operation passes over nested operations
+/// asynchronously across multiple threads.
+class OpToOpPassAdaptorParallel
+    : public OperationPass<OpToOpPassAdaptorParallel> {
 public:
-  /// Run the held function pipeline over all non-external functions within the
-  /// module.
-  void runOnModule() override;
+  OpToOpPassAdaptorParallel(std::unique_ptr<OpPassManager> mgr);
+  OpToOpPassAdaptorParallel(const OpToOpPassAdaptorParallel &rhs);
 
-  /// Returns the function pass executor for this adaptor.
-  FunctionPassExecutor &getFunctionExecutor() { return fpe; }
+  /// Run the held pipeline over all operations.
+  void runOnOperation() override;
 
-private:
-  FunctionPassExecutor fpe;
-};
-
-/// An adaptor module pass used to run function passes over all of the
-/// non-external functions of a module asynchronously across multiple threads.
-class ModuleToFunctionPassAdaptorParallel
-    : public ModulePass<ModuleToFunctionPassAdaptorParallel> {
-public:
-  /// Run the held function pipeline over all non-external functions within the
-  /// module.
-  void runOnModule() override;
-
-  /// Returns the function pass executor for this adaptor.
-  FunctionPassExecutor &getFunctionExecutor() { return fpe; }
+  /// Returns the nested pass manager for this adaptor.
+  OpPassManager &getPassManager() { return *mgr; }
 
 private:
-  // The main function pass executor for this adaptor.
-  FunctionPassExecutor fpe;
+  // The main pass executor for this adaptor.
+  std::unique_ptr<OpPassManager> mgr;
 
   // A set of executors, cloned from the main executor, that run asynchronously
   // on different threads.
-  std::vector<FunctionPassExecutor> asyncExecutors;
+  std::vector<OpPassManager> asyncExecutors;
 };
 
-/// Utility function to return if a pass refers to an
-/// ModuleToFunctionPassAdaptor instance.
-inline bool isModuleToFunctionAdaptorPass(Pass *pass) {
-  return isa<ModuleToFunctionPassAdaptorParallel>(pass) ||
-         isa<ModuleToFunctionPassAdaptor>(pass);
+/// Utility function to return if a pass refers to an OpToOpAdaptorPass
+/// instance.
+inline bool isOpToOpAdaptorPass(Pass *pass) {
+  return isa<OpToOpPassAdaptorParallel>(pass) || isa<OpToOpPassAdaptor>(pass);
 }
 
 /// Utility function to return if a pass refers to an adaptor pass. Adaptor
 /// passes are those that internally execute a pipeline, such as the
-/// ModuleToFunctionPassAdaptor.
-inline bool isAdaptorPass(Pass *pass) {
-  return isModuleToFunctionAdaptorPass(pass);
-}
+/// OpToOpPassAdaptor.
+inline bool isAdaptorPass(Pass *pass) { return isOpToOpAdaptorPass(pass); }
 
-/// Utility function to return if a pass refers to a verifier pass.
-inline bool isVerifierPass(Pass *pass) {
-  return isa<FunctionVerifierPass>(pass) || isa<ModuleVerifierPass>(pass);
-}
+/// Utility function to return the operation name that the given adaptor pass
+/// operates on. Return None if the given pass is not an adaptor pass.
+Optional<StringRef> getAdaptorPassOpName(Pass *pass);
 
 } // end namespace detail
 } // end namespace mlir
index 91b838c..b58bd7c 100644 (file)
@@ -221,9 +221,9 @@ struct PassTiming : public PassInstrumentation {
 
 /// Start a new timer for the given pass.
 void PassTiming::startPassTimer(Pass *pass) {
-  Timer *timer = getTimer(pass, [pass] {
-    if (isModuleToFunctionAdaptorPass(pass))
-      return StringRef("Function Pipeline");
+  Timer *timer = getTimer(pass, [pass]() -> std::string {
+    if (auto pipelineName = getAdaptorPassOpName(pass))
+      return ("'" + *pipelineName + "' Pipeline").str();
     return pass->getName();
   });
 
@@ -246,9 +246,9 @@ void PassTiming::runAfterPass(Pass *pass, Operation *) {
   assert(!activeTimers.empty() && "expected active timer");
   Timer *timer = activeTimers.pop_back_val();
 
-  // If this is an ModuleToFunctionPassAdaptorParallel, then we need to merge in
-  // the timing data for the other threads.
-  if (isa<ModuleToFunctionPassAdaptorParallel>(pass)) {
+  // If this is an OpToOpPassAdaptorParallel, then we need to merge in the
+  // timing data for the other threads.
+  if (isa<OpToOpPassAdaptorParallel>(pass)) {
     // The asychronous pipeline timers should exist as children of root timers
     // for other threads.
     for (auto &rootTimer : llvm::make_early_inc_range(rootTimers)) {
index 07d6841..549b1ad 100644 (file)
@@ -169,7 +169,7 @@ static void printMemRefArguments(ArrayRef<Type> argTypes,
 // - affine to standard lowering
 // - standard to llvm lowering
 static LogicalResult convertAffineStandardToLLVMIR(ModuleOp module) {
-  PassManager manager;
+  PassManager manager(module.getContext());
   manager.addPass(mlir::createCanonicalizerPass());
   manager.addPass(mlir::createCSEPass());
   manager.addPass(mlir::createLowerAffinePass());
index 80cba5a..0b234e6 100644 (file)
@@ -54,7 +54,7 @@ performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
     return failure();
 
   // Apply any pass manager command line options.
-  PassManager pm(verifyPasses);
+  PassManager pm(context, verifyPasses);
   applyPassManagerCLOptions(pm);
 
   // Run each of the passes that were selected.
index 15f6147..b7b4101 100644 (file)
@@ -7,50 +7,48 @@
 // LIST: Total Execution Time:
 // LIST: Name
 // LIST-DAG: Canonicalizer
-// LIST-DAG: FunctionVerifier
+// LIST-DAG: Verifier
 // LIST-DAG: CSE
-// LIST-DAG: ModuleVerifier
 // LIST-DAG: DominanceInfo
 // LIST: Total
 
 // PIPELINE: Pass execution timing report
 // PIPELINE: Total Execution Time:
 // PIPELINE: Name
-// PIPELINE-NEXT: Function Pipeline
+// PIPELINE-NEXT: 'func' Pipeline
 // PIPELINE-NEXT:   CSE
 // PIPELINE-NEXT:     (A) DominanceInfo
-// PIPELINE-NEXT:   FunctionVerifier
+// PIPELINE-NEXT:   Verifier
 // PIPELINE-NEXT:   Canonicalizer
-// PIPELINE-NEXT:   FunctionVerifier
+// PIPELINE-NEXT:   Verifier
 // PIPELINE-NEXT:   CSE
 // PIPELINE-NEXT:     (A) DominanceInfo
-// PIPELINE-NEXT:   FunctionVerifier
-// PIPELINE-NEXT: ModuleVerifier
+// PIPELINE-NEXT:   Verifier
+// PIPELINE-NEXT: Verifier
 // PIPELINE-NEXT: Total
 
 // MT_LIST: Pass execution timing report
 // MT_LIST: Total Execution Time:
 // MT_LIST: Name
 // MT_LIST-DAG: Canonicalizer
-// MT_LIST-DAG: FunctionVerifier
+// MT_LIST-DAG: Verifier
 // MT_LIST-DAG: CSE
-// MT_LIST-DAG: ModuleVerifier
 // MT_LIST-DAG: DominanceInfo
 // MT_LIST: Total
 
 // MT_PIPELINE: Pass execution timing report
 // MT_PIPELINE: Total Execution Time:
 // MT_PIPELINE: Name
-// MT_PIPELINE-NEXT: Function Pipeline
+// MT_PIPELINE-NEXT: 'func' Pipeline
 // MT_PIPELINE-NEXT:   CSE
 // MT_PIPELINE-NEXT:     (A) DominanceInfo
-// MT_PIPELINE-NEXT:   FunctionVerifier
+// MT_PIPELINE-NEXT:   Verifier
 // MT_PIPELINE-NEXT:   Canonicalizer
-// MT_PIPELINE-NEXT:   FunctionVerifier
+// MT_PIPELINE-NEXT:   Verifier
 // MT_PIPELINE-NEXT:   CSE
 // MT_PIPELINE-NEXT:     (A) DominanceInfo
-// MT_PIPELINE-NEXT:   FunctionVerifier
-// MT_PIPELINE-NEXT: ModuleVerifier
+// MT_PIPELINE-NEXT:   Verifier
+// MT_PIPELINE-NEXT: Verifier
 // MT_PIPELINE-NEXT: Total
 
 func @foo() {
index 797b7bb..df69407 100644 (file)
@@ -137,7 +137,7 @@ public:
 } // end anonymous namespace
 
 static LogicalResult runMLIRPasses(ModuleOp m) {
-  PassManager pm;
+  PassManager pm(m.getContext());
 
   pm.addPass(createGpuKernelOutliningPass());
   pm.addPass(static_cast<std::unique_ptr<ModulePassBase>>(