From: River Riddle Date: Thu, 15 Dec 2022 22:02:06 +0000 (-0800) Subject: [mlir] Limit Interface generation to the top-level input file X-Git-Tag: upstream/17.0.6~20490 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=3e731af9124cc74d2276da19031e6dd275a7c33f;p=platform%2Fupstream%2Fllvm.git [mlir] Limit Interface generation to the top-level input file There are very few instances in which we use multiple files for interface definitions (none upstream), and this allows for including interfaces that shouldn't be generated (for interface inheritance, dependencies, etc.) Differential Revision: https://reviews.llvm.org/D140196 --- diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md index b6007d0..9482c5a 100644 --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -634,6 +634,13 @@ def OpWithOverrideInferTypeInterfaceOp : Op<... [DeclareOpInterfaceMethods]> { ... } ``` +Once the interfaces have been defined, the C++ header and source files can be +generated using the `--gen--interface-decls` and +`--gen--interface-defs` options with mlir-tblgen. Note that when +generating interfaces, mlir-tblgen will only generate interfaces defined in +the top-level input `.td` file. This means that any interfaces that are +defined within include files will not be considered for generation. + Note: Existing operation interfaces defined in C++ can be accessed in the ODS framework via the `OpInterfaceTrait` class. diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index c832d0a..9e84d19 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -62,12 +62,19 @@ static void emitMethodNameAndArgs(const InterfaceMethod &method, /// Get an array of all OpInterface definitions but exclude those subclassing /// "DeclareOpInterfaceMethods". static std::vector -getAllOpInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper) { +getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper, + StringRef name) { std::vector defs = - recordKeeper.getAllDerivedDefinitions("OpInterface"); - - llvm::erase_if(defs, [](const llvm::Record *def) { - return def->isSubClassOf("DeclareOpInterfaceMethods"); + recordKeeper.getAllDerivedDefinitions((name + "Interface").str()); + + std::string declareName = ("Declare" + name + "InterfaceMethods").str(); + llvm::erase_if(defs, [&](const llvm::Record *def) { + // Ignore any "declare methods" interfaces. + if (def->isSubClassOf(declareName)) + return true; + // Ignore interfaces defined outside of the top-level file. + return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) != + llvm::SrcMgr.getMainFileID(); }); return defs; } @@ -110,8 +117,7 @@ protected: /// A specialized generator for attribute interfaces. struct AttrInterfaceGenerator : public InterfaceGenerator { AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) - : InterfaceGenerator(records.getAllDerivedDefinitions("AttrInterface"), - os) { + : InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) { valueType = "::mlir::Attribute"; interfaceBaseType = "AttributeInterface"; valueTemplate = "ConcreteAttr"; @@ -125,7 +131,7 @@ struct AttrInterfaceGenerator : public InterfaceGenerator { /// A specialized generator for operation interfaces. struct OpInterfaceGenerator : public InterfaceGenerator { OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) - : InterfaceGenerator(getAllOpInterfaceDefinitions(records), os) { + : InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) { valueType = "::mlir::Operation *"; interfaceBaseType = "OpInterface"; valueTemplate = "ConcreteOp"; @@ -140,8 +146,7 @@ struct OpInterfaceGenerator : public InterfaceGenerator { /// A specialized generator for type interfaces. struct TypeInterfaceGenerator : public InterfaceGenerator { TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os) - : InterfaceGenerator(records.getAllDerivedDefinitions("TypeInterface"), - os) { + : InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) { valueType = "::mlir::Type"; interfaceBaseType = "TypeInterface"; valueTemplate = "ConcreteType";