[mlir][mlir-opt] Disable multithreading when parsing the input module.
authorRiver Riddle <riddleriver@gmail.com>
Mon, 4 May 2020 18:21:49 +0000 (11:21 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 5 May 2020 00:29:56 +0000 (17:29 -0700)
This removes the unnecessary/costly context synchronization when parsing, as the context is guaranteed to not be used by any other threads.

mlir/include/mlir/IR/MLIRContext.h
mlir/lib/Support/MlirOptMain.cpp
mlir/tools/mlir-opt/mlir-opt.cpp

index 40b3326..da0b0bd 100644 (file)
@@ -60,6 +60,9 @@ public:
 
   /// Set the flag specifying if multi-threading is disabled by the context.
   void disableMultithreading(bool disable = true);
+  void enableMultithreading(bool enable = true) {
+    disableMultithreading(!enable);
+  }
 
   /// Return true if we should attach the operation to diagnostics emitted via
   /// Operation::emit.
index 5c21e19..25e1970 100644 (file)
@@ -40,7 +40,14 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
                                     bool verifyPasses, SourceMgr &sourceMgr,
                                     MLIRContext *context,
                                     const PassPipelineCLParser &passPipeline) {
+  // Disable multi-threading when parsing the input file. This removes the
+  // unnecessary/costly context synchronization when parsing.
+  bool wasThreadingEnabled = context->isMultithreadingEnabled();
+  context->disableMultithreading();
+
+  // Parse the input file and reset the context threading state.
   OwningModuleRef module(parseSourceFile(sourceMgr, context));
+  context->enableMultithreading(wasThreadingEnabled);
   if (!module)
     return failure();
 
index d9c8221..c5cc533 100644 (file)
@@ -150,9 +150,9 @@ int main(int argc, char **argv) {
   // Parse pass names in main to ensure static initialization completed.
   cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n");
 
-  MLIRContext context;
   if(showDialects) {
     llvm::outs() << "Registered Dialects:\n";
+    MLIRContext context;
     for(Dialect *dialect : context.getRegisteredDialects()) {
       llvm::outs() << dialect->getNamespace() << "\n";
     }