[mlir] ODS: emit interface model method at the end of the header
authorAlex Zinenko <zinenko@google.com>
Fri, 21 Oct 2022 00:53:05 +0000 (00:53 +0000)
committerDiego Caballero <diegocaballero@google.com>
Thu, 27 Oct 2022 22:54:19 +0000 (22:54 +0000)
Previously, ODS interface generator was placing implementations of the
interface's internal "Model" class template immediately after the class
definitions in the header. This doesn't allow this implementation, and
consequently the interface itself, to return an instance of another
interface if its class definition is emitted below. This creates
undesired ordering effects and makes it impossible for two or more
interfaces to return instances of each other. Change the interface
generator to place the implementations of these methods after all
interface classes.

Reviewed By: dcaballe

Differential Revision: https://reviews.llvm.org/D136322

mlir/test/mlir-tblgen/op-interface.td
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp

index 901d9ea..8ac0d6d 100644 (file)
@@ -76,15 +76,16 @@ def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
 // DECL: /// some function comment
 // DECL: int foo(int input);
 
-// DECL: template<typename ConcreteOp>
-// DECL: int detail::TestOpInterfaceInterfaceTraits::Model<ConcreteOp>::foo
-
 // DECL-LABEL: struct TestOpInterfaceVerifyTrait
 // DECL: verifyTrait
 
 // DECL-LABEL: struct TestOpInterfaceVerifyRegionTrait
 // DECL: verifyRegionTrait
 
+// Method implementations come last, after all class definitions.
+// DECL: template<typename ConcreteOp>
+// DECL: int detail::TestOpInterfaceInterfaceTraits::Model<ConcreteOp>::foo
+
 // OP_DECL-LABEL: class DeclareMethodsOp : public
 // OP_DECL: int foo(int input);
 // OP_DECL-NOT: int default_foo(int input);
index 340759d..3abd828 100644 (file)
@@ -289,6 +289,11 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) {
 }
 
 void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
+  llvm::SmallVector<StringRef, 2> namespaces;
+  llvm::SplitString(interface.getCppNamespace(), namespaces, "::");
+  for (StringRef ns : namespaces)
+    os << "namespace " << ns << " {\n";
+
   for (auto &method : interface.getMethods()) {
     os << "template<typename " << valueTemplate << ">\n";
     emitCPPType(method.getReturnType(), os);
@@ -384,6 +389,9 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) {
                         method.isStatic() ? &ctx : &nonStaticMethodFmt);
     os << "\n}\n";
   }
+
+  for (StringRef ns : llvm::reverse(namespaces))
+    os << "} // namespace " << ns << "\n";
 }
 
 void InterfaceGenerator::emitTraitDecl(const Interface &interface,
@@ -498,8 +506,6 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
   emitTraitDecl(interface, interfaceName, interfaceTraitsName);
   os << "}// namespace detail\n";
 
-  emitModelMethodsDef(interface);
-
   for (StringRef ns : llvm::reverse(namespaces))
     os << "} // namespace " << ns << "\n";
 }
@@ -507,8 +513,10 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) {
 bool InterfaceGenerator::emitInterfaceDecls() {
   llvm::emitSourceFileHeader("Interface Declarations", os);
 
-  for (const auto *def : defs)
+  for (const llvm::Record *def : defs)
     emitInterfaceDecl(Interface(def));
+  for (const llvm::Record *def : defs)
+    emitModelMethodsDef(Interface(def));
   return false;
 }