[TableGen] Add overload of RecordKeeper::getAllDerivedDefinitions()
authorPaul C. Anagnostopoulos <paul@windfall.com>
Sun, 4 Oct 2020 18:48:44 +0000 (14:48 -0400)
committerPaul C. Anagnostopoulos <paul@windfall.com>
Mon, 12 Oct 2020 20:40:09 +0000 (16:40 -0400)
  and use in PseudoLowering backend.
Now the two getAllDerivedDefinitions() use StringRef and Arrayref.
Use all_of() in getAllDerivedDefinitions().

llvm/docs/TableGen/BackGuide.rst
llvm/include/llvm/TableGen/Record.h
llvm/lib/TableGen/Record.cpp
llvm/utils/TableGen/PseudoLoweringEmitter.cpp

index 4ee5453..829f5c8 100644 (file)
@@ -569,9 +569,9 @@ The ``RecordKeeper`` class provides four functions for getting the
   ``Record`` references for the concrete records that derive from the
   given class.
 
-* ``getAllDerivedDefinitionsTwo(``\ *classname1*\ ``,`` *classname2*\ ``)`` returns
+* ``getAllDerivedDefinitions(``\ *classnames*\ ``)`` returns
   a vector of ``Record`` references for the concrete records that derive from
-  *both* of the given classes. [function to come]
+  *all* of the given classes.
 
 This statement obtains all the records that derive from the ``Attribute``
 class and iterates over them.
index 2a02093..c7009e4 100644 (file)
@@ -1784,9 +1784,17 @@ public:
   //===--------------------------------------------------------------------===//
   // High-level helper methods, useful for tablegen backends.
 
-  /// Get all the concrete records that inherit from the specified
+  /// Get all the concrete records that inherit from all the specified
+  /// classes. The classes must be defined.
+  std::vector<Record *> getAllDerivedDefinitions(
+      const ArrayRef<StringRef> ClassNames) const;
+
+  /// Get all the concrete records that inherit from the one specified
   /// class. The class must be defined.
-  std::vector<Record *> getAllDerivedDefinitions(StringRef ClassName) const;
+  std::vector<Record *> getAllDerivedDefinitions(StringRef ClassName) const {
+
+    return getAllDerivedDefinitions(makeArrayRef(ClassName));
+  }
 
   void dump() const;
 };
index 260cca6..2a46449 100644 (file)
@@ -2470,16 +2470,25 @@ Init *RecordKeeper::getNewAnonymousName() {
   return StringInit::get("anonymous_" + utostr(AnonCounter++));
 }
 
-std::vector<Record *>
-RecordKeeper::getAllDerivedDefinitions(StringRef ClassName) const {
-  Record *Class = getClass(ClassName);
-  if (!Class)
-    PrintFatalError("ERROR: Couldn't find the `" + ClassName + "' class!\n");
+std::vector<Record *> RecordKeeper::getAllDerivedDefinitions(
+    const ArrayRef<StringRef> ClassNames) const {
+  SmallVector<Record *, 2> ClassRecs;
+  std::vector<Record *> Defs;
 
-  std::vector<Record*> Defs;
-  for (const auto &D : getDefs())
-    if (D.second->isSubClassOf(Class))
-      Defs.push_back(D.second.get());
+  assert(ClassNames.size() > 0 && "At least one class must be passed.");
+  for (const auto &ClassName : ClassNames) {
+    Record *Class = getClass(ClassName);
+    if (!Class)
+      PrintFatalError("The class '" + ClassName + "' is not defined\n");
+    ClassRecs.push_back(Class);
+  }
+
+  for (const auto &OneDef : getDefs()) {
+    if (all_of(ClassRecs, [&OneDef](const Record *Class) {
+                            return OneDef.second->isSubClassOf(Class);
+                          }))
+      Defs.push_back(OneDef.second.get());
+  }
 
   return Defs;
 }
index 1f3f93d..0200e86 100644 (file)
@@ -293,17 +293,9 @@ void PseudoLoweringEmitter::emitLoweringEmitter(raw_ostream &o) {
 }
 
 void PseudoLoweringEmitter::run(raw_ostream &o) {
-  Record *ExpansionClass = Records.getClass("PseudoInstExpansion");
-  Record *InstructionClass = Records.getClass("Instruction");
-  assert(ExpansionClass && "PseudoInstExpansion class definition missing!");
-  assert(InstructionClass && "Instruction class definition missing!");
-
-  std::vector<Record*> Insts;
-  for (const auto &D : Records.getDefs()) {
-    if (D.second->isSubClassOf(ExpansionClass) &&
-        D.second->isSubClassOf(InstructionClass))
-      Insts.push_back(D.second.get());
-  }
+  StringRef Classes[] = {"PseudoInstExpansion", "Instruction"};
+  std::vector<Record *> Insts =
+      Records.getAllDerivedDefinitions(makeArrayRef(Classes));
 
   // Process the pseudo expansion definitions, validating them as we do so.
   for (unsigned i = 0, e = Insts.size(); i != e; ++i)