[TableGen] Add a !listremove() bang operator
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Fri, 9 Dec 2022 15:03:11 +0000 (15:03 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Fri, 9 Dec 2022 15:03:18 +0000 (15:03 +0000)
This patch proposes to add a !listremove() bang operator to allow us to prune list entries by removing any entries from the first list arg that are also contained in the second list arg.

The particular use case I have in mind is for improved analysis of x86 scheduler models for which I'm hoping to start using the CodeGenProcModel 'Unsupported' features lists, which lists the ISA features a particular model DOESN'T support - with such a diverse and growing list of x86 ISAs, I don't want to have to update all these lists with every ISA change to every model - so I'm intending to keep a single central list of all x86 features, and then have the each model "remove" the features that it supports via a !listremove() - leaving just the unsupported ones.

Differential Revision: https://reviews.llvm.org/D139642

llvm/docs/TableGen/ProgRef.rst
llvm/include/llvm/TableGen/Record.h
llvm/lib/TableGen/Record.cpp
llvm/lib/TableGen/TGLexer.cpp
llvm/lib/TableGen/TGLexer.h
llvm/lib/TableGen/TGParser.cpp
llvm/test/TableGen/listremove.td [new file with mode: 0644]
llvm/utils/kate/llvm-tablegen.xml

index 3ced39e..16d24a9 100644 (file)
@@ -223,10 +223,10 @@ TableGen provides "bang operators" that have a wide variety of uses:
                : !div        !empty       !eq          !filter      !find
                : !foldl      !foreach     !ge          !getdagop    !gt
                : !head       !if          !interleave  !isa         !le
-               : !listconcat !listsplat   !logtwo      !lt          !mul
-               : !ne         !not         !or          !setdagop    !shl
-               : !size       !sra         !srl         !strconcat   !sub
-               : !subst      !substr      !tail        !xor
+               : !listconcat !listremove  !listsplat   !logtwo      !lt
+               : !mul        !ne          !not         !or          !setdagop
+               : !shl        !size        !sra         !srl         !strconcat
+               : !sub        !subst       !substr      !tail        !xor
 
 The ``!cond`` operator has a slightly different
 syntax compared to other bang operators, so it is defined separately:
@@ -1740,6 +1740,10 @@ and non-0 as true.
     This operator concatenates the list arguments *list1*, *list2*, etc., and
     produces the resulting list. The lists must have the same element type.
 
+``!listremove(``\ *list1*\ ``,`` *list2*\ ``)``
+    This operator returns a copy of *list1* removing all elements that also occur in
+    *list2*. The lists must have the same element type.
+
 ``!listsplat(``\ *value*\ ``,`` *count*\ ``)``
     This operator produces a list of length *count* whose elements are all
     equal to the *value*. For example, ``!listsplat(42, 3)`` results in
index d89dbfc..d49c1b0 100644 (file)
@@ -847,6 +847,7 @@ public:
     SRL,
     LISTCONCAT,
     LISTSPLAT,
+    LISTREMOVE,
     STRCONCAT,
     INTERLEAVE,
     CONCAT,
@@ -900,6 +901,8 @@ public:
   Init *getLHS() const { return LHS; }
   Init *getRHS() const { return RHS; }
 
+  std::optional<bool> CompareInit(unsigned Opc, Init *LHS, Init *RHS) const;
+
   // Fold - If possible, fold this to a simpler init.  Return this if not
   // possible to fold.
   Init *Fold(Record *CurRec) const;
index af67eba..112ae08 100644 (file)
@@ -1037,7 +1037,83 @@ Init *BinOpInit::getListConcat(TypedInit *LHS, Init *RHS) {
    return BinOpInit::get(BinOpInit::LISTCONCAT, LHS, RHS, LHS->getType());
 }
 
-Init *BinOpInit::Fold(Record *CurRec) const {
+std::optional<bool> BinOpInit::CompareInit(unsigned Opc, Init *LHS, Init *RHS) const {
+   // First see if we have two bit, bits, or int.
+   IntInit *LHSi = dyn_cast_or_null<IntInit>(
+       LHS->convertInitializerTo(IntRecTy::get(getRecordKeeper())));
+   IntInit *RHSi = dyn_cast_or_null<IntInit>(
+       RHS->convertInitializerTo(IntRecTy::get(getRecordKeeper())));
+
+   if (LHSi && RHSi) {
+     bool Result;
+     switch (Opc) {
+     case EQ:
+       Result = LHSi->getValue() == RHSi->getValue();
+       break;
+     case NE:
+       Result = LHSi->getValue() != RHSi->getValue();
+       break;
+     case LE:
+       Result = LHSi->getValue() <= RHSi->getValue();
+       break;
+     case LT:
+       Result = LHSi->getValue() < RHSi->getValue();
+       break;
+     case GE:
+       Result = LHSi->getValue() >= RHSi->getValue();
+       break;
+     case GT:
+       Result = LHSi->getValue() > RHSi->getValue();
+       break;
+     default:
+       llvm_unreachable("unhandled comparison");
+     }
+     return Result;
+   }
+
+   // Next try strings.
+   StringInit *LHSs = dyn_cast<StringInit>(LHS);
+   StringInit *RHSs = dyn_cast<StringInit>(RHS);
+
+   if (LHSs && RHSs) {
+     bool Result;
+     switch (Opc) {
+     case EQ:
+       Result = LHSs->getValue() == RHSs->getValue();
+       break;
+     case NE:
+       Result = LHSs->getValue() != RHSs->getValue();
+       break;
+     case LE:
+       Result = LHSs->getValue() <= RHSs->getValue();
+       break;
+     case LT:
+       Result = LHSs->getValue() < RHSs->getValue();
+       break;
+     case GE:
+       Result = LHSs->getValue() >= RHSs->getValue();
+       break;
+     case GT:
+       Result = LHSs->getValue() > RHSs->getValue();
+       break;
+     default:
+       llvm_unreachable("unhandled comparison");
+     }
+     return Result;
+   }
+
+   // Finally, !eq and !ne can be used with records.
+   if (Opc == EQ || Opc == NE) {
+     DefInit *LHSd = dyn_cast<DefInit>(LHS);
+     DefInit *RHSd = dyn_cast<DefInit>(RHS);
+     if (LHSd && RHSd)
+       return (Opc == EQ) ? LHSd == RHSd : LHSd != RHSd;
+   }
+
+   return std::nullopt;
+}
+
+ Init *BinOpInit::Fold(Record *CurRec) const {
   switch (getOpcode()) {
   case CONCAT: {
     DagInit *LHSs = dyn_cast<DagInit>(LHS);
@@ -1091,6 +1167,28 @@ Init *BinOpInit::Fold(Record *CurRec) const {
     }
     break;
   }
+  case LISTREMOVE: {
+    ListInit *LHSs = dyn_cast<ListInit>(LHS);
+    ListInit *RHSs = dyn_cast<ListInit>(RHS);
+    if (LHSs && RHSs) {
+      SmallVector<Init *, 8> Args;
+      for (Init *EltLHS : *LHSs) {
+        bool Found = false;
+        for (Init *EltRHS : *RHSs) {
+          if (std::optional<bool> Result = CompareInit(EQ, EltLHS, EltRHS)) {
+            if (*Result) {
+              Found = true;
+              break;
+            }
+          }
+        }
+        if (!Found)
+          Args.push_back(EltLHS);
+      }
+      return ListInit::get(Args, LHSs->getElementType());
+    }
+    break;
+  }
   case STRCONCAT: {
     StringInit *LHSs = dyn_cast<StringInit>(LHS);
     StringInit *RHSs = dyn_cast<StringInit>(RHS);
@@ -1118,53 +1216,8 @@ Init *BinOpInit::Fold(Record *CurRec) const {
   case LT:
   case GE:
   case GT: {
-    // First see if we have two bit, bits, or int.
-    IntInit *LHSi = dyn_cast_or_null<IntInit>(
-        LHS->convertInitializerTo(IntRecTy::get(getRecordKeeper())));
-    IntInit *RHSi = dyn_cast_or_null<IntInit>(
-        RHS->convertInitializerTo(IntRecTy::get(getRecordKeeper())));
-
-    if (LHSi && RHSi) {
-      bool Result;
-      switch (getOpcode()) {
-      case EQ: Result = LHSi->getValue() == RHSi->getValue(); break;
-      case NE: Result = LHSi->getValue() != RHSi->getValue(); break;
-      case LE: Result = LHSi->getValue() <= RHSi->getValue(); break;
-      case LT: Result = LHSi->getValue() <  RHSi->getValue(); break;
-      case GE: Result = LHSi->getValue() >= RHSi->getValue(); break;
-      case GT: Result = LHSi->getValue() >  RHSi->getValue(); break;
-      default: llvm_unreachable("unhandled comparison");
-      }
-      return BitInit::get(getRecordKeeper(), Result);
-    }
-
-    // Next try strings.
-    StringInit *LHSs = dyn_cast<StringInit>(LHS);
-    StringInit *RHSs = dyn_cast<StringInit>(RHS);
-
-    if (LHSs && RHSs) {
-      bool Result;
-      switch (getOpcode()) {
-      case EQ: Result = LHSs->getValue() == RHSs->getValue(); break;
-      case NE: Result = LHSs->getValue() != RHSs->getValue(); break;
-      case LE: Result = LHSs->getValue() <= RHSs->getValue(); break;
-      case LT: Result = LHSs->getValue() <  RHSs->getValue(); break;
-      case GE: Result = LHSs->getValue() >= RHSs->getValue(); break;
-      case GT: Result = LHSs->getValue() >  RHSs->getValue(); break;
-      default: llvm_unreachable("unhandled comparison");
-      }
-      return BitInit::get(getRecordKeeper(), Result);
-    }
-
-    // Finally, !eq and !ne can be used with records.
-    if (getOpcode() == EQ || getOpcode() == NE) {
-      DefInit *LHSd = dyn_cast<DefInit>(LHS);
-      DefInit *RHSd = dyn_cast<DefInit>(RHS);
-      if (LHSd && RHSd)
-        return BitInit::get(getRecordKeeper(),
-                            (getOpcode() == EQ) ? LHSd == RHSd : LHSd != RHSd);
-    }
-
+    if (std::optional<bool> Result = CompareInit(getOpcode(), LHS, RHS))
+      return BitInit::get(getRecordKeeper(), *Result);
     break;
   }
   case SETDAGOP: {
@@ -1260,6 +1313,7 @@ std::string BinOpInit::getAsString() const {
   case GT: Result = "!gt"; break;
   case LISTCONCAT: Result = "!listconcat"; break;
   case LISTSPLAT: Result = "!listsplat"; break;
+  case LISTREMOVE: Result = "!listremove"; break;
   case STRCONCAT: Result = "!strconcat"; break;
   case INTERLEAVE: Result = "!interleave"; break;
   case SETDAGOP: Result = "!setdagop"; break;
index 34a04e4..f2148b4 100644 (file)
@@ -584,6 +584,7 @@ tgtok::TokKind TGLexer::LexExclaim() {
     .Case("filter", tgtok::XFilter)
     .Case("listconcat", tgtok::XListConcat)
     .Case("listsplat", tgtok::XListSplat)
+    .Case("listremove", tgtok::XListRemove)
     .Case("strconcat", tgtok::XStrConcat)
     .Case("interleave", tgtok::XInterleave)
     .Case("substr", tgtok::XSubstr)
index cf1746f..284f1ba 100644 (file)
@@ -56,7 +56,7 @@ namespace tgtok {
     XSHL, XListConcat, XListSplat, XStrConcat, XInterleave, XSubstr, XFind,
     XCast, XSubst, XForEach, XFilter, XFoldl, XHead, XTail, XSize, XEmpty, XIf,
     XCond, XEq, XIsA, XDag, XNe, XLe, XLt, XGe, XGt, XSetDagOp, XGetDagOp,
-    XExists,
+    XExists, XListRemove,
 
     // Boolean literals.
     TrueVal, FalseVal,
index ed20ffd..7fc46a8 100644 (file)
@@ -1179,6 +1179,7 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
   case tgtok::XGt:
   case tgtok::XListConcat:
   case tgtok::XListSplat:
+  case tgtok::XListRemove:
   case tgtok::XStrConcat:
   case tgtok::XInterleave:
   case tgtok::XSetDagOp: { // Value ::= !binop '(' Value ',' Value ')'
@@ -1208,6 +1209,7 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
     case tgtok::XGt:     Code = BinOpInit::GT; break;
     case tgtok::XListConcat: Code = BinOpInit::LISTCONCAT; break;
     case tgtok::XListSplat:  Code = BinOpInit::LISTSPLAT; break;
+    case tgtok::XListRemove: Code = BinOpInit::LISTREMOVE; break;
     case tgtok::XStrConcat:  Code = BinOpInit::STRCONCAT; break;
     case tgtok::XInterleave: Code = BinOpInit::INTERLEAVE; break;
     case tgtok::XSetDagOp:   Code = BinOpInit::SETDAGOP; break;
@@ -1246,12 +1248,16 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
       // ArgType for the comparison operators is not yet known.
       break;
     case tgtok::XListConcat:
-      // We don't know the list type until we parse the first argument
+      // We don't know the list type until we parse the first argument.
       ArgType = ItemType;
       break;
     case tgtok::XListSplat:
       // Can't do any typechecking until we parse the first argument.
       break;
+    case tgtok::XListRemove:
+      // We don't know the list type until we parse the first argument.
+      ArgType = ItemType;
+      break;
     case tgtok::XStrConcat:
       Type = StringRecTy::get(Records);
       ArgType = StringRecTy::get(Records);
@@ -1329,6 +1335,13 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
           }
           ArgType = nullptr; // Broken invariant: types not identical.
           break;
+        case BinOpInit::LISTREMOVE:
+          if (!isa<ListRecTy>(ArgType)) {
+            Error(InitLoc, Twine("expected a list, got value of type '") +
+                               ArgType->getAsString() + "'");
+            return nullptr;
+          }
+          break;
         case BinOpInit::EQ:
         case BinOpInit::NE:
           if (!ArgType->typeIsConvertibleTo(IntRecTy::get(Records)) &&
@@ -1423,6 +1436,9 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
     // listsplat returns a list of type of the *first* argument.
     if (Code == BinOpInit::LISTSPLAT)
       Type = cast<TypedInit>(InitList.front())->getType()->getListTy();
+    // listremove returns a list with type of the argument.
+    if (Code == BinOpInit::LISTREMOVE)
+      Type = ArgType;
 
     // We allow multiple operands to associative operators like !strconcat as
     // shorthand for nesting them.
@@ -2154,6 +2170,7 @@ Init *TGParser::ParseOperationCond(Record *CurRec, RecTy *ItemType) {
 ///   SimpleValue ::= SRLTOK '(' Value ',' Value ')'
 ///   SimpleValue ::= LISTCONCATTOK '(' Value ',' Value ')'
 ///   SimpleValue ::= LISTSPLATTOK '(' Value ',' Value ')'
+///   SimpleValue ::= LISTREMOVETOK '(' Value ',' Value ')'
 ///   SimpleValue ::= STRCONCATTOK '(' Value ',' Value ')'
 ///   SimpleValue ::= COND '(' [Value ':' Value,]+ ')'
 ///
@@ -2453,6 +2470,7 @@ Init *TGParser::ParseSimpleValue(Record *CurRec, RecTy *ItemType,
   case tgtok::XGt:
   case tgtok::XListConcat:
   case tgtok::XListSplat:
+  case tgtok::XListRemove:
   case tgtok::XStrConcat:
   case tgtok::XInterleave:
   case tgtok::XSetDagOp: // Value ::= !binop '(' Value ',' Value ')'
diff --git a/llvm/test/TableGen/listremove.td b/llvm/test/TableGen/listremove.td
new file mode 100644 (file)
index 0000000..d0f18ad
--- /dev/null
@@ -0,0 +1,17 @@
+// RUN: llvm-tblgen %s | FileCheck %s
+
+// CHECK: class X {
+// CHECK:   list<string> T0 = ["foo", "bar"];
+// CHECK:   list<string> T1 = ["foo", "bar"];
+// CHECK:   list<string> T2 = ["bar"];
+// CHECK:   list<string> T3 = ["foo"];
+// CHECK:   list<string> T4 = [];
+// CHECK: }
+
+class X {
+  list<string> T0 = !listremove(["foo", "bar"], []);
+  list<string> T1 = !listremove(["foo", "bar"], ["baz"]);
+  list<string> T2 = !listremove(["foo", "bar"], ["foo"]);
+  list<string> T3 = !listremove(["foo", "bar"], ["bar", "bar"]);
+  list<string> T4 = !listremove(["foo", "bar"], ["bar", "foo"]);
+}
index 2a3f040..2c0c6bc 100644 (file)
@@ -31,6 +31,7 @@
       <item> !strconcat </item>
       <item> !cast </item>
       <item> !listconcat </item>
+      <item> !listreplace </item>
       <item> !listsplat </item>
       <item> !size </item>
       <item> !foldl </item>