[mlir] Enable specifying verify on OpInterface
authorJacques Pienaar <jpienaar@google.com>
Tue, 21 Jan 2020 17:40:22 +0000 (09:40 -0800)
committerJacques Pienaar <jpienaar@google.com>
Wed, 22 Jan 2020 12:43:22 +0000 (04:43 -0800)
Summary:
Add method in ODS to specify verification for operations implementing a
OpInterface. Use this with infer type op interface to verify that the
inferred type matches the return type and remove special case in
TestPatterns.

This could also have been achieved by using OpInterfaceMethod but verify
seems pretty common and it is not an arbitrary method that just happened
to be named verifyTrait, so having it be defined in special way seems
appropriate/better documenting.

Differential Revision: https://reviews.llvm.org/D73122

mlir/docs/OpDefinitions.md
mlir/include/mlir/Analysis/InferTypeOpInterface.h
mlir/include/mlir/Analysis/InferTypeOpInterface.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/OpInterfaces.h
mlir/lib/Analysis/InferTypeOpInterface.cpp
mlir/lib/TableGen/OpInterfaces.cpp
mlir/test/lib/TestDialect/TestPatterns.cpp
mlir/test/mlir-tblgen/return-types.mlir
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp

index a035d9b200bab37c3009b8bef747ad8459dcc8ef..1a9a2104bbc009bcba8f2a0e8ced7f14c29404ba 100644 (file)
@@ -400,6 +400,10 @@ def OpWithInferTypeInterfaceOp : Op<...
     [DeclareOpInterfaceMethods<MyInterface>]> { ... }
 ```
 
+A verification method can also be specified on the `OpInterface` by setting
+`verify`. Setting `verify` results in the generated trait having a `verifyTrait`
+method that is applied to all operations implementing the trait.
+
 ### Builder methods
 
 For each operation, there are a few builders automatically generated based on
index 260787147ed5d3790313252861c57009ce00296d..be68e7db8be564772d3249e8aa2fbb83c35918fc 100644 (file)
@@ -72,8 +72,6 @@ private:
   Attribute attr;
 };
 
-#include "mlir/Analysis/InferTypeOpInterface.h.inc"
-
 namespace detail {
 // Helper function to infer return tensor returns types given element and shape
 // inference function.
@@ -89,8 +87,14 @@ LogicalResult inferReturnTensorTypes(
     MLIRContext *context, Optional<Location> location, ValueRange operands,
     ArrayRef<NamedAttribute> attributes, RegionRange regions,
     SmallVectorImpl<Type> &inferedReturnTypes);
+
+/// Verifies that the inferred result types match the actual result types for
+/// the op. Precondition: op implements InferTypeOpInterface.
+LogicalResult verifyInferredResultTypes(Operation *op);
 } // namespace detail
 
+#include "mlir/Analysis/InferTypeOpInterface.h.inc"
+
 namespace OpTrait {
 
 /// Tensor type inference trait that constructs a tensor from the infered
index bc06b29ab81c6189105cff47cecc2bf36f1375ab..58f4508cd03e64eaf4aa0feddc188cda976f26cc 100644 (file)
@@ -60,6 +60,10 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
       }]
     >,
   ];
+
+  let verify = [{
+    return detail::verifyInferredResultTypes($_op);
+  }];
 }
 
 def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
index 775123362bbff46be9c5aabfa0ad6140b659af48..c19c0b4be37c36f9dddac984a38c9dfa652c729a 100644 (file)
@@ -1411,8 +1411,12 @@ def ins;
 // OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in
 // C++. The purpose to wrap around C++ symbol string with this class is to make
 // interfaces specified for ops in TableGen less alien and more integrated.
-class OpInterfaceTrait<string name> : NativeOpTrait<""> {
+class OpInterfaceTrait<string name, code verifyBody = [{}]> : NativeOpTrait<""> {
   let trait = name # "::Trait";
+
+  // Specify the body of the verification function. `$_op` will be replaced with
+  // the operation being verified.
+  code verify = verifyBody;
 }
 
 // This class represents a single, optionally static, interface method.
index 9bf181615648c6fe8db9aeef9abc99cc1e43e36c..1ee5dde6c7e224f59e9b98dce359156845056ad5 100644 (file)
@@ -86,6 +86,9 @@ public:
   // Return the description of this method if it has one.
   llvm::Optional<StringRef> getDescription() const;
 
+  // Return the verify method body if it has one.
+  llvm::Optional<StringRef> getVerify() const;
+
 private:
   // The TableGen definition of this interface.
   const llvm::Record *def;
index b1637b8bb253590fff44be201ca76d8ce42ab70f..fd929e2b9c3c2cbbfa7f69d9a78284f9d578f5f4 100644 (file)
@@ -45,3 +45,17 @@ LogicalResult mlir::detail::inferReturnTensorTypes(
   }
   return success();
 }
+
+LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
+  SmallVector<Type, 4> inferedReturnTypes;
+  auto retTypeFn = cast<InferTypeOpInterface>(op);
+  if (failed(retTypeFn.inferReturnTypes(op->getContext(), op->getLoc(),
+                                        op->getOperands(), op->getAttrs(),
+                                        op->getRegions(), inferedReturnTypes)))
+    return failure();
+  SmallVector<Type, 4> resultTypes(op->getResultTypes());
+  if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes))
+    return op->emitOpError(
+        "inferred type incompatible with return type of operation");
+  return success();
+}
index b1e56efc029133570f598b9f9e7edf13cbb6cec7..5f18c45a87a35caa7098e6afc779b45ea3c22491 100644 (file)
@@ -85,3 +85,9 @@ llvm::Optional<StringRef> OpInterface::getDescription() const {
   auto value = def->getValueAsString("description");
   return value.empty() ? llvm::Optional<StringRef>() : value;
 }
+
+// Return the body for this method if it has one.
+llvm::Optional<StringRef> OpInterface::getVerify() const {
+  auto value = def->getValueAsString("verify");
+  return value.empty() ? llvm::Optional<StringRef>() : value;
+}
index 86abf8c14a8d9439028f6ffd280250482f706575..3bbeefd99270c9084f7dbeeb622e471f0371ab92 100644 (file)
@@ -103,26 +103,6 @@ struct TestReturnTypeDriver : public FunctionPass<TestReturnTypeDriver> {
       };
       return;
     }
-
-    // Verification check.
-    // TODO: Move to ops that implement type infer interface.
-    getFunction().walk([this](Operation *op) -> void {
-      auto retTypeFn = dyn_cast<InferTypeOpInterface>(op);
-      if (!retTypeFn)
-        return;
-      auto *context = &getContext();
-      SmallVector<Type, 2> inferedReturnTypes;
-      if (failed(retTypeFn.inferReturnTypes(
-              context, op->getLoc(), op->getOperands(), op->getAttrs(),
-              op->getRegions(), inferedReturnTypes)))
-        return;
-      SmallVector<Type, 1> resultTypes(op->getResultTypes());
-      if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes)) {
-        op->emitOpError(
-            "inferred type incompatible with return type of operation");
-        return;
-      }
-    });
   }
 };
 } // end anonymous namespace
index 640f06dd6534b8b2baa50a43a6b131df4d095d9b..3fcb22331fa150b981ce5aa40cba06ac78d29f8c 100644 (file)
@@ -23,7 +23,6 @@ func @testCreateFunctions(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) {
 
 // -----
 
-// CHECK-LABEL: testReturnTypeOpInterface
 func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
   // expected-error@+1 {{incompatible with return type}}
   %bad = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
@@ -32,7 +31,6 @@ func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
 
 // -----
 
-// CHECK-LABEL: testReturnTypeOpInterface
 func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) {
   // expected-error@+1 {{operand type mismatch}}
   %bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32>
index a96736cd2c5b7306cc9e704f45a00dffa06ff68d..b0bc14c2fc53c9a4a86482fba0844c6834a001b2 100644 (file)
@@ -12,6 +12,7 @@
 
 #include "DocGenUtilities.h"
 #include "mlir/Support/STLExtras.h"
+#include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
 #include "mlir/TableGen/OpInterfaces.h"
 #include "llvm/ADT/SmallVector.h"
@@ -152,6 +153,12 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
 
   // Insert the default implementation for any methods.
   for (auto &method : interface.getMethods()) {
+    // Flag interface methods named verifyTrait.
+    if (method.getName() == "verifyTrait")
+      PrintFatalError(
+          formatv("'verifyTrait' method cannot be specified as interface "
+                  "method for '{0}'; set 'verify' on OpInterfaceTrait instead",
+                  interfaceName));
     auto defaultImpl = method.getDefaultImplementation();
     if (!defaultImpl)
       continue;
@@ -162,6 +169,13 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
     os << " {\n" << defaultImpl.getValue() << "  }\n";
   }
 
+  tblgen::FmtContext traitCtx;
+  traitCtx.withOp("op");
+  if (auto verify = interface.getVerify()) {
+    os << "  static LogicalResult verifyTrait(Operation* op) {\n"
+       << tblgen::tgfmt(*verify, &traitCtx) << "\n  }\n";
+  }
+
   os << "  };\n";
 }