NFC: Cleanup FuncVerifier and refactor it into a general OperationVerifier. The funct...
authorRiver Riddle <riverriddle@google.com>
Fri, 7 Jun 2019 16:46:13 +0000 (09:46 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 9 Jun 2019 23:23:23 +0000 (16:23 -0700)
PiperOrigin-RevId: 252065646

mlir/lib/Analysis/Verifier.cpp
mlir/test/IR/invalid.mlir

index 843c499..e1f6f28 100644 (file)
@@ -28,7 +28,7 @@
 // valid form.
 //
 // This should not check for things that are always wrong by construction (e.g.
-// affine maps or other immutable structures that are incorrect), because those
+// attributes or other immutable structures that are incorrect), because those
 // are not mutable and can be checked at time of construction.
 //
 //===----------------------------------------------------------------------===//
 using namespace mlir;
 
 namespace {
-/// This class encapsulates all the state used to verify a function body.  It is
-/// a pervasive truth that this file treats "true" as an error that needs to be
-/// recovered from, and "false" as success.
-///
-class FuncVerifier {
+/// This class encapsulates all the state used to verify an operation region.
+class OperationVerifier {
 public:
-  LogicalResult failure() { return mlir::failure(); }
+  explicit OperationVerifier(MLIRContext *ctx)
+      : ctx(ctx), identifierRegex("^[a-zA-Z_][a-zA-Z_0-9\\.\\$]*$") {}
 
-  LogicalResult failure(const Twine &message, Operation &value) {
-    return value.emitError(message);
-  }
-
-  LogicalResult failure(const Twine &message, Function &fn) {
-    return fn.emitError(message);
-  }
-
-  LogicalResult failure(const Twine &message, Block &bb) {
-    // Take the location information for the first operation in the block.
-    if (!bb.empty())
-      return failure(message, bb.front());
-
-    // Worst case, fall back to using the function's location.
-    return failure(message, fn);
-  }
+  /// Verify the body of the given function.
+  LogicalResult verify(Function &fn);
 
   /// Returns the registered dialect for a dialect-specific attribute.
   Dialect *getDialectForAttribute(const NamedAttribute &attr) {
     assert(attr.first.strref().contains('.') && "expected dialect attribute");
     auto dialectNamePair = attr.first.strref().split('.');
-    return fn.getContext()->getRegisteredDialect(dialectNamePair.first);
+    return ctx->getRegisteredDialect(dialectNamePair.first);
   }
 
-  LogicalResult verify();
+  /// Returns if the given string is valid to use as an identifier name.
+  bool isValidName(StringRef name) { return identifierRegex.match(name); }
+
+private:
+  /// Verify the given potentially nested region or block.
+  LogicalResult verifyRegion(Region &region, bool isTopLevel);
   LogicalResult verifyBlock(Block &block, bool isTopLevel);
   LogicalResult verifyOperation(Operation &op);
+
+  /// Verify the dominance within the given IR unit.
+  LogicalResult verifyDominance(Region &region);
   LogicalResult verifyDominance(Block &block);
-  LogicalResult verifyOpDominance(Operation &op);
+  LogicalResult verifyDominance(Operation &op);
 
-  explicit FuncVerifier(Function &fn)
-      : fn(fn), identifierRegex("^[a-zA-Z_][a-zA-Z_0-9\\.\\$]*$") {}
+  /// Emit an error for the given block.
+  InFlightDiagnostic emitError(Block &bb, const Twine &message) {
+    // Take the location information for the first operation in the block.
+    if (!bb.empty())
+      return bb.front().emitError(message);
 
-private:
-  /// The function being checked.
-  Function &fn;
+    // Worst case, fall back to using the parent's location.
+    return ctx->emitError(bb.getParent()->getLoc(), message);
+  }
+
+  /// The current context for the verifier.
+  MLIRContext *ctx;
 
   /// Dominance information for this function, when checking dominance.
   DominanceInfo *domInfo = nullptr;
@@ -103,80 +101,11 @@ private:
 };
 } // end anonymous namespace
 
-LogicalResult FuncVerifier::verify() {
-  llvm::PrettyStackTraceFormat fmt("MLIR Verifier: func @%s",
-                                   fn.getName().c_str());
-
-  // Check that the function name is valid.
-  if (!identifierRegex.match(fn.getName().strref()))
-    return failure("invalid function name '" + fn.getName().strref() + "'", fn);
-
-  /// Verify that all of the attributes are okay.
-  for (auto attr : fn.getAttrs()) {
-    if (!identifierRegex.match(attr.first))
-      return failure("invalid attribute name '" + attr.first.strref() + "'",
-                     fn);
-
-    /// Check that the attribute is a dialect attribute, i.e. contains a '.' for
-    /// the namespace.
-    if (!attr.first.strref().contains('.'))
-      return failure("functions may only have dialect attributes", fn);
-
-    // Verify this attribute with the defining dialect.
-    if (auto *dialect = getDialectForAttribute(attr))
-      if (failed(dialect->verifyFunctionAttribute(&fn, attr)))
-        return failure();
-  }
-
-  /// Verify that all of the argument attributes are okay.
-  for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
-    for (auto attr : fn.getArgAttrs(i)) {
-      if (!identifierRegex.match(attr.first))
-        return failure(
-            llvm::formatv("invalid attribute name '{0}' on argument {1}",
-                          attr.first.strref(), i),
-            fn);
-
-      /// Check that the attribute is a dialect attribute, i.e. contains a '.'
-      /// for the namespace.
-      if (!attr.first.strref().contains('.'))
-        return failure("function arguments may only have dialect attributes",
-                       fn);
-
-      // Verify this attribute with the defining dialect.
-      if (auto *dialect = getDialectForAttribute(attr))
-        if (failed(dialect->verifyFunctionArgAttribute(&fn, i, attr)))
-          return failure();
-    }
-  }
-
-  // External functions have nothing more to check.
-  if (fn.isExternal())
-    return success();
-
-  // Verify the first block has no predecessors.
-  auto *firstBB = &fn.front();
-  if (!firstBB->hasNoPredecessors())
-    return failure("entry block of function may not have predecessors", fn);
-
-  // Verify that the argument list of the function and the arg list of the first
-  // block line up.
-  auto fnInputTypes = fn.getType().getInputs();
-  if (fnInputTypes.size() != firstBB->getNumArguments())
-    return failure("first block of function must have " +
-                       Twine(fnInputTypes.size()) +
-                       " arguments to match function signature",
-                   fn);
-  for (unsigned i = 0, e = firstBB->getNumArguments(); i != e; ++i)
-    if (fnInputTypes[i] != firstBB->getArgument(i)->getType())
-      return failure(
-          "type of argument #" + Twine(i) +
-              " must match corresponding argument in function signature",
-          fn);
-
-  for (auto &block : fn)
-    if (failed(verifyBlock(block, /*isTopLevel=*/true)))
-      return failure();
+/// Verify the body of the given function.
+LogicalResult OperationVerifier::verify(Function &fn) {
+  // Verify the body first.
+  if (failed(verifyRegion(fn.getBody(), /*isTopLevel=*/true)))
+    return failure();
 
   // Since everything looks structurally ok to this point, we do a dominance
   // check.  We do this as a second pass since malformed CFG's can cause
@@ -192,24 +121,38 @@ LogicalResult FuncVerifier::verify() {
   return success();
 }
 
-LogicalResult FuncVerifier::verifyBlock(Block &block, bool isTopLevel) {
-  for (auto *arg : block.getArguments()) {
+LogicalResult OperationVerifier::verifyRegion(Region &region, bool isTopLevel) {
+  if (region.empty())
+    return success();
+
+  // Verify the first block has no predecessors.
+  auto *firstBB = &region.front();
+  if (!firstBB->hasNoPredecessors())
+    return ctx->emitError(region.getLoc(),
+                          "entry block of region may not have predecessors");
+
+  // Verify each of the blocks within the region.
+  for (auto &block : region)
+    if (failed(verifyBlock(block, isTopLevel)))
+      return failure();
+  return success();
+}
+
+LogicalResult OperationVerifier::verifyBlock(Block &block, bool isTopLevel) {
+  for (auto *arg : block.getArguments())
     if (arg->getOwner() != &block)
-      return failure("block argument not owned by block", block);
-  }
+      return emitError(block, "block argument not owned by block");
 
   // Verify that this block has a terminator.
-  if (block.empty()) {
-    return failure("block with no terminator", block);
-  }
+  if (block.empty())
+    return emitError(block, "block with no terminator");
 
   // Verify the non-terminator operations separately so that we can verify
   // they has no successors.
   for (auto &op : llvm::make_range(block.begin(), std::prev(block.end()))) {
     if (op.getNumSuccessors() != 0)
-      return failure(
-          "operation with block successors must terminate its parent block",
-          op);
+      return op.emitError(
+          "operation with block successors must terminate its parent block");
 
     if (failed(verifyOperation(op)))
       return failure();
@@ -219,36 +162,28 @@ LogicalResult FuncVerifier::verifyBlock(Block &block, bool isTopLevel) {
   if (failed(verifyOperation(block.back())))
     return failure();
   if (block.back().isKnownNonTerminator())
-    return failure("block with no terminator", block);
+    return emitError(block, "block with no terminator");
 
   // Verify that this block is not branching to a block of a different
   // region.
   for (Block *successor : block.getSuccessors())
     if (successor->getParent() != block.getParent())
-      return failure("branching to block of a different region", block.back());
+      return block.back().emitOpError(
+          "branching to block of a different region");
 
   return success();
 }
 
-/// Check the invariants of the specified operation.
-LogicalResult FuncVerifier::verifyOperation(Operation &op) {
-  if (op.getFunction() != &fn)
-    return failure("operation in the wrong function", op);
-
+LogicalResult OperationVerifier::verifyOperation(Operation &op) {
   // Check that operands are non-nil and structurally ok.
-  for (auto *operand : op.getOperands()) {
+  for (auto *operand : op.getOperands())
     if (!operand)
-      return failure("null operand found", op);
-
-    if (operand->getFunction() != &fn)
-      return failure("reference to operand defined in another function", op);
-  }
+      return op.emitError("null operand found");
 
   /// Verify that all of the attributes are okay.
   for (auto attr : op.getAttrs()) {
     if (!identifierRegex.match(attr.first))
-      return failure("invalid attribute name '" + attr.first.strref() + "'",
-                     op);
+      return op.emitError("invalid attribute name '") << attr.first << "'";
 
     // Check for any optional dialect specific attributes.
     if (!attr.first.strref().contains('.'))
@@ -263,11 +198,10 @@ LogicalResult FuncVerifier::verifyOperation(Operation &op) {
   if (opInfo && failed(opInfo->verifyInvariants(&op)))
     return failure();
 
-  // Verify that all child blocks are ok.
+  // Verify that all child regions are ok.
   for (auto &region : op.getRegions())
-    for (auto &b : region)
-      if (failed(verifyBlock(b, /*isTopLevel=*/false)))
-        return failure();
+    if (failed(verifyRegion(region, /*isTopLevel=*/false)))
+      return failure();
 
   // If this is a registered operation, there is nothing left to do.
   if (opInfo)
@@ -280,7 +214,7 @@ LogicalResult FuncVerifier::verifyOperation(Operation &op) {
   auto it = dialectAllowsUnknownOps.find(dialectPrefix);
   if (it == dialectAllowsUnknownOps.end()) {
     // If the operation dialect is registered, query it directly.
-    if (auto *dialect = fn.getContext()->getRegisteredDialect(dialectPrefix))
+    if (auto *dialect = ctx->getRegisteredDialect(dialectPrefix))
       it = dialectAllowsUnknownOps
                .try_emplace(dialectPrefix, dialect->allowsUnknownOperations())
                .first;
@@ -290,24 +224,23 @@ LogicalResult FuncVerifier::verifyOperation(Operation &op) {
   }
 
   if (!it->second) {
-    return failure("unregistered operation '" + op.getName().getStringRef() +
-                       "' found in dialect ('" + dialectPrefix +
-                       "') that does not allow unknown operations",
-                   op);
+    return op.emitError("unregistered operation '")
+           << op.getName() << "' found in dialect ('" << dialectPrefix
+           << "') that does not allow unknown operations";
   }
 
   return success();
 }
 
-LogicalResult FuncVerifier::verifyDominance(Block &block) {
+LogicalResult OperationVerifier::verifyDominance(Block &block) {
   // Verify the dominance of each of the held operations.
   for (auto &op : block)
-    if (failed(verifyOpDominance(op)))
+    if (failed(verifyDominance(op)))
       return failure();
   return success();
 }
 
-LogicalResult FuncVerifier::verifyOpDominance(Operation &op) {
+LogicalResult OperationVerifier::verifyDominance(Operation &op) {
   // Check that operands properly dominate this use.
   for (unsigned operandNo = 0, e = op.getNumOperands(); operandNo != e;
        ++operandNo) {
@@ -338,7 +271,69 @@ LogicalResult FuncVerifier::verifyOpDominance(Operation &op) {
 /// Perform (potentially expensive) checks of invariants, used to detect
 /// compiler bugs.  On error, this reports the error through the MLIRContext and
 /// returns failure.
-LogicalResult Function::verify() { return FuncVerifier(*this).verify(); }
+LogicalResult Function::verify() {
+  OperationVerifier opVerifier(getContext());
+  llvm::PrettyStackTraceFormat fmt("MLIR Verifier: func @%s",
+                                   getName().c_str());
+
+  // Check that the function name is valid.
+  if (!opVerifier.isValidName(getName().strref()))
+    return emitError("invalid function name '") << getName() << "'";
+
+  /// Verify that all of the attributes are okay.
+  for (auto attr : getAttrs()) {
+    if (!opVerifier.isValidName(attr.first))
+      return emitError("invalid attribute name '") << attr.first << "'";
+
+    /// Check that the attribute is a dialect attribute, i.e. contains a '.' for
+    /// the namespace.
+    if (!attr.first.strref().contains('.'))
+      return emitError("functions may only have dialect attributes");
+
+    // Verify this attribute with the defining dialect.
+    if (auto *dialect = opVerifier.getDialectForAttribute(attr))
+      if (failed(dialect->verifyFunctionAttribute(this, attr)))
+        return failure();
+  }
+
+  /// Verify that all of the argument attributes are okay.
+  for (unsigned i = 0, e = getNumArguments(); i != e; ++i) {
+    for (auto attr : getArgAttrs(i)) {
+      if (!opVerifier.isValidName(attr.first))
+        return emitError("invalid attribute name '")
+               << attr.first << "' on argument " << i;
+
+      /// Check that the attribute is a dialect attribute, i.e. contains a '.'
+      /// for the namespace.
+      if (!attr.first.strref().contains('.'))
+        return emitError("function arguments may only have dialect attributes");
+
+      // Verify this attribute with the defining dialect.
+      if (auto *dialect = opVerifier.getDialectForAttribute(attr))
+        if (failed(dialect->verifyFunctionArgAttribute(this, i, attr)))
+          return failure();
+    }
+  }
+
+  // External functions have nothing more to check.
+  if (isExternal())
+    return success();
+
+  // Verify that the argument list of the function and the arg list of the first
+  // block line up.
+  auto *firstBB = &front();
+  auto fnInputTypes = getType().getInputs();
+  if (fnInputTypes.size() != firstBB->getNumArguments())
+    return emitError("first block of function must have ")
+           << fnInputTypes.size() << " arguments to match function signature";
+  for (unsigned i = 0, e = firstBB->getNumArguments(); i != e; ++i)
+    if (fnInputTypes[i] != firstBB->getArgument(i)->getType())
+      return emitError("type of argument #")
+             << i << " must match corresponding argument in function signature";
+
+  // Finally, verify the body of the function.
+  return opVerifier.verify(*this);
+}
 
 /// Perform (potentially expensive) checks of invariants, used to detect
 /// compiler bugs.  On error, this reports the error through the MLIRContext and
index 9825648..20d7a3b 100644 (file)
@@ -134,7 +134,7 @@ func @block_arg_no_close_paren() {
 // -----
 
 func @block_first_has_predecessor() {
-// expected-error@-1 {{entry block of function may not have predecessors}}
+// expected-error@-1 {{entry block of region may not have predecessors}}
 ^bb42:
   br ^bb43
 ^bb43: