[mlir] Limit Interface generation to the top-level input file
authorRiver Riddle <riddleriver@gmail.com>
Thu, 15 Dec 2022 22:02:06 +0000 (14:02 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Thu, 19 Jan 2023 03:16:30 +0000 (19:16 -0800)
There are very few instances in which we use multiple files
for interface definitions (none upstream), and this allows
for including interfaces that shouldn't be generated (for
interface inheritance, dependencies, etc.)

Differential Revision: https://reviews.llvm.org/D140196

mlir/docs/Interfaces.md
mlir/tools/mlir-tblgen/OpInterfacesGen.cpp

index b6007d0..9482c5a 100644 (file)
@@ -634,6 +634,13 @@ def OpWithOverrideInferTypeInterfaceOp : Op<...
     [DeclareOpInterfaceMethods<MyInterface, ["getNumWithDefault"]>]> { ... }
 ```
 
+Once the interfaces have been defined, the C++ header and source files can be
+generated using the `--gen-<attr|op|type>-interface-decls` and
+`--gen-<attr|op|type>-interface-defs` options with mlir-tblgen. Note that when
+generating interfaces, mlir-tblgen will only generate interfaces defined in
+the top-level input `.td` file. This means that any interfaces that are
+defined within include files will not be considered for generation.
+
 Note: Existing operation interfaces defined in C++ can be accessed in the ODS
 framework via the `OpInterfaceTrait` class.
 
index c832d0a..9e84d19 100644 (file)
@@ -62,12 +62,19 @@ static void emitMethodNameAndArgs(const InterfaceMethod &method,
 /// Get an array of all OpInterface definitions but exclude those subclassing
 /// "DeclareOpInterfaceMethods".
 static std::vector<llvm::Record *>
-getAllOpInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper) {
+getAllInterfaceDefinitions(const llvm::RecordKeeper &recordKeeper,
+                           StringRef name) {
   std::vector<llvm::Record *> defs =
-      recordKeeper.getAllDerivedDefinitions("OpInterface");
-
-  llvm::erase_if(defs, [](const llvm::Record *def) {
-    return def->isSubClassOf("DeclareOpInterfaceMethods");
+      recordKeeper.getAllDerivedDefinitions((name + "Interface").str());
+
+  std::string declareName = ("Declare" + name + "InterfaceMethods").str();
+  llvm::erase_if(defs, [&](const llvm::Record *def) {
+    // Ignore any "declare methods" interfaces.
+    if (def->isSubClassOf(declareName))
+      return true;
+    // Ignore interfaces defined outside of the top-level file.
+    return llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
+           llvm::SrcMgr.getMainFileID();
   });
   return defs;
 }
@@ -110,8 +117,7 @@ protected:
 /// A specialized generator for attribute interfaces.
 struct AttrInterfaceGenerator : public InterfaceGenerator {
   AttrInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
-      : InterfaceGenerator(records.getAllDerivedDefinitions("AttrInterface"),
-                           os) {
+      : InterfaceGenerator(getAllInterfaceDefinitions(records, "Attr"), os) {
     valueType = "::mlir::Attribute";
     interfaceBaseType = "AttributeInterface";
     valueTemplate = "ConcreteAttr";
@@ -125,7 +131,7 @@ struct AttrInterfaceGenerator : public InterfaceGenerator {
 /// A specialized generator for operation interfaces.
 struct OpInterfaceGenerator : public InterfaceGenerator {
   OpInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
-      : InterfaceGenerator(getAllOpInterfaceDefinitions(records), os) {
+      : InterfaceGenerator(getAllInterfaceDefinitions(records, "Op"), os) {
     valueType = "::mlir::Operation *";
     interfaceBaseType = "OpInterface";
     valueTemplate = "ConcreteOp";
@@ -140,8 +146,7 @@ struct OpInterfaceGenerator : public InterfaceGenerator {
 /// A specialized generator for type interfaces.
 struct TypeInterfaceGenerator : public InterfaceGenerator {
   TypeInterfaceGenerator(const llvm::RecordKeeper &records, raw_ostream &os)
-      : InterfaceGenerator(records.getAllDerivedDefinitions("TypeInterface"),
-                           os) {
+      : InterfaceGenerator(getAllInterfaceDefinitions(records, "Type"), os) {
     valueType = "::mlir::Type";
     interfaceBaseType = "TypeInterface";
     valueTemplate = "ConcreteType";