This enables providing a default implementation of an interface method. This method is defined on the Trait that is attached to the operation, and thus has all of the same constraints and properties as any other interface method. This allows for interface authors to provide a conservative default implementation for certain methods, without requiring that all users explicitly define it. The default implementation can be specified via the argument directly after the interface method body:
StaticInterfaceMethod<
/*desc=*/"Returns whether two array of types are compatible result types for an op.",
/*retTy=*/"bool",
/*methodName=*/"isCompatibleReturnTypes",
/*args=*/(ins "ArrayRef<Type>":$lhs, "ArrayRef<Type>":$rhs),
/*methodBody=*/[{
return ConcreteOp::isCompatibleReturnTypes(lhs, rhs);
}],
/*defaultImplementation=*/[{
/// Returns whether two arrays are equal as strongest check for
/// compatibility by default.
return lhs == rhs;
}]
PiperOrigin-RevId:
286226054
to the type of the derived operation currently being operated on.
- In non-static methods, a variable 'ConcreteOp op' is defined and may be
used to refer to an instance of the derived operation.
+* DefaultImplementation (Optional)
+ - An optional explicit default implementation of the interface method.
+ - This method is placed within the `Trait` class that is attached to the
+ operation. As such, this method has the same characteristics as any
+ other [`Trait`](Traits.md) method.
+ - `ConcreteOp` is an implicitly defined typename that can be used to refer
+ to the type of the derived operation currently being operated on.
ODS also allows generating the declarations for the `InterfaceMethod` of the op
if one specifies the interface with `DeclareOpInterfaceMethods` (see example
"unsigned", "getNumInputsAndOutputs", (ins), [{
return op.getNumInputs() + op.getNumOutputs();
}]>,
+
+ // Provide only a default definition of the method.
+ // Note: `ConcreteOp` corresponds to the derived operation typename.
+ InterfaceMethod<"/*insert doc here*/",
+ "unsigned", "getNumInputsAndOutputs", (ins), /*methodBody=*/[{}], [{
+ ConcreteOp op = cast<ConcreteOp>(getOperation());
+ return op.getNumInputs() + op.getNumOutputs();
+ }]>,
];
}
/*retTy=*/"bool",
/*methodName=*/"isCompatibleReturnTypes",
/*args=*/(ins "ArrayRef<Type>":$lhs, "ArrayRef<Type>":$rhs),
- [{
+ /*methodBody=*/[{
return ConcreteOp::isCompatibleReturnTypes(lhs, rhs);
+ }],
+ /*defaultImplementation=*/[{
+ /// Returns whether two arrays are equal as strongest check for
+ /// compatibility by default.
+ return lhs == rhs;
}]
>,
];
}
-// Default implementations for some of the interface methods above:
-// - compatibleReturnTypes returns whether strictly true.
-def InferTypeOpInterfaceDefault : NativeOpTrait<"TypeOpInterfaceDefault">;
-
#endif // MLIR_INFERTYPEOPINTERFACE
// Note: non-static interface methods have an implicit 'op' parameter
// corresponding to an instance of the derived operation.
class InterfaceMethod<string desc, string retTy, string methodName,
- dag args = (ins), code methodBody = [{}]> {
+ dag args = (ins), code methodBody = [{}],
+ code defaultImplementation = [{}]> {
// A human-readable description of what this method does.
string description = desc;
// An optional body to the method.
code body = methodBody;
+
+ // An optional default implementation of the method.
+ code defaultBody = defaultImplementation;
}
// This class represents a single static interface method.
class StaticInterfaceMethod<string desc, string retTy, string methodName,
- dag args = (ins), code methodBody = [{}]>
- : InterfaceMethod<desc, retTy, methodName, args, methodBody>;
+ dag args = (ins), code methodBody = [{}],
+ code defaultImplementation = [{}]>
+ : InterfaceMethod<desc, retTy, methodName, args, methodBody,
+ defaultImplementation>;
// OpInterface represents an interface regarding an op.
class OpInterface<string name> : OpInterfaceTrait<name> {
// Return the body for this method if it has one.
llvm::Optional<StringRef> getBody() const;
+ // Return the default implementation for this method if it has one.
+ llvm::Optional<StringRef> getDefaultImplementation() const;
+
// Return the description of this method if it has one.
llvm::Optional<StringRef> getDescription() const;
return value.empty() ? llvm::Optional<StringRef>() : value;
}
+// Return the default implementation for this method if it has one.
+llvm::Optional<StringRef> OpInterfaceMethod::getDefaultImplementation() const {
+ auto value = def->getValueAsString("defaultBody");
+ return value.empty() ? llvm::Optional<StringRef>() : value;
+}
+
// Return the description of this method if it has one.
llvm::Optional<StringRef> OpInterfaceMethod::getDescription() const {
auto value = def->getValueAsString("description");
}
def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
- DeclareOpInterfaceMethods<InferTypeOpInterface>,
- InferTypeOpInterfaceDefault]> {
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let arguments = (ins AnyTensor, AnyTensor);
let results = (outs AnyTensor);
}
os << " };\n";
}
+static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
+ StringRef interfaceName,
+ StringRef interfaceTraitsName) {
+ os << " template <typename ConcreteOp>\n "
+ << llvm::formatv("struct Trait : public OpInterface<{0},"
+ " detail::{1}>::Trait<ConcreteOp> {{\n",
+ interfaceName, interfaceTraitsName);
+
+ // Insert the default implementation for any methods.
+ for (auto &method : interface.getMethods()) {
+ auto defaultImpl = method.getDefaultImplementation();
+ if (!defaultImpl)
+ continue;
+
+ os << " " << (method.isStatic() ? "static " : "") << method.getReturnType()
+ << " ";
+ emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
+ os << " {\n" << defaultImpl.getValue() << " }\n";
+ }
+
+ os << " };\n";
+}
+
static void emitInterfaceDecl(OpInterface &interface, raw_ostream &os) {
StringRef interfaceName = interface.getName();
auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
" using OpInterface<{1}, detail::{2}>::OpInterface;\n",
interfaceName, interfaceName, interfaceTraitsName);
+ // Emit the derived trait for the interface.
+ emitTraitDecl(interface, os, interfaceName, interfaceTraitsName);
+
// Insert the method declarations.
for (auto &method : interface.getMethods()) {
os << " " << method.getReturnType() << " ";