[mlir] Fix missing verification after running an OpToOpAdaptorPass
authorRiver Riddle <riddleriver@gmail.com>
Wed, 16 Mar 2022 18:45:14 +0000 (11:45 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Wed, 16 Mar 2022 21:53:41 +0000 (14:53 -0700)
The current decision of when to run the verifier is running on the
assumption that nested passes can't affect the validity of the parent
operation, which isn't true. Parent operations may attach any number
of constraints on nested operations, which may not necessarily be
captured (or shouldn't be captured) at a smaller granularity.

This commit rectifies this by properly running the verifier after an
OpToOpAdaptor pass. To avoid an explosive increase in compile time,
we only run verification on the parent operation itself. To do this, a
flag to mlir::verify is added to avoid recursive verification if it isn't
desired.

Fixes #54288

Differential Revision: https://reviews.llvm.org/D121836

mlir/include/mlir/IR/Verifier.h
mlir/lib/IR/Verifier.cpp
mlir/lib/Pass/Pass.cpp
mlir/test/Pass/invalid-parent.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Pass/CMakeLists.txt
mlir/test/lib/Pass/TestPassManager.cpp

index 3e1a0f8..1fcc99e 100644 (file)
@@ -15,8 +15,12 @@ class Operation;
 
 /// Perform (potentially expensive) checks of invariants, used to detect
 /// compiler bugs, on this operation and any nested operations. On error, this
-/// reports the error through the MLIRContext and returns failure.
-LogicalResult verify(Operation *op);
+/// reports the error through the MLIRContext and returns failure. If
+/// `verifyRecursively` is false, this assumes that nested operations have
+/// already been properly verified, and does not recursively invoke the verifier
+/// on nested operations.
+LogicalResult verify(Operation *op, bool verifyRecursively = true);
+
 } // namespace mlir
 
 #endif
index 0c8724d..62212db 100644 (file)
@@ -43,6 +43,11 @@ namespace {
 /// This class encapsulates all the state used to verify an operation region.
 class OperationVerifier {
 public:
+  /// If `verifyRecursively` is true, then this will also recursively verify
+  /// nested operations.
+  explicit OperationVerifier(bool verifyRecursively)
+      : verifyRecursively(verifyRecursively) {}
+
   /// Verify the given operation.
   LogicalResult verifyOpAndDominance(Operation &op);
 
@@ -61,6 +66,10 @@ private:
   /// Operation.
   LogicalResult verifyDominanceOfContainedRegions(Operation &op,
                                                   DominanceInfo &domInfo);
+
+  /// A flag indicating if this verifier should recursively verify nested
+  /// operations.
+  bool verifyRecursively;
 };
 } // namespace
 
@@ -81,8 +90,12 @@ LogicalResult OperationVerifier::verifyOpAndDominance(Operation &op) {
       return failure();
   }
 
-  // Check the dominance properties and invariants of any operations in the
-  // regions contained by the 'opsWithIsolatedRegions' operations.
+  // If we aren't verifying nested operations, then we're done.
+  if (!verifyRecursively)
+    return success();
+
+  // Otherwise, check the dominance properties and invariants of any operations
+  // in the regions contained by the 'opsWithIsolatedRegions' operations.
   return failableParallelForEach(
       op.getContext(), opsWithIsolatedRegions,
       [&](Operation *op) { return verifyOpAndDominance(*op); });
@@ -120,21 +133,25 @@ LogicalResult OperationVerifier::verifyBlock(
 
   // Check each operation, and make sure there are no branches out of the
   // middle of this block.
-  for (auto &op : block) {
+  for (Operation &op : block) {
     // Only the last instructions is allowed to have successors.
     if (op.getNumSuccessors() != 0 && &op != &block.back())
       return op.emitError(
           "operation with block successors must terminate its parent block");
 
+    // If we aren't verifying recursievly, there is nothing left to check.
+    if (!verifyRecursively)
+      continue;
+
     // If this operation has regions and is IsolatedFromAbove, we defer
     // checking.  This allows us to parallelize verification better.
     if (op.getNumRegions() != 0 &&
         op.hasTrait<OpTrait::IsIsolatedFromAbove>()) {
       opsWithIsolatedRegions.push_back(&op);
-    } else {
+
       // Otherwise, check the operation inline.
-      if (failed(verifyOperation(op, opsWithIsolatedRegions)))
-        return failure();
+    } else if (failed(verifyOperation(op, opsWithIsolatedRegions))) {
+      return failure();
     }
   }
 
@@ -185,8 +202,9 @@ LogicalResult OperationVerifier::verifyOperation(
     auto kindInterface = dyn_cast<RegionKindInterface>(op);
 
     // Verify that all child regions are ok.
+    MutableArrayRef<Region> regions = op.getRegions();
     for (unsigned i = 0; i < numRegions; ++i) {
-      Region &region = op.getRegion(i);
+      Region &region = regions[i];
       RegionKind kind =
           kindInterface ? kindInterface.getRegionKind(i) : RegionKind::SSACFG;
       // Check that Graph Regions only have a single basic block. This is
@@ -210,10 +228,13 @@ LogicalResult OperationVerifier::verifyOperation(
         return emitError(op.getLoc(),
                          "entry block of region may not have predecessors");
 
-      // Verify each of the blocks within the region.
-      for (Block &block : region)
-        if (failed(verifyBlock(block, opsWithIsolatedRegions)))
-          return failure();
+      // Verify each of the blocks within the region if we are verifying
+      // recursively.
+      if (verifyRecursively) {
+        for (Block &block : region)
+          if (failed(verifyBlock(block, opsWithIsolatedRegions)))
+            return failure();
+      }
     }
   }
 
@@ -330,10 +351,10 @@ OperationVerifier::verifyDominanceOfContainedRegions(Operation &op,
           }
         }
 
-        // Recursively verify dominance within each operation in the
-        // block, even if the block itself is not reachable, or we are in
-        // a region which doesn't respect dominance.
-        if (op.getNumRegions() != 0) {
+        // Recursively verify dominance within each operation in the block, even
+        // if the block itself is not reachable, or we are in a region which
+        // doesn't respect dominance.
+        if (verifyRecursively && op.getNumRegions() != 0) {
           // If this operation is IsolatedFromAbove, then we'll handle it in the
           // outer verification loop.
           if (op.hasTrait<OpTrait::IsIsolatedFromAbove>())
@@ -352,9 +373,7 @@ OperationVerifier::verifyDominanceOfContainedRegions(Operation &op,
 // Entrypoint
 //===----------------------------------------------------------------------===//
 
-/// Perform (potentially expensive) checks of invariants, used to detect
-/// compiler bugs.  On error, this reports the error through the MLIRContext and
-/// returns failure.
-LogicalResult mlir::verify(Operation *op) {
-  return OperationVerifier().verifyOpAndDominance(*op);
+LogicalResult mlir::verify(Operation *op, bool verifyRecursively) {
+  OperationVerifier verifier(verifyRecursively);
+  return verifier.verifyOpAndDominance(*op);
 }
index 22a9641..7256f44 100644 (file)
@@ -408,22 +408,24 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
   // failed).
   if (!passFailed && verifyPasses) {
     bool runVerifierNow = true;
+
+    // If the pass is an adaptor pass, we don't run the verifier recursively
+    // because the nested operations should have already been verified after
+    // nested passes had run.
+    bool runVerifierRecursively = !isa<OpToOpPassAdaptor>(pass);
+
     // Reduce compile time by avoiding running the verifier if the pass didn't
     // change the IR since the last time the verifier was run:
     //
     //  1) If the pass said that it preserved all analyses then it can't have
     //     permuted the IR.
-    //  2) If we just ran an OpToOpPassAdaptor (e.g. to run function passes
-    //     within a module) then each sub-unit will have been verified on the
-    //     subunit (and those passes aren't allowed to modify the parent).
     //
     // We run these checks in EXPENSIVE_CHECKS mode out of caution.
 #ifndef EXPENSIVE_CHECKS
-    runVerifierNow = !isa<OpToOpPassAdaptor>(pass) &&
-                     !pass->passState->preservedAnalyses.isAll();
+    runVerifierNow = !pass->passState->preservedAnalyses.isAll();
 #endif
     if (runVerifierNow)
-      passFailed = failed(verify(op));
+      passFailed = failed(verify(op, runVerifierRecursively));
   }
 
   // Instrument after the pass has run.
diff --git a/mlir/test/Pass/invalid-parent.mlir b/mlir/test/Pass/invalid-parent.mlir
new file mode 100644 (file)
index 0000000..2979ba9
--- /dev/null
@@ -0,0 +1,8 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.func(test-pass-invalid-parent)' -verify-diagnostics
+
+// Test that we properly report errors when the parent becomes invalid after running a pass
+// on a child operation.
+// expected-error@below {{'some_unknown_func' does not reference a valid function}}
+func @TestCreateInvalidCallInPass() {
+  return
+}
index e0c5eea..c530582 100644 (file)
@@ -359,6 +359,21 @@ void TestDialect::getCanonicalizationPatterns(
 }
 
 //===----------------------------------------------------------------------===//
+// TestCallOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  // Check that the callee attribute was specified.
+  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
+  if (!fnAttr)
+    return emitOpError("requires a 'callee' symbol reference attribute");
+  if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
+    return emitOpError() << "'" << fnAttr.getValue()
+                         << "' does not reference a valid function";
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
 // TestFoldToCallOp
 //===----------------------------------------------------------------------===//
 
index da1aa2f..3f3f812 100644 (file)
@@ -375,6 +375,14 @@ def VariadicNoTerminatorOp : TEST_Op<"variadic_no_terminator_op",
 // Test Call Interfaces
 //===----------------------------------------------------------------------===//
 
+def TestCallOp : TEST_Op<"call", [DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
+  let results = (outs Variadic<AnyType>);
+  let assemblyFormat = [{
+    $callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
+  }];
+}
+
 def ConversionCallOp : TEST_Op<"conversion_call_op",
     [CallOpInterface]> {
   let arguments = (ins Variadic<AnyType>:$arg_operands, SymbolRefAttr:$callee);
index 061d40f..dd90c22 100644 (file)
@@ -11,4 +11,11 @@ add_mlir_library(MLIRTestPass
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRPass
+  MLIRTestDialect
+  )
+
+target_include_directories(MLIRTestPass
+  PRIVATE
+  ${CMAKE_CURRENT_SOURCE_DIR}/../Dialect/Test
+  ${CMAKE_CURRENT_BINARY_DIR}/../Dialect/Test
   )
index 5a759f3..85dc7bf 100644 (file)
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "TestDialect.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
@@ -98,6 +99,27 @@ class TestFailurePass : public PassWrapper<TestFailurePass, OperationPass<>> {
   }
 };
 
+/// A test pass that always fails to enable testing the failure recovery
+/// mechanisms of the pass manager.
+class TestInvalidParentPass
+    : public PassWrapper<TestInvalidParentPass,
+                         InterfacePass<FunctionOpInterface>> {
+  StringRef getArgument() const final { return "test-pass-invalid-parent"; }
+  StringRef getDescription() const final {
+    return "Test a pass in the pass manager that makes the parent operation "
+           "invalid";
+  }
+  void getDependentDialects(DialectRegistry &registry) const final {
+    registry.insert<test::TestDialect>();
+  }
+  void runOnOperation() final {
+    FunctionOpInterface op = getOperation();
+    OpBuilder b(getOperation().getBody());
+    b.create<test::TestCallOp>(op.getLoc(), TypeRange(), "some_unknown_func",
+                               ValueRange());
+  }
+};
+
 /// A test pass that contains a statistic.
 struct TestStatisticPass
     : public PassWrapper<TestStatisticPass, OperationPass<>> {
@@ -144,6 +166,7 @@ void registerPassManagerTestPass() {
 
   PassRegistration<TestCrashRecoveryPass>();
   PassRegistration<TestFailurePass>();
+  PassRegistration<TestInvalidParentPass>();
 
   PassRegistration<TestStatisticPass>();