[mlir] Allow for using interface class name in ODS interface definitions
authorAlex Zinenko <zinenko@google.com>
Sun, 15 Nov 2020 16:49:37 +0000 (17:49 +0100)
committerAlex Zinenko <zinenko@google.com>
Tue, 17 Nov 2020 13:28:55 +0000 (14:28 +0100)
It may be necessary for interface methods to process or return variables with
the interface class type, in particular for attribute and type interfaces that
can return modified attributes and types that implement the same interface.
However, the code generated by ODS in this case would not compile because the
signature (and the body if provided) appear in the definition of the Model
class and before the interface class, which derives from the Model. Change the ODS
interface method generator to emit only method declarations in the Model class
itself, and emit method definitions after the interface class. Mark as "inline"
since their definitions are still emitted in the header and are no longer
implicitly inline. Add a forward declaration of the interface class before the
Concept+Model classes to make the class name usable in declarations.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D91499

mlir/test/lib/Dialect/Test/TestInterfaces.td
mlir/test/lib/IR/TestInterfaces.cpp
mlir/test/mlir-tblgen/interfaces.mlir
mlir/test/mlir-tblgen/op-interface.td
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp

index 08f07cf..19a779d 100644 (file)
@@ -28,6 +28,15 @@ def TestTypeInterface : TypeInterface<"TestTypeInterface"> {
     InterfaceMethod<"Prints the type name.",
       "void", "printTypeC", (ins "Location":$loc)
     >,
+    // It should be possible to use the interface type name as result type
+    // as well as in the implementation.
+    InterfaceMethod<"Prints the type name and returns the type as interface.",
+      "TestTypeInterface", "printTypeRet", (ins "Location":$loc),
+      [{}], /*defaultImplementation=*/[{
+        emitRemark(loc) << $_type << " - TestRet";
+        return $_type;
+      }]
+    >,
   ];
   let extraClassDeclaration = [{
     /// Prints the type name.
index 369001f..3a6e10d 100644 (file)
@@ -25,6 +25,10 @@ struct TestTypeInterfaces
           testInterface.printTypeB(op->getLoc());
           testInterface.printTypeC(op->getLoc());
           testInterface.printTypeD(op->getLoc());
+          // Just check that we can assign the result to a variable of interface
+          // type.
+          TestTypeInterface result = testInterface.printTypeRet(op->getLoc());
+          (void)result;
         }
         if (auto testType = type.dyn_cast<TestType>())
           testType.printTypeE(op->getLoc());
index 712d934..5c1ec61 100644 (file)
@@ -4,6 +4,7 @@
 // expected-remark@below {{'!test.test_type' - TestB}}
 // expected-remark@below {{'!test.test_type' - TestC}}
 // expected-remark@below {{'!test.test_type' - TestD}}
+// expected-remark@below {{'!test.test_type' - TestRet}}
 // expected-remark@below {{'!test.test_type' - TestE}}
 %foo0 = "foo.test"() : () -> (!test.test_type)
 
index 4ca2b79..7f5ae6c 100644 (file)
@@ -41,9 +41,11 @@ def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
 
 // DECL-LABEL: TestOpInterfaceInterfaceTraits
 // DECL: class TestOpInterface : public ::mlir::OpInterface<TestOpInterface, detail::TestOpInterfaceInterfaceTraits>
+
 // DECL: int foo(int input);
 
-// DECL-NOT: TestOpInterface
+// DECL: template<typename ConcreteOp>
+// DECL: int detail::TestOpInterfaceInterfaceTraits::Model<ConcreteOp>::foo
 
 // OP_DECL-LABEL: class DeclareMethodsOp : public
 // OP_DECL: int foo(int input);
index 1b0cb27..1a8f6b7 100644 (file)
@@ -82,6 +82,7 @@ protected:
 
   void emitConceptDecl(Interface &interface);
   void emitModelDecl(Interface &interface);
+  void emitModelMethodsDef(Interface &interface);
   void emitTraitDecl(Interface &interface, StringRef interfaceName,
                      StringRef interfaceTraitsName);
   void emitInterfaceDecl(Interface interface);
@@ -217,11 +218,25 @@ void InterfaceGenerator::emitModelDecl(Interface &interface) {
 
   // Insert each of the virtual method overrides.
   for (auto &method : interface.getMethods()) {
-    emitCPPType(method.getReturnType(), os << "    static ");
+    emitCPPType(method.getReturnType(), os << "    static inline ");
     emitMethodNameAndArgs(method, os, valueType,
                           /*addThisArg=*/!method.isStatic(),
                           /*addConst=*/false);
-    os << " {\n      ";
+    os << ";\n";
+  }
+  os << "  };\n";
+}
+
+void InterfaceGenerator::emitModelMethodsDef(Interface &interface) {
+  for (auto &method : interface.getMethods()) {
+    os << "template<typename " << valueTemplate << ">\n";
+    emitCPPType(method.getReturnType(), os);
+    os << "detail::" << interface.getName() << "InterfaceTraits::Model<"
+       << valueTemplate << ">::";
+    emitMethodNameAndArgs(method, os, valueType,
+                          /*addThisArg=*/!method.isStatic(),
+                          /*addConst=*/false);
+    os << " {\n  ";
 
     // Check for a provided body to the function.
     if (Optional<StringRef> body = method.getBody()) {
@@ -229,7 +244,7 @@ void InterfaceGenerator::emitModelDecl(Interface &interface) {
         os << body->trim();
       else
         os << tblgen::tgfmt(body->trim(), &nonStaticMethodFmt);
-      os << "\n    }\n";
+      os << "\n}\n";
       continue;
     }
 
@@ -244,9 +259,8 @@ void InterfaceGenerator::emitModelDecl(Interface &interface) {
     llvm::interleaveComma(
         method.getArguments(), os,
         [&](const InterfaceMethod::Argument &arg) { os << arg.name; });
-    os << ");\n    }\n";
+    os << ");\n}\n";
   }
-  os << "  };\n";
 }
 
 void InterfaceGenerator::emitTraitDecl(Interface &interface,
@@ -308,6 +322,10 @@ void InterfaceGenerator::emitInterfaceDecl(Interface interface) {
   StringRef interfaceName = interface.getName();
   auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str();
 
+  // Emit a forward declaration of the interface class so that it becomes usable
+  // in the signature of its methods.
+  os << "class " << interfaceName << ";\n";
+
   // Emit the traits struct containing the concept and model declarations.
   os << "namespace detail {\n"
      << "struct " << interfaceTraitsName << " {\n";
@@ -340,6 +358,8 @@ void InterfaceGenerator::emitInterfaceDecl(Interface interface) {
 
   os << "};\n";
 
+  emitModelMethodsDef(interface);
+
   for (StringRef ns : llvm::reverse(namespaces))
     os << "} // namespace " << ns << "\n";
 }