[PassManager] Define a ParallelDiagnosticHandler to ensure that diagnostics are still...
authorRiver Riddle <riverriddle@google.com>
Thu, 28 Mar 2019 18:40:17 +0000 (11:40 -0700)
committerjpienaar <jpienaar@google.com>
Sat, 30 Mar 2019 00:50:59 +0000 (17:50 -0700)
PiperOrigin-RevId: 240817922

mlir/lib/Pass/Pass.cpp

index 9cb46d4..1ffab6c 100644 (file)
@@ -175,6 +175,79 @@ void ModuleToFunctionPassAdaptor::runOnModule() {
   }
 }
 
+namespace {
+/// A utility class to ensure that diagnostics are emitted in a deterministic
+/// order when executing a pipeline asynchronously.
+struct ParallelDiagnosticHandler {
+  struct ThreadDiagnostic {
+    ThreadDiagnostic(size_t id, Location loc, StringRef msg,
+                     MLIRContext::DiagnosticKind kind)
+        : id(id), loc(loc), msg(msg), kind(kind) {}
+    bool operator<(const ThreadDiagnostic &rhs) const { return id < rhs.id; }
+
+    /// The function id for this diagnostic, this is used for ordering.
+    size_t id;
+
+    /// Information for the diagnostic.
+    Location loc;
+    std::string msg;
+    MLIRContext::DiagnosticKind kind;
+  };
+
+  ParallelDiagnosticHandler(MLIRContext &ctx)
+      : prevHandler(ctx.getDiagnosticHandler()), context(ctx) {
+    ctx.registerDiagnosticHandler([this](Location loc, StringRef message,
+                                         MLIRContext::DiagnosticKind kind) {
+      uint64_t tid = llvm::get_threadid();
+      llvm::sys::SmartScopedLock<true> lock(mutex);
+
+      // Append a new diagnostic.
+      diagnostics.emplace_back(threadToFuncID[tid], loc, message, kind);
+    });
+  }
+
+  ~ParallelDiagnosticHandler() {
+    // Restore the previous diagnostic handler.
+    context.registerDiagnosticHandler(prevHandler);
+
+    // Early exit if there are no diagnostics, this is the common case.
+    if (diagnostics.empty())
+      return;
+
+    // Stable sort all of the diagnostics that were emitted. This creates a
+    // deterministic ordering for the diagnostics based upon which function they
+    // were emitted for.
+    std::stable_sort(diagnostics.begin(), diagnostics.end());
+
+    // Emit each diagnostic to the context again.
+    for (ThreadDiagnostic &diag : diagnostics)
+      context.emitDiagnostic(diag.loc, diag.msg, diag.kind);
+  }
+
+  /// Set the function id for the current thread.
+  void setFuncIDForThread(size_t funcID) {
+    uint64_t tid = llvm::get_threadid();
+    llvm::sys::SmartScopedLock<true> lock(mutex);
+    threadToFuncID[tid] = funcID;
+  }
+
+  /// The previous context diagnostic handler.
+  MLIRContext::DiagnosticHandlerTy prevHandler;
+
+  /// A smart mutex to lock access to the internal state.
+  llvm::sys::SmartMutex<true> mutex;
+
+  /// A mapping between the thread id and the current function id.
+  DenseMap<uint64_t, size_t> threadToFuncID;
+
+  /// An unordered list of diagnostics that were emitted.
+  std::vector<ThreadDiagnostic> diagnostics;
+
+  /// The context to emit the diagnostics to.
+  MLIRContext &context;
+};
+} // end anonymous namespace
+
 // Run the held function pipeline synchronously across the functions within
 // the module.
 void ModuleToFunctionPassAdaptorParallel::runOnModule() {
@@ -193,6 +266,10 @@ void ModuleToFunctionPassAdaptorParallel::runOnModule() {
     if (!func.isExternal())
       funcAMPairs.emplace_back(&func, mam.slice(&func));
 
+  // A parallel diagnostic handler that provides deterministic diagnostic
+  // ordering.
+  ParallelDiagnosticHandler diagHandler(getContext());
+
   // An index for the current function/analysis manager pair.
   std::atomic<unsigned> funcIt(0);
 
@@ -209,6 +286,9 @@ void ModuleToFunctionPassAdaptorParallel::runOnModule() {
           if (nextID >= e)
             break;
 
+          // Set the function id for this thread in the diagnostic handler.
+          diagHandler.setFuncIDForThread(nextID);
+
           // Run the executor over the current function.
           auto &it = funcAMPairs[nextID];
           if (failed(runFunctionPipeline(executor, it.first, it.second))) {