From 3e731af9124cc74d2276da19031e6dd275a7c33f Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 15 Dec 2022 14:02:06 -0800 Subject: [PATCH] [mlir] Limit Interface generation to the top-level input file There are very few instances in which we use multiple files for interface definitions (none upstream), and this allows for including interfaces that shouldn't be generated (for interface inheritance, dependencies, etc.) Differential Revision: https://reviews.llvm.org/D140196 --- mlir/docs/Interfaces.md | 7 +++++++ mlir/tools/mlir-tblgen/OpInterfacesGen.cpp | 25 +++++++++++++++---------- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md index b6007d0..9482c5a 100644 --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -634,6 +634,13 @@ def OpWithOverrideInferTypeInterfaceOp : Op<... [DeclareOpInterfaceMethods]> { ... } ``` +Once the interfaces have been defined, the C++ header and source files can be +generated using the `--gen--interface-decls` and +`--gen--interface-defs` options with mlir-tblgen. Note that when +generating interfaces, mlir-tblgen will only generate interfaces defined in +the top-level input `.td` file. This means that any interfaces that are +defined within include files will not be considered for generation. + Note: Existing operation interfaces defined in C++ can be accessed in the ODS framework via the `OpInterfaceTrait` class. diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index c832d0a..9e84d19 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -62,12 +62,19 @@ static void emitMethodNameAndArgs(const InterfaceMethod &method, /// Get an array of all OpInterface definitions but exclude those subclassing /// "DeclareOpInterfaceMethods". static std::vector -getAllOpInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper) { +getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper, + StringRef name) { std::vector defs = - recordKeeper.getAllDerivedDefinitions("OpInterface"); - - llvm::erase_if(defs, [](const llvm::Record *def) { - return def->isSubClassOf("DeclareOpInterfaceMethods"); + recordKeeper.getAllDerivedDefinitions((name + "Interface").str()); + + std::string declareName = ("Declare" + name + "InterfaceMethods").str(); + llvm::erase_if(defs, [&](const llvm::Record *def) { + // Ignore any "declare methods" interfaces. + if (def->isSubClassOf(declareName)) + return true; + // Ignore interfaces defined outside of the top-level file. + return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) != + llvm::SrcMgr.getMainFileID(); }); return defs; } @@ -110,8 +117,7 @@ protected: /// A specialized generator for attribute interfaces. struct AttrInterfaceGenerator : public InterfaceGenerator { AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) - : InterfaceGenerator(records.getAllDerivedDefinitions("AttrInterface"), - os) { + : InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) { valueType = "::mlir::Attribute"; interfaceBaseType = "AttributeInterface"; valueTemplate = "ConcreteAttr"; @@ -125,7 +131,7 @@ struct AttrInterfaceGenerator : public InterfaceGenerator { /// A specialized generator for operation interfaces. struct OpInterfaceGenerator : public InterfaceGenerator { OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) - : InterfaceGenerator(getAllOpInterfaceDefinitions(records), os) { + : InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) { valueType = "::mlir::Operation *"; interfaceBaseType = "OpInterface"; valueTemplate = "ConcreteOp"; @@ -140,8 +146,7 @@ struct OpInterfaceGenerator : public InterfaceGenerator { /// A specialized generator for type interfaces. struct TypeInterfaceGenerator : public InterfaceGenerator { TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) - : InterfaceGenerator(records.getAllDerivedDefinitions("TypeInterface"), - os) { + : InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) { valueType = "::mlir::Type"; interfaceBaseType = "TypeInterface"; valueTemplate = "ConcreteType"; -- 2.7.4