Add support for providing an output stream to the SourceMgrDiagnosticHandlers.
authorRiver Riddle <riverriddle@google.com>
Fri, 31 May 2019 23:03:26 +0000 (16:03 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:13:39 +0000 (20:13 -0700)
--

PiperOrigin-RevId: 250974331

mlir/include/mlir/IR/Diagnostics.h
mlir/lib/IR/Diagnostics.cpp

index b2205f0..6038a23 100644 (file)
@@ -484,6 +484,8 @@ struct SourceMgrDiagnosticHandlerImpl;
 /// This class is a utility diagnostic handler for use with llvm::SourceMgr.
 class SourceMgrDiagnosticHandler : public ScopedDiagnosticHandler {
 public:
+  SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx,
+                             llvm::raw_ostream &os);
   SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx);
   ~SourceMgrDiagnosticHandler();
 
@@ -501,6 +503,9 @@ protected:
   /// The source manager that we are wrapping.
   llvm::SourceMgr &mgr;
 
+  /// The output stream to use when printing diagnostics.
+  llvm::raw_ostream &os;
+
 private:
   /// Convert a location into the given memory buffer into an SMLoc.
   llvm::SMLoc convertLocToSMLoc(FileLineColLoc loc);
@@ -525,6 +530,8 @@ struct SourceMgrDiagnosticVerifierHandlerImpl;
 /// corresponding line of the source file.
 class SourceMgrDiagnosticVerifierHandler : public SourceMgrDiagnosticHandler {
 public:
+  SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx,
+                                     llvm::raw_ostream &out);
   SourceMgrDiagnosticVerifierHandler(llvm::SourceMgr &srcMgr, MLIRContext *ctx);
   ~SourceMgrDiagnosticVerifierHandler();
 
index 67e1405..9fbf9c5 100644 (file)
@@ -342,13 +342,19 @@ static llvm::SourceMgr::DiagKind getDiagKind(DiagnosticSeverity kind) {
 }
 
 SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr,
-                                                       MLIRContext *ctx)
-    : ScopedDiagnosticHandler(ctx), mgr(mgr),
+                                                       MLIRContext *ctx,
+                                                       llvm::raw_ostream &os)
+    : ScopedDiagnosticHandler(ctx), mgr(mgr), os(os),
       impl(new SourceMgrDiagnosticHandlerImpl()) {
   // Register a simple diagnostic handler.
   ctx->getDiagEngine().setHandler(
       [this](Diagnostic diag) { emitDiagnostic(diag); });
 }
+
+SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr,
+                                                       MLIRContext *ctx)
+    : SourceMgrDiagnosticHandler(mgr, ctx, llvm::errs()) {}
+
 SourceMgrDiagnosticHandler::~SourceMgrDiagnosticHandler() {}
 
 void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc, Twine message,
@@ -359,16 +365,15 @@ void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc, Twine message,
   // If one doesn't exist, then print the location as part of the message.
   if (!fileLoc) {
     std::string str;
-    llvm::raw_string_ostream os(str);
-    os << loc << ": ";
-    return mgr.PrintMessage(llvm::SMLoc(), getDiagKind(kind),
-                            os.str() + message);
+    llvm::raw_string_ostream strOS(str);
+    strOS << loc << ": " << message;
+    return mgr.PrintMessage(os, llvm::SMLoc(), getDiagKind(kind), strOS.str());
   }
 
   // Otherwise, try to convert the file location to an SMLoc.
   auto smloc = convertLocToSMLoc(*fileLoc);
   if (smloc.isValid())
-    return mgr.PrintMessage(smloc, getDiagKind(kind), message);
+    return mgr.PrintMessage(os, smloc, getDiagKind(kind), message);
 
   // If the conversion was unsuccessful, create a diagnostic with the source
   // location information directly.
@@ -376,7 +381,7 @@ void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc, Twine message,
                           fileLoc->getLine(), fileLoc->getColumn(),
                           getDiagKind(kind), message.str(), /*LineStr=*/"",
                           /*Ranges=*/llvm::None);
-  diag.print(nullptr, llvm::errs());
+  diag.print(nullptr, os);
 }
 
 /// Emit the given diagnostic with the held source manager.
@@ -584,8 +589,8 @@ SourceMgrDiagnosticVerifierHandlerImpl::computeExpectedDiags(
 }
 
 SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
-    llvm::SourceMgr &srcMgr, MLIRContext *ctx)
-    : SourceMgrDiagnosticHandler(srcMgr, ctx),
+    llvm::SourceMgr &srcMgr, MLIRContext *ctx, llvm::raw_ostream &out)
+    : SourceMgrDiagnosticHandler(srcMgr, ctx, out),
       impl(new SourceMgrDiagnosticVerifierHandlerImpl()) {
   // Compute the expected diagnostics for each of the current files in the
   // source manager.
@@ -603,6 +608,10 @@ SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
   });
 }
 
+SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
+    llvm::SourceMgr &srcMgr, MLIRContext *ctx)
+    : SourceMgrDiagnosticVerifierHandler(srcMgr, ctx, llvm::errs()) {}
+
 SourceMgrDiagnosticVerifierHandler::~SourceMgrDiagnosticVerifierHandler() {
   // Ensure that all expected diagnosics were handled.
   (void)verify();
@@ -620,7 +629,7 @@ LogicalResult SourceMgrDiagnosticVerifierHandler::verify() {
       llvm::SMRange range(err.fileLoc,
                           llvm::SMLoc::getFromPointer(err.fileLoc.getPointer() +
                                                       err.substring.size()));
-      mgr.PrintMessage(err.fileLoc, llvm::SourceMgr::DK_Error,
+      mgr.PrintMessage(os, err.fileLoc, llvm::SourceMgr::DK_Error,
                        "expected " + getDiagKindStr(err.kind) + " \"" +
                            err.substring + "\" was not produced",
                        range);
@@ -675,7 +684,7 @@ void SourceMgrDiagnosticVerifierHandler::process(FileLineColLoc loc,
 
   // Otherwise, emit an error for the near miss.
   if (nearMiss)
-    mgr.PrintMessage(nearMiss->fileLoc, llvm::SourceMgr::DK_Error,
+    mgr.PrintMessage(os, nearMiss->fileLoc, llvm::SourceMgr::DK_Error,
                      "'" + getDiagKindStr(kind) +
                          "' diagnostic emitted when expecting a '" +
                          getDiagKindStr(nearMiss->kind) + "'");