TableGen: Let expressions available to list subscriptions and list slices
authorNAKAMURA Takumi <geek4civic@gmail.com>
Sat, 4 Mar 2023 10:53:52 +0000 (19:53 +0900)
committerNAKAMURA Takumi <geek4civic@gmail.com>
Wed, 26 Apr 2023 14:47:16 +0000 (23:47 +0900)
This enables indexing in `!foreach` and permutation with `list[permlist]`.

Enhancements in syntax:

  - `list<int>` is applicable as a slice element.
  - `list[int,]` is evaluated as not `ElemType` but `list<ElemType>`
    with a single element.

Part of D145872

FIXME: I didn't apply new semantics to BitSlice.

llvm/docs/TableGen/ProgRef.rst
llvm/include/llvm/TableGen/Record.h
llvm/lib/TableGen/Record.cpp
llvm/lib/TableGen/TGParser.cpp
llvm/lib/TableGen/TGParser.h
llvm/test/TableGen/ListSlices-fail.td
llvm/test/TableGen/ListSlices.td

index ba3fa29..b4b35c4 100644 (file)
@@ -335,19 +335,25 @@ to an entity of type ``bits<4>``.
    Value: `SimpleValue` `ValueSuffix`*
         :| `Value` "#" [`Value`]
    ValueSuffix: "{" `RangeList` "}"
-              :| "[" `RangeList` "]"
+              :| "[" `SliceElements` "]"
               :| "." `TokIdentifier`
    RangeList: `RangePiece` ("," `RangePiece`)*
    RangePiece: `TokInteger`
              :| `TokInteger` "..." `TokInteger`
              :| `TokInteger` "-" `TokInteger`
              :| `TokInteger` `TokInteger`
+   SliceElements: (`SliceElement` ",")* `SliceElement` ","?
+   SliceElement: `Value`
+               :| `Value` "..." `Value`
+               :| `Value` "-" `Value`
+               :| `Value` `TokInteger`
 
 .. warning::
-  The peculiar last form of :token:`RangePiece` is due to the fact that the
-  "``-``" is included in the :token:`TokInteger`, hence ``1-5`` gets lexed as
-  two consecutive tokens, with values ``1`` and ``-5``, instead of "1", "-",
-  and "5". The use of hyphen as the range punctuation is deprecated.
+  The peculiar last form of :token:`RangePiece` and :token:`SliceElement` is
+  due to the fact that the "``-``" is included in the :token:`TokInteger`,
+  hence ``1-5`` gets lexed as two consecutive tokens, with values ``1`` and
+  ``-5``, instead of "1", "-", and "5".
+  The use of hyphen as the range punctuation is deprecated.
 
 Simple values
 -------------
@@ -505,17 +511,26 @@ primary value. Here are the possible suffixes for some primary *value*.
     The final value is bits 8--15 of the integer *value*. The order of the
     bits can be reversed by specifying ``{15...8}``.
 
-*value*\ ``[4]``
-    The final value is element 4 of the list *value* (note the brackets).
+*value*\ ``[i]``
+    The final value is element `i` of the list *value* (note the brackets).
     In other words, the brackets act as a subscripting operator on the list.
     This is the case only when a single element is specified.
 
+*value*\ ``[i,]``
+    The final value is a list that contains a single element `i` of the list.
+    In short, a list slice with a single element.
+
 *value*\ ``[4...7,17,2...3,4]``
     The final value is a new list that is a slice of the list *value*.
     The new list contains elements 4, 5, 6, 7, 17, 2, 3, and 4.
     Elements may be included multiple times and in any order. This is the result
     only when more than one element is specified.
 
+    *value*\ ``[i,m...n,j,ls]``
+        Each element may be an expression (variables, bang operators).
+        The type of `m` and `n` should be `int`.
+        The type of `i`, `j`, and `ls` should be either `int` or `list<int>`.
+
 *value*\ ``.``\ *field*
     The final value is the value of the specified *field* in the specified
     record *value*.
index 6ffb3a7..e44bb75 100644 (file)
@@ -857,7 +857,10 @@ public:
     LISTCONCAT,
     LISTSPLAT,
     LISTREMOVE,
+    LISTELEM,
+    LISTSLICE,
     RANGE,
+    RANGEC,
     STRCONCAT,
     INTERLEAVE,
     CONCAT,
index 200233c..64fc006 100644 (file)
@@ -1200,7 +1200,36 @@ std::optional<bool> BinOpInit::CompareInit(unsigned Opc, Init *LHS, Init *RHS) c
     }
     break;
   }
-  case RANGE: {
+  case LISTELEM: {
+    auto *TheList = dyn_cast<ListInit>(LHS);
+    auto *Idx = dyn_cast<IntInit>(RHS);
+    if (!TheList || !Idx)
+      break;
+    auto i = Idx->getValue();
+    if (i < 0 || i >= (ssize_t)TheList->size())
+      break;
+    return TheList->getElement(i);
+  }
+  case LISTSLICE: {
+    auto *TheList = dyn_cast<ListInit>(LHS);
+    auto *SliceIdxs = dyn_cast<ListInit>(RHS);
+    if (!TheList || !SliceIdxs)
+      break;
+    SmallVector<Init *, 8> Args;
+    Args.reserve(SliceIdxs->size());
+    for (auto *I : *SliceIdxs) {
+      auto *II = dyn_cast<IntInit>(I);
+      if (!II)
+        goto unresolved;
+      auto i = II->getValue();
+      if (i < 0 || i >= (ssize_t)TheList->size())
+        goto unresolved;
+      Args.push_back(TheList->getElement(i));
+    }
+    return ListInit::get(Args, TheList->getElementType());
+  }
+  case RANGE:
+  case RANGEC: {
     auto *LHSi = dyn_cast<IntInit>(LHS);
     auto *RHSi = dyn_cast<IntInit>(RHS);
     if (!LHSi || !RHSi)
@@ -1209,7 +1238,20 @@ std::optional<bool> BinOpInit::CompareInit(unsigned Opc, Init *LHS, Init *RHS) c
     auto Start = LHSi->getValue();
     auto End = RHSi->getValue();
     SmallVector<Init *, 8> Args;
-    if (Start < End) {
+    if (getOpcode() == RANGEC) {
+      // Closed interval
+      if (Start <= End) {
+        // Ascending order
+        Args.reserve(End - Start + 1);
+        for (auto i = Start; i <= End; ++i)
+          Args.push_back(IntInit::get(getRecordKeeper(), i));
+      } else {
+        // Descending order
+        Args.reserve(Start - End + 1);
+        for (auto i = Start; i >= End; --i)
+          Args.push_back(IntInit::get(getRecordKeeper(), i));
+      }
+    } else if (Start < End) {
       // Half-open interval (excludes `End`)
       Args.reserve(End - Start);
       for (auto i = Start; i < End; ++i)
@@ -1308,6 +1350,7 @@ std::optional<bool> BinOpInit::CompareInit(unsigned Opc, Init *LHS, Init *RHS) c
     break;
   }
   }
+unresolved:
   return const_cast<BinOpInit *>(this);
 }
 
@@ -1324,6 +1367,11 @@ Init *BinOpInit::resolveReferences(Resolver &R) const {
 std::string BinOpInit::getAsString() const {
   std::string Result;
   switch (getOpcode()) {
+  case LISTELEM:
+  case LISTSLICE:
+    return LHS->getAsString() + "[" + RHS->getAsString() + "]";
+  case RANGEC:
+    return LHS->getAsString() + "..." + RHS->getAsString();
   case CONCAT: Result = "!con"; break;
   case ADD: Result = "!add"; break;
   case SUB: Result = "!sub"; break;
index 0603195..a67066b 100644 (file)
@@ -709,6 +709,148 @@ ParseSubMultiClassReference(MultiClass *CurMC) {
   return Result;
 }
 
+/// ParseSliceElement - Parse subscript or range
+///
+///  SliceElement  ::= Value<list<int>>
+///  SliceElement  ::= Value<int>
+///  SliceElement  ::= Value<int> '...' Value<int>
+///  SliceElement  ::= Value<int> '-' Value<int> (deprecated)
+///  SliceElement  ::= Value<int> INTVAL(Negative; deprecated)
+///
+/// SliceElement is either IntRecTy, ListRecTy, or nullptr
+///
+TypedInit *TGParser::ParseSliceElement(Record *CurRec) {
+  auto LHSLoc = Lex.getLoc();
+  auto *CurVal = ParseValue(CurRec);
+  if (!CurVal)
+    return nullptr;
+  auto *LHS = cast<TypedInit>(CurVal);
+
+  TypedInit *RHS = nullptr;
+  switch (Lex.getCode()) {
+  case tgtok::dotdotdot:
+  case tgtok::minus: { // Deprecated
+    Lex.Lex();         // eat
+    auto RHSLoc = Lex.getLoc();
+    CurVal = ParseValue(CurRec);
+    if (!CurVal)
+      return nullptr;
+    RHS = cast<TypedInit>(CurVal);
+    if (!isa<IntRecTy>(RHS->getType())) {
+      Error(RHSLoc,
+            "expected int...int, got " + Twine(RHS->getType()->getAsString()));
+      return nullptr;
+    }
+    break;
+  }
+  case tgtok::IntVal: { // Deprecated "-num"
+    auto *RHSi = IntInit::get(Records, -Lex.getCurIntVal());
+    if (RHSi->getValue() < 0) {
+      TokError("invalid range, cannot be negative");
+      return nullptr;
+    }
+    RHS = RHSi;
+    Lex.Lex(); // eat IntVal
+    break;
+  }
+  default: // Single value (IntRecTy or ListRecTy)
+    return LHS;
+  }
+
+  assert(RHS);
+  assert(isa<IntRecTy>(RHS->getType()));
+
+  // Closed-interval range <LHS:IntRecTy>...<RHS:IntRecTy>
+  if (!isa<IntRecTy>(LHS->getType())) {
+    Error(LHSLoc,
+          "expected int...int, got " + Twine(LHS->getType()->getAsString()));
+    return nullptr;
+  }
+
+  return cast<TypedInit>(BinOpInit::get(BinOpInit::RANGEC, LHS, RHS,
+                                        IntRecTy::get(Records)->getListTy())
+                             ->Fold(CurRec));
+}
+
+/// ParseSliceElements - Parse subscripts in square brackets.
+///
+///  SliceElements ::= ( SliceElement ',' )* SliceElement ','?
+///
+/// SliceElement is either IntRecTy, ListRecTy, or nullptr
+///
+/// Returns ListRecTy by defaut.
+/// Returns IntRecTy if;
+///  - Single=true
+///  - SliceElements is Value<int> w/o trailing comma
+///
+TypedInit *TGParser::ParseSliceElements(Record *CurRec, bool Single) {
+  TypedInit *CurVal;
+  SmallVector<Init *, 2> Elems;       // int
+  SmallVector<TypedInit *, 2> Slices; // list<int>
+
+  auto FlushElems = [&] {
+    if (!Elems.empty()) {
+      Slices.push_back(ListInit::get(Elems, IntRecTy::get(Records)));
+      Elems.clear();
+    }
+  };
+
+  do {
+    auto LHSLoc = Lex.getLoc();
+    CurVal = ParseSliceElement(CurRec);
+    if (!CurVal)
+      return nullptr;
+    auto *CurValTy = CurVal->getType();
+
+    if (auto *ListValTy = dyn_cast<ListRecTy>(CurValTy)) {
+      if (!isa<IntRecTy>(ListValTy->getElementType())) {
+        Error(LHSLoc,
+              "expected list<int>, got " + Twine(ListValTy->getAsString()));
+        return nullptr;
+      }
+
+      FlushElems();
+      Slices.push_back(CurVal);
+      Single = false;
+      CurVal = nullptr;
+    } else if (!isa<IntRecTy>(CurValTy)) {
+      Error(LHSLoc,
+            "unhandled type " + Twine(CurValTy->getAsString()) + " in range");
+      return nullptr;
+    }
+
+    if (Lex.getCode() != tgtok::comma)
+      break;
+
+    Lex.Lex(); // eat comma
+
+    // `[i,]` is not LISTELEM but LISTSLICE
+    Single = false;
+    if (CurVal)
+      Elems.push_back(CurVal);
+    CurVal = nullptr;
+  } while (Lex.getCode() != tgtok::r_square);
+
+  if (CurVal) {
+    // LISTELEM
+    if (Single)
+      return CurVal;
+
+    Elems.push_back(CurVal);
+  }
+
+  FlushElems();
+
+  // Concatenate lists in Slices
+  TypedInit *Result = nullptr;
+  for (auto *Slice : Slices) {
+    Result = (Result ? cast<TypedInit>(BinOpInit::getListConcat(Result, Slice))
+                     : Slice);
+  }
+
+  return Result;
+}
+
 /// ParseRangePiece - Parse a bit/value range.
 ///   RangePiece ::= INTVAL
 ///   RangePiece ::= INTVAL '...' INTVAL
@@ -2593,10 +2735,11 @@ Init *TGParser::ParseSimpleValue(Record *CurRec, RecTy *ItemType,
 ///
 ///   Value       ::= SimpleValue ValueSuffix*
 ///   ValueSuffix ::= '{' BitList '}'
-///   ValueSuffix ::= '[' BitList ']'
+///   ValueSuffix ::= '[' SliceElements ']'
 ///   ValueSuffix ::= '.' ID
 ///
 Init *TGParser::ParseValue(Record *CurRec, RecTy *ItemType, IDParseMode Mode) {
+  SMLoc LHSLoc = Lex.getLoc();
   Init *Result = ParseSimpleValue(CurRec, ItemType, Mode);
   if (!Result) return nullptr;
 
@@ -2631,19 +2774,35 @@ Init *TGParser::ParseValue(Record *CurRec, RecTy *ItemType, IDParseMode Mode) {
       break;
     }
     case tgtok::l_square: {
-      SMLoc SquareLoc = Lex.getLoc();
-      Lex.Lex(); // eat the '['
-      SmallVector<unsigned, 16> Ranges;
-      ParseRangeList(Ranges);
-      if (Ranges.empty())
+      auto *LHS = dyn_cast<TypedInit>(Result);
+      if (!LHS) {
+        Error(LHSLoc, "Invalid value, list expected");
         return nullptr;
+      }
 
-      Result = Result->convertInitListSlice(Ranges);
-      if (!Result) {
-        Error(SquareLoc, "Invalid range for list slice");
+      auto *LHSTy = dyn_cast<ListRecTy>(LHS->getType());
+      if (!LHSTy) {
+        Error(LHSLoc, "Type '" + Twine(LHS->getType()->getAsString()) +
+                          "' is invalid, list expected");
+        return nullptr;
+      }
+
+      Lex.Lex(); // eat the '['
+      TypedInit *RHS = ParseSliceElements(CurRec, /*Single=*/true);
+      if (!RHS)
         return nullptr;
+
+      if (isa<ListRecTy>(RHS->getType())) {
+        Result =
+            BinOpInit::get(BinOpInit::LISTSLICE, LHS, RHS, LHSTy)->Fold(CurRec);
+      } else {
+        Result = BinOpInit::get(BinOpInit::LISTELEM, LHS, RHS,
+                                LHSTy->getElementType())
+                     ->Fold(CurRec);
       }
 
+      assert(Result);
+
       // Eat the ']'.
       if (!consume(tgtok::r_square)) {
         TokError("expected ']' at end of list slice");
index 63a8db0..4569b22 100644 (file)
@@ -265,6 +265,8 @@ private:  // Parser methods.
       Record *CurRec);
   bool ParseOptionalRangeList(SmallVectorImpl<unsigned> &Ranges);
   bool ParseOptionalBitList(SmallVectorImpl<unsigned> &Ranges);
+  TypedInit *ParseSliceElement(Record *CurRec);
+  TypedInit *ParseSliceElements(Record *CurRec, bool Single = false);
   void ParseRangeList(SmallVectorImpl<unsigned> &Result);
   bool ParseRangePiece(SmallVectorImpl<unsigned> &Ranges,
                        TypedInit *FirstItem = nullptr);
index 7eaa284..9e623f5 100644 (file)
@@ -19,37 +19,37 @@ defvar errs = list_str [ , ] ;
 
 #ifdef ERR2
 // RUN: not llvm-tblgen %s -DERR2 2>&1 | FileCheck -DFILE=%s %s --check-prefix=ERR2
-// ERR2: [[FILE]]:[[@LINE+1]]:35: error: expected integer or bitrange
+// ERR2: [[FILE]]:[[@LINE+1]]:26: error: expected list<int>, got list<string>
 defvar errs = list_str [ list_str ] ;
 #endif
 
 #ifdef ERR3
 // RUN: not llvm-tblgen %s -DERR3 2>&1 | FileCheck -DFILE=%s %s --check-prefix=ERR3
-// ERR3: [[FILE]]:[[@LINE+1]]:35: error: expected integer or bitrange
+// ERR3: [[FILE]]:[[@LINE+1]]:26: error: expected int...int, got list<string>
 defvar errs = list_str [ list_str ... 42 ] ;
 #endif
 
 #ifdef ERR4
 // RUN: not llvm-tblgen %s -DERR4 2>&1 | FileCheck -DFILE=%s %s --check-prefix=ERR4
-// ERR4: [[FILE]]:[[@LINE+1]]:41: error: expected integer value as end of range
+// ERR4: [[FILE]]:[[@LINE+1]]:32: error: expected int...int, got list<string>
 defvar errs = list_str [ 0 ... list_str ] ;
 #endif
 
 #ifdef ERR5
 // RUN: not llvm-tblgen %s -DERR5 2>&1 | FileCheck -DFILE=%s %s --check-prefix=ERR5
-// ERR5: [[FILE]]:[[@LINE+1]]:30: error: expected integer or bitrange
+// ERR5: [[FILE]]:[[@LINE+1]]:26: error: unhandled type string in range
 defvar errs = list_str [ str ] ;
 #endif
 
 #ifdef ERR6
 // RUN: not llvm-tblgen %s -DERR6 2>&1 | FileCheck -DFILE=%s %s --check-prefix=ERR6
-// ERR6: [[FILE]]:[[@LINE+1]]:30: error: invalid range, cannot be negative
+// ERR6: [[FILE]]:[[@LINE+1]]:28: error: invalid range, cannot be negative
 defvar errs = list_str [ 5 1 ] ;
 #endif
 
 #ifdef ERR7
 // RUN: not llvm-tblgen %s -DERR7 2>&1 | FileCheck -DFILE=%s %s --check-prefix=ERR7
-// ERR7: [[FILE]]:[[@LINE+1]]:19: error: Invalid range for list slice
+// ERR7: [[FILE]]:[[@LINE+1]]:15: error: Type 'string' is invalid, list expected
 defvar errs = str [ 0 ] ;
 #endif
 
@@ -67,6 +67,6 @@ defvar errs = list_int [ 0 ... ] ;
 
 #ifdef ERRA
 // RUN: not llvm-tblgen %s -DERRA 2>&1 | FileCheck -DFILE=%s %s --check-prefix=ERRA
-// ERRA: [[FILE]]:[[@LINE+1]]:19: error: Invalid range for list slice
+// ERRA: [[FILE]]:[[@LINE+1]]:15: error: Invalid value, list expected
 defvar errs = und [ 0 ] ;
 #endif
index 9e80fb2..0b69a35 100644 (file)
@@ -123,3 +123,42 @@ def Rec10 {
   int Zero = Class1<[?, ?, 2, 3, ?, 5, ?]>.Zero;
   list<int> TwoFive = Class1<[?, ?, 2, 3, ?, 5, ?]>.TwoFive;
 }
+
+// Test list[list] and list[int]
+// CHECK: def Rec11
+def Rec11 {
+  list<int> s5 = Var1[0...4];
+
+  // list[expr]
+  // CHECK:   list<int> rev = [4, 3, 2, 1, 0];
+  list<int> rev = !foreach(i, s5, Var1[!sub(4, i)]);
+
+  // Slice by list[foreach]
+  // CHECK:   list<int> revf = [4, 3, 2, 1, 0];
+  list<int> revf = Var1[!foreach(i, s5, !sub(4, i))];
+
+  // Simple slice
+  // CHECK:   list<int> rr = [0, 1, 2, 3, 4];
+  list<int> rr  = rev[rev];
+
+  // Trailing comma is acceptable
+  // CHECK:   list<int> rr_ = [0, 1, 2, 3, 4];
+  list<int> rr_ = rev[rev,];
+
+  // Concatenation in slice
+  // CHECK:   list<int> rrr = [1, 2, 4, 3, 2, 1, 0, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 8];
+  list<int> empty = [];
+  list<int> rrr = Var1[1, 2, rev, 3...6, 7, empty, rr, 8];
+
+  // Recognized as slice by the trailing comma
+  // CHECK:   list<list<int>> rl1 = {{\[}}[0], [1], [2], [3], [4]];
+  list<list<int>> rl1 = !foreach(i, rev, rev[i,]);
+
+  // Slice by pair<int,int>
+  // CHECK:   list<list<int>> rll = {{\[}}[0, 4], [1, 3], [2, 2], [3, 1], [4, 0]];
+  list<list<int>> rll = !foreach(i, rev, rev[i, !sub(4, i)]);
+
+  // Slice by dynamic range
+  // CHECK:   list<list<int>> rlr = {{\[}}[4, 3, 2, 1, 0], [3, 2, 1], [2], [1, 2, 3], [0, 1, 2, 3, 4]];
+  list<list<int>> rlr = !foreach(i, s5, rev[i...!sub(4, i)]);
+}