[mlir] Resolve TODO and use the pass argument instead of the TypeID for registration
authorRiver Riddle <riddleriver@gmail.com>
Wed, 2 Jun 2021 19:06:32 +0000 (12:06 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Wed, 2 Jun 2021 19:17:36 +0000 (12:17 -0700)
This simplifies various pieces of code that interact with the pass registry, e.g. this removes the need to register passes to get accurate pass pipelines descriptions when generating crash reproducers.

Differential Revision: https://reviews.llvm.org/D101880

mlir/include/mlir/Pass/Pass.h
mlir/include/mlir/Pass/PassRegistry.h
mlir/lib/Pass/PassRegistry.cpp
mlir/test/lib/Pass/TestPassManager.cpp

index 42df2dc..67c6954 100644 (file)
@@ -56,13 +56,12 @@ public:
   TypeID getTypeID() const { return passID; }
 
   /// Returns the pass info for the specified pass class or null if unknown.
-  static const PassInfo *lookupPassInfo(TypeID passID);
-  template <typename PassT> static const PassInfo *lookupPassInfo() {
-    return lookupPassInfo(TypeID::get<PassT>());
-  }
+  static const PassInfo *lookupPassInfo(StringRef passArg);
 
-  /// Returns the pass info for this pass.
-  const PassInfo *lookupPassInfo() const { return lookupPassInfo(getTypeID()); }
+  /// Returns the pass info for this pass, or null if unknown.
+  const PassInfo *lookupPassInfo() const {
+    return lookupPassInfo(getArgument());
+  }
 
   /// Returns the derived pass name.
   virtual StringRef getName() const = 0;
@@ -76,11 +75,7 @@ public:
 
   /// Returns the command line argument used when registering this pass. Return
   /// an empty string if one does not exist.
-  virtual StringRef getArgument() const {
-    if (const PassInfo *passInfo = lookupPassInfo())
-      return passInfo->getPassArgument();
-    return "";
-  }
+  virtual StringRef getArgument() const { return ""; }
 
   /// Returns the name of the operation that this pass operates on, or None if
   /// this is a generic OperationPass.
index 8def0f3..d03aaf8 100644 (file)
@@ -108,7 +108,7 @@ class PassInfo : public PassRegistryEntry {
 public:
   /// PassInfo constructor should not be invoked directly, instead use
   /// PassRegistration or registerPass.
-  PassInfo(StringRef arg, StringRef description, TypeID passID,
+  PassInfo(StringRef arg, StringRef description,
            const PassAllocatorFunction &allocator);
 };
 
index e53113e..2c690a2 100644 (file)
@@ -19,7 +19,11 @@ using namespace mlir;
 using namespace detail;
 
 /// Static mapping of all of the registered passes.
-static llvm::ManagedStatic<DenseMap<TypeID, PassInfo>> passRegistry;
+static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
+
+/// A mapping of the above pass registry entries to the corresponding TypeID
+/// of the pass that they generate.
+static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
 
 /// Static mapping of all of the registered pass pipelines.
 static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
@@ -94,7 +98,7 @@ void mlir::registerPassPipeline(
 // PassInfo
 //===----------------------------------------------------------------------===//
 
-PassInfo::PassInfo(StringRef arg, StringRef description, TypeID passID,
+PassInfo::PassInfo(StringRef arg, StringRef description,
                    const PassAllocatorFunction &allocator)
     : PassRegistryEntry(
           arg, description, buildDefaultRegistryFn(allocator),
@@ -105,18 +109,23 @@ PassInfo::PassInfo(StringRef arg, StringRef description, TypeID passID,
 
 void mlir::registerPass(StringRef arg, StringRef description,
                         const PassAllocatorFunction &function) {
-  // TODO: We should use the 'arg' as the lookup key instead of the pass id.
-  TypeID passID = function()->getTypeID();
-  PassInfo passInfo(arg, description, passID, function);
-  passRegistry->try_emplace(passID, passInfo);
+  PassInfo passInfo(arg, description, function);
+  passRegistry->try_emplace(arg, passInfo);
+
+  // Verify that the registered pass has the same ID as any registered to this
+  // arg before it.
+  TypeID entryTypeID = function()->getTypeID();
+  auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
+  if (it->second != entryTypeID) {
+    llvm_unreachable("pass allocator creates a different pass than previously "
+                     "registered");
+  }
 }
 
-/// Returns the pass info for the specified pass class or null if unknown.
-const PassInfo *mlir::Pass::lookupPassInfo(TypeID passID) {
-  auto it = passRegistry->find(passID);
-  if (it == passRegistry->end())
-    return nullptr;
-  return &it->getSecond();
+/// Returns the pass info for the specified pass argument or null if unknown.
+const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) {
+  auto it = passRegistry->find(passArg);
+  return it == passRegistry->end() ? nullptr : &it->second;
 }
 
 //===----------------------------------------------------------------------===//
@@ -433,12 +442,8 @@ TextualPipeline::resolvePipelineElement(PipelineElement &element,
   }
 
   // If not, then this must be a specific pass name.
-  for (auto &passIt : *passRegistry) {
-    if (passIt.second.getPassArgument() == element.name) {
-      element.registryEntry = &passIt.second;
-      return success();
-    }
-  }
+  if ((element.registryEntry = Pass::lookupPassInfo(element.name)))
+    return success();
 
   // Emit an error for the unknown pass.
   auto *rawLoc = element.name.data();
index 937a5c2..6e5a5b9 100644 (file)
@@ -16,9 +16,11 @@ namespace {
 struct TestModulePass
     : public PassWrapper<TestModulePass, OperationPass<ModuleOp>> {
   void runOnOperation() final {}
+  StringRef getArgument() const final { return "test-module-pass"; }
 };
 struct TestFunctionPass : public PassWrapper<TestFunctionPass, FunctionPass> {
   void runOnFunction() final {}
+  StringRef getArgument() const final { return "test-function-pass"; }
 };
 class TestOptionsPass : public PassWrapper<TestOptionsPass, FunctionPass> {
 public:
@@ -41,6 +43,7 @@ public:
   }
 
   void runOnFunction() final {}
+  StringRef getArgument() const final { return "test-options-pass"; }
 
   ListOption<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated,
                              llvm::cl::desc("Example list option")};
@@ -56,6 +59,7 @@ public:
 class TestCrashRecoveryPass
     : public PassWrapper<TestCrashRecoveryPass, OperationPass<>> {
   void runOnOperation() final { abort(); }
+  StringRef getArgument() const final { return "test-pass-crash"; }
 };
 
 /// A test pass that always fails to enable testing the failure recovery