Add an option to print an operation if a diagnostic is emitted on it
authorRiver Riddle <riverriddle@google.com>
Tue, 12 Nov 2019 19:57:47 +0000 (11:57 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 12 Nov 2019 19:59:19 +0000 (11:59 -0800)
It is often helpful to inspect the operation that the error/warning/remark/etc. originated from, especially in the context of debugging or in the case of a verifier failure. This change adds an option 'mlir-print-op-on-diagnostic' that attaches the operation as a note to any diagnostic that is emitted on it via Operation::emit(Error|Warning|Remark). In the case of an error, the operation is printed in the generic form.

PiperOrigin-RevId: 280021438

mlir/lib/IR/Operation.cpp
mlir/test/IR/print-op-on-diagnostic.mlir [new file with mode: 0644]
mlir/test/lib/TestDialect/TestDialect.cpp
mlir/test/lib/TestDialect/TestDialect.h

index 96c488c..aa13a71 100644 (file)
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "llvm/Support/CommandLine.h"
 #include <numeric>
 
 using namespace mlir;
 
+static llvm::cl::opt<bool> printOpOnDiagnostic(
+    "mlir-print-op-on-diagnostic",
+    llvm::cl::desc("When a diagnostic is emitted on an operation, also print "
+                   "the operation as an attached note"));
+
+OpAsmParser::~OpAsmParser() {}
+
+//===----------------------------------------------------------------------===//
+// OperationName
+//===----------------------------------------------------------------------===//
+
 /// Form the OperationName for an op with the specified string.  This either is
 /// a reference to an AbstractOperation if one is known, or a uniqued Identifier
 /// if not.
@@ -60,8 +72,6 @@ OperationName OperationName::getFromOpaquePointer(void *pointer) {
   return OperationName(RepresentationUnion::getFromOpaqueValue(pointer));
 }
 
-OpAsmParser::~OpAsmParser() {}
-
 //===----------------------------------------------------------------------===//
 // OpResult
 //===----------------------------------------------------------------------===//
@@ -301,27 +311,51 @@ void Operation::replaceUsesOfWith(Value *from, Value *to) {
 }
 
 //===----------------------------------------------------------------------===//
-// Other
+// Diagnostics
 //===----------------------------------------------------------------------===//
 
 /// Emit an error about fatal conditions with this operation, reporting up to
 /// any diagnostic handlers that may be listening.
 InFlightDiagnostic Operation::emitError(const Twine &message) {
-  return mlir::emitError(getLoc(), message);
+  InFlightDiagnostic diag = mlir::emitError(getLoc(), message);
+  if (printOpOnDiagnostic) {
+    // Print out the operation explicitly here so that we can print the generic
+    // form.
+    // TODO(riverriddle) It would be nice if we could instead provide the
+    // specific printing flags when adding the operation as an argument to the
+    // diagnostic.
+    std::string printedOp;
+    {
+      llvm::raw_string_ostream os(printedOp);
+      print(os, OpPrintingFlags().printGenericOpForm().useLocalScope());
+    }
+    diag.attachNote(getLoc()) << "see current operation: " << printedOp;
+  }
+  return diag;
 }
 
 /// Emit a warning about this operation, reporting up to any diagnostic
 /// handlers that may be listening.
 InFlightDiagnostic Operation::emitWarning(const Twine &message) {
-  return mlir::emitWarning(getLoc(), message);
+  InFlightDiagnostic diag = mlir::emitWarning(getLoc(), message);
+  if (printOpOnDiagnostic)
+    diag.attachNote(getLoc()) << "see current operation: " << *this;
+  return diag;
 }
 
 /// Emit a remark about this operation, reporting up to any diagnostic
 /// handlers that may be listening.
 InFlightDiagnostic Operation::emitRemark(const Twine &message) {
-  return mlir::emitRemark(getLoc(), message);
+  InFlightDiagnostic diag = mlir::emitRemark(getLoc(), message);
+  if (printOpOnDiagnostic)
+    diag.attachNote(getLoc()) << "see current operation: " << *this;
+  return diag;
 }
 
+//===----------------------------------------------------------------------===//
+// Other
+//===----------------------------------------------------------------------===//
+
 /// Given an operation 'other' that is within the same parent block, return
 /// whether the current operation is before 'other' in the operation list
 /// of the parent block.
diff --git a/mlir/test/IR/print-op-on-diagnostic.mlir b/mlir/test/IR/print-op-on-diagnostic.mlir
new file mode 100644 (file)
index 0000000..439aa40
--- /dev/null
@@ -0,0 +1,7 @@
+// RUN: mlir-opt %s -verify-diagnostics -mlir-print-op-on-diagnostic
+
+// This file tests the functionality of 'mlir-print-op-on-diagnostic'.
+
+// expected-error@below {{invalid to use 'test.invalid_attr'}}
+// expected-note@below {{see current operation: "module"()}}
+module attributes {test.invalid_attr} {}
index a6ec6ad..0178043 100644 (file)
@@ -116,6 +116,13 @@ TestDialect::TestDialect(MLIRContext *context)
   allowUnknownOperations();
 }
 
+LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
+                                                    NamedAttribute namedAttr) {
+  if (namedAttr.first == "test.invalid_attr")
+    return op->emitError() << "invalid to use 'test.invalid_attr'";
+  return success();
+}
+
 LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
                                                     unsigned regionIndex,
                                                     unsigned argIndex,
index 5be1c10..f10b984 100644 (file)
@@ -42,13 +42,14 @@ public:
   /// Get the canonical string name of the dialect.
   static StringRef getDialectName() { return "test"; }
 
-  LogicalResult verifyRegionArgAttribute(Operation *, unsigned regionIndex,
+  LogicalResult verifyOperationAttribute(Operation *op,
+                                         NamedAttribute namedAttr) override;
+  LogicalResult verifyRegionArgAttribute(Operation *op, unsigned regionIndex,
                                          unsigned argIndex,
-                                         NamedAttribute) override;
-
-  LogicalResult verifyRegionResultAttribute(Operation *, unsigned regionIndex,
+                                         NamedAttribute namedAttr) override;
+  LogicalResult verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
                                             unsigned resultIndex,
-                                            NamedAttribute) override;
+                                            NamedAttribute namedAttr) override;
 };
 
 #define GET_OP_CLASSES