Add support for providing a default implementation for an interface method.
authorRiver Riddle <riverriddle@google.com>
Wed, 18 Dec 2019 19:02:35 +0000 (11:02 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 18 Dec 2019 19:09:11 +0000 (11:09 -0800)
This enables providing a default implementation of an interface method. This method is defined on the Trait that is attached to the operation, and thus has all of the same constraints and properties as any other interface method. This allows for interface authors to provide a conservative default implementation for certain methods, without requiring that all users explicitly define it. The default implementation can be specified via the argument directly after the interface method body:

  StaticInterfaceMethod<
    /*desc=*/"Returns whether two array of types are compatible result types for an op.",
    /*retTy=*/"bool",
    /*methodName=*/"isCompatibleReturnTypes",
    /*args=*/(ins "ArrayRef<Type>":$lhs, "ArrayRef<Type>":$rhs),
    /*methodBody=*/[{
      return ConcreteOp::isCompatibleReturnTypes(lhs, rhs);
    }],
    /*defaultImplementation=*/[{
      /// Returns whether two arrays are equal as strongest check for
      /// compatibility by default.
      return lhs == rhs;
    }]

PiperOrigin-RevId: 286226054

mlir/g3doc/OpDefinitions.md
mlir/include/mlir/Analysis/InferTypeOpInterface.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/OpInterfaces.h
mlir/lib/TableGen/OpInterfaces.cpp
mlir/test/lib/TestDialect/TestOps.td
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp

index 0c00135..1f98671 100644 (file)
@@ -332,6 +332,13 @@ An `InterfaceMethod` is comprised of the following components:
         to the type of the derived operation currently being operated on.
     -   In non-static methods, a variable 'ConcreteOp op' is defined and may be
         used to refer to an instance of the derived operation.
+*   DefaultImplementation (Optional)
+    -   An optional explicit default implementation of the interface method.
+    -   This method is placed within the `Trait` class that is attached to the
+        operation. As such, this method has the same characteristics as any
+        other [`Trait`](Traits.md) method.
+    -   `ConcreteOp` is an implicitly defined typename that can be used to refer
+        to the type of the derived operation currently being operated on.
 
 ODS also allows generating the declarations for the `InterfaceMethod` of the op
 if one specifies the interface with `DeclareOpInterfaceMethods` (see example
@@ -374,6 +381,14 @@ def MyInterface : OpInterface<"MyInterface"> {
       "unsigned", "getNumInputsAndOutputs", (ins), [{
         return op.getNumInputs() + op.getNumOutputs();
     }]>,
+
+    // Provide only a default definition of the method.
+    // Note: `ConcreteOp` corresponds to the derived operation typename.
+    InterfaceMethod<"/*insert doc here*/",
+      "unsigned", "getNumInputsAndOutputs", (ins), /*methodBody=*/[{}], [{
+        ConcreteOp op = cast<ConcreteOp>(getOperation());
+        return op.getNumInputs() + op.getNumOutputs();
+    }]>,
   ];
 }
 
index aae6e83..14d5809 100644 (file)
@@ -59,15 +59,16 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
       /*retTy=*/"bool",
       /*methodName=*/"isCompatibleReturnTypes",
       /*args=*/(ins "ArrayRef<Type>":$lhs, "ArrayRef<Type>":$rhs),
-      [{
+      /*methodBody=*/[{
         return ConcreteOp::isCompatibleReturnTypes(lhs, rhs);
+      }],
+      /*defaultImplementation=*/[{
+        /// Returns whether two arrays are equal as strongest check for
+        /// compatibility by default.
+        return lhs == rhs;
       }]
     >,
   ];
 }
 
-// Default implementations for some of the interface methods above:
-// - compatibleReturnTypes returns whether strictly true.
-def InferTypeOpInterfaceDefault : NativeOpTrait<"TypeOpInterfaceDefault">;
-
 #endif // MLIR_INFERTYPEOPINTERFACE
index 4d5d1fe..8f6770f 100644 (file)
@@ -1425,7 +1425,8 @@ class OpInterfaceTrait<string name> : NativeOpTrait<""> {
 // Note: non-static interface methods have an implicit 'op' parameter
 // corresponding to an instance of the derived operation.
 class InterfaceMethod<string desc, string retTy, string methodName,
-                      dag args = (ins), code methodBody = [{}]> {
+                      dag args = (ins), code methodBody = [{}],
+                      code defaultImplementation = [{}]> {
   // A human-readable description of what this method does.
   string description = desc;
 
@@ -1440,12 +1441,17 @@ class InterfaceMethod<string desc, string retTy, string methodName,
 
   // An optional body to the method.
   code body = methodBody;
+
+  // An optional default implementation of the method.
+  code defaultBody = defaultImplementation;
 }
 
 // This class represents a single static interface method.
 class StaticInterfaceMethod<string desc, string retTy, string methodName,
-                            dag args = (ins), code methodBody = [{}]>
-    : InterfaceMethod<desc, retTy, methodName, args, methodBody>;
+                            dag args = (ins), code methodBody = [{}],
+                            code defaultImplementation = [{}]>
+    : InterfaceMethod<desc, retTy, methodName, args, methodBody,
+                      defaultImplementation>;
 
 // OpInterface represents an interface regarding an op.
 class OpInterface<string name> : OpInterfaceTrait<name> {
index 4a87876..0959f6b 100644 (file)
@@ -58,6 +58,9 @@ public:
   // Return the body for this method if it has one.
   llvm::Optional<StringRef> getBody() const;
 
+  // Return the default implementation for this method if it has one.
+  llvm::Optional<StringRef> getDefaultImplementation() const;
+
   // Return the description of this method if it has one.
   llvm::Optional<StringRef> getDescription() const;
 
index e4e80e0..1687f3a 100644 (file)
@@ -57,6 +57,12 @@ llvm::Optional<StringRef> OpInterfaceMethod::getBody() const {
   return value.empty() ? llvm::Optional<StringRef>() : value;
 }
 
+// Return the default implementation for this method if it has one.
+llvm::Optional<StringRef> OpInterfaceMethod::getDefaultImplementation() const {
+  auto value = def->getValueAsString("defaultBody");
+  return value.empty() ? llvm::Optional<StringRef>() : value;
+}
+
 // Return the description of this method if it has one.
 llvm::Optional<StringRef> OpInterfaceMethod::getDescription() const {
   auto value = def->getValueAsString("description");
index cf15d0e..e33d9c2 100644 (file)
@@ -403,8 +403,7 @@ def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> {
 }
 
 def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [
-    DeclareOpInterfaceMethods<InferTypeOpInterface>,
-    InferTypeOpInterfaceDefault]> {
+    DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
   let arguments = (ins AnyTensor, AnyTensor);
   let results = (outs AnyTensor);
 }
index 4c22d62..a48bd25 100644 (file)
@@ -151,6 +151,29 @@ static void emitModelDecl(OpInterface &interface, raw_ostream &os) {
   os << "  };\n";
 }
 
+static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
+                          StringRef interfaceName,
+                          StringRef interfaceTraitsName) {
+  os << "  template <typename ConcreteOp>\n  "
+     << llvm::formatv("struct Trait : public OpInterface<{0},"
+                      " detail::{1}>::Trait<ConcreteOp> {{\n",
+                      interfaceName, interfaceTraitsName);
+
+  // Insert the default implementation for any methods.
+  for (auto &method : interface.getMethods()) {
+    auto defaultImpl = method.getDefaultImplementation();
+    if (!defaultImpl)
+      continue;
+
+    os << "  " << (method.isStatic() ? "static " : "") << method.getReturnType()
+       << " ";
+    emitMethodNameAndArgs(method, os, /*addOperationArg=*/false);
+    os << " {\n" << defaultImpl.getValue() << "  }\n";
+  }
+
+  os << "  };\n";
+}
+
 static void emitInterfaceDecl(OpInterface &interface, raw_ostream &os) {
   StringRef interfaceName = interface.getName();
   auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
@@ -168,6 +191,9 @@ static void emitInterfaceDecl(OpInterface &interface, raw_ostream &os) {
                       "  using OpInterface<{1}, detail::{2}>::OpInterface;\n",
                       interfaceName, interfaceName, interfaceTraitsName);
 
+  // Emit the derived trait for the interface.
+  emitTraitDecl(interface, os, interfaceName, interfaceTraitsName);
+
   // Insert the method declarations.
   for (auto &method : interface.getMethods()) {
     os << "  " << method.getReturnType() << " ";