[mlir] Cleanup DialectDocGen to check for the dialect early
authorRiver Riddle <riddleriver@gmail.com>
Wed, 12 Oct 2022 20:41:00 +0000 (13:41 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Wed, 12 Oct 2022 21:43:20 +0000 (14:43 -0700)
We only ever generate documentation for one dialect, so there
isn't a good reason to collect every possible dialect entity.

Differential Revision: https://reviews.llvm.org/D135812

mlir/include/mlir/Dialect/MemRef/IR/CMakeLists.txt
mlir/tools/mlir-tblgen/OpDocGen.cpp

index 47efc6d..be40d7b 100644 (file)
@@ -1,2 +1,2 @@
 add_mlir_dialect(MemRefOps memref)
-add_mlir_doc(MemRefOps MemRefOps Dialects/ -gen-dialect-doc)
+add_mlir_doc(MemRefOps MemRefOps Dialects/ -gen-dialect-doc -dialect=memref)
index ed9e514..8ebcd33 100644 (file)
@@ -360,6 +360,13 @@ static void emitDialectDoc(const Dialect &dialect,
 }
 
 static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
+  std::vector<Record *> dialectDefs =
+      recordKeeper.getAllDerivedDefinitionsIfDefined("Dialect");
+  SmallVector<Dialect> dialects(dialectDefs.begin(), dialectDefs.end());
+  Optional<Dialect> dialect = findDialectToGenerate(dialects);
+  if (!dialect)
+    return true;
+
   std::vector<Record *> opDefs = getRequestedOpDefinitions(recordKeeper);
   std::vector<Record *> attrDefs =
       recordKeeper.getAllDerivedDefinitionsIfDefined("DialectAttr");
@@ -370,61 +377,31 @@ static bool emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
   std::vector<Record *> attrDefDefs =
       recordKeeper.getAllDerivedDefinitionsIfDefined("AttrDef");
 
-  llvm::SetVector<Dialect, SmallVector<Dialect>, std::set<Dialect>>
-      dialectsWithDocs;
-
-  llvm::StringMap<std::vector<Attribute>> dialectAttrs;
-  llvm::StringMap<std::vector<AttrDef>> dialectAttrDefs;
-  llvm::StringMap<std::vector<Operator>> dialectOps;
-  llvm::StringMap<std::vector<Type>> dialectTypes;
-  llvm::StringMap<std::vector<TypeDef>> dialectTypeDefs;
+  std::vector<Attribute> dialectAttrs;
+  std::vector<AttrDef> dialectAttrDefs;
+  std::vector<Operator> dialectOps;
+  std::vector<Type> dialectTypes;
+  std::vector<TypeDef> dialectTypeDefs;
   llvm::SmallDenseSet<Record *> seen;
-  for (Record *attrDef : attrDefDefs) {
-    AttrDef attr(attrDef);
-    dialectAttrDefs[attr.getDialect().getName()].push_back(attr);
-    dialectsWithDocs.insert(attr.getDialect());
-    seen.insert(attrDef);
-  }
-  for (Record *attrDef : attrDefs) {
-    if (seen.count(attrDef))
-      continue;
-    Attribute attr(attrDef);
-    if (const Dialect &dialect = attr.getDialect()) {
-      dialectAttrs[dialect.getName()].push_back(attr);
-      dialectsWithDocs.insert(dialect);
-    }
-  }
-  for (Record *opDef : opDefs) {
-    Operator op(opDef);
-    dialectOps[op.getDialect().getName()].push_back(op);
-    dialectsWithDocs.insert(op.getDialect());
-  }
-  for (Record *typeDef : typeDefDefs) {
-    TypeDef type(typeDef);
-    dialectTypeDefs[type.getDialect().getName()].push_back(type);
-    dialectsWithDocs.insert(type.getDialect());
-    seen.insert(typeDef);
-  }
-  for (Record *typeDef : typeDefs) {
-    if (seen.count(typeDef))
-      continue;
-    Type type(typeDef);
-    if (const Dialect &dialect = type.getDialect()) {
-      dialectTypes[dialect.getName()].push_back(type);
-      dialectsWithDocs.insert(dialect);
-    }
-  }
-
-  Optional<Dialect> dialect =
-      findDialectToGenerate(dialectsWithDocs.getArrayRef());
-  if (!dialect)
-    return true;
+  auto addIfInDialect = [&](llvm::Record *record, const auto &def, auto &vec) {
+    if (seen.insert(record).second && def.getDialect() == *dialect)
+      vec.push_back(def);
+  };
+
+  for (Record *def : attrDefDefs)
+    addIfInDialect(def, AttrDef(def), dialectAttrDefs);
+  for (Record *def : attrDefs)
+    addIfInDialect(def, Attribute(def), dialectAttrs);
+  for (Record *def : opDefs)
+    addIfInDialect(def, Operator(def), dialectOps);
+  for (Record *def : typeDefDefs)
+    addIfInDialect(def, TypeDef(def), dialectTypeDefs);
+  for (Record *def : typeDefs)
+    addIfInDialect(def, Type(def), dialectTypes);
 
   os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
-  StringRef dialectName = dialect->getName();
-  emitDialectDoc(*dialect, dialectAttrs[dialectName],
-                 dialectAttrDefs[dialectName], dialectOps[dialectName],
-                 dialectTypes[dialectName], dialectTypeDefs[dialectName], os);
+  emitDialectDoc(*dialect, dialectAttrs, dialectAttrDefs, dialectOps,
+                 dialectTypes, dialectTypeDefs, os);
   return false;
 }