From a4f81b2054c30954c6739532b923f2b223bc7d77 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 16 Jun 2021 16:31:17 +0200 Subject: [PATCH] [mlir] ODS: emit interface traits outside of the interface class ODS currently emits the interface trait class as a nested class inside the interface class. As an unintended consequence, the default implementations of interface methods have implicit access to static fields of the interface class, e.g. those declared in `extraClassDeclaration`, including private methods (!), or in the parent class. This may break the use of default implementations for external models, which are not defined in the interface class, and generally complexifies the abstraction. Emit intraface traits outside of the interface class itself to avoid accidental implicit visibility. Public static fields can still be accessed via explicit qualification with a class name, e.g., `MyOpInterface::staticMethod()` instead of `staticMethod`. Update the documentation to clarify the role of `extraClassDeclaration` in interfaces. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D104384 --- mlir/docs/Interfaces.md | 5 ++++ .../mlir/Interfaces/ControlFlowInterfaces.td | 4 +-- mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 35 ++++++++++------------ 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md index 633e492..8e75146 100644 --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -384,6 +384,9 @@ comprised of the following components: - Additional C++ code that is generated in the declaration of the interface class. This allows for defining methods and more on the user facing interface class, that do not need to hook into the IR entity. + These declarations are _not_ implicitly visible in default + implementations of interface methods, but static declarations may be + accessed with full name qualification. `OpInterface` classes may additionally contain the following: @@ -430,6 +433,8 @@ Interface methods are comprised of the following components: - `ConcreteAttr`/`ConcreteOp`/`ConcreteType` is an implicitly defined `typename` that can be used to refer to the type of the derived IR entity currently being operated on. + - This may refer to static fields of the interface class using the + qualified name, e.g., `TestOpInterface::staticMethod()`. ODS also allows for generating declarations for the `InterfaceMethod`s of an operation if the operation specifies the interface with diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td index 3c568a0..bb69ed8 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -78,7 +78,7 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> { ]; let verify = [{ - auto concreteOp = cast($_op); + auto concreteOp = cast($_op); for (unsigned i = 0, e = $_op->getNumSuccessors(); i != e; ++i) { Optional operands = concreteOp.getSuccessorOperands(i); if (failed(detail::verifyBranchSuccessorOperands($_op, i, operands))) @@ -154,7 +154,7 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { ]; let verify = [{ - static_assert(!ConcreteOpType::template hasTrait(), + static_assert(!ConcreteOp::template hasTrait(), "expected operation to have non-zero regions"); return success(); }]; diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index d4b4cd1..a365b42 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -241,17 +241,10 @@ void InterfaceGenerator::emitModelDecl(Interface &interface) { os << " };\n"; } + // Emit the template for the external model. os << " template\n"; os << " class ExternalModel : public FallbackModel {\n"; - - // Emit the template for the external model if there are no extra class - // declarations. - if (interface.getExtraClassDeclaration()) { - os << " };\n"; - return; - } - os << " public:\n"; // Emit declarations for methods that have default implementations. Other @@ -345,9 +338,6 @@ void InterfaceGenerator::emitModelMethodsDef(Interface &interface) { } // Emit default implementations for the external model. - if (interface.getExtraClassDeclaration()) - return; - for (auto &method : interface.getMethods()) { if (!method.getDefaultImplementation()) continue; @@ -427,11 +417,6 @@ void InterfaceGenerator::emitTraitDecl(Interface &interface, os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; os << " };\n"; - - // Emit a utility wrapper trait class. - os << llvm::formatv(" template \n" - " struct Trait : public {0}Trait<{1}> {{};\n", - interfaceName, valueTemplate); } void InterfaceGenerator::emitInterfaceDecl(Interface interface) { @@ -452,7 +437,13 @@ void InterfaceGenerator::emitInterfaceDecl(Interface interface) { << "struct " << interfaceTraitsName << " {\n"; emitConceptDecl(interface); emitModelDecl(interface); - os << "};\n} // end namespace detail\n"; + os << "};"; + + // Emit the derived trait for the interface. + os << "template \n"; + os << "struct " << interface.getName() << "Trait;\n"; + + os << "\n} // end namespace detail\n"; // Emit the main interface class declaration. os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n" @@ -461,8 +452,10 @@ void InterfaceGenerator::emitInterfaceDecl(Interface interface) { interfaceName, interfaceName, interfaceTraitsName, interfaceBaseType); - // Emit the derived trait for the interface. - emitTraitDecl(interface, interfaceName, interfaceTraitsName); + // Emit a utility wrapper trait class. + os << llvm::formatv(" template \n" + " struct Trait : public detail::{0}Trait<{1}> {{};\n", + interfaceName, valueTemplate); // Insert the method declarations. bool isOpInterface = isa(interface); @@ -479,6 +472,10 @@ void InterfaceGenerator::emitInterfaceDecl(Interface interface) { os << "};\n"; + os << "namespace detail {\n"; + emitTraitDecl(interface, interfaceName, interfaceTraitsName); + os << "}// namespace detail\n"; + emitModelMethodsDef(interface); for (StringRef ns : llvm::reverse(namespaces)) -- 2.7.4