From: Alex Zinenko Date: Sun, 15 Nov 2020 16:49:37 +0000 (+0100) Subject: [mlir] Allow for using interface class name in ODS interface definitions X-Git-Tag: llvmorg-13-init~5892 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=88f25bda1376b68631106c0e1c5cbe3f385204e0;p=platform%2Fupstream%2Fllvm.git [mlir] Allow for using interface class name in ODS interface definitions It may be necessary for interface methods to process or return variables with the interface class type, in particular for attribute and type interfaces that can return modified attributes and types that implement the same interface. However, the code generated by ODS in this case would not compile because the signature (and the body if provided) appear in the definition of the Model class and before the interface class, which derives from the Model. Change the ODS interface method generator to emit only method declarations in the Model class itself, and emit method definitions after the interface class. Mark as "inline" since their definitions are still emitted in the header and are no longer implicitly inline. Add a forward declaration of the interface class before the Concept+Model classes to make the class name usable in declarations. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D91499 --- diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.td b/mlir/test/lib/Dialect/Test/TestInterfaces.td index 08f07cf..19a779d 100644 --- a/mlir/test/lib/Dialect/Test/TestInterfaces.td +++ b/mlir/test/lib/Dialect/Test/TestInterfaces.td @@ -28,6 +28,15 @@ def TestTypeInterface : TypeInterface<"TestTypeInterface"> { InterfaceMethod<"Prints the type name.", "void", "printTypeC", (ins "Location":$loc) >, + // It should be possible to use the interface type name as result type + // as well as in the implementation. + InterfaceMethod<"Prints the type name and returns the type as interface.", + "TestTypeInterface", "printTypeRet", (ins "Location":$loc), + [{}], /*defaultImplementation=*/[{ + emitRemark(loc) << $_type << " - TestRet"; + return $_type; + }] + >, ]; let extraClassDeclaration = [{ /// Prints the type name. diff --git a/mlir/test/lib/IR/TestInterfaces.cpp b/mlir/test/lib/IR/TestInterfaces.cpp index 369001f..3a6e10d 100644 --- a/mlir/test/lib/IR/TestInterfaces.cpp +++ b/mlir/test/lib/IR/TestInterfaces.cpp @@ -25,6 +25,10 @@ struct TestTypeInterfaces testInterface.printTypeB(op->getLoc()); testInterface.printTypeC(op->getLoc()); testInterface.printTypeD(op->getLoc()); + // Just check that we can assign the result to a variable of interface + // type. + TestTypeInterface result = testInterface.printTypeRet(op->getLoc()); + (void)result; } if (auto testType = type.dyn_cast()) testType.printTypeE(op->getLoc()); diff --git a/mlir/test/mlir-tblgen/interfaces.mlir b/mlir/test/mlir-tblgen/interfaces.mlir index 712d934..5c1ec61 100644 --- a/mlir/test/mlir-tblgen/interfaces.mlir +++ b/mlir/test/mlir-tblgen/interfaces.mlir @@ -4,6 +4,7 @@ // expected-remark@below {{'!test.test_type' - TestB}} // expected-remark@below {{'!test.test_type' - TestC}} // expected-remark@below {{'!test.test_type' - TestD}} +// expected-remark@below {{'!test.test_type' - TestRet}} // expected-remark@below {{'!test.test_type' - TestE}} %foo0 = "foo.test"() : () -> (!test.test_type) diff --git a/mlir/test/mlir-tblgen/op-interface.td b/mlir/test/mlir-tblgen/op-interface.td index 4ca2b79..7f5ae6c 100644 --- a/mlir/test/mlir-tblgen/op-interface.td +++ b/mlir/test/mlir-tblgen/op-interface.td @@ -41,9 +41,11 @@ def DeclareMethodsWithDefaultOp : Op + // DECL: int foo(int input); -// DECL-NOT: TestOpInterface +// DECL: template +// DECL: int detail::TestOpInterfaceInterfaceTraits::Model::foo // OP_DECL-LABEL: class DeclareMethodsOp : public // OP_DECL: int foo(int input); diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index 1b0cb27..1a8f6b7 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -82,6 +82,7 @@ protected: void emitConceptDecl(Interface &interface); void emitModelDecl(Interface &interface); + void emitModelMethodsDef(Interface &interface); void emitTraitDecl(Interface &interface, StringRef interfaceName, StringRef interfaceTraitsName); void emitInterfaceDecl(Interface interface); @@ -217,11 +218,25 @@ void InterfaceGenerator::emitModelDecl(Interface &interface) { // Insert each of the virtual method overrides. for (auto &method : interface.getMethods()) { - emitCPPType(method.getReturnType(), os << " static "); + emitCPPType(method.getReturnType(), os << " static inline "); emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/!method.isStatic(), /*addConst=*/false); - os << " {\n "; + os << ";\n"; + } + os << " };\n"; +} + +void InterfaceGenerator::emitModelMethodsDef(Interface &interface) { + for (auto &method : interface.getMethods()) { + os << "template\n"; + emitCPPType(method.getReturnType(), os); + os << "detail::" << interface.getName() << "InterfaceTraits::Model<" + << valueTemplate << ">::"; + emitMethodNameAndArgs(method, os, valueType, + /*addThisArg=*/!method.isStatic(), + /*addConst=*/false); + os << " {\n "; // Check for a provided body to the function. if (Optional body = method.getBody()) { @@ -229,7 +244,7 @@ void InterfaceGenerator::emitModelDecl(Interface &interface) { os << body->trim(); else os << tblgen::tgfmt(body->trim(), &nonStaticMethodFmt); - os << "\n }\n"; + os << "\n}\n"; continue; } @@ -244,9 +259,8 @@ void InterfaceGenerator::emitModelDecl(Interface &interface) { llvm::interleaveComma( method.getArguments(), os, [&](const InterfaceMethod::Argument &arg) { os << arg.name; }); - os << ");\n }\n"; + os << ");\n}\n"; } - os << " };\n"; } void InterfaceGenerator::emitTraitDecl(Interface &interface, @@ -308,6 +322,10 @@ void InterfaceGenerator::emitInterfaceDecl(Interface interface) { StringRef interfaceName = interface.getName(); auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str(); + // Emit a forward declaration of the interface class so that it becomes usable + // in the signature of its methods. + os << "class " << interfaceName << ";\n"; + // Emit the traits struct containing the concept and model declarations. os << "namespace detail {\n" << "struct " << interfaceTraitsName << " {\n"; @@ -340,6 +358,8 @@ void InterfaceGenerator::emitInterfaceDecl(Interface interface) { os << "};\n"; + emitModelMethodsDef(interface); + for (StringRef ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; }