[mlir][Pass] Make PassManager default to op-agnostic
authorrkayaith <rkayaith@gmail.com>
Wed, 9 Nov 2022 04:23:28 +0000 (23:23 -0500)
committerRahul Kayaith <rkayaith@gmail.com>
Wed, 25 Jan 2023 20:38:19 +0000 (15:38 -0500)
Currently `PassManager` defaults to being anchored on `builtin.module`.
Switching the default makes `PassManager` consistent with
`OpPassManager` and avoids the implicit dependency on `builtin.module`.

Specifying the anchor op type isn't strictly necessary when using
explicit nesting (existing pipelines will continue to work), but I've
updated most call sites to specify the anchor since it allows for better
error-checking during pipeline construction.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D137731

17 files changed:
flang/lib/Frontend/FrontendActions.cpp
flang/tools/bbc/bbc.cpp
flang/tools/tco/tco.cpp
mlir/docs/PassManagement.md
mlir/docs/Tutorials/Toy/Ch-3.md
mlir/examples/toy/Ch3/toyc.cpp
mlir/examples/toy/Ch4/toyc.cpp
mlir/examples/toy/Ch5/toyc.cpp
mlir/examples/toy/Ch6/toyc.cpp
mlir/examples/toy/Ch7/toyc.cpp
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/Pass/PassManager.h
mlir/lib/Pass/Pass.cpp
mlir/lib/Rewrite/FrozenRewritePatternSet.cpp
mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
mlir/unittests/ExecutionEngine/Invoke.cpp
mlir/unittests/Pass/PassManagerTest.cpp

index 64641b7..927591c 100644 (file)
@@ -184,7 +184,8 @@ bool CodeGenAction::beginSourceFileAction() {
   lb.lower(parseTree, ci.getInvocation().getSemanticsContext());
 
   // run the default passes.
-  mlir::PassManager pm(mlirCtx.get(), mlir::OpPassManager::Nesting::Implicit);
+  mlir::PassManager pm((*mlirModule)->getName(),
+                       mlir::OpPassManager::Nesting::Implicit);
   pm.enableVerifier(/*verifyPasses=*/true);
   pm.addPass(std::make_unique<Fortran::lower::VerifierPass>());
 
@@ -535,7 +536,8 @@ void CodeGenAction::generateLLVMIR() {
   fir::support::registerLLVMTranslation(*mlirCtx);
 
   // Set-up the MLIR pass manager
-  mlir::PassManager pm(mlirCtx.get(), mlir::OpPassManager::Nesting::Implicit);
+  mlir::PassManager pm((*mlirModule)->getName(),
+                       mlir::OpPassManager::Nesting::Implicit);
 
   pm.addPass(std::make_unique<Fortran::lower::VerifierPass>());
   pm.enableVerifier(/*verifyPasses=*/true);
index 796e7fa..40cea36 100644 (file)
@@ -249,7 +249,8 @@ static mlir::LogicalResult convertFortranSourceToMLIR(
            << outputName;
 
   // Otherwise run the default passes.
-  mlir::PassManager pm(&ctx, mlir::OpPassManager::Nesting::Implicit);
+  mlir::PassManager pm(mlirModule->getName(),
+                       mlir::OpPassManager::Nesting::Implicit);
   pm.enableVerifier(/*verifyPasses=*/true);
   mlir::applyPassManagerCLOptions(pm);
   if (passPipeline.hasAnyOccurrences()) {
index a1b60aa..c68e545 100644 (file)
@@ -103,7 +103,8 @@ compileFIR(const mlir::PassPipelineCLParser &passPipeline) {
   fir::KindMapping kindMap{&context};
   fir::setTargetTriple(*owningRef, targetTriple);
   fir::setKindMapping(*owningRef, kindMap);
-  mlir::PassManager pm(&context, mlir::OpPassManager::Nesting::Implicit);
+  mlir::PassManager pm((*owningRef)->getName(),
+                       mlir::OpPassManager::Nesting::Implicit);
   pm.enableVerifier(/*verifyPasses=*/true);
   mlir::applyPassManagerCLOptions(pm);
   if (emitFir) {
index 498f887..661592b 100644 (file)
@@ -399,11 +399,8 @@ Below is an example of constructing a pipeline that operates on the above
 structure:
 
 ```c++
-// Create a top-level `PassManager` class. If an operation type is not
-// explicitly specific, the default is the builtin `module` operation.
-PassManager pm(ctx);
-// Note: We could also create the above `PassManager` this way.
-PassManager pm(ctx, /*operationName=*/"builtin.module");
+// Create a top-level `PassManager` class.
+auto pm = PassManager::on<ModuleOp>(ctx);
 
 // Add a pass on the top-level module operation.
 pm.addPass(std::make_unique<MyModulePass>());
index 8203082..811a1d2 100644 (file)
@@ -124,7 +124,7 @@ pipeline. In MLIR, the optimizations are run through a `PassManager` in a
 similar way to LLVM:
 
 ```c++
-  mlir::PassManager pm(module.getContext());
+  mlir::PassManager pm(module->getName());
   pm.addNestedPass<mlir::toy::FuncOp>(mlir::createCanonicalizerPass());
 ```
 
index ce33e50..ef362e4 100644 (file)
@@ -113,7 +113,7 @@ int dumpMLIR() {
     return error;
 
   if (enableOpt) {
-    mlir::PassManager pm(&context);
+    mlir::PassManager pm(module.get()->getName());
     // Apply any generic pass manager command line options and run the pipeline.
     applyPassManagerCLOptions(pm);
 
index cdd1577..bf8e694 100644 (file)
@@ -114,7 +114,7 @@ int dumpMLIR() {
     return error;
 
   if (enableOpt) {
-    mlir::PassManager pm(&context);
+    mlir::PassManager pm(module.get()->getName());
     // Apply any generic pass manager command line options and run the pipeline.
     applyPassManagerCLOptions(pm);
 
index 6331318..5a23c49 100644 (file)
@@ -117,7 +117,7 @@ int dumpMLIR() {
   if (int error = loadMLIR(sourceMgr, context, module))
     return error;
 
-  mlir::PassManager pm(&context);
+  mlir::PassManager pm(module.get()->getName());
   // Apply any generic pass manager command line options and run the pipeline.
   applyPassManagerCLOptions(pm);
 
index 32261ec..3983efd 100644 (file)
@@ -132,7 +132,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
   if (int error = loadMLIR(context, module))
     return error;
 
-  mlir::PassManager pm(&context);
+  mlir::PassManager pm(module.get()->getName());
   // Apply any generic pass manager command line options and run the pipeline.
   applyPassManagerCLOptions(pm);
 
index 2b8dc76..ddf6e46 100644 (file)
@@ -132,7 +132,7 @@ int loadAndProcessMLIR(mlir::MLIRContext &context,
   if (int error = loadMLIR(context, module))
     return error;
 
-  mlir::PassManager pm(&context);
+  mlir::PassManager pm(module.get()->getName());
   // Apply any generic pass manager command line options and run the pipeline.
   applyPassManagerCLOptions(pm);
 
index 5d526f9..99e63f2 100644 (file)
@@ -322,6 +322,9 @@ public:
   /// Return the operation name with dialect name stripped, if it has one.
   StringRef stripDialect() const { return getStringRef().split('.').second; }
 
+  /// Return the context this operation is associated with.
+  MLIRContext *getContext() { return getIdentifier().getContext(); }
+
   /// Return the name of this operation. This always succeeds.
   StringRef getStringRef() const { return getIdentifier(); }
 
index a2b1d97..71982c3 100644 (file)
@@ -213,14 +213,20 @@ public:
   /// Create a new pass manager under the given context with a specific nesting
   /// style. The created pass manager can schedule operations that match
   /// `operationName`.
-  /// FIXME: We should make the specification of `builtin.module` explicit here,
-  /// so that we can have top-level op-agnostic pass managers.
-  PassManager(MLIRContext *ctx, Nesting nesting = Nesting::Explicit,
-              StringRef operationName = "builtin.module");
-  PassManager(MLIRContext *ctx, StringRef operationName)
-      : PassManager(ctx, Nesting::Explicit, operationName) {}
+  PassManager(MLIRContext *ctx,
+              StringRef operationName = PassManager::getAnyOpAnchorName(),
+              Nesting nesting = Nesting::Explicit);
+  PassManager(OperationName operationName, Nesting nesting = Nesting::Explicit);
   ~PassManager();
 
+  /// Create a new pass manager under the given context with a specific nesting
+  /// style. The created pass manager can schedule operations that match
+  /// `OperationTy`.
+  template <typename OperationTy>
+  static PassManager on(MLIRContext *ctx, Nesting nesting = Nesting::Explicit) {
+    return PassManager(ctx, OperationTy::getOperationName(), nesting);
+  }
+
   /// Run the passes within this manager on the provided operation. The
   /// specified operation must have the same name as the one provided the pass
   /// manager on construction.
@@ -438,7 +444,8 @@ private:
   std::unique_ptr<detail::PassCrashReproducerGenerator> crashReproGenerator;
 
   /// A hash key used to detect when reinitialization is necessary.
-  llvm::hash_code initializationKey;
+  llvm::hash_code initializationKey =
+      DenseMapInfo<llvm::hash_code>::getTombstoneKey();
 
   /// Flag that specifies if pass timing is enabled.
   bool passTiming : 1;
index 8b5bd38..194ddac 100644 (file)
@@ -769,11 +769,15 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
 // PassManager
 //===----------------------------------------------------------------------===//
 
-PassManager::PassManager(MLIRContext *ctx, Nesting nesting,
-                         StringRef operationName)
-    : OpPassManager(OperationName(operationName, ctx), nesting), context(ctx),
-      initializationKey(DenseMapInfo<llvm::hash_code>::getTombstoneKey()),
-      passTiming(false), verifyPasses(true) {}
+PassManager::PassManager(MLIRContext *ctx, StringRef operationName,
+                         Nesting nesting)
+    : OpPassManager(operationName, nesting), context(ctx), passTiming(false),
+      verifyPasses(true) {}
+
+PassManager::PassManager(OperationName operationName, Nesting nesting)
+    : OpPassManager(operationName, nesting),
+      context(operationName.getContext()), passTiming(false),
+      verifyPasses(true) {}
 
 PassManager::~PassManager() = default;
 
index 0fa1668..43840d1 100644 (file)
@@ -34,7 +34,7 @@ convertPDLToPDLInterp(ModuleOp pdlModule,
   pdlModule.getBody()->walk(simplifyFn);
 
   /// Lower the PDL pattern module to the interpreter dialect.
-  PassManager pdlPipeline(pdlModule.getContext());
+  PassManager pdlPipeline(pdlModule->getName());
 #ifdef NDEBUG
   // We don't want to incur the hit of running the verifier when in release
   // mode.
index 62c84e2..94d9a24 100644 (file)
@@ -79,8 +79,7 @@ performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
   parserTiming.stop();
 
   // Prepare the pass manager, applying command-line and reproducer options.
-  PassManager pm(context, OpPassManager::Nesting::Implicit,
-                 op.get()->getName().getStringRef());
+  PassManager pm(op.get()->getName(), PassManager::Nesting::Implicit);
   pm.enableVerifier(verifyPasses);
   applyPassManagerCLOptions(pm);
   pm.enableTiming(timing);
index 57a260b..d0e2374 100644 (file)
@@ -52,7 +52,7 @@ static struct LLVMInitializer {
 /// Simple conversion pipeline for the purpose of testing sources written in
 /// dialects lowering to LLVM Dialect.
 static LogicalResult lowerToLLVMDialect(ModuleOp module) {
-  PassManager pm(module.getContext());
+  PassManager pm(module->getName());
   pm.addPass(mlir::createMemRefToLLVMConversionPass());
   pm.addNestedPass<func::FuncOp>(mlir::createArithToLLVMConversionPass());
   pm.addPass(mlir::createConvertFuncToLLVMPass());
index 9ed49c8..24e8702 100644 (file)
@@ -68,7 +68,7 @@ TEST(PassManagerTest, OpSpecificAnalysis) {
   }
 
   // Instantiate and run our pass.
-  PassManager pm(&context);
+  auto pm = PassManager::on<ModuleOp>(&context);
   pm.addNestedPass<func::FuncOp>(std::make_unique<AnnotateFunctionPass>());
   LogicalResult result = pm.run(module.get());
   EXPECT_TRUE(succeeded(result));
@@ -123,7 +123,7 @@ TEST(PassManagerTest, InvalidPass) {
   });
 
   // Instantiate and run our pass.
-  PassManager pm(&context);
+  auto pm = PassManager::on<ModuleOp>(&context);
   pm.nest("invalid_op").addPass(std::make_unique<InvalidPass>());
   LogicalResult result = pm.run(module.get());
   EXPECT_TRUE(failed(result));
@@ -138,7 +138,10 @@ TEST(PassManagerTest, InvalidPass) {
   EXPECT_TRUE(succeeded(result));
 
   // Check that adding the pass at the top-level triggers a fatal error.
-  ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()), "");
+  ASSERT_DEATH(pm.addPass(std::make_unique<InvalidPass>()),
+               "Can't add pass 'Invalid Pass' restricted to 'invalid_op' on a "
+               "PassManager intended to run on 'builtin.module', did you "
+               "intend to nest?");
 }
 
 } // namespace