From cb9ae0025c4ed966a3a9b5539a9ff6b6e865516f Mon Sep 17 00:00:00 2001 From: River Riddle Date: Sat, 2 May 2020 12:28:57 -0700 Subject: [PATCH] [mlir] Add a new context flag for disabling/enabling multi-threading This is useful for several reasons: * In some situations the user can guarantee that thread-safety isn't necessary and don't want to pay the cost of synchronization, e.g., when parsing a very large module. * For things like logging threading is not desirable as the output is not guaranteed to be in stable order. This flag also subsumes the pass manager flag for multi-threading. Differential Revision: https://reviews.llvm.org/D79266 --- mlir/docs/PassManagement.md | 8 +- mlir/include/mlir/IR/MLIRContext.h | 6 ++ mlir/include/mlir/Pass/PassManager.h | 9 +- mlir/include/mlir/Support/StorageUniquer.h | 3 + mlir/lib/IR/MLIRContext.cpp | 159 ++++++++++++++++++++--------- mlir/lib/Pass/IRPrinting.cpp | 3 +- mlir/lib/Pass/Pass.cpp | 42 +++----- mlir/lib/Pass/PassManagerOptions.cpp | 12 --- mlir/lib/Support/StorageUniquer.cpp | 39 +++++-- mlir/lib/Transforms/Inliner.cpp | 38 ++++--- mlir/test/Dialect/SPIRV/availability.mlir | 2 +- mlir/test/Dialect/SPIRV/target-env.mlir | 2 +- mlir/test/IR/test-matchers.mlir | 2 +- mlir/test/Pass/ir-printing.mlir | 12 +-- mlir/test/Pass/pass-timing.mlir | 10 +- 15 files changed, 206 insertions(+), 141 deletions(-) diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md index 90d30de..04a4ca0 100644 --- a/mlir/docs/PassManagement.md +++ b/mlir/docs/PassManagement.md @@ -801,7 +801,7 @@ pipeline. This display mode is available in mlir-opt via `-pass-timing-display=list`. ```shell -$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing -pass-timing-display=list +$ mlir-opt foo.mlir -mlir-disable-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing -pass-timing-display=list ===-------------------------------------------------------------------------=== ... Pass execution timing report ... @@ -826,7 +826,7 @@ the most time, and can also be used to identify when analyses are being invalidated and recomputed. This is the default display mode. ```shell -$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing +$ mlir-opt foo.mlir -mlir-disable-threading -pass-pipeline='func(cse,canonicalize)' -convert-std-to-llvm -pass-timing ===-------------------------------------------------------------------------=== ... Pass execution timing report ... @@ -943,10 +943,10 @@ func @simple_constant() -> (i32, i32) { * Always print the top-level module operation, regardless of pass type or operation nesting level. * Note: Printing at module scope should only be used when multi-threading - is disabled(`-disable-pass-threading`) + is disabled(`-mlir-disable-threading`) ```shell -$ mlir-opt foo.mlir -disable-pass-threading -pass-pipeline='func(cse)' -print-ir-after=cse -print-ir-module-scope +$ mlir-opt foo.mlir -mlir-disable-threading -pass-pipeline='func(cse)' -print-ir-after=cse -print-ir-module-scope *** IR Dump After CSE *** ('func' operation: @bar) func @bar(%arg0: f32, %arg1: f32) -> f32 { diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h index e121388..40b3326 100644 --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -55,6 +55,12 @@ public: /// Enables creating operations in unregistered dialects. void allowUnregisteredDialects(bool allow = true); + /// Return true if multi-threading is enabled by the context. + bool isMultithreadingEnabled(); + + /// Set the flag specifying if multi-threading is disabled by the context. + void disableMultithreading(bool disable = true); + /// Return true if we should attach the operation to diagnostics emitted via /// Operation::emit. bool shouldPrintOpOnDiagnostic(); diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h index 15c8812..be5c5d5 100644 --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -99,7 +99,7 @@ public: void mergeStatisticsInto(OpPassManager &other); private: - OpPassManager(OperationName name, bool disableThreads, bool verifyPasses); + OpPassManager(OperationName name, bool verifyPasses); /// A pointer to an internal implementation instance. std::unique_ptr impl; @@ -139,13 +139,6 @@ public: LLVM_NODISCARD LogicalResult run(ModuleOp module); - /// Disable support for multi-threading within the pass manager. - void disableMultithreading(bool disable = true); - - /// Return true if the pass manager is configured with multi-threading - /// enabled. - bool isMultithreadingEnabled(); - /// Enable support for the pass manager to generate a reproducer on the event /// of a crash or a pass failure. `outputFile` is a .mlir filename used to /// write the generated reproducer. If `genLocalReproducer` is true, the pass diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h index 62a43ff..f13a2fe 100644 --- a/mlir/include/mlir/Support/StorageUniquer.h +++ b/mlir/include/mlir/Support/StorageUniquer.h @@ -65,6 +65,9 @@ public: StorageUniquer(); ~StorageUniquer(); + /// Set the flag specifying if multi-threading is disabled within the uniquer. + void disableMultithreading(bool disable = true); + /// This class acts as the base storage that all storage classes must derived /// from. class BaseStorage { diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index b25c511..c59c535 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -50,6 +50,10 @@ namespace { /// various bits of an MLIRContext. This uses a struct wrapper to avoid the need /// for global command line options. struct MLIRContextOptions { + llvm::cl::opt disableThreading{ + "mlir-disable-threading", + llvm::cl::desc("Disabling multi-threading within MLIR")}; + llvm::cl::opt printOpOnDiagnostic{ "mlir-print-op-on-diagnostic", llvm::cl::desc("When a diagnostic is emitted on an operation, also print " @@ -102,6 +106,41 @@ struct BuiltinDialect : public Dialect { } // end anonymous namespace. //===----------------------------------------------------------------------===// +// Locking Utilities +//===----------------------------------------------------------------------===// + +namespace { +/// Utility reader lock that takes a runtime flag that specifies if we really +/// need to lock. +struct ScopedReaderLock { + ScopedReaderLock(llvm::sys::SmartRWMutex &mutexParam, bool shouldLock) + : mutex(shouldLock ? &mutexParam : nullptr) { + if (mutex) + mutex->lock_shared(); + } + ~ScopedReaderLock() { + if (mutex) + mutex->unlock_shared(); + } + llvm::sys::SmartRWMutex *mutex; +}; +/// Utility writer lock that takes a runtime flag that specifies if we really +/// need to lock. +struct ScopedWriterLock { + ScopedWriterLock(llvm::sys::SmartRWMutex &mutexParam, bool shouldLock) + : mutex(shouldLock ? &mutexParam : nullptr) { + if (mutex) + mutex->lock(); + } + ~ScopedWriterLock() { + if (mutex) + mutex->unlock(); + } + llvm::sys::SmartRWMutex *mutex; +}; +} // end anonymous namespace. + +//===----------------------------------------------------------------------===// // AffineMap and IntegerSet hashing //===----------------------------------------------------------------------===// @@ -111,8 +150,10 @@ template static ValueT safeGetOrCreate(DenseSet &container, KeyT &&key, llvm::sys::SmartRWMutex &mutex, + bool threadingIsEnabled, ConstructorFn &&constructorFn) { - { // Check for an existing instance in read-only mode. + // Check for an existing instance in read-only mode. + if (threadingIsEnabled) { llvm::sys::SmartScopedReader instanceLock(mutex); auto it = container.find_as(key); if (it != container.end()) @@ -120,16 +161,14 @@ static ValueT safeGetOrCreate(DenseSet &container, } // Acquire a writer-lock so that we can safely create the new instance. - llvm::sys::SmartScopedWriter instanceLock(mutex); + ScopedWriterLock instanceLock(mutex, threadingIsEnabled); // Check for an existing instance again here, because another writer thread - // may have already created one. + // may have already created one. Otherwise, construct a new instance. auto existing = container.insert_as(ValueT(), key); - if (!existing.second) - return *existing.first; - - // Otherwise, construct a new instance of the value. - return *existing.first = constructorFn(); + if (existing.second) + return *existing.first = constructorFn(); + return *existing.first; } namespace { @@ -217,6 +256,9 @@ public: /// detect such use cases bool allowUnregisteredDialects = false; + /// Enable support for multi-threading within MLIR. + bool threadingIsEnabled = true; + /// If the operation should be attached to diagnostics printed via the /// Operation::emit methods. bool printOpOnDiagnostic = true; @@ -288,17 +330,19 @@ public: UnknownLoc unknownLocAttr; public: - MLIRContextImpl() : identifiers(identifierAllocator) { - // Initialize values based on the command line flags if they were provided. - if (clOptions.isConstructed()) { - printOpOnDiagnostic = clOptions->printOpOnDiagnostic; - printStackTraceOnDiagnostic = clOptions->printStackTraceOnDiagnostic; - } - } + MLIRContextImpl() : identifiers(identifierAllocator) {} }; } // end namespace mlir MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) { + // Initialize values based on the command line flags if they were provided. + if (clOptions.isConstructed()) { + disableMultithreading(clOptions->disableThreading); + printOpOnDiagnostic(clOptions->printOpOnDiagnostic); + printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic); + } + + // Register dialects with this context. new BuiltinDialect(this); registerAllDialects(this); @@ -372,11 +416,10 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; } /// Return information about all registered IR dialects. std::vector MLIRContext::getRegisteredDialects() { // Lock access to the context registry. - llvm::sys::SmartScopedReader registryLock(getImpl().contextMutex); - + ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled); std::vector result; - result.reserve(getImpl().dialects.size()); - for (auto &dialect : getImpl().dialects) + result.reserve(impl->dialects.size()); + for (auto &dialect : impl->dialects) result.push_back(dialect.get()); return result; } @@ -385,11 +428,15 @@ std::vector MLIRContext::getRegisteredDialects() { /// then return nullptr. Dialect *MLIRContext::getRegisteredDialect(StringRef name) { // Lock access to the context registry. - llvm::sys::SmartScopedReader registryLock(getImpl().contextMutex); - for (auto &dialect : getImpl().dialects) - if (name == dialect->getNamespace()) - return dialect.get(); - return nullptr; + ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled); + + // Dialects are sorted by name, so we can use binary search for lookup. + auto it = llvm::lower_bound( + impl->dialects, name, + [](const auto &lhs, StringRef rhs) { return lhs->getNamespace() < rhs; }); + return (it != impl->dialects.end() && (*it)->getNamespace() == name) + ? (*it).get() + : nullptr; } /// Register this dialect object with the specified context. The context @@ -399,15 +446,13 @@ void Dialect::registerDialect(MLIRContext *context) { std::unique_ptr dialect(this); // Lock access to the context registry. - llvm::sys::SmartScopedWriter registryLock(impl.contextMutex); + ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled); // Get the correct insertion position sorted by namespace. - auto insertPt = - llvm::lower_bound(impl.dialects, dialect, - [](const std::unique_ptr &lhs, - const std::unique_ptr &rhs) { - return lhs->getNamespace() < rhs->getNamespace(); - }); + auto insertPt = llvm::lower_bound( + impl.dialects, dialect, [](const auto &lhs, const auto &rhs) { + return lhs->getNamespace() < rhs->getNamespace(); + }); // Abort if dialect with namespace has already been registered. if (insertPt != impl.dialects.end() && @@ -426,6 +471,21 @@ void MLIRContext::allowUnregisteredDialects(bool allowing) { impl->allowUnregisteredDialects = allowing; } +/// Return true if multi-threading is disabled by the context. +bool MLIRContext::isMultithreadingEnabled() { + return impl->threadingIsEnabled && llvm::llvm_is_multithreaded(); +} + +/// Set the flag specifying if multi-threading is disabled by the context. +void MLIRContext::disableMultithreading(bool disable) { + impl->threadingIsEnabled = !disable; + + // Update the threading mode for each of the uniquers. + impl->affineUniquer.disableMultithreading(disable); + impl->attributeUniquer.disableMultithreading(disable); + impl->typeUniquer.disableMultithreading(disable); +} + /// Return true if we should attach the operation to diagnostics emitted via /// Operation::emit. bool MLIRContext::shouldPrintOpOnDiagnostic() { @@ -457,13 +517,13 @@ std::vector MLIRContext::getRegisteredOperations() { std::vector> opsToSort; { // Lock access to the context registry. - llvm::sys::SmartScopedReader registryLock(getImpl().contextMutex); + ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled); // We just have the operations in a non-deterministic hash table order. Dump // into a temporary array, then sort it by operation name to get a stable // ordering. llvm::StringMap ®isteredOps = - getImpl().registeredOperations; + impl->registeredOperations; opsToSort.reserve(registeredOps.size()); for (auto &elt : registeredOps) @@ -487,7 +547,7 @@ void Dialect::addOperation(AbstractOperation opInfo) { auto &impl = context->getImpl(); // Lock access to the context registry. - llvm::sys::SmartScopedWriter registryLock(impl.contextMutex); + ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled); if (!impl.registeredOperations.insert({opInfo.name, opInfo}).second) { llvm::errs() << "error: operation named '" << opInfo.name << "' is already registered.\n"; @@ -500,7 +560,7 @@ void Dialect::addSymbol(TypeID typeID) { auto &impl = context->getImpl(); // Lock access to the context registry. - llvm::sys::SmartScopedWriter registryLock(impl.contextMutex); + ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled); if (!impl.registeredDialectSymbols.insert({typeID, this}).second) { llvm::errs() << "error: dialect symbol already registered.\n"; abort(); @@ -514,7 +574,7 @@ const AbstractOperation *AbstractOperation::lookup(StringRef opName, auto &impl = context->getImpl(); // Lock access to the context registry. - llvm::sys::SmartScopedReader registryLock(impl.contextMutex); + ScopedReaderLock registryLock(impl.contextMutex, impl.threadingIsEnabled); auto it = impl.registeredOperations.find(opName); if (it != impl.registeredOperations.end()) return &it->second; @@ -529,7 +589,8 @@ const AbstractOperation *AbstractOperation::lookup(StringRef opName, Identifier Identifier::get(StringRef str, MLIRContext *context) { auto &impl = context->getImpl(); - { // Check for an existing identifier in read-only mode. + // Check for an existing identifier in read-only mode. + if (context->isMultithreadingEnabled()) { llvm::sys::SmartScopedReader contextLock(impl.identifierMutex); auto it = impl.identifiers.find(str); if (it != impl.identifiers.end()) @@ -544,7 +605,7 @@ Identifier Identifier::get(StringRef str, MLIRContext *context) { "Cannot create an identifier with a nul character"); // Acquire a writer-lock so that we can safely create the new instance. - llvm::sys::SmartScopedWriter contextLock(impl.identifierMutex); + ScopedWriterLock contextLock(impl.identifierMutex, impl.threadingIsEnabled); auto it = impl.identifiers.insert(str).first; return Identifier(&*it); } @@ -696,16 +757,18 @@ AffineMap AffineMap::getImpl(unsigned dimCount, unsigned symbolCount, auto key = std::make_tuple(dimCount, symbolCount, results); // Safely get or create an AffineMap instance. - return safeGetOrCreate(impl.affineMaps, key, impl.affineMutex, [&] { - auto *res = impl.affineAllocator.Allocate(); + return safeGetOrCreate( + impl.affineMaps, key, impl.affineMutex, impl.threadingIsEnabled, [&] { + auto *res = impl.affineAllocator.Allocate(); - // Copy the results into the bump pointer. - results = copyArrayRefInto(impl.affineAllocator, results); + // Copy the results into the bump pointer. + results = copyArrayRefInto(impl.affineAllocator, results); - // Initialize the memory using placement new. - new (res) detail::AffineMapStorage{dimCount, symbolCount, results, context}; - return AffineMap(res); - }); + // Initialize the memory using placement new. + new (res) + detail::AffineMapStorage{dimCount, symbolCount, results, context}; + return AffineMap(res); + }); } AffineMap AffineMap::get(MLIRContext *context) { @@ -760,12 +823,12 @@ IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount, if (constraints.size() < IntegerSet::kUniquingThreshold) { auto key = std::make_tuple(dimCount, symbolCount, constraints, eqFlags); return safeGetOrCreate(impl.integerSets, key, impl.affineMutex, - constructorFn); + impl.threadingIsEnabled, constructorFn); } // Otherwise, acquire a writer-lock so that we can safely create the new // instance. - llvm::sys::SmartScopedWriter affineLock(impl.affineMutex); + ScopedWriterLock affineLock(impl.affineMutex, impl.threadingIsEnabled); return constructorFn(); } diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp index ba9ff98..842b83c 100644 --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -257,7 +257,8 @@ struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig { /// Add an instrumentation to print the IR before and after pass execution, /// using the provided configuration. void PassManager::enableIRPrinting(std::unique_ptr config) { - if (config->shouldPrintAtModuleScope() && isMultithreadingEnabled()) + if (config->shouldPrintAtModuleScope() && + getContext()->isMultithreadingEnabled()) llvm::report_fatal_error("IR printing can't be setup on a pass-manager " "without disabling multi-threading first."); addInstrumentation( diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index ef4ed76..83855fa 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -118,9 +118,8 @@ void VerifierPass::runOnOperation() { namespace mlir { namespace detail { struct OpPassManagerImpl { - OpPassManagerImpl(OperationName name, bool disableThreads, bool verifyPasses) - : name(name), disableThreads(disableThreads), verifyPasses(verifyPasses) { - } + OpPassManagerImpl(OperationName name, bool verifyPasses) + : name(name), verifyPasses(verifyPasses) {} /// Merge the passes of this pass manager into the one provided. void mergeInto(OpPassManagerImpl &rhs); @@ -152,9 +151,6 @@ struct OpPassManagerImpl { /// The name of the operation that passes of this pass manager operate on. OperationName name; - /// Flag to disable multi-threading of passes. - bool disableThreads : 1; - /// Flag that specifies if the IR should be verified after each pass has run. bool verifyPasses : 1; @@ -172,7 +168,7 @@ void OpPassManagerImpl::mergeInto(OpPassManagerImpl &rhs) { } OpPassManager &OpPassManagerImpl::nest(const OperationName &nestedName) { - OpPassManager nested(nestedName, disableThreads, verifyPasses); + OpPassManager nested(nestedName, verifyPasses); auto *adaptor = new OpToOpPassAdaptor(std::move(nested)); addPass(std::unique_ptr(adaptor)); return adaptor->getPassManagers().front(); @@ -269,9 +265,8 @@ void OpPassManagerImpl::splitAdaptorPasses() { // OpPassManager //===----------------------------------------------------------------------===// -OpPassManager::OpPassManager(OperationName name, bool disableThreads, - bool verifyPasses) - : impl(new OpPassManagerImpl(name, disableThreads, verifyPasses)) { +OpPassManager::OpPassManager(OperationName name, bool verifyPasses) + : impl(new OpPassManagerImpl(name, verifyPasses)) { assert(name.getAbstractOperation() && "OpPassManager can only operate on registered operations"); assert(name.getAbstractOperation()->hasProperty( @@ -282,8 +277,7 @@ OpPassManager::OpPassManager(OperationName name, bool disableThreads, OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {} OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; } OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) { - impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->disableThreads, - rhs.impl->verifyPasses)); + impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->verifyPasses)); for (auto &pass : rhs.impl->passes) impl->passes.emplace_back(pass->clone()); return *this; @@ -419,10 +413,10 @@ std::string OpToOpPassAdaptor::getAdaptorName() { /// Run the held pipeline over all nested operations. void OpToOpPassAdaptor::runOnOperation() { - if (mgrs.front().getImpl().disableThreads || !llvm::llvm_is_multithreaded()) - runOnOperationImpl(); - else + if (getContext().isMultithreadingEnabled()) runOnOperationAsyncImpl(); + else + runOnOperationImpl(); } /// Run this pass adaptor synchronously. @@ -576,7 +570,7 @@ private: /// The filename to use when generating the reproducer. StringRef filename; - /// Various pass manager flags. + /// Various pass manager and context flags. bool disableThreads; bool verifyPasses; @@ -628,7 +622,7 @@ LogicalResult RecoveryReproducerContext::generate(std::string &error) { // Output the current pass manager configuration. outputOS << "// configuration: -pass-pipeline='" << pipeline << "'"; if (disableThreads) - outputOS << " -disable-pass-threading"; + outputOS << " -mlir-disable-threading"; // TODO: Should this also be configured with a pass manager flag? outputOS << "\n// note: verifyPasses=" << (verifyPasses ? "true" : "false") @@ -684,7 +678,8 @@ LogicalResult PassManager::runWithCrashRecovery(MutableArrayRef> passes, ModuleOp module, AnalysisManager am) { RecoveryReproducerContext context(passes, module, *crashReproducerFileName, - impl->disableThreads, impl->verifyPasses); + !getContext()->isMultithreadingEnabled(), + impl->verifyPasses); // Safely invoke the passes within a recovery context. llvm::CrashRecoveryContext::Enable(); @@ -715,7 +710,7 @@ PassManager::runWithCrashRecovery(MutableArrayRef> passes, PassManager::PassManager(MLIRContext *ctx, bool verifyPasses) : OpPassManager(OperationName(ModuleOp::getOperationName(), ctx), - /*disableThreads=*/false, verifyPasses), + verifyPasses), passTiming(false), localReproducer(false) {} PassManager::~PassManager() {} @@ -741,15 +736,6 @@ LogicalResult PassManager::run(ModuleOp module) { return result; } -/// Disable support for multi-threading within the pass manager. -void PassManager::disableMultithreading(bool disable) { - getImpl().disableThreads = disable; -} - -bool PassManager::isMultithreadingEnabled() { - return !getImpl().disableThreads; -} - /// Enable support for the pass manager to generate a reproducer on the event /// of a crash or a pass failure. `outputFile` is a .mlir filename used to write /// the generated reproducer. If `genLocalReproducer` is true, the pass manager diff --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp index c1b1eee..b00f992 100644 --- a/mlir/lib/Pass/PassManagerOptions.cpp +++ b/mlir/lib/Pass/PassManagerOptions.cpp @@ -30,14 +30,6 @@ struct PassManagerOptions { llvm::cl::init(false)}; //===--------------------------------------------------------------------===// - // Multi-threading - //===--------------------------------------------------------------------===// - llvm::cl::opt disableThreads{ - "disable-pass-threading", - llvm::cl::desc("Disable multithreading in the pass manager"), - llvm::cl::init(false)}; - - //===--------------------------------------------------------------------===// // IR Printing //===--------------------------------------------------------------------===// PassPipelineCLParser printBefore{"print-ir-before", @@ -164,10 +156,6 @@ void mlir::applyPassManagerCLOptions(PassManager &pm) { pm.enableCrashReproducerGeneration(options->reproducerFile, options->localReproducer); - // Disable multi-threading. - if (options->disableThreads) - pm.disableMultithreading(); - // Enable statistics dumping. if (options->passStatistics) pm.enableStatistics(options->passStatisticsDisplayMode); diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp index d50c599..40304a5 100644 --- a/mlir/lib/Support/StorageUniquer.cpp +++ b/mlir/lib/Support/StorageUniquer.cpp @@ -46,6 +46,8 @@ struct StorageUniquerImpl { function_ref isEqual, function_ref ctorFn) { LookupKey lookupKey{kind, hashValue, isEqual}; + if (!threadingIsEnabled) + return getOrCreateUnsafe(kind, hashValue, lookupKey, ctorFn); // Check for an existing instance in read-only mode. { @@ -57,9 +59,12 @@ struct StorageUniquerImpl { // Acquire a writer-lock so that we can safely create the new type instance. llvm::sys::SmartScopedWriter typeLock(mutex); - - // Check for an existing instance again here, because another writer thread - // may have already created one. + return getOrCreateUnsafe(kind, hashValue, lookupKey, ctorFn); + } + /// Get or create an instance of a complex derived type in an unsafe fashion. + BaseStorage * + getOrCreateUnsafe(unsigned kind, unsigned hashValue, LookupKey &lookupKey, + function_ref ctorFn) { auto existing = storageTypes.insert_as({}, lookupKey); if (!existing.second) return existing.first->storage; @@ -75,6 +80,9 @@ struct StorageUniquerImpl { BaseStorage * getOrCreate(unsigned kind, function_ref ctorFn) { + if (!threadingIsEnabled) + return getOrCreateUnsafe(kind, ctorFn); + // Check for an existing instance in read-only mode. { llvm::sys::SmartScopedReader typeLock(mutex); @@ -85,9 +93,12 @@ struct StorageUniquerImpl { // Acquire a writer-lock so that we can safely create the new type instance. llvm::sys::SmartScopedWriter typeLock(mutex); - - // Check for an existing instance again here, because another writer thread - // may have already created one. + return getOrCreateUnsafe(kind, ctorFn); + } + /// Get or create an instance of a simple derived type in an unsafe fashion. + BaseStorage * + getOrCreateUnsafe(unsigned kind, + function_ref ctorFn) { auto &result = simpleTypes[kind]; if (result) return result; @@ -152,18 +163,21 @@ struct StorageUniquerImpl { } }; - // Unique types with specific hashing or storage constraints. + /// Unique types with specific hashing or storage constraints. using StorageTypeSet = DenseSet; StorageTypeSet storageTypes; - // Unique types with just the kind. + /// Unique types with just the kind. DenseMap simpleTypes; - // Allocator to use when constructing derived type instances. + /// Allocator to use when constructing derived type instances. StorageUniquer::StorageAllocator allocator; - // A mutex to keep type uniquing thread-safe. + /// A mutex to keep type uniquing thread-safe. llvm::sys::SmartRWMutex mutex; + + /// Flag specifying if multi-threading is enabled within the uniquer. + bool threadingIsEnabled = true; }; } // end namespace detail } // namespace mlir @@ -171,6 +185,11 @@ struct StorageUniquerImpl { StorageUniquer::StorageUniquer() : impl(new StorageUniquerImpl()) {} StorageUniquer::~StorageUniquer() {} +/// Set the flag specifying if multi-threading is disabled within the uniquer. +void StorageUniquer::disableMultithreading(bool disable) { + impl->threadingIsEnabled = !disable; +} + /// Implementation for getting/creating an instance of a derived type with /// complex storage. auto StorageUniquer::getImpl( diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp index c0f89da..ee645cb 100644 --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -496,22 +496,28 @@ static void canonicalizeSCC(CallGraph &cg, CGUseList &useList, // NOTE: This is simple now, because we don't enable canonicalizing nodes // within children. When we remove this restriction, this logic will need to // be reworked. - ParallelDiagnosticHandler canonicalizationHandler(context); - llvm::parallel::for_each_n( - llvm::parallel::par, /*Begin=*/size_t(0), - /*End=*/nodesToCanonicalize.size(), [&](size_t index) { - // Set the order for this thread so that diagnostics will be properly - // ordered. - canonicalizationHandler.setOrderIDForThread(index); - - // Apply the canonicalization patterns to this region. - auto *node = nodesToCanonicalize[index]; - applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns); - - // Make sure to reset the order ID for the diagnostic handler, as this - // thread may be used in a different context. - canonicalizationHandler.eraseOrderIDForThread(); - }); + if (context->isMultithreadingEnabled()) { + ParallelDiagnosticHandler canonicalizationHandler(context); + llvm::parallel::for_each_n( + llvm::parallel::par, /*Begin=*/size_t(0), + /*End=*/nodesToCanonicalize.size(), [&](size_t index) { + // Set the order for this thread so that diagnostics will be properly + // ordered. + canonicalizationHandler.setOrderIDForThread(index); + + // Apply the canonicalization patterns to this region. + auto *node = nodesToCanonicalize[index]; + applyPatternsAndFoldGreedily(*node->getCallableRegion(), + canonPatterns); + + // Make sure to reset the order ID for the diagnostic handler, as this + // thread may be used in a different context. + canonicalizationHandler.eraseOrderIDForThread(); + }); + } else { + for (CallGraphNode *node : nodesToCanonicalize) + applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns); + } // Recompute the uses held by each of the nodes. for (CallGraphNode *node : nodesToCanonicalize) diff --git a/mlir/test/Dialect/SPIRV/availability.mlir b/mlir/test/Dialect/SPIRV/availability.mlir index e31c1bd..322cc53 100644 --- a/mlir/test/Dialect/SPIRV/availability.mlir +++ b/mlir/test/Dialect/SPIRV/availability.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -disable-pass-threading -test-spirv-op-availability %s | FileCheck %s +// RUN: mlir-opt -mlir-disable-threading -test-spirv-op-availability %s | FileCheck %s // CHECK-LABEL: iadd func @iadd(%arg: i32) -> i32 { diff --git a/mlir/test/Dialect/SPIRV/target-env.mlir b/mlir/test/Dialect/SPIRV/target-env.mlir index 9b42314..27c4e8d 100644 --- a/mlir/test/Dialect/SPIRV/target-env.mlir +++ b/mlir/test/Dialect/SPIRV/target-env.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -disable-pass-threading -test-spirv-target-env %s | FileCheck %s +// RUN: mlir-opt -mlir-disable-threading -test-spirv-target-env %s | FileCheck %s // Note: The following tests check that a spv.target_env can properly control // the conversion target and filter unavailable ops during the conversion. diff --git a/mlir/test/IR/test-matchers.mlir b/mlir/test/IR/test-matchers.mlir index 60d5bcf..925b01b 100644 --- a/mlir/test/IR/test-matchers.mlir +++ b/mlir/test/IR/test-matchers.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -disable-pass-threading=true -test-matchers -o /dev/null 2>&1 | FileCheck %s +// RUN: mlir-opt %s -mlir-disable-threading=true -test-matchers -o /dev/null 2>&1 | FileCheck %s func @test1(%a: f32, %b: f32, %c: f32) { %0 = addf %a, %b: f32 diff --git a/mlir/test/Pass/ir-printing.mlir b/mlir/test/Pass/ir-printing.mlir index 892dc40..8bb86b3 100644 --- a/mlir/test/Pass/ir-printing.mlir +++ b/mlir/test/Pass/ir-printing.mlir @@ -1,9 +1,9 @@ -// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before=cse -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE %s -// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before-all -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE_ALL %s -// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-after=cse -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER %s -// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-after-all -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL %s -// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before=cse -print-ir-module-scope -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE_MODULE %s -// RUN: mlir-opt %s -disable-pass-threading=true -pass-pipeline='func(cse,cse)' -print-ir-after-all -print-ir-after-change -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL_CHANGE %s +// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before=cse -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE %s +// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before-all -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE_ALL %s +// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-after=cse -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER %s +// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-after-all -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL %s +// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,canonicalize)' -print-ir-before=cse -print-ir-module-scope -o /dev/null 2>&1 | FileCheck -check-prefix=BEFORE_MODULE %s +// RUN: mlir-opt %s -mlir-disable-threading=true -pass-pipeline='func(cse,cse)' -print-ir-after-all -print-ir-after-change -o /dev/null 2>&1 | FileCheck -check-prefix=AFTER_ALL_CHANGE %s func @foo() { %0 = constant 0 : i32 diff --git a/mlir/test/Pass/pass-timing.mlir b/mlir/test/Pass/pass-timing.mlir index db39ad6..6cd8a29 100644 --- a/mlir/test/Pass/pass-timing.mlir +++ b/mlir/test/Pass/pass-timing.mlir @@ -1,8 +1,8 @@ -// RUN: mlir-opt %s -disable-pass-threading=true -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=list 2>&1 | FileCheck -check-prefix=LIST %s -// RUN: mlir-opt %s -disable-pass-threading=true -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=PIPELINE %s -// RUN: mlir-opt %s -disable-pass-threading=false -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=list 2>&1 | FileCheck -check-prefix=MT_LIST %s -// RUN: mlir-opt %s -disable-pass-threading=false -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=MT_PIPELINE %s -// RUN: mlir-opt %s -disable-pass-threading=false -verify-each=false -test-pm-nested-pipeline -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=NESTED_MT_PIPELINE %s +// RUN: mlir-opt %s -mlir-disable-threading=true -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=list 2>&1 | FileCheck -check-prefix=LIST %s +// RUN: mlir-opt %s -mlir-disable-threading=true -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=PIPELINE %s +// RUN: mlir-opt %s -mlir-disable-threading=false -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=list 2>&1 | FileCheck -check-prefix=MT_LIST %s +// RUN: mlir-opt %s -mlir-disable-threading=false -verify-each=true -pass-pipeline='func(cse,canonicalize,cse)' -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=MT_PIPELINE %s +// RUN: mlir-opt %s -mlir-disable-threading=false -verify-each=false -test-pm-nested-pipeline -pass-timing -pass-timing-display=pipeline 2>&1 | FileCheck -check-prefix=NESTED_MT_PIPELINE %s // LIST: Pass execution timing report // LIST: Total Execution Time: -- 2.7.4