TypeID getTypeID() const { return passID; }
/// Returns the pass info for the specified pass class or null if unknown.
- static const PassInfo *lookupPassInfo(TypeID passID);
- template <typename PassT> static const PassInfo *lookupPassInfo() {
- return lookupPassInfo(TypeID::get<PassT>());
- }
+ static const PassInfo *lookupPassInfo(StringRef passArg);
- /// Returns the pass info for this pass.
- const PassInfo *lookupPassInfo() const { return lookupPassInfo(getTypeID()); }
+ /// Returns the pass info for this pass, or null if unknown.
+ const PassInfo *lookupPassInfo() const {
+ return lookupPassInfo(getArgument());
+ }
/// Returns the derived pass name.
virtual StringRef getName() const = 0;
/// Returns the command line argument used when registering this pass. Return
/// an empty string if one does not exist.
- virtual StringRef getArgument() const {
- if (const PassInfo *passInfo = lookupPassInfo())
- return passInfo->getPassArgument();
- return "";
- }
+ virtual StringRef getArgument() const { return ""; }
/// Returns the name of the operation that this pass operates on, or None if
/// this is a generic OperationPass.
using namespace detail;
/// Static mapping of all of the registered passes.
-static llvm::ManagedStatic<DenseMap<TypeID, PassInfo>> passRegistry;
+static llvm::ManagedStatic<llvm::StringMap<PassInfo>> passRegistry;
+
+/// A mapping of the above pass registry entries to the corresponding TypeID
+/// of the pass that they generate.
+static llvm::ManagedStatic<llvm::StringMap<TypeID>> passRegistryTypeIDs;
/// Static mapping of all of the registered pass pipelines.
static llvm::ManagedStatic<llvm::StringMap<PassPipelineInfo>>
// PassInfo
//===----------------------------------------------------------------------===//
-PassInfo::PassInfo(StringRef arg, StringRef description, TypeID passID,
+PassInfo::PassInfo(StringRef arg, StringRef description,
const PassAllocatorFunction &allocator)
: PassRegistryEntry(
arg, description, buildDefaultRegistryFn(allocator),
void mlir::registerPass(StringRef arg, StringRef description,
const PassAllocatorFunction &function) {
- // TODO: We should use the 'arg' as the lookup key instead of the pass id.
- TypeID passID = function()->getTypeID();
- PassInfo passInfo(arg, description, passID, function);
- passRegistry->try_emplace(passID, passInfo);
+ PassInfo passInfo(arg, description, function);
+ passRegistry->try_emplace(arg, passInfo);
+
+ // Verify that the registered pass has the same ID as any registered to this
+ // arg before it.
+ TypeID entryTypeID = function()->getTypeID();
+ auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first;
+ if (it->second != entryTypeID) {
+ llvm_unreachable("pass allocator creates a different pass than previously "
+ "registered");
+ }
}
-/// Returns the pass info for the specified pass class or null if unknown.
-const PassInfo *mlir::Pass::lookupPassInfo(TypeID passID) {
- auto it = passRegistry->find(passID);
- if (it == passRegistry->end())
- return nullptr;
- return &it->getSecond();
+/// Returns the pass info for the specified pass argument or null if unknown.
+const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) {
+ auto it = passRegistry->find(passArg);
+ return it == passRegistry->end() ? nullptr : &it->second;
}
//===----------------------------------------------------------------------===//
}
// If not, then this must be a specific pass name.
- for (auto &passIt : *passRegistry) {
- if (passIt.second.getPassArgument() == element.name) {
- element.registryEntry = &passIt.second;
- return success();
- }
- }
+ if ((element.registryEntry = Pass::lookupPassInfo(element.name)))
+ return success();
// Emit an error for the unknown pass.
auto *rawLoc = element.name.data();
struct TestModulePass
: public PassWrapper<TestModulePass, OperationPass<ModuleOp>> {
void runOnOperation() final {}
+ StringRef getArgument() const final { return "test-module-pass"; }
};
struct TestFunctionPass : public PassWrapper<TestFunctionPass, FunctionPass> {
void runOnFunction() final {}
+ StringRef getArgument() const final { return "test-function-pass"; }
};
class TestOptionsPass : public PassWrapper<TestOptionsPass, FunctionPass> {
public:
}
void runOnFunction() final {}
+ StringRef getArgument() const final { return "test-options-pass"; }
ListOption<int> listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::desc("Example list option")};
class TestCrashRecoveryPass
: public PassWrapper<TestCrashRecoveryPass, OperationPass<>> {
void runOnOperation() final { abort(); }
+ StringRef getArgument() const final { return "test-pass-crash"; }
};
/// A test pass that always fails to enable testing the failure recovery