From 0f304ef0170231b860a249f34e07f50686392253 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 18 Jan 2022 16:16:54 -0800 Subject: [PATCH] [mlir] Add asserts when changing various MLIRContext configurations This helps to prevent tsan failures when users inadvertantly mutate the context in a non-safe way. Differential Revision: https://reviews.llvm.org/D112021 --- mlir/include/mlir/IR/DialectRegistry.h | 4 ++++ mlir/lib/IR/Dialect.cpp | 9 +++++++++ mlir/lib/IR/MLIRContext.cpp | 18 ++++++++++++++++++ mlir/lib/Reducer/OptReductionPass.cpp | 10 ++++++++-- .../lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp | 11 +++++------ 5 files changed, 44 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h index fbc81b0..5e55fab 100644 --- a/mlir/include/mlir/IR/DialectRegistry.h +++ b/mlir/include/mlir/IR/DialectRegistry.h @@ -212,6 +212,10 @@ public: addExtension(std::make_unique(std::move(extensionFn))); } + /// Returns true if the current registry is a subset of 'rhs', i.e. if 'rhs' + /// contains all of the components of this registry. + bool isSubsetOf(const DialectRegistry &rhs) const; + private: MapTy registry; std::vector> extensions; diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index 2e983d6..b8f5aa2 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -228,3 +228,12 @@ void DialectRegistry::applyExtensions(MLIRContext *ctx) const { for (const auto &extension : extensions) applyExtension(*extension); } + +bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const { + // Treat any extensions conservatively. + if (!extensions.empty()) + return false; + // Check that the current dialects fully overlap with the dialects in 'rhs'. + return llvm::all_of( + registry, [&](const auto &it) { return rhs.registry.count(it.first); }); +} diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 2c0b3ba..eae6363 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -355,6 +355,12 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; } //===----------------------------------------------------------------------===// void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) { + if (registry.isSubsetOf(impl->dialectsRegistry)) + return; + + assert(impl->multiThreadedExecutionContext == 0 && + "appending to the MLIRContext dialect registry while in a " + "multi-threaded execution context"); registry.appendTo(impl->dialectsRegistry); // For the already loaded dialects, apply any possible extensions immediately. @@ -470,6 +476,9 @@ bool MLIRContext::allowsUnregisteredDialects() { } void MLIRContext::allowUnregisteredDialects(bool allowing) { + assert(impl->multiThreadedExecutionContext == 0 && + "changing MLIRContext `allow-unregistered-dialects` configuration " + "while in a multi-threaded execution context"); impl->allowUnregisteredDialects = allowing; } @@ -484,6 +493,9 @@ void MLIRContext::disableMultithreading(bool disable) { // --mlir-disable-threading if (isThreadingGloballyDisabled()) return; + assert(impl->multiThreadedExecutionContext == 0 && + "changing MLIRContext `disable-threading` configuration while " + "in a multi-threaded execution context"); impl->threadingIsEnabled = !disable; @@ -557,6 +569,9 @@ bool MLIRContext::shouldPrintOpOnDiagnostic() { /// Set the flag specifying if we should attach the operation to diagnostics /// emitted via Operation::emit. void MLIRContext::printOpOnDiagnostic(bool enable) { + assert(impl->multiThreadedExecutionContext == 0 && + "changing MLIRContext `print-op-on-diagnostic` configuration while in " + "a multi-threaded execution context"); impl->printOpOnDiagnostic = enable; } @@ -569,6 +584,9 @@ bool MLIRContext::shouldPrintStackTraceOnDiagnostic() { /// Set the flag specifying if we should attach the current stacktrace when /// emitting diagnostics. void MLIRContext::printStackTraceOnDiagnostic(bool enable) { + assert(impl->multiThreadedExecutionContext == 0 && + "changing MLIRContext `print-stacktrace-on-diagnostic` configuration " + "while in a multi-threaded execution context"); impl->printStackTraceOnDiagnostic = enable; } diff --git a/mlir/lib/Reducer/OptReductionPass.cpp b/mlir/lib/Reducer/OptReductionPass.cpp index 806ce67..a7f09b4 100644 --- a/mlir/lib/Reducer/OptReductionPass.cpp +++ b/mlir/lib/Reducer/OptReductionPass.cpp @@ -42,7 +42,7 @@ void OptReductionPass::runOnOperation() { ModuleOp module = this->getOperation(); ModuleOp moduleVariant = module.clone(); - PassManager passManager(module.getContext()); + OpPassManager passManager("builtin.module"); if (failed(parsePassPipeline(optPass, passManager))) { module.emitError() << "\nfailed to parse pass pipeline"; return signalPassFailure(); @@ -54,7 +54,13 @@ void OptReductionPass::runOnOperation() { return signalPassFailure(); } - if (failed(passManager.run(moduleVariant))) { + // Temporarily push the variant under the main module and execute the pipeline + // on it. + module.getBody()->push_back(moduleVariant); + LogicalResult pipelineResult = runPipeline(passManager, moduleVariant); + moduleVariant->remove(); + + if (failed(pipelineResult)) { module.emitError() << "\nfailed to run pass pipeline"; return signalPassFailure(); } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp index 88ccc77..573197a 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -255,14 +255,13 @@ struct TestLinalgGreedyFusion patterns.add(context); scf::populateSCFForLoopCanonicalizationPatterns(patterns); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + OpPassManager pm(FuncOp::getOperationName()); + pm.addPass(createLoopInvariantCodeMotionPass()); + pm.addPass(createCanonicalizerPass()); + pm.addPass(createCSEPass()); do { (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns); - PassManager pm(context); - pm.addPass(createLoopInvariantCodeMotionPass()); - pm.addPass(createCanonicalizerPass()); - pm.addPass(createCSEPass()); - LogicalResult res = pm.run(getOperation()->getParentOfType()); - if (failed(res)) + if (failed(runPipeline(pm, getOperation()))) this->signalPassFailure(); } while (succeeded(fuseLinalgOpsGreedily(getOperation()))); } -- 2.7.4