[mlir] Support verification order (2/3)
authorChia-hung Duan <chiahungduan@google.com>
Fri, 25 Feb 2022 18:17:30 +0000 (18:17 +0000)
committerChia-hung Duan <chiahungduan@google.com>
Fri, 25 Feb 2022 19:04:56 +0000 (19:04 +0000)
    This change gives explicit order of verifier execution and adds
    `hasRegionVerifier` and `verifyWithRegions` to increase the granularity
    of verifier classification. The orders are as below,

    1. InternalOpTrait will be verified first, they can be run independently.
    2. `verifyInvariants` which is constructed by ODS, it verifies the type,
       attributes, .etc.
    3. Other Traits/Interfaces that have marked their verifier as
       `verifyTrait` or `verifyWithRegions=0`.
    4. Custom verifier which is defined in the op and has marked
       `hasVerifier=1`

    If an operation has regions, then it may have the second phase,

    5. Traits/Interfaces that have marked their verifier as
       `verifyRegionTrait` or
       `verifyWithRegions=1`. This implies the verifier needs to access the
       operations in its regions.
    6. Custom verifier which is defined in the op and has marked
       `hasRegionVerifier=1`

    Note that the second phase will be run after the operations in the
    region are verified. Based on the verification order, you will be able to
    avoid verifying duplicate things.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D116789

28 files changed:
mlir/docs/OpDefinitions.md
mlir/docs/Traits.md
mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/OperationSupport.h
mlir/include/mlir/TableGen/Interfaces.h
mlir/include/mlir/TableGen/Trait.h
mlir/lib/Dialect/Affine/IR/AffineOps.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Verifier.cpp
mlir/lib/TableGen/Interfaces.cpp
mlir/lib/TableGen/Trait.cpp
mlir/test/Dialect/Arithmetic/invalid.mlir
mlir/test/Dialect/GPU/invalid.mlir
mlir/test/Dialect/LLVMIR/global.mlir
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/named-ops.mlir
mlir/test/Dialect/SPIRV/IR/bit-ops.mlir
mlir/test/Dialect/Shape/invalid.mlir
mlir/test/Dialect/traits.mlir
mlir/test/IR/invalid-module-op.mlir
mlir/test/IR/traits.mlir
mlir/test/mlir-tblgen/op-decl-and-defs.td
mlir/test/mlir-tblgen/op-interface.td
mlir/test/mlir-tblgen/types.mlir
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp

index e9aa37f..09a3ca5 100644 (file)
@@ -567,10 +567,39 @@ _additional_ verification, you can use
 let hasVerifier = 1;
 ```
 
-This will generate a `LogicalResult verify()` method declaration on the op class
-that can be defined with any additional verification constraints. This method
-will be invoked after the auto-generated verification code. The order of trait
-verification excluding those of `hasVerifier` should not be relied upon.
+or
+
+```tablegen
+let hasRegionVerifier = 1;
+```
+
+This will generate either `LogicalResult verify()` or
+`LogicalResult verifyRegions()` method declaration on the op class
+that can be defined with any additional verification constraints. These method
+will be invoked on its verification order.
+
+#### Verification Ordering
+
+The verification of an operation involves several steps,
+
+1. StructuralOpTrait will be verified first, they can be run independently.
+1. `verifyInvariants` which is constructed by ODS, it verifies the type,
+   attributes, .etc.
+1. Other Traits/Interfaces that have marked their verifier as `verifyTrait` or
+   `verifyWithRegions=0`.
+1. Custom verifier which is defined in the op and has marked `hasVerifier=1`
+
+If an operation has regions, then it may have the second phase,
+
+1. Traits/Interfaces that have marked their verifier as `verifyRegionTrait` or
+   `verifyWithRegions=1`. This implies the verifier needs to access the
+   operations in its regions.
+1. Custom verifier which is defined in the op and has marked
+   `hasRegionVerifier=1`
+
+Note that the second phase will be run after the operations in the region are
+verified. Verifiers further down the order can rely on certain invariants being
+verified by a previous verifier and do not need to re-verify them.
 
 ### Declarative Assembly Format
 
index 065c515..4a6915c 100644 (file)
@@ -36,9 +36,12 @@ class MyTrait : public TraitBase<ConcreteType, MyTrait> {
 };
 ```
 
-Operation traits may also provide a `verifyTrait` hook, that is called when
-verifying the concrete operation. The trait verifiers will currently always be
-invoked before the main `Op::verify`.
+Operation traits may also provide a `verifyTrait` or `verifyRegionTrait` hook
+that is called when verifying the concrete operation. The difference between
+these two is that whether the verifier needs to access the regions, if so, the
+operations in the regions will be verified before the verification of this
+trait. The [verification order](OpDefinitions.md/#verification-ordering)
+determines when a verifier will be invoked.
 
 ```c++
 template <typename ConcreteType>
@@ -53,8 +56,9 @@ public:
 ```
 
 Note: It is generally good practice to define the implementation of the
-`verifyTrait` hook out-of-line as a free function when possible to avoid
-instantiating the implementation for every concrete operation type.
+`verifyTrait` or `verifyRegionTrait` hook out-of-line as a free function when
+possible to avoid instantiating the implementation for every concrete operation
+type.
 
 Operation traits may also provide a `foldTrait` hook that is called when folding
 the concrete operation. The trait folders will only be invoked if the concrete
index d2bef9d..96d134d 100644 (file)
@@ -76,7 +76,7 @@ bool isTopLevelValue(Value value);
 class AffineDmaStartOp
     : public Op<AffineDmaStartOp, OpTrait::MemRefsNormalizable,
                 OpTrait::VariadicOperands, OpTrait::ZeroResult,
-                AffineMapAccessInterface::Trait> {
+                OpTrait::OpInvariants, AffineMapAccessInterface::Trait> {
 public:
   using Op::Op;
   static ArrayRef<StringRef> getAttributeNames() { return {}; }
@@ -227,7 +227,8 @@ public:
   static StringRef getOperationName() { return "affine.dma_start"; }
   static ParseResult parse(OpAsmParser &parser, OperationState &result);
   void print(OpAsmPrinter &p);
-  LogicalResult verifyInvariants();
+  LogicalResult verifyInvariantsImpl();
+  LogicalResult verifyInvariants() { return verifyInvariantsImpl(); }
   LogicalResult fold(ArrayRef<Attribute> cstOperands,
                      SmallVectorImpl<OpFoldResult> &results);
 
@@ -268,7 +269,7 @@ public:
 class AffineDmaWaitOp
     : public Op<AffineDmaWaitOp, OpTrait::MemRefsNormalizable,
                 OpTrait::VariadicOperands, OpTrait::ZeroResult,
-                AffineMapAccessInterface::Trait> {
+                OpTrait::OpInvariants, AffineMapAccessInterface::Trait> {
 public:
   using Op::Op;
   static ArrayRef<StringRef> getAttributeNames() { return {}; }
@@ -315,7 +316,8 @@ public:
   static StringRef getTagMapAttrName() { return "tag_map"; }
   static ParseResult parse(OpAsmParser &parser, OperationState &result);
   void print(OpAsmPrinter &p);
-  LogicalResult verifyInvariants();
+  LogicalResult verifyInvariantsImpl();
+  LogicalResult verifyInvariants() { return verifyInvariantsImpl(); }
   LogicalResult fold(ArrayRef<Attribute> cstOperands,
                      SmallVectorImpl<OpFoldResult> &results);
 };
index 5ad8bd4..11bc1a6 100644 (file)
@@ -2023,6 +2023,10 @@ class PredAttrTrait<string descr, Pred pred> : PredTrait<descr, pred>;
 // OpTrait definitions
 //===----------------------------------------------------------------------===//
 
+// A trait that describes the structure of operation will be marked with
+// `StructuralOpTrait` and they will be verified first.
+class StructuralOpTrait;
+
 // These classes are used to define operation specific traits.
 class NativeOpTrait<string name, list<Trait> traits = []>
     : NativeTrait<name, "Op"> {
@@ -2053,7 +2057,8 @@ class PredOpTrait<string descr, Pred pred, list<Trait> traits = []>
 // Op defines an affine scope.
 def AffineScope : NativeOpTrait<"AffineScope">;
 // Op defines an automatic allocation scope.
-def AutomaticAllocationScope : NativeOpTrait<"AutomaticAllocationScope">;
+def AutomaticAllocationScope :
+  NativeOpTrait<"AutomaticAllocationScope">;
 // Op supports operand broadcast behavior.
 def ResultsBroadcastableShape :
   NativeOpTrait<"ResultsBroadcastableShape">;
@@ -2074,9 +2079,11 @@ def SameTypeOperands : NativeOpTrait<"SameTypeOperands">;
 // Op has same shape for all operands.
 def SameOperandsShape : NativeOpTrait<"SameOperandsShape">;
 // Op has same operand and result shape.
-def SameOperandsAndResultShape : NativeOpTrait<"SameOperandsAndResultShape">;
+def SameOperandsAndResultShape :
+  NativeOpTrait<"SameOperandsAndResultShape">;
 // Op has the same element type (or type itself, if scalar) for all operands.
-def SameOperandsElementType : NativeOpTrait<"SameOperandsElementType">;
+def SameOperandsElementType :
+  NativeOpTrait<"SameOperandsElementType">;
 // Op has the same operand and result element type (or type itself, if scalar).
 def SameOperandsAndResultElementType :
   NativeOpTrait<"SameOperandsAndResultElementType">;
@@ -2104,21 +2111,23 @@ def ElementwiseMappable : TraitList<[
 ]>;
 
 // Op's regions have a single block.
-def SingleBlock : NativeOpTrait<"SingleBlock">;
+def SingleBlock : NativeOpTrait<"SingleBlock">, StructuralOpTrait;
 
 // Op's regions have a single block with the specified terminator.
 class SingleBlockImplicitTerminator<string op>
-    : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>;
+    : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>,
+      StructuralOpTrait;
 
 // Op's regions don't have terminator.
-def NoTerminator : NativeOpTrait<"NoTerminator">;
+def NoTerminator : NativeOpTrait<"NoTerminator">, StructuralOpTrait;
 
 // Op's parent operation is the provided one.
 class HasParent<string op>
-    : ParamNativeOpTrait<"HasParent", op>;
+    : ParamNativeOpTrait<"HasParent", op>, StructuralOpTrait;
 
 class ParentOneOf<list<string> ops>
-    : ParamNativeOpTrait<"HasParent", !interleave(ops, ", ")>;
+    : ParamNativeOpTrait<"HasParent", !interleave(ops, ", ")>,
+      StructuralOpTrait;
 
 // Op result type is derived from the first attribute. If the attribute is an
 // subclass of `TypeAttrBase`, its value is used, otherwise, the type of the
@@ -2147,13 +2156,15 @@ def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">;
 // vector that has the same number of elements as the number of ODS declared
 // operands. That means even if some operands are non-variadic, the attribute
 // still need to have an element for its size, which is always 1.
-def AttrSizedOperandSegments : NativeOpTrait<"AttrSizedOperandSegments">;
+def AttrSizedOperandSegments :
+  NativeOpTrait<"AttrSizedOperandSegments">, StructuralOpTrait;
 // Similar to AttrSizedOperandSegments, but used for results. The attribute
 // should be named as `result_segment_sizes`.
-def AttrSizedResultSegments  : NativeOpTrait<"AttrSizedResultSegments">;
+def AttrSizedResultSegments  :
+  NativeOpTrait<"AttrSizedResultSegments">, StructuralOpTrait;
 
 // Op attached regions have no arguments
-def NoRegionArguments : NativeOpTrait<"NoRegionArguments">;
+def NoRegionArguments : NativeOpTrait<"NoRegionArguments">, StructuralOpTrait;
 
 //===----------------------------------------------------------------------===//
 // OpInterface definitions
@@ -2191,6 +2202,11 @@ class OpInterfaceTrait<string name, code verifyBody = [{}],
   // the operation being verified.
   code verify = verifyBody;
 
+  // A bit indicating if the verifier needs to access the ops in the regions. If
+  // it set to `1`, the region ops will be verified before invoking this
+  // verifier.
+  bit verifyWithRegions = 0;
+
   // Specify the list of traits that need to be verified before the verification
   // of this OpInterfaceTrait.
   list<Trait> dependentTraits = traits;
@@ -2467,6 +2483,16 @@ class Op<Dialect dialect, string mnemonic, list<Trait> props = []> {
   // operation class. The operation should implement this method and verify the
   // additional necessary invariants.
   bit hasVerifier = 0;
+
+  // A bit indicating if the operation has additional invariants that need to
+  // verified and which associate with regions (aside from those verified by the
+  // traits). If set to `1`, an additional `LogicalResult verifyRegions()`
+  // declaration will be generated on the operation class. The operation should
+  // implement this method and verify the additional necessary invariants
+  // associated with regions. Note that this method is invoked after all the
+  // region ops are verified.
+  bit hasRegionVerifier = 0;
+
   // A custom code block corresponding to the extra verification code of the
   // operation.
   // NOTE: This field is deprecated in favor of `hasVerifier` and is slated for
index e38b0cb..268b666 100644 (file)
@@ -200,7 +200,8 @@ public:
 protected:
   /// If the concrete type didn't implement a custom verifier hook, just fall
   /// back to this one which accepts everything.
-  LogicalResult verifyInvariants() { return success(); }
+  LogicalResult verify() { return success(); }
+  LogicalResult verifyRegions() { return success(); }
 
   /// Parse the custom form of an operation. Unless overridden, this method will
   /// first try to get an operation parser from the op's dialect. Otherwise the
@@ -376,6 +377,18 @@ struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
 };
 } // namespace detail
 
+/// `verifyInvariantsImpl` verifies the invariants like the types, attrs, .etc.
+/// It should be run after core traits and before any other user defined traits.
+/// In order to run it in the correct order, wrap it with OpInvariants trait so
+/// that tblgen will be able to put it in the right order.
+template <typename ConcreteType>
+class OpInvariants : public TraitBase<ConcreteType, OpInvariants> {
+public:
+  static LogicalResult verifyTrait(Operation *op) {
+    return cast<ConcreteType>(op).verifyInvariantsImpl();
+  }
+};
+
 /// This class provides the API for ops that are known to have no
 /// SSA operand.
 template <typename ConcreteType>
@@ -1572,6 +1585,14 @@ using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
 template <typename T>
 using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;
 
+/// Trait to check if T provides a `verifyTrait` method.
+template <typename T, typename... Args>
+using has_verify_region_trait =
+    decltype(T::verifyRegionTrait(std::declval<Operation *>()));
+template <typename T>
+using detect_has_verify_region_trait =
+    llvm::is_detected<has_verify_region_trait, T>;
+
 /// The internal implementation of `verifyTraits` below that returns the result
 /// of verifying the current operation with all of the provided trait types
 /// `Ts`.
@@ -1589,6 +1610,26 @@ template <typename TraitTupleT>
 static LogicalResult verifyTraits(Operation *op) {
   return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
 }
+
+/// The internal implementation of `verifyRegionTraits` below that returns the
+/// result of verifying the current operation with all of the provided trait
+/// types `Ts`.
+template <typename... Ts>
+static LogicalResult verifyRegionTraitsImpl(Operation *op,
+                                            std::tuple<Ts...> *) {
+  LogicalResult result = success();
+  (void)std::initializer_list<int>{
+      (result = succeeded(result) ? Ts::verifyRegionTrait(op) : failure(),
+       0)...};
+  return result;
+}
+
+/// Given a tuple type containing a set of traits that contain a
+/// `verifyTrait` method, return the result of verifying the given operation.
+template <typename TraitTupleT>
+static LogicalResult verifyRegionTraits(Operation *op) {
+  return verifyRegionTraitsImpl(op, (TraitTupleT *)nullptr);
+}
 } // namespace op_definition_impl
 
 //===----------------------------------------------------------------------===//
@@ -1603,7 +1644,8 @@ class Op : public OpState, public Traits<ConcreteType>... {
 public:
   /// Inherit getOperation from `OpState`.
   using OpState::getOperation;
-  using OpState::verifyInvariants;
+  using OpState::verify;
+  using OpState::verifyRegions;
 
   /// Return if this operation contains the provided trait.
   template <template <typename T> class Trait>
@@ -1704,6 +1746,10 @@ private:
   using VerifiableTraitsTupleT =
       typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
                                    Traits<ConcreteType>...>::type;
+  /// A tuple type containing the region traits that have a verify function.
+  using VerifiableRegionTraitsTupleT = typename detail::FilterTypes<
+      op_definition_impl::detect_has_verify_region_trait,
+      Traits<ConcreteType>...>::type;
 
   /// Returns an interface map containing the interfaces registered to this
   /// operation.
@@ -1839,11 +1885,22 @@ private:
                   "Op class shouldn't define new data members");
     return failure(
         failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
-        failed(cast<ConcreteType>(op).verifyInvariants()));
+        failed(cast<ConcreteType>(op).verify()));
   }
   static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
     return static_cast<LogicalResult (*)(Operation *)>(&verifyInvariants);
   }
+  /// Implementation of `VerifyRegionInvariantsFn` OperationName hook.
+  static LogicalResult verifyRegionInvariants(Operation *op) {
+    static_assert(hasNoDataMembers(),
+                  "Op class shouldn't define new data members");
+    return failure(failed(op_definition_impl::verifyRegionTraits<
+                          VerifiableRegionTraitsTupleT>(op)) ||
+                   failed(cast<ConcreteType>(op).verifyRegions()));
+  }
+  static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() {
+    return static_cast<LogicalResult (*)(Operation *)>(&verifyRegionInvariants);
+  }
 
   static constexpr bool hasNoDataMembers() {
     // Checking that the derived class does not define any member by comparing
index 5ae0728..f72c244 100644 (file)
@@ -73,6 +73,8 @@ public:
       llvm::unique_function<void(Operation *, OpAsmPrinter &, StringRef) const>;
   using VerifyInvariantsFn =
       llvm::unique_function<LogicalResult(Operation *) const>;
+  using VerifyRegionInvariantsFn =
+      llvm::unique_function<LogicalResult(Operation *) const>;
 
 protected:
   /// This class represents a type erased version of an operation. It contains
@@ -112,6 +114,7 @@ protected:
     ParseAssemblyFn parseAssemblyFn;
     PrintAssemblyFn printAssemblyFn;
     VerifyInvariantsFn verifyInvariantsFn;
+    VerifyRegionInvariantsFn verifyRegionInvariantsFn;
 
     /// A list of attribute names registered to this operation in StringAttr
     /// form. This allows for operation classes to use StringAttr for attribute
@@ -238,16 +241,18 @@ public:
   static void insert(Dialect &dialect) {
     insert(T::getOperationName(), dialect, TypeID::get<T>(),
            T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
-           T::getVerifyInvariantsFn(), T::getFoldHookFn(),
-           T::getGetCanonicalizationPatternsFn(), T::getInterfaceMap(),
-           T::getHasTraitFn(), T::getAttributeNames());
+           T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(),
+           T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(),
+           T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames());
   }
   /// The use of this method is in general discouraged in favor of
   /// 'insert<CustomOp>(dialect)'.
   static void
   insert(StringRef name, Dialect &dialect, TypeID typeID,
          ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
-         VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
+         VerifyInvariantsFn &&verifyInvariants,
+         VerifyRegionInvariantsFn &&verifyRegionInvariants,
+         FoldHookFn &&foldHook,
          GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
          detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
          ArrayRef<StringRef> attrNames);
@@ -272,12 +277,15 @@ public:
     return impl->printAssemblyFn(op, p, defaultDialect);
   }
 
-  /// This hook implements the verifier for this operation.  It should emits an
-  /// error message and returns failure if a problem is detected, or returns
+  /// These hooks implement the verifiers for this operation.  It should emits
+  /// an error message and returns failure if a problem is detected, or returns
   /// success if everything is ok.
   LogicalResult verifyInvariants(Operation *op) const {
     return impl->verifyInvariantsFn(op);
   }
+  LogicalResult verifyRegionInvariants(Operation *op) const {
+    return impl->verifyRegionInvariantsFn(op);
+  }
 
   /// This hook implements a generalized folder for this operation.  Operations
   /// can implement this to provide simplifications rules that are applied by
index 74a15b8..9d50b26 100644 (file)
@@ -98,6 +98,10 @@ public:
   // Return the verify method body if it has one.
   llvm::Optional<StringRef> getVerify() const;
 
+  // If there's a verify method, return if it needs to access the ops in the
+  // regions.
+  bool verifyWithRegions() const;
+
   // Returns the Tablegen definition this interface was constructed from.
   const llvm::Record &getDef() const { return *def; }
 
index c3d0d2a..8da5303 100644 (file)
@@ -65,6 +65,9 @@ public:
   // Returns the trait corresponding to a C++ trait class.
   std::string getFullyQualifiedTraitName() const;
 
+  // Returns if this is a structural op trait.
+  bool isStructuralOpTrait() const;
+
   static bool classof(const Trait *t) { return t->getKind() == Kind::Native; }
 };
 
index 2fa2623..90aa7ba 100644 (file)
@@ -1117,7 +1117,7 @@ ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
   return success();
 }
 
-LogicalResult AffineDmaStartOp::verifyInvariants() {
+LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
   if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
     return emitOpError("expected DMA source to be of memref type");
   if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
@@ -1219,7 +1219,7 @@ ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
   return success();
 }
 
-LogicalResult AffineDmaWaitOp::verifyInvariants() {
+LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
   if (!getOperand(0).getType().isa<MemRefType>())
     return emitOpError("expected DMA tag to be of memref type");
   Region *scope = getAffineScope(*this);
index 6c4dcac..1b6d354 100644 (file)
@@ -693,7 +693,8 @@ RegisteredOperationName::parseAssembly(OpAsmParser &parser,
 void RegisteredOperationName::insert(
     StringRef name, Dialect &dialect, TypeID typeID,
     ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
-    VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
+    VerifyInvariantsFn &&verifyInvariants,
+    VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook,
     GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
     detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
     ArrayRef<StringRef> attrNames) {
@@ -749,6 +750,7 @@ void RegisteredOperationName::insert(
   impl.parseAssemblyFn = std::move(parseAssembly);
   impl.printAssemblyFn = std::move(printAssembly);
   impl.verifyInvariantsFn = std::move(verifyInvariants);
+  impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants);
   impl.attributeNames = cachedAttrNames;
 }
 
index bbc560d..0c8724d 100644 (file)
@@ -217,6 +217,11 @@ LogicalResult OperationVerifier::verifyOperation(
     }
   }
 
+  // After the region ops are verified, run the verifiers that have additional
+  // region invariants need to veirfy.
+  if (registeredInfo && failed(registeredInfo->verifyRegionInvariants(&op)))
+    return failure();
+
   // If this is a registered operation, there is nothing left to do.
   if (registeredInfo)
     return success();
index d26ca0b..4d72cee 100644 (file)
@@ -125,6 +125,10 @@ llvm::Optional<StringRef> Interface::getVerify() const {
   return value.empty() ? llvm::Optional<StringRef>() : value;
 }
 
+bool Interface::verifyWithRegions() const {
+  return def->getValueAsBit("verifyWithRegions");
+}
+
 //===----------------------------------------------------------------------===//
 // AttrInterface
 //===----------------------------------------------------------------------===//
index 4e28e99..ee4b999 100644 (file)
@@ -50,6 +50,10 @@ std::string NativeTrait::getFullyQualifiedTraitName() const {
                               : (cppNamespace + "::" + trait).str();
 }
 
+bool NativeTrait::isStructuralOpTrait() const {
+  return def->isSubClassOf("StructuralOpTrait");
+}
+
 //===----------------------------------------------------------------------===//
 // InternalTrait
 //===----------------------------------------------------------------------===//
index c0e6ebd..a4f1353 100644 (file)
@@ -168,7 +168,7 @@ func @func_with_ops(i32, i32) {
 func @func_with_ops() {
 ^bb0:
   %c = arith.constant dense<0> : vector<42 x i32>
-  // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}}
+  // expected-error@+1 {{op failed to verify that result type has i1 element type and same shape as operands}}
   %r = "arith.cmpi"(%c, %c) {predicate = 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<41 x i1>
 }
 
@@ -249,7 +249,7 @@ func @cmpf_canonical_wrong_result_type(%a : f32, %b : f32) -> f32 {
 // -----
 
 func @cmpf_result_shape_mismatch(%a : vector<42xf32>) {
-  // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}}
+  // expected-error@+1 {{op failed to verify that result type has i1 element type and same shape as operands}}
   %r = "arith.cmpf"(%a, %a) {predicate = 0} : (vector<42 x f32>, vector<42 x f32>) -> vector<41 x i1>
 }
 
@@ -285,7 +285,7 @@ func @index_cast_index_to_index(%arg0: index) {
 // -----
 
 func @index_cast_float(%arg0: index, %arg1: f32) {
-  // expected-error@+1 {{are cast incompatible}}
+  // expected-error@+1 {{op result #0 must be signless-integer-like or memref of signless-integer, but got 'f32'}}
   %0 = arith.index_cast %arg0 : index to f32
   return
 }
@@ -293,7 +293,7 @@ func @index_cast_float(%arg0: index, %arg1: f32) {
 // -----
 
 func @index_cast_float_to_index(%arg0: f32) {
-  // expected-error@+1 {{are cast incompatible}}
+  // expected-error@+1 {{op operand #0 must be signless-integer-like or memref of signless-integer, but got 'f32'}}
   %0 = arith.index_cast %arg0 : f32 to index
   return
 }
@@ -301,7 +301,7 @@ func @index_cast_float_to_index(%arg0: f32) {
 // -----
 
 func @sitofp_i32_to_i64(%arg0 : i32) {
-  // expected-error@+1 {{are cast incompatible}}
+  // expected-error@+1 {{op result #0 must be floating-point-like, but got 'i64'}}
   %0 = arith.sitofp %arg0 : i32 to i64
   return
 }
@@ -309,7 +309,7 @@ func @sitofp_i32_to_i64(%arg0 : i32) {
 // -----
 
 func @sitofp_f32_to_i32(%arg0 : f32) {
-  // expected-error@+1 {{are cast incompatible}}
+  // expected-error@+1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'f32'}}
   %0 = arith.sitofp %arg0 : f32 to i32
   return
 }
@@ -333,7 +333,7 @@ func @fpext_f16_to_f16(%arg0 : f16) {
 // -----
 
 func @fpext_i32_to_f32(%arg0 : i32) {
-  // expected-error@+1 {{are cast incompatible}}
+  // expected-error@+1 {{op operand #0 must be floating-point-like, but got 'i32'}}
   %0 = arith.extf %arg0 : i32 to f32
   return
 }
@@ -341,7 +341,7 @@ func @fpext_i32_to_f32(%arg0 : i32) {
 // -----
 
 func @fpext_f32_to_i32(%arg0 : f32) {
-  // expected-error@+1 {{are cast incompatible}}
+  // expected-error@+1 {{op result #0 must be floating-point-like, but got 'i32'}}
   %0 = arith.extf %arg0 : f32 to i32
   return
 }
@@ -373,7 +373,7 @@ func @fpext_vec_f16_to_f16(%arg0 : vector<2xf16>) {
 // -----
 
 func @fpext_vec_i32_to_f32(%arg0 : vector<2xi32>) {
-  // expected-error@+1 {{are cast incompatible}}
+  // expected-error@+1 {{op operand #0 must be floating-point-like, but got 'vector<2xi32>'}}
   %0 = arith.extf %arg0 : vector<2xi32> to vector<2xf32>
   return
 }
@@ -381,7 +381,7 @@ func @fpext_vec_i32_to_f32(%arg0 : vector<2xi32>) {
 // -----
 
 func @fpext_vec_f32_to_i32(%arg0 : vector<2xf32>) {
-  // expected-error@+1 {{are cast incompatible}}
+  // expected-error@+1 {{op result #0 must be floating-point-like, but got 'vector<2xi32>'}}
   %0 = arith.extf %arg0 : vector<2xf32> to vector<2xi32>
   return
 }
@@ -405,7 +405,7 @@ func @fptrunc_f32_to_f32(%arg0 : f32) {
 // -----
 
 func @fptrunc_i32_to_f32(%arg0 : i32) {
-  // expected-error@+1 {{are cast incompatible}}
+  // expected-error@+1 {{op operand #0 must be floating-point-like, but got 'i32'}}
   %0 = arith.truncf %arg0 : i32 to f32
   return
 }
@@ -413,7 +413,7 @@ func @fptrunc_i32_to_f32(%arg0 : i32) {
 // -----
 
 func @fptrunc_f32_to_i32(%arg0 : f32) {
-  // expected-error@+1 {{are cast incompatible}}
+  // expected-error@+1 {{op result #0 must be floating-point-like, but got 'i32'}}
   %0 = arith.truncf %arg0 : f32 to i32
   return
 }
@@ -445,7 +445,7 @@ func @fptrunc_vec_f32_to_f32(%arg0 : vector<2xf32>) {
 // -----
 
 func @fptrunc_vec_i32_to_f32(%arg0 : vector<2xi32>) {
-  // expected-error@+1 {{are cast incompatible}}
+  // expected-error@+1 {{op operand #0 must be floating-point-like, but got 'vector<2xi32>'}}
   %0 = arith.truncf %arg0 : vector<2xi32> to vector<2xf32>
   return
 }
@@ -453,7 +453,7 @@ func @fptrunc_vec_i32_to_f32(%arg0 : vector<2xi32>) {
 // -----
 
 func @fptrunc_vec_f32_to_i32(%arg0 : vector<2xf32>) {
-  // expected-error@+1 {{are cast incompatible}}
+  // expected-error@+1 {{op result #0 must be floating-point-like, but got 'vector<2xi32>'}}
   %0 = arith.truncf %arg0 : vector<2xf32> to vector<2xi32>
   return
 }
@@ -461,7 +461,7 @@ func @fptrunc_vec_f32_to_i32(%arg0 : vector<2xf32>) {
 // -----
 
 func @sexti_index_as_operand(%arg0 : index) {
-  // expected-error@+1 {{are cast incompatible}}
+  // expected-error@+1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'index'}}
   %0 = arith.extsi %arg0 : index to i128
   return
 }
@@ -469,7 +469,7 @@ func @sexti_index_as_operand(%arg0 : index) {
 // -----
 
 func @zexti_index_as_operand(%arg0 : index) {
-  // expected-error@+1 {{operand type 'index' and result type}}
+  // expected-error@+1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'index'}}
   %0 = arith.extui %arg0 : index to i128
   return
 }
@@ -477,7 +477,7 @@ func @zexti_index_as_operand(%arg0 : index) {
 // -----
 
 func @trunci_index_as_operand(%arg0 : index) {
-  // expected-error@+1 {{operand type 'index' and result type}}
+  // expected-error@+1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'index'}}
   %2 = arith.trunci %arg0 : index to i128
   return
 }
@@ -485,7 +485,7 @@ func @trunci_index_as_operand(%arg0 : index) {
 // -----
 
 func @sexti_index_as_result(%arg0 : i1) {
-  // expected-error@+1 {{result type 'index' are cast incompatible}}
+  // expected-error@+1 {{op result #0 must be signless-fixed-width-integer-like, but got 'index'}}
   %0 = arith.extsi %arg0 : i1 to index
   return
 }
@@ -493,7 +493,7 @@ func @sexti_index_as_result(%arg0 : i1) {
 // -----
 
 func @zexti_index_as_operand(%arg0 : i1) {
-  // expected-error@+1 {{result type 'index' are cast incompatible}}
+  // expected-error@+1 {{op result #0 must be signless-fixed-width-integer-like, but got 'index'}}
   %0 = arith.extui %arg0 : i1 to index
   return
 }
@@ -501,7 +501,7 @@ func @zexti_index_as_operand(%arg0 : i1) {
 // -----
 
 func @trunci_index_as_result(%arg0 : i128) {
-  // expected-error@+1 {{result type 'index' are cast incompatible}}
+  // expected-error@+1 {{op result #0 must be signless-fixed-width-integer-like, but got 'index'}}
   %2 = arith.trunci %arg0 : i128 to index
   return
 }
index 006e118..cdaef01 100644 (file)
@@ -301,7 +301,7 @@ func @reduce_incorrect_yield(%arg0 : f32) {
 // -----
 
 func @shuffle_mismatching_type(%arg0 : f32, %arg1 : i32, %arg2 : i32) {
-  // expected-error@+1 {{inferred type(s) 'f32', 'i1' are incompatible with return type(s) of operation 'i32', 'i1'}}
+  // expected-error@+1 {{op failed to verify that all of {value, result} have same type}}
   %shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = #gpu<"shuffle_mode xor"> } : (f32, i32, i32) -> (i32, i1)
   return
 }
index b831f3e..84f1d9e 100644 (file)
@@ -80,8 +80,8 @@ llvm.mlir.global internal constant @sectionvar("teststring")  {section = ".mysec
 
 // -----
 
-// expected-error @+1 {{requires string attribute 'sym_name'}}
-"llvm.mlir.global"() ({}) {type = i64, constant, value = 42 : i64} : () -> ()
+// expected-error @+1 {{op requires attribute 'sym_name'}}
+"llvm.mlir.global"() ({}) {type = i64, constant, global_type = i64, value = 42 : i64} : () -> ()
 
 // -----
 
index 40defe4..c45c58a 100644 (file)
@@ -214,15 +214,15 @@ func @generic_shaped_operand_block_arg_type(%arg0: memref<f32>) {
 
 // -----
 
-func @generic_scalar_operand_block_arg_type(%arg0: f32) {
+func @generic_scalar_operand_block_arg_type(%arg0: tensor<f32>) {
   // expected-error @+1 {{expected type of bb argument #0 ('i1') to match element or self type of the corresponding operand ('f32')}}
   linalg.generic {
     indexing_maps =  [ affine_map<() -> ()> ],
     iterator_types = []}
-      outs(%arg0 : f32) {
+      outs(%arg0 : tensor<f32>) {
     ^bb(%i: i1):
     linalg.yield %i : i1
-  }
+  } -> tensor<f32>
 }
 
 // -----
@@ -243,7 +243,7 @@ func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(o
 
 func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>,
                                  %arg1: tensor<?xf32>) {
-  // expected-error @+1 {{expected type of operand #1 ('tensor<?xf32>') to match type of corresponding result ('f32')}}
+  // expected-error @+1 {{expected type of operand #1 ('tensor<?xf32>') to match type of corresponding result ('tensor<f32>')}}
   %0 = linalg.generic {
     indexing_maps = [ affine_map<(i) -> (i)> , affine_map<(i) -> (i)> ],
     iterator_types = ["parallel"]}
@@ -251,7 +251,7 @@ func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off
       outs(%arg1 : tensor<?xf32>) {
     ^bb(%i: f32, %j: f32):
       linalg.yield %i: f32
-  } -> f32
+  } -> tensor<f32>
 }
 
 // -----
@@ -362,11 +362,11 @@ func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f32)
 
 // -----
 
-func @illegal_fill_memref_with_return(%arg0 : memref<?x?xf32>, %arg1 : f32) -> memref<?x?xf32>
+func @illegal_fill_memref_with_return(%arg0 : memref<?x?xf32>, %arg1 : f32) -> tensor<?x?xf32>
 {
-  // expected-error @+1 {{expected the number of results (1) to be equal to the number of output tensors (0)}}
-  %0 = linalg.fill(%arg1, %arg0) : f32, memref<?x?xf32> -> memref<?x?xf32>
-  return %0 : memref<?x?xf32>
+  // expected-error @+1 {{op expected the number of results (1) to be equal to the number of output tensors (0)}}
+  %0 = linalg.fill(%arg1, %arg0) : f32, memref<?x?xf32> -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
 }
 
 // -----
@@ -384,7 +384,7 @@ func @illegal_fill_memref_with_tensor_return
 func @illegal_fill_tensor_with_memref_return
   (%arg0 : tensor<?x?xf32>, %arg1 : f32) -> memref<?x?xf32>
 {
-  // expected-error @+1 {{expected type of operand #1 ('tensor<?x?xf32>') to match type of corresponding result ('memref<?x?xf32>')}}
+  // expected-error @+1 {{op result #0 must be ranked tensor of any type values, but got 'memref<?x?xf32>'}}
   %0 = linalg.fill(%arg1, %arg0) : f32, tensor<?x?xf32> -> memref<?x?xf32>
   return %0 : memref<?x?xf32>
 }
@@ -477,7 +477,7 @@ func @tiled_loop_incorrent_iterator_types_count(%A: memref<192x192xf32>,
   %c0 = arith.constant 0 : index
   %c192 = arith.constant 192 : index
   // expected-error @+1 {{expected iterator types array attribute size = 1 to match the number of loops = 2}}
-  %0 = "linalg.tiled_loop"(%c0, %c0, %c192, %c192, %c24, %c24, %A, %B, %C_tensor, %C) ({
+  %0 = "linalg.tiled_loop"(%c0, %c0, %c192, %c192, %c24, %c24, %A, %B, %C_tensor, %C) ( {
     ^bb0(%arg4: index, %arg5: index, %A_: memref<192x192xf32>,
          %B_: memref<192x192xf32>, %CT_: tensor<192x192xf32>,
          %C_: memref<192x192xf32>):
@@ -502,7 +502,7 @@ func @tiled_loop_incorrent_block_arg_type(%A: memref<192xf32>) {
   %c192 = arith.constant 192 : index
   %c24 = arith.constant 24 : index
   // expected-error @+1 {{expected output arg 0 with type = 'memref<192xf32>' to match region arg 1 type = 'memref<100xf32>'}}
-  "linalg.tiled_loop"(%c0, %c192, %c24, %A) ({
+  "linalg.tiled_loop"(%c0, %c192, %c24, %A) ( {
     ^bb0(%arg4: index, %A_: memref<100xf32>):
       call @foo(%A_) : (memref<100xf32>)-> ()
       linalg.yield
index 8de70c2..a06ca5f 100644 (file)
@@ -111,7 +111,7 @@ func @depthwise_conv_2d_input_nhwc_filter_default_attributes(%input: memref<1x11
 // -----
 
 func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
-  // expected-error @+1 {{incorrect element type for index attribute 'strides'}}
+  // expected-error @+1 {{op attribute 'strides' failed to satisfy constraint: 64-bit signless int elements attribute of shape [2]}}
   linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>}
     ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
     outs(%output: memref<1x56x56x96xf32>)
@@ -121,7 +121,7 @@ func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memr
 // -----
 
 func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
-  // expected-error @+1 {{incorrect shape for index attribute 'strides'}}
+  // expected-error @+1 {{op attribute 'strides' failed to satisfy constraint: 64-bit signless int elements attribute of shape [2]}}
   linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<3xi64> }
     ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
     outs(%output: memref<1x56x56x96xf32>)
index cd775b4..607d4b1 100644 (file)
@@ -59,7 +59,7 @@ func @bit_field_u_extract_vec(%base: vector<3xi32>, %offset: i8, %count: i8) ->
 // -----
 
 func @bit_field_u_extract_invalid_result_type(%base: vector<3xi32>, %offset: i32, %count: i16) -> vector<4xi32> {
-  // expected-error @+1 {{inferred type(s) 'vector<3xi32>' are incompatible with return type(s) of operation 'vector<4xi32>'}}
+  // expected-error @+1 {{failed to verify that all of {base, result} have same type}}
   %0 = "spv.BitFieldUExtract" (%base, %offset, %count) : (vector<3xi32>, i32, i16) -> vector<4xi32>
   spv.ReturnValue %0 : vector<4xi32>
 }
@@ -181,7 +181,7 @@ func @shift_left_logical(%arg0: i32, %arg1 : i16) -> i32 {
 // -----
 
 func @shift_left_logical_invalid_result_type(%arg0: i32, %arg1 : i16) -> i16 {
-  // expected-error @+1 {{op inferred type(s) 'i32' are incompatible with return type(s) of operation 'i16'}}
+  // expected-error @+1 {{op failed to verify that all of {operand1, result} have same type}}
   %0 = "spv.ShiftLeftLogical" (%arg0, %arg1) : (i32, i16) -> (i16)
   spv.ReturnValue %0 : i16
 }
index cdbe57a..648786e 100644 (file)
@@ -98,8 +98,8 @@ func @shape_of(%value_arg : !shape.value_shape,
 // -----
 
 func @shape_of_incompatible_return_types(%value_arg : tensor<1x2xindex>) {
-  // expected-error@+1 {{'shape.shape_of' op inferred type(s) 'tensor<2xindex>' are incompatible with return type(s) of operation 'tensor<3xf32>'}}
-  %0 = shape.shape_of %value_arg : tensor<1x2xindex> -> tensor<3xf32>
+  // expected-error@+1 {{'shape.shape_of' op inferred type(s) 'tensor<2xindex>' are incompatible with return type(s) of operation 'tensor<3xindex>'}}
+  %0 = shape.shape_of %value_arg : tensor<1x2xindex> -> tensor<3xindex>
   return
 }
 
index daf09eb..74bbcf7 100644 (file)
@@ -58,7 +58,7 @@ func @broadcast_tensor_tensor_tensor(tensor<8x1x?x1xi32>, tensor<7x1x5xi32>) ->
 // Check incompatible vector and tensor result type
 func @broadcast_scalar_vector_vector(tensor<4xf32>, tensor<4xf32>) -> vector<4xf32> {
 ^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
-  // expected-error @+1 {{cannot broadcast vector with tensor}}
+  // expected-error @+1 {{op result #0 must be tensor of any type values, but got 'vector<4xf32>'}}
   %0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> vector<4xf32>
   return %0 : vector<4xf32>
 }
index 8caace9..8db29bf 100644 (file)
@@ -3,7 +3,7 @@
 // -----
 
 func @module_op() {
-  // expected-error@+1 {{Operations with a 'SymbolTable' must have exactly one block}}
+  // expected-error@+1 {{'builtin.module' op expects region #0 to have 0 or 1 blocks}}
   builtin.module {
   ^bb1:
     "test.dummy"() : () -> ()
index e6283b5..7432069 100644 (file)
@@ -332,12 +332,12 @@ func @failedSingleBlockImplicitTerminator_missing_terminator() {
 
 // Test the invariants of operations with the Symbol Trait.
 
-// expected-error@+1 {{requires string attribute 'sym_name'}}
+// expected-error@+1 {{op requires attribute 'sym_name'}}
 "test.symbol"() {} : () -> ()
 
 // -----
 
-// expected-error@+1 {{requires visibility attribute 'sym_visibility' to be a string attribute}}
+// expected-error@+1 {{op attribute 'sym_visibility' failed to satisfy constraint: string attribute}}
 "test.symbol"() {sym_name = "foo_2", sym_visibility} : () -> ()
 
 // -----
@@ -364,7 +364,7 @@ func private @foo()
 // -----
 
 // Test that operation with the SymbolTable Trait fails with  too many blocks.
-// expected-error@+1 {{Operations with a 'SymbolTable' must have exactly one block}}
+// expected-error@+1 {{op expects region #0 to have 0 or 1 blocks}}
 "test.symbol_scope"() ({
   ^entry:
     "test.finish" () : () -> ()
@@ -668,4 +668,4 @@ func @failed_attr_traits() {
   // expected-error@+1 {{'attr' attribute should have trait 'TestAttrTrait'}}
   "test.attr_with_trait"() {attr = 42 : i32} : () -> ()
   return
-}
\ No newline at end of file
+}
index 2c97422..047a237 100644 (file)
@@ -68,7 +68,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
 // CHECK:   ::mlir::ValueRange odsOperands;
 // CHECK: };
 
-// CHECK: class AOp : public ::mlir::Op<AOp, ::mlir::OpTrait::AtLeastNRegions<1>::Impl, ::mlir::OpTrait::AtLeastNResults<1>::Impl, ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::AtLeastNOperands<1>::Impl, ::mlir::OpTrait::IsIsolatedFromAbove
+// CHECK: class AOp : public ::mlir::Op<AOp, ::mlir::OpTrait::AtLeastNRegions<1>::Impl, ::mlir::OpTrait::AtLeastNResults<1>::Impl, ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::AtLeastNOperands<1>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::OpTrait::IsIsolatedFromAbove
 // CHECK-NOT: ::mlir::OpTrait::IsIsolatedFromAbove
 // CHECK: public:
 // CHECK:   using Op::Op;
index 64392c3..51403fb 100644 (file)
@@ -42,6 +42,19 @@ def TestOpInterface : OpInterface<"TestOpInterface"> {
   ];
 }
 
+def TestOpInterfaceVerify : OpInterface<"TestOpInterfaceVerify"> {
+  let verify = [{
+    return foo();
+  }];
+}
+
+def TestOpInterfaceVerifyRegion : OpInterface<"TestOpInterfaceVerifyRegion"> {
+  let verify = [{
+    return foo();
+  }];
+  let verifyWithRegions = 1;
+}
+
 // Define Ops with TestOpInterface and
 // DeclareOpInterfaceMethods<TestOpInterface> traits to check that there
 // are not duplicated C++ classes generated.
@@ -65,6 +78,12 @@ def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
 // DECL: template<typename ConcreteOp>
 // DECL: int detail::TestOpInterfaceInterfaceTraits::Model<ConcreteOp>::foo
 
+// DECL-LABEL: struct TestOpInterfaceVerifyTrait
+// DECL: verifyTrait
+
+// DECL-LABEL: struct TestOpInterfaceVerifyRegionTrait
+// DECL: verifyRegionTrait
+
 // OP_DECL-LABEL: class DeclareMethodsOp : public
 // OP_DECL: int foo(int input);
 // OP_DECL-NOT: int default_foo(int input);
index 33ad65e..a8dbf60 100644 (file)
@@ -58,7 +58,7 @@ func @complex_f64_tensor_success() {
 // -----
 
 func @complex_f64_failure() {
-  // expected-error@+1 {{op inferred type(s) 'complex<f64>' are incompatible with return type(s) of operation 'f64'}}
+  // expected-error@+1 {{op result #0 must be complex type with 64-bit float elements, but got 'f64'}}
   "test.complex_f64"() : () -> (f64)
   return
 }
@@ -438,7 +438,7 @@ func @operand_rank_equals_result_size_failure(%arg : tensor<1x2x3x4xi32>) {
 // -----
 
 func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
-  // expected-error@+1 {{op inferred type(s) 'tensor<*xi32>' are incompatible with return type(s) of operation 'tensor<*xf32>'}}
+  // expected-error@+1 {{op failed to verify that all of {x, res} have same type}}
   "test.operand0_and_result_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x f32>) -> tensor<* x f32>
   return
 }
@@ -446,7 +446,7 @@ func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>
 // -----
 
 func @same_types_shape_mismatch(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) {
-  // expected-error@+1 {{op inferred type(s) 'tensor<1x2xi32>' are incompatible with return type(s) of operation 'tensor<2x1xi32>'}}
+  // expected-error@+1 {{op failed to verify that all of {x, res} have same type}}
   "test.operand0_and_result_have_same_type"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x1xi32>
   return
 }
index 6a45e1f..2c56c92 100644 (file)
@@ -394,6 +394,9 @@ private:
   // Generates verify method for the operation.
   void genVerifier();
 
+  // Generates custom verify methods for the operation.
+  void genCustomVerifier();
+
   // Generates verify statements for operands and results in the operation.
   // The generated code will be attached to `body`.
   void genOperandResultVerifier(MethodBody &body,
@@ -593,6 +596,7 @@ OpEmitter::OpEmitter(const Operator &op,
   genParser();
   genPrinter();
   genVerifier();
+  genCustomVerifier();
   genCanonicalizerDecls();
   genFolderDecls();
   genTypeInterfaceMethods();
@@ -2236,47 +2240,76 @@ static void genNativeTraitAttrVerifier(MethodBody &body,
 }
 
 void OpEmitter::genVerifier() {
-  auto *method = opClass.addMethod("::mlir::LogicalResult", "verifyInvariants");
-  ERROR_IF_PRUNED(method, "verifyInvariants", op);
-  auto &body = method->body();
+  auto *implMethod =
+      opClass.addMethod("::mlir::LogicalResult", "verifyInvariantsImpl");
+  ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op);
+  auto &implBody = implMethod->body();
 
   OpOrAdaptorHelper emitHelper(op, /*isOp=*/true);
-  genNativeTraitAttrVerifier(body, emitHelper);
+  genNativeTraitAttrVerifier(implBody, emitHelper);
 
-  auto *valueInit = def.getValueInit("verifier");
-  StringInit *stringInit = dyn_cast<StringInit>(valueInit);
-  bool hasCustomVerifyCodeBlock = stringInit && !stringInit->getValue().empty();
   populateSubstitutions(emitHelper, verifyCtx);
 
-  genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
-  genOperandResultVerifier(body, op.getOperands(), "operand");
-  genOperandResultVerifier(body, op.getResults(), "result");
+  genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter);
+  genOperandResultVerifier(implBody, op.getOperands(), "operand");
+  genOperandResultVerifier(implBody, op.getResults(), "result");
 
   for (auto &trait : op.getTraits()) {
     if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) {
-      body << tgfmt("  if (!($0))\n    "
-                    "return emitOpError(\"failed to verify that $1\");\n",
-                    &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
-                    t->getSummary());
+      implBody << tgfmt("  if (!($0))\n    "
+                        "return emitOpError(\"failed to verify that $1\");\n",
+                        &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
+                        t->getSummary());
     }
   }
 
-  genRegionVerifier(body);
-  genSuccessorVerifier(body);
+  genRegionVerifier(implBody);
+  genSuccessorVerifier(implBody);
+
+  implBody << "  return ::mlir::success();\n";
+
+  // TODO: Some places use the `verifyInvariants` to do operation verification.
+  // This may not act as their expectation because this doesn't call any
+  // verifiers of native/interface traits. Needs to review those use cases and
+  // see if we should use the mlir::verify() instead.
+  auto *valueInit = def.getValueInit("verifier");
+  StringInit *stringInit = dyn_cast<StringInit>(valueInit);
+  bool hasCustomVerifyCodeBlock = stringInit && !stringInit->getValue().empty();
+
+  auto *method = opClass.addMethod("::mlir::LogicalResult", "verifyInvariants");
+  ERROR_IF_PRUNED(method, "verifyInvariants", op);
+  auto &body = method->body();
+  if (hasCustomVerifyCodeBlock || def.getValueAsBit("hasVerifier")) {
+    body << "  if(::mlir::succeeded(verifyInvariantsImpl()) && "
+            "::mlir::succeeded(verify()))\n";
+    body << "    return ::mlir::success();\n";
+    body << "  return ::mlir::failure();";
+  } else {
+    body << "  return verifyInvariantsImpl();";
+  }
+}
+
+void OpEmitter::genCustomVerifier() {
+  auto *valueInit = def.getValueInit("verifier");
+  StringInit *stringInit = dyn_cast<StringInit>(valueInit);
+  bool hasCustomVerifyCodeBlock = stringInit && !stringInit->getValue().empty();
 
   if (def.getValueAsBit("hasVerifier")) {
-    auto *method = opClass.declareMethod<Method::Private>(
-        "::mlir::LogicalResult", "verify");
+    auto *method = opClass.declareMethod("::mlir::LogicalResult", "verify");
     ERROR_IF_PRUNED(method, "verify", op);
-    body << "  return verify();\n";
-
+  } else if (def.getValueAsBit("hasRegionVerifier")) {
+    auto *method =
+        opClass.declareMethod("::mlir::LogicalResult", "verifyRegions");
+    ERROR_IF_PRUNED(method, "verifyRegions", op);
   } else if (hasCustomVerifyCodeBlock) {
+    auto *method = opClass.addMethod("::mlir::LogicalResult", "verify");
+    ERROR_IF_PRUNED(method, "verify", op);
+    auto &body = method->body();
+
     FmtContext fctx;
     fctx.addSubst("cppClass", opClass.getClassName());
     auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
     body << "  " << tgfmt(printer, &fctx);
-  } else {
-    body << "  return ::mlir::success();\n";
   }
 }
 
@@ -2508,12 +2541,27 @@ void OpEmitter::genTraits() {
     }
   }
 
+  // The op traits defined internal are ensured that they can be verified
+  // earlier.
+  for (const auto &trait : op.getTraits()) {
+    if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
+      if (opTrait->isStructuralOpTrait())
+        opClass.addTrait(opTrait->getFullyQualifiedTraitName());
+    }
+  }
+
+  // OpInvariants wrapps the verifyInvariants which needs to be run before
+  // native/interface traits and after all the traits with `StructuralOpTrait`.
+  opClass.addTrait("::mlir::OpTrait::OpInvariants");
+
   // Add the native and interface traits.
   for (const auto &trait : op.getTraits()) {
-    if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait))
-      opClass.addTrait(opTrait->getFullyQualifiedTraitName());
-    else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
+    if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
+      if (!opTrait->isStructuralOpTrait())
+        opClass.addTrait(opTrait->getFullyQualifiedTraitName());
+    } else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait)) {
       opClass.addTrait(opTrait->getFullyQualifiedTraitName());
+    }
   }
 }
 
index 0513c5e..76a5164 100644 (file)
@@ -413,9 +413,12 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface,
 
     tblgen::FmtContext verifyCtx;
     verifyCtx.withOp("op");
-    os << "    static ::mlir::LogicalResult verifyTrait(::mlir::Operation *op) "
-          "{\n      "
-       << tblgen::tgfmt(verify->trim(), &verifyCtx) << "\n    }\n";
+    os << llvm::formatv(
+              "    static ::mlir::LogicalResult {0}(::mlir::Operation *op) ",
+              (interface.verifyWithRegions() ? "verifyRegionTrait"
+                                             : "verifyTrait"))
+       << "{\n      " << tblgen::tgfmt(verify->trim(), &verifyCtx)
+       << "\n    }\n";
   }
   if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
     os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";