[mlir] Fix dumping invalid ops
authorSergei Grechanik <sergei.grechanik@intel.com>
Mon, 7 Mar 2022 15:49:46 +0000 (07:49 -0800)
committerSergei Grechanik <sergei.grechanik@intel.com>
Mon, 7 Mar 2022 16:32:31 +0000 (08:32 -0800)
This patch fixes the crash when printing some ops (like affine.for and
scf.for) when they are dumped in invalid state, e.g. during pattern
application. Now the AsmState constructor verifies the operation
first and switches to generic operation printing when the verification
fails. Also operations are now printed in generic form when emitting
diagnostics and the severity level is Error.

Reviewed By: rriddle, mehdi_amini

Differential Revision: https://reviews.llvm.org/D117834

mlir/docs/Diagnostics.md
mlir/docs/OpDefinitions.md
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/IR/Value.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Diagnostics.cpp
mlir/lib/IR/Operation.cpp
mlir/test/IR/print-ir-invalid.mlir [new file with mode: 0644]
mlir/test/lib/IR/CMakeLists.txt
mlir/test/lib/IR/TestPrintInvalid.cpp [new file with mode: 0644]
mlir/tools/mlir-opt/mlir-opt.cpp

index 221f11a..a41b98e 100644 (file)
@@ -107,6 +107,18 @@ op->emitError() << "Compose an interesting error: " << fooAttr << ", " << fooTyp
 "Compose an interesting error: @foo, i32, (0, 1, 2)"
 ```
 
+Operations attached to a diagnostic will be printed in generic form if the
+severity level is `Error`, otherwise custom operation printers will be used.
+```c++
+// `anotherOp` will be printed in generic form,
+// e.g. %3 = "arith.addf"(%arg4, %2) : (f32, f32) -> f32
+op->emitError() << anotherOp;
+
+// `anotherOp` will be printed using the custom printer,
+// e.g. %3 = arith.addf %arg4, %2 : f32
+op->emitRemark() << anotherOp;
+```
+
 ### Attaching notes
 
 Unlike many other compiler frameworks, notes in MLIR cannot be emitted directly.
index 87c33b5..dc25ef0 100644 (file)
@@ -601,6 +601,15 @@ Note that the second phase will be run after the operations in the region are
 verified. Verifiers further down the order can rely on certain invariants being
 verified by a previous verifier and do not need to re-verify them.
 
+#### Emitting diagnostics in custom verifiers
+
+Custom verifiers should avoid printing operations using custom operation
+printers, because they require the printed operation (and sometimes its parent
+operation) to be verified first. In particular, when emitting diagnostics,
+custom verifiers should use the `Error` severity level, which prints operations
+in generic form by default, and avoid using lower severity levels (`Note`,
+`Remark`, `Warning`).
+
 ### Declarative Assembly Format
 
 The custom assembly form of the operation may be specified in a declarative
index f72c244..3707747 100644 (file)
@@ -726,6 +726,9 @@ public:
   /// Always print operations in the generic form.
   OpPrintingFlags &printGenericOpForm();
 
+  /// Do not verify the operation when using custom operation printers.
+  OpPrintingFlags &assumeVerified();
+
   /// Use local scope when printing the operation. This allows for using the
   /// printer in a more localized and thread-safe setting, but may not
   /// necessarily be identical to what the IR will look like when dumping
@@ -747,6 +750,9 @@ public:
   /// Return if operations should be printed in the generic form.
   bool shouldPrintGenericOpForm() const;
 
+  /// Return if operation verification should be skipped.
+  bool shouldAssumeVerified() const;
+
   /// Return if the printer should use local scope when dumping the IR.
   bool shouldUseLocalScope() const;
 
@@ -762,6 +768,9 @@ private:
   /// Print operations in the generic form.
   bool printGenericOpFormFlag : 1;
 
+  /// Skip operation verification.
+  bool assumeVerifiedFlag : 1;
+
   /// Print operations with numberings local to the current operation.
   bool printLocalScope : 1;
 };
index 65f5b8c..2f0524c 100644 (file)
@@ -24,6 +24,7 @@ class Block;
 class BlockArgument;
 class Operation;
 class OpOperand;
+class OpPrintingFlags;
 class OpResult;
 class Region;
 class Value;
@@ -215,6 +216,7 @@ public:
   // Utilities
 
   void print(raw_ostream &os);
+  void print(raw_ostream &os, const OpPrintingFlags &flags);
   void print(raw_ostream &os, AsmState &state);
   void dump();
 
index 3fc07f3..980c886 100644 (file)
@@ -25,6 +25,7 @@
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/SubElementInterfaces.h"
+#include "mlir/IR/Verifier.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/MapVector.h"
@@ -40,6 +41,7 @@
 #include "llvm/Support/Endian.h"
 #include "llvm/Support/Regex.h"
 #include "llvm/Support/SaveAndRestore.h"
+#include "llvm/Support/Threading.h"
 
 #include <tuple>
 
@@ -141,6 +143,11 @@ struct AsmPrinterOptions {
       "mlir-print-op-generic", llvm::cl::init(false),
       llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
 
+  llvm::cl::opt<bool> assumeVerifiedOpt{
+      "mlir-print-assume-verified", llvm::cl::init(false),
+      llvm::cl::desc("Skip op verification when using custom printers"),
+      llvm::cl::Hidden};
+
   llvm::cl::opt<bool> printLocalScopeOpt{
       "mlir-print-local-scope", llvm::cl::init(false),
       llvm::cl::desc("Print with local scope and inline information (eliding "
@@ -160,7 +167,8 @@ void mlir::registerAsmPrinterCLOptions() {
 /// Initialize the printing flags with default supplied by the cl::opts above.
 OpPrintingFlags::OpPrintingFlags()
     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
-      printGenericOpFormFlag(false), printLocalScope(false) {
+      printGenericOpFormFlag(false), assumeVerifiedFlag(false),
+      printLocalScope(false) {
   // Initialize based upon command line options, if they are available.
   if (!clOptions.isConstructed())
     return;
@@ -169,6 +177,7 @@ OpPrintingFlags::OpPrintingFlags()
   printDebugInfoFlag = clOptions->printDebugInfoOpt;
   printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
   printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
+  assumeVerifiedFlag = clOptions->assumeVerifiedOpt;
   printLocalScope = clOptions->printLocalScopeOpt;
 }
 
@@ -196,6 +205,12 @@ OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
   return *this;
 }
 
+/// Do not verify the operation when using custom operation printers.
+OpPrintingFlags &OpPrintingFlags::assumeVerified() {
+  assumeVerifiedFlag = true;
+  return *this;
+}
+
 /// Use local scope when printing the operation. This allows for using the
 /// printer in a more localized and thread-safe setting, but may not necessarily
 /// be identical of what the IR will look like when dumping the full module.
@@ -231,6 +246,11 @@ bool OpPrintingFlags::shouldPrintGenericOpForm() const {
   return printGenericOpFormFlag;
 }
 
+/// Return if operation verification should be skipped.
+bool OpPrintingFlags::shouldAssumeVerified() const {
+  return assumeVerifiedFlag;
+}
+
 /// Return if the printer should use local scope when dumping the IR.
 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
 
@@ -1245,9 +1265,31 @@ private:
 } // namespace detail
 } // namespace mlir
 
+/// Verifies the operation and switches to generic op printing if verification
+/// fails. We need to do this because custom print functions may fail for
+/// invalid ops.
+static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
+                                              OpPrintingFlags printerFlags) {
+  if (printerFlags.shouldPrintGenericOpForm() ||
+      printerFlags.shouldAssumeVerified())
+    return printerFlags;
+
+  // Ignore errors emitted by the verifier. We check the thread id to avoid
+  // consuming other threads' errors.
+  auto parentThreadId = llvm::get_threadid();
+  ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &) {
+    return success(parentThreadId == llvm::get_threadid());
+  });
+  if (failed(verify(op)))
+    printerFlags.printGenericOpForm();
+
+  return printerFlags;
+}
+
 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
                    LocationMap *locationMap)
-    : impl(std::make_unique<AsmStateImpl>(op, printerFlags, locationMap)) {}
+    : impl(std::make_unique<AsmStateImpl>(
+          op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {}
 AsmState::~AsmState() = default;
 
 const OpPrintingFlags &AsmState::getPrinterFlags() const {
@@ -2853,14 +2895,15 @@ void IntegerSet::print(raw_ostream &os) const {
   AsmPrinter::Impl(os).printIntegerSet(*this);
 }
 
-void Value::print(raw_ostream &os) {
+void Value::print(raw_ostream &os) { print(os, OpPrintingFlags()); }
+void Value::print(raw_ostream &os, const OpPrintingFlags &flags) {
   if (!impl) {
     os << "<<NULL VALUE>>";
     return;
   }
 
   if (auto *op = getDefiningOp())
-    return op->print(os);
+    return op->print(os, flags);
   // TODO: Improve BlockArgument print'ing.
   BlockArgument arg = this->cast<BlockArgument>();
   os << "<block argument> of type '" << arg.getType()
index ea0ff5a..975f694 100644 (file)
@@ -121,6 +121,17 @@ Diagnostic &Diagnostic::operator<<(OperationName val) {
   return *this;
 }
 
+/// Adjusts operation printing flags used in diagnostics for the given severity
+/// level.
+static OpPrintingFlags adjustPrintingFlags(OpPrintingFlags flags,
+                                           DiagnosticSeverity severity) {
+  flags.useLocalScope();
+  flags.elideLargeElementsAttrs();
+  if (severity == DiagnosticSeverity::Error)
+    flags.printGenericOpForm();
+  return flags;
+}
+
 /// Stream in an Operation.
 Diagnostic &Diagnostic::operator<<(Operation &val) {
   return appendOp(val, OpPrintingFlags());
@@ -128,8 +139,7 @@ Diagnostic &Diagnostic::operator<<(Operation &val) {
 Diagnostic &Diagnostic::appendOp(Operation &val, const OpPrintingFlags &flags) {
   std::string str;
   llvm::raw_string_ostream os(str);
-  val.print(os,
-            OpPrintingFlags(flags).useLocalScope().elideLargeElementsAttrs());
+  val.print(os, adjustPrintingFlags(flags, severity));
   return *this << os.str();
 }
 
@@ -137,7 +147,7 @@ Diagnostic &Diagnostic::appendOp(Operation &val, const OpPrintingFlags &flags) {
 Diagnostic &Diagnostic::operator<<(Value val) {
   std::string str;
   llvm::raw_string_ostream os(str);
-  val.print(os);
+  val.print(os, adjustPrintingFlags(OpPrintingFlags(), severity));
   return *this << os.str();
 }
 
index 89a6b27..ea68f11 100644 (file)
@@ -1097,6 +1097,8 @@ LogicalResult OpTrait::impl::verifyIsIsolatedFromAbove(Operation *isolatedOp) {
           // Check that any value that is used by an operation is defined in the
           // same region as either an operation result.
           auto *operandRegion = operand.getParentRegion();
+          if (!operandRegion)
+            return op.emitError("operation's operand is unlinked");
           if (!region.isAncestor(operandRegion)) {
             return op.emitOpError("using value defined outside the region")
                        .attachNote(isolatedOp->getLoc())
diff --git a/mlir/test/IR/print-ir-invalid.mlir b/mlir/test/IR/print-ir-invalid.mlir
new file mode 100644 (file)
index 0000000..e833534
--- /dev/null
@@ -0,0 +1,33 @@
+// # RUN: mlir-opt -test-print-invalid %s | FileCheck %s
+// # RUN: mlir-opt -test-print-invalid %s --mlir-print-assume-verified  | FileCheck %s --check-prefix=ASSUME-VERIFIED
+
+// The pass creates some ops and prints them to stdout, the input is just an
+// empty module.
+module {}
+
+// The operation is invalid because the body does not have a terminator, print
+// the generic form.
+// CHECK:      Invalid operation:
+// CHECK-NEXT: "builtin.func"() ({
+// CHECK-NEXT: ^bb0:
+// CHECK-NEXT: })
+// CHECK-SAME: sym_name = "test"
+
+// The operation is valid because the body has a terminator, print the custom
+// form.
+// CHECK:      Valid operation:
+// CHECK-NEXT: func @test() {
+// CHECK-NEXT:   return
+// CHECK-NEXT: }
+
+// With --mlir-print-assume-verified the custom form is printed in both cases.
+// This works in this particular case, but may crash in general.
+
+// ASSUME-VERIFIED:      Invalid operation:
+// ASSUME-VERIFIED-NEXT: func @test() {
+// ASSUME-VERIFIED-NEXT: }
+
+// ASSUME-VERIFIED:      Valid operation:
+// ASSUME-VERIFIED-NEXT: func @test() {
+// ASSUME-VERIFIED-NEXT:   return
+// ASSUME-VERIFIED-NEXT: }
index f656a4e..a195817 100644 (file)
@@ -9,6 +9,7 @@ add_mlir_library(MLIRTestIR
   TestOpaqueLoc.cpp
   TestOperationEquals.cpp
   TestPrintDefUse.cpp
+  TestPrintInvalid.cpp
   TestPrintNesting.cpp
   TestSideEffects.cpp
   TestSlicing.cpp
diff --git a/mlir/test/lib/IR/TestPrintInvalid.cpp b/mlir/test/lib/IR/TestPrintInvalid.cpp
new file mode 100644 (file)
index 0000000..537af8b
--- /dev/null
@@ -0,0 +1,52 @@
+//===- TestPrintInvalid.cpp - Test printing invalid ops -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass creates and prints to the standard output an invalid operation and
+// a valid operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+namespace {
+struct TestPrintInvalidPass
+    : public PassWrapper<TestPrintInvalidPass, OperationPass<ModuleOp>> {
+  StringRef getArgument() const final { return "test-print-invalid"; }
+  StringRef getDescription() const final {
+    return "Test printing invalid ops.";
+  }
+  void getDependentDialects(DialectRegistry &registry) const {
+    registry.insert<func::FuncDialect>();
+  }
+
+  void runOnOperation() override {
+    Location loc = getOperation().getLoc();
+    OpBuilder builder(getOperation().body());
+    auto funcOp = builder.create<FuncOp>(
+        loc, "test", FunctionType::get(getOperation().getContext(), {}, {}));
+    funcOp.addEntryBlock();
+    // The created function is invalid because there is no return op.
+    llvm::outs() << "Invalid operation:\n" << funcOp << "\n";
+    builder.setInsertionPointToEnd(&funcOp.getBody().front());
+    builder.create<func::ReturnOp>(loc);
+    // Now this function is valid.
+    llvm::outs() << "Valid operation:\n" << funcOp << "\n";
+    funcOp.erase();
+  }
+};
+} // namespace
+
+namespace mlir {
+void registerTestPrintInvalidPass() {
+  PassRegistration<TestPrintInvalidPass>{};
+}
+} // namespace mlir
index 7336980..9c64317 100644 (file)
@@ -45,6 +45,7 @@ void registerTestLoopPermutationPass();
 void registerTestMatchers();
 void registerTestOperationEqualPass();
 void registerTestPrintDefUsePass();
+void registerTestPrintInvalidPass();
 void registerTestPrintNestingPass();
 void registerTestReducer();
 void registerTestSpirvEntryPointABIPass();
@@ -132,6 +133,7 @@ void registerTestPasses() {
   registerTestMatchers();
   registerTestOperationEqualPass();
   registerTestPrintDefUsePass();
+  registerTestPrintInvalidPass();
   registerTestPrintNestingPass();
   registerTestReducer();
   registerTestSpirvEntryPointABIPass();