[mlir] Add support to SourceMgrDiagnosticHandler for filtering FileLineColLocs
authorRiver Riddle <riddleriver@gmail.com>
Fri, 18 Jun 2021 20:30:16 +0000 (20:30 +0000)
committerRiver Riddle <riddleriver@gmail.com>
Fri, 18 Jun 2021 21:12:28 +0000 (21:12 +0000)
This revision adds support for passing a functor to SourceMgrDiagnosticHandler for filtering out FileLineColLocs when emitting a diagnostic. More specifically, this can be useful in situations where there may be large CallSiteLocs with locations that aren't necessarily important/useful for users.

For now the filtering support is limited to FileLineColLocs, but conceptually we could allow filtering for all locations types if a need arises in the future.

Differential Revision: https://reviews.llvm.org/D103649

mlir/docs/Diagnostics.md
mlir/include/mlir/IR/Diagnostics.h
mlir/lib/IR/Diagnostics.cpp
mlir/test/IR/diagnostic-handler-filter.mlir [new file with mode: 0644]
mlir/test/lib/IR/CMakeLists.txt
mlir/test/lib/IR/TestDiagnostics.cpp [new file with mode: 0644]
mlir/tools/mlir-opt/mlir-opt.cpp

index ae18844..285da68 100644 (file)
@@ -243,6 +243,45 @@ MLIRContext context;
 SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
 ```
 
+#### Filtering Locations
+
+In some situations, a diagnostic may be emitted with a callsite location in a
+very deep call stack in which many frames are unrelated to the user source code.
+These situations often arise when the user source code is intertwined with that
+of a large framework or library. The context of the diagnostic in these cases is
+often obfuscated by the unrelated framework source locations. To help alleviate
+this obfuscation, the `SourceMgrDiagnosticHandler` provides support for
+filtering which locations are shown to the user. To enable filtering, a user
+must simply provide a filter function to the `SourceMgrDiagnosticHandler` on
+construction that indicates which locations should be shown. A quick example is
+shown below:
+
+```c++
+// Here we define the functor that controls which locations are shown to the
+// user. This functor should return true when a location should be shown, and
+// false otherwise. When filtering a container location, such as a NameLoc, this
+// function should not recurse into the child location. Recursion into nested
+// location is performed as necessary by the caller.
+auto shouldShowFn = [](Location loc) -> bool {
+  FileLineColLoc fileLoc = loc.dyn_cast<FileLineColLoc>();
+
+  // We don't perform any filtering on non-file locations.
+  // Reminder: The caller will recurse into any necessary child locations.
+  if (!fileLoc)
+    return true;
+
+  // Don't show file locations that contain our framework code.
+  return !fileLoc.getFilename().strref().contains("my/framework/source/");
+};
+
+SourceMgr sourceMgr;
+MLIRContext context;
+SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context, shouldShowFn);
+```
+
+Note: In the case where all locations are filtered out, the first location in
+the stack will still be shown.
+
 ### SourceMgr Diagnostic Verifier Handler
 
 This handler is a wrapper around a llvm::SourceMgr that is used to verify that
index 1fee314..89f50ab 100644 (file)
@@ -530,9 +530,20 @@ struct SourceMgrDiagnosticHandlerImpl;
 /// This class is a utility diagnostic handler for use with llvm::SourceMgr.
 class SourceMgrDiagnosticHandler : public ScopedDiagnosticHandler {
 public:
+  /// This type represents a functor used to filter out locations when printing
+  /// a diagnostic. It should return true if the provided location is okay to
+  /// display, false otherwise. If all locations in a diagnostic are filtered
+  /// out, the first location is used as the sole location. When deciding
+  /// whether or not to filter a location, this function should not recurse into
+  /// any nested location. This recursion is handled automatically by the
+  /// caller.
+  using ShouldShowLocFn = llvm::unique_function<bool(Location)>;
+
+  SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx,
+                             raw_ostream &os,
+                             ShouldShowLocFn &&shouldShowLocFn = {});
   SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx,
-                             raw_ostream &os);
-  SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx);
+                             ShouldShowLocFn &&shouldShowLocFn = {});
   ~SourceMgrDiagnosticHandler();
 
   /// Emit the given diagnostic information with the held source manager.
@@ -553,10 +564,18 @@ protected:
   /// The output stream to use when printing diagnostics.
   raw_ostream &os;
 
+  /// A functor used when determining if a location for a diagnostic should be
+  /// shown. If null, all locations should be shown.
+  ShouldShowLocFn shouldShowLocFn;
+
 private:
   /// Convert a location into the given memory buffer into an SMLoc.
   llvm::SMLoc convertLocToSMLoc(FileLineColLoc loc);
 
+  /// Given a location, returns the first nested location (including 'loc') that
+  /// can be shown to the user.
+  Optional<Location> findLocToShow(Location loc);
+
   /// The maximum depth that a call stack will be printed.
   /// TODO: This should be a tunable flag.
   unsigned callStackLimit = 10;
index cb37994..4b4add2 100644 (file)
@@ -16,6 +16,7 @@
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Mutex.h"
 #include "llvm/Support/PrettyStackTrace.h"
 #include "llvm/Support/Regex.h"
@@ -409,17 +410,19 @@ static llvm::SourceMgr::DiagKind getDiagKind(DiagnosticSeverity kind) {
   llvm_unreachable("Unknown DiagnosticSeverity");
 }
 
-SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr,
-                                                       MLIRContext *ctx,
-                                                       raw_ostream &os)
+SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
+    llvm::SourceMgr &mgr, MLIRContext *ctx, raw_ostream &os,
+    ShouldShowLocFn &&shouldShowLocFn)
     : ScopedDiagnosticHandler(ctx), mgr(mgr), os(os),
+      shouldShowLocFn(std::move(shouldShowLocFn)),
       impl(new SourceMgrDiagnosticHandlerImpl()) {
   setHandler([this](Diagnostic &diag) { emitDiagnostic(diag); });
 }
 
-SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr,
-                                                       MLIRContext *ctx)
-    : SourceMgrDiagnosticHandler(mgr, ctx, llvm::errs()) {}
+SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(
+    llvm::SourceMgr &mgr, MLIRContext *ctx, ShouldShowLocFn &&shouldShowLocFn)
+    : SourceMgrDiagnosticHandler(mgr, ctx, llvm::errs(),
+                                 std::move(shouldShowLocFn)) {}
 
 SourceMgrDiagnosticHandler::~SourceMgrDiagnosticHandler() {}
 
@@ -460,17 +463,23 @@ void SourceMgrDiagnosticHandler::emitDiagnostic(Location loc, Twine message,
 
 /// Emit the given diagnostic with the held source manager.
 void SourceMgrDiagnosticHandler::emitDiagnostic(Diagnostic &diag) {
-  // Emit the diagnostic.
+  SmallVector<std::pair<Location, StringRef>> locationStack;
+  auto addLocToStack = [&](Location loc, StringRef locContext) {
+    if (Optional<Location> showableLoc = findLocToShow(loc))
+      locationStack.emplace_back(loc, locContext);
+  };
+
+  // Add locations to display for this diagnostic.
   Location loc = diag.getLocation();
-  emitDiagnostic(loc, diag.str(), diag.getSeverity());
+  addLocToStack(loc, /*locContext=*/{});
 
-  // If the diagnostic location was a call site location, then print the call
-  // stack as well.
+  // If the diagnostic location was a call site location, add the call stack as
+  // well.
   if (auto callLoc = getCallSiteLoc(loc)) {
     // Print the call stack while valid, or until the limit is reached.
     loc = callLoc->getCaller();
     for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) {
-      emitDiagnostic(loc, "called from", DiagnosticSeverity::Note);
+      addLocToStack(loc, "called from");
       if ((callLoc = getCallSiteLoc(loc)))
         loc = callLoc->getCaller();
       else
@@ -478,6 +487,17 @@ void SourceMgrDiagnosticHandler::emitDiagnostic(Diagnostic &diag) {
     }
   }
 
+  // If the location stack is empty, use the initial location.
+  if (locationStack.empty()) {
+    emitDiagnostic(diag.getLocation(), diag.str(), diag.getSeverity());
+
+    // Otherwise, use the location stack.
+  } else {
+    emitDiagnostic(locationStack.front().first, diag.str(), diag.getSeverity());
+    for (auto &it : llvm::drop_begin(locationStack))
+      emitDiagnostic(it.first, it.second, DiagnosticSeverity::Note);
+  }
+
   // Emit each of the notes. Only display the source code if the location is
   // different from the previous location.
   for (auto &note : diag.getNotes()) {
@@ -495,6 +515,41 @@ SourceMgrDiagnosticHandler::getBufferForFile(StringRef filename) {
   return nullptr;
 }
 
+Optional<Location> SourceMgrDiagnosticHandler::findLocToShow(Location loc) {
+  if (!shouldShowLocFn)
+    return loc;
+  if (!shouldShowLocFn(loc))
+    return llvm::None;
+
+  // Recurse into the child locations of some of location types.
+  return TypeSwitch<LocationAttr, Optional<Location>>(loc)
+      .Case([&](CallSiteLoc callLoc) -> Optional<Location> {
+        // We recurse into the callee of a call site, as the caller will be
+        // emitted in a different note on the main diagnostic.
+        return findLocToShow(callLoc.getCallee());
+      })
+      .Case([&](FileLineColLoc) -> Optional<Location> { return loc; })
+      .Case([&](FusedLoc fusedLoc) -> Optional<Location> {
+        // Fused location is unique in that we try to find a sub-location to
+        // show, rather than the top-level location itself.
+        for (Location childLoc : fusedLoc.getLocations())
+          if (Optional<Location> showableLoc = findLocToShow(childLoc))
+            return showableLoc;
+        return llvm::None;
+      })
+      .Case([&](NameLoc nameLoc) -> Optional<Location> {
+        return findLocToShow(nameLoc.getChildLoc());
+      })
+      .Case([&](OpaqueLoc opaqueLoc) -> Optional<Location> {
+        // OpaqueLoc always falls back to a different source location.
+        return findLocToShow(opaqueLoc.getFallbackLocation());
+      })
+      .Case([](UnknownLoc) -> Optional<Location> {
+        // Prefer not to show unknown locations.
+        return llvm::None;
+      });
+}
+
 /// Get a memory buffer for the given file, or the main file of the source
 /// manager if one doesn't exist. This always returns non-null.
 llvm::SMLoc SourceMgrDiagnosticHandler::convertLocToSMLoc(FileLineColLoc loc) {
diff --git a/mlir/test/IR/diagnostic-handler-filter.mlir b/mlir/test/IR/diagnostic-handler-filter.mlir
new file mode 100644 (file)
index 0000000..d193c1a
--- /dev/null
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -test-diagnostic-filter='filters=mysource1' -o - 2>&1 | FileCheck %s
+// This test verifies that diagnostic handler can emit the call stack successfully.
+
+// CHECK-LABEL: Test 'test1'
+// CHECK-NEXT: mysource2:1:0: error: test diagnostic
+// CHECK-NEXT: mysource3:2:0: note: called from
+func private @test1() attributes {
+  test.loc = loc(callsite("foo"("mysource1":0:0) at callsite("mysource2":1:0 at "mysource3":2:0)))
+}
+
+// CHECK-LABEL: Test 'test2'
+// CHECK-NEXT: mysource1:0:0: error: test diagnostic
+func private @test2() attributes {
+  test.loc = loc("mysource1":0:0)
+}
index fc47158..e809ea1 100644 (file)
@@ -1,5 +1,6 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRTestIR
+  TestDiagnostics.cpp
   TestDominance.cpp
   TestFunc.cpp
   TestInterfaces.cpp
diff --git a/mlir/test/lib/IR/TestDiagnostics.cpp b/mlir/test/lib/IR/TestDiagnostics.cpp
new file mode 100644 (file)
index 0000000..0021e0d
--- /dev/null
@@ -0,0 +1,65 @@
+//===- TestDiagnostics.cpp - Test Diagnostic Utilities --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains test passes for constructing and resolving dominance
+// information.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/SourceMgr.h"
+
+using namespace mlir;
+
+namespace {
+struct TestDiagnosticFilterPass
+    : public PassWrapper<TestDiagnosticFilterPass, OperationPass<FuncOp>> {
+  TestDiagnosticFilterPass() {}
+  TestDiagnosticFilterPass(const TestDiagnosticFilterPass &) {}
+
+  void runOnOperation() override {
+    llvm::errs() << "Test '" << getOperation().getName() << "'\n";
+
+    // Build a diagnostic handler that has filtering capabilities.
+    auto filterFn = [&](Location loc) {
+      // Ignore non-file locations.
+      FileLineColLoc fileLoc = loc.dyn_cast<FileLineColLoc>();
+      if (!fileLoc)
+        return true;
+
+      // Don't show file locations if their name contains a filter.
+      return llvm::none_of(filters, [&](StringRef filter) {
+        return fileLoc.getFilename().strref().contains(filter);
+      });
+    };
+    llvm::SourceMgr sourceMgr;
+    SourceMgrDiagnosticHandler handler(sourceMgr, &getContext(), llvm::errs(),
+                                       filterFn);
+
+    // Emit a diagnostic for every operation with a valid loc.
+    getOperation()->walk([&](Operation *op) {
+      if (LocationAttr locAttr = op->getAttrOfType<LocationAttr>("test.loc"))
+        emitError(locAttr, "test diagnostic");
+    });
+  }
+
+  ListOption<std::string> filters{
+      *this, "filters", llvm::cl::MiscFlags::CommaSeparated,
+      llvm::cl::desc("Specifies the diagnostic file name filters.")};
+};
+
+} // end anonymous namespace
+
+namespace mlir {
+namespace test {
+void registerTestDiagnosticsPass() {
+  PassRegistration<TestDiagnosticFilterPass>(
+      "test-diagnostic-filter", "Test diagnostic filtering support.");
+}
+} // namespace test
+} // namespace mlir
index c2966e6..c30575b 100644 (file)
@@ -68,6 +68,7 @@ void registerTestGpuSerializeToCubinPass();
 void registerTestGpuSerializeToHsacoPass();
 void registerTestDataLayoutQuery();
 void registerTestDecomposeCallGraphTypes();
+void registerTestDiagnosticsPass();
 void registerTestDialect(DialectRegistry &);
 void registerTestDominancePass();
 void registerTestDynamicPipelinePass();
@@ -140,6 +141,7 @@ void registerTestPasses() {
   test::registerTestAliasAnalysisPass();
   test::registerTestCallGraphPass();
   test::registerTestConstantFold();
+  test::registerTestDiagnosticsPass();
 #if MLIR_CUDA_CONVERSIONS_ENABLED
   test::registerTestGpuSerializeToCubinPass();
 #endif