[DeclareOpInterfaceMethods<MyInterface>]> { ... }
```
+A verification method can also be specified on the `OpInterface` by setting
+`verify`. Setting `verify` results in the generated trait having a `verifyTrait`
+method that is applied to all operations implementing the trait.
+
### Builder methods
For each operation, there are a few builders automatically generated based on
Attribute attr;
};
-#include "mlir/Analysis/InferTypeOpInterface.h.inc"
-
namespace detail {
// Helper function to infer return tensor returns types given element and shape
// inference function.
MLIRContext *context, Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
SmallVectorImpl<Type> &inferedReturnTypes);
+
+/// Verifies that the inferred result types match the actual result types for
+/// the op. Precondition: op implements InferTypeOpInterface.
+LogicalResult verifyInferredResultTypes(Operation *op);
} // namespace detail
+#include "mlir/Analysis/InferTypeOpInterface.h.inc"
+
namespace OpTrait {
/// Tensor type inference trait that constructs a tensor from the infered
}]
>,
];
+
+ let verify = [{
+ return detail::verifyInferredResultTypes($_op);
+ }];
}
def InferShapedTypeOpInterface : OpInterface<"InferShapedTypeOpInterface"> {
// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in
// C++. The purpose to wrap around C++ symbol string with this class is to make
// interfaces specified for ops in TableGen less alien and more integrated.
-class OpInterfaceTrait<string name> : NativeOpTrait<""> {
+class OpInterfaceTrait<string name, code verifyBody = [{}]> : NativeOpTrait<""> {
let trait = name # "::Trait";
+
+ // Specify the body of the verification function. `$_op` will be replaced with
+ // the operation being verified.
+ code verify = verifyBody;
}
// This class represents a single, optionally static, interface method.
// Return the description of this method if it has one.
llvm::Optional<StringRef> getDescription() const;
+ // Return the verify method body if it has one.
+ llvm::Optional<StringRef> getVerify() const;
+
private:
// The TableGen definition of this interface.
const llvm::Record *def;
}
return success();
}
+
+LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
+ SmallVector<Type, 4> inferedReturnTypes;
+ auto retTypeFn = cast<InferTypeOpInterface>(op);
+ if (failed(retTypeFn.inferReturnTypes(op->getContext(), op->getLoc(),
+ op->getOperands(), op->getAttrs(),
+ op->getRegions(), inferedReturnTypes)))
+ return failure();
+ SmallVector<Type, 4> resultTypes(op->getResultTypes());
+ if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes))
+ return op->emitOpError(
+ "inferred type incompatible with return type of operation");
+ return success();
+}
auto value = def->getValueAsString("description");
return value.empty() ? llvm::Optional<StringRef>() : value;
}
+
+// Return the body for this method if it has one.
+llvm::Optional<StringRef> OpInterface::getVerify() const {
+ auto value = def->getValueAsString("verify");
+ return value.empty() ? llvm::Optional<StringRef>() : value;
+}
};
return;
}
-
- // Verification check.
- // TODO: Move to ops that implement type infer interface.
- getFunction().walk([this](Operation *op) -> void {
- auto retTypeFn = dyn_cast<InferTypeOpInterface>(op);
- if (!retTypeFn)
- return;
- auto *context = &getContext();
- SmallVector<Type, 2> inferedReturnTypes;
- if (failed(retTypeFn.inferReturnTypes(
- context, op->getLoc(), op->getOperands(), op->getAttrs(),
- op->getRegions(), inferedReturnTypes)))
- return;
- SmallVector<Type, 1> resultTypes(op->getResultTypes());
- if (!retTypeFn.isCompatibleReturnTypes(inferedReturnTypes, resultTypes)) {
- op->emitOpError(
- "inferred type incompatible with return type of operation");
- return;
- }
- });
}
};
} // end anonymous namespace
// -----
-// CHECK-LABEL: testReturnTypeOpInterface
func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
// expected-error@+1 {{incompatible with return type}}
%bad = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
// -----
-// CHECK-LABEL: testReturnTypeOpInterface
func @testReturnTypeOpInterfaceMismatch(%arg0 : tensor<10xf32>, %arg1 : tensor<20xf32>) {
// expected-error@+1 {{operand type mismatch}}
%bad = "test.op_with_infer_type_if"(%arg0, %arg1) : (tensor<10xf32>, tensor<20xf32>) -> tensor<*xf32>
#include "DocGenUtilities.h"
#include "mlir/Support/STLExtras.h"
+#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/OpInterfaces.h"
#include "llvm/ADT/SmallVector.h"
// Insert the default implementation for any methods.
for (auto &method : interface.getMethods()) {
+ // Flag interface methods named verifyTrait.
+ if (method.getName() == "verifyTrait")
+ PrintFatalError(
+ formatv("'verifyTrait' method cannot be specified as interface "
+ "method for '{0}'; set 'verify' on OpInterfaceTrait instead",
+ interfaceName));
auto defaultImpl = method.getDefaultImplementation();
if (!defaultImpl)
continue;
os << " {\n" << defaultImpl.getValue() << " }\n";
}
+ tblgen::FmtContext traitCtx;
+ traitCtx.withOp("op");
+ if (auto verify = interface.getVerify()) {
+ os << " static LogicalResult verifyTrait(Operation* op) {\n"
+ << tblgen::tgfmt(*verify, &traitCtx) << "\n }\n";
+ }
+
os << " };\n";
}