From 02bc4c95f0729cc819776f73ec94a25405579183 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 26 Jan 2021 16:54:25 -0800 Subject: [PATCH] [mlir][PassManager] Only reinitialize the pass manager if the context registry changes This prevents needless reinitialization for clients that want to reuse a pass manager multiple times. A new `getRegisryHash` function is exposed by the context to give a rough indicator of when the context registry has changed. Differential Revision: https://reviews.llvm.org/D95493 --- mlir/include/mlir/IR/MLIRContext.h | 6 ++++++ mlir/include/mlir/Pass/PassManager.h | 3 +++ mlir/lib/IR/MLIRContext.cpp | 10 ++++++++++ mlir/lib/Pass/Pass.cpp | 7 ++++++- 4 files changed, 25 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h index 4751f00..eace86f 100644 --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -166,6 +166,12 @@ public: Dialect *getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, function_ref()> ctor); + /// Returns a hash of the registry of the context that may be used to give + /// a rough indicator of if the state of the context registry has changed. The + /// context registry correlates to loaded dialects and their entities + /// (attributes, operations, types, etc.). + llvm::hash_code getRegistryHash(); + private: const std::unique_ptr impl; diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h index beb6bc9..e73459f 100644 --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -375,6 +375,9 @@ private: /// An optional factory to use when generating a crash reproducer if valid. ReproducerStreamFactory crashReproducerStreamFactory; + /// A hash key used to detect when reinitialization is necessary. + llvm::hash_code initializationKey; + /// Flag that specifies if pass timing is enabled. bool passTiming : 1; diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 9307d9c..d782a85 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -492,6 +492,16 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, return dialect.get(); } +llvm::hash_code MLIRContext::getRegistryHash() { + llvm::hash_code hash(0); + // Factor in number of loaded dialects, attributes, operations, types. + hash = llvm::hash_combine(hash, impl->loadedDialects.size()); + hash = llvm::hash_combine(hash, impl->registeredAttributes.size()); + hash = llvm::hash_combine(hash, impl->registeredOperations.size()); + hash = llvm::hash_combine(hash, impl->registeredTypes.size()); + return hash; +} + bool MLIRContext::allowsUnregisteredDialects() { return impl->allowUnregisteredDialects; } diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 0828941..66c8f66 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -846,6 +846,7 @@ PassManager::runWithCrashRecovery(MutableArrayRef> passes, PassManager::PassManager(MLIRContext *ctx, Nesting nesting, StringRef operationName) : OpPassManager(Identifier::get(operationName, ctx), nesting), context(ctx), + initializationKey(DenseMapInfo::getTombstoneKey()), passTiming(false), localReproducer(false), verifyPasses(true) {} PassManager::~PassManager() {} @@ -868,7 +869,11 @@ LogicalResult PassManager::run(Operation *op) { dependentDialects.loadAll(context); // Initialize all of the passes within the pass manager with a new generation. - initialize(context, impl->initializationGeneration + 1); + llvm::hash_code newInitKey = context->getRegistryHash(); + if (newInitKey != initializationKey) { + initialize(context, impl->initializationGeneration + 1); + initializationKey = newInitKey; + } // Construct a top level analysis manager for the pipeline. ModuleAnalysisManager am(op, instrumentor.get()); -- 2.7.4