[mlir] Better handling for bit groups in enum parser/printer
authorRiver Riddle <riddleriver@gmail.com>
Sat, 22 Oct 2022 20:57:15 +0000 (13:57 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Mon, 24 Oct 2022 06:59:55 +0000 (23:59 -0700)
We currently wrap all multi-bit cases with a string, but this is
overly restrictive. This commit refactors to use keywords when
we know they are valid, and only degrade to string when the validity
of the bitgroup is unknown.

Differential Revision: https://reviews.llvm.org/D136540

mlir/test/mlir-tblgen/enums-gen.td
mlir/tools/mlir-tblgen/EnumsGen.cpp

index ed1b8f5..977647c 100644 (file)
@@ -10,9 +10,12 @@ def Bit0: I32BitEnumAttrCaseBit<"Bit0", 0, "tagged">;
 def Bit1: I32BitEnumAttrCaseBit<"Bit1", 1>;
 def Bit2: I32BitEnumAttrCaseBit<"Bit2", 2>;
 def Bit3: I32BitEnumAttrCaseBit<"Bit3", 3>;
+def BitGroup: I32BitEnumAttrCaseGroup<"BitGroup", [
+  Bit0, Bit1
+]>;
 
 def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
-                           [None, Bit0, Bit1, Bit2, Bit3]> {
+                           [None, Bit0, Bit1, Bit2, Bit3, BitGroup]> {
   let genSpecializedAttr = 0;
 }
 
@@ -44,6 +47,15 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
 
 // DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyBitEnum value) {
 // DECL:   auto valueStr = stringifyEnum(value);
+// DECL:   switch (value) {
+// DECL:   case ::MyBitEnum::BitGroup:
+// DECL:     return p << valueStr;
+// DECL:   default:
+// DECL:     break;
+// DECL:   }
+// DECL:   auto underlyingValue = static_cast<std::make_unsigned_t<::MyBitEnum>>(value);
+// DECL:   if (underlyingValue && !llvm::has_single_bit(underlyingValue))
+// DECL:     return p << '"' << valueStr << '"';
 // DECL:   return p << valueStr;
 
 // DEF-LABEL: std::string stringifyMyBitEnum
index c84995e..e95d5d6 100644 (file)
@@ -80,16 +80,6 @@ static void emitParserPrinter(const EnumAttr &enumAttr, StringRef qualName,
     if (!mlir::tblgen::canFormatStringAsKeyword(caseVal.getStr()))
       nonKeywordCases.set(index);
 
-  // If this is a bit enum attribute, don't allow cases that may overlap with
-  // other cases. For simplicity sake, only allow cases with a single bit value.
-  if (enumAttr.isBitEnum()) {
-    for (auto [index, caseVal] : llvm::enumerate(cases)) {
-      int64_t value = caseVal.getValue();
-      if (value < 0 || (value != 0 && !llvm::isPowerOf2_64(value)))
-        nonKeywordCases.set(index);
-    }
-  }
-
   // Generate the parser and the start of the printer for the enum.
   const char *parsedAndPrinterStart = R"(
 namespace mlir {
@@ -137,7 +127,7 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
       if (nonKeywordCases.test(it.index()))
         continue;
       StringRef symbol = it.value().getSymbol();
-      os << llvm::formatv("    case {0}::{1}:\n", qualName,
+      os << llvm::formatv("  case {0}::{1}:\n", qualName,
                           llvm::isDigit(symbol.front()) ? ("_" + symbol)
                                                         : symbol);
     }
@@ -145,6 +135,37 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
           "  default:\n"
           "    return p << '\"' << valueStr << '\"';\n"
           "  }\n";
+
+    // If this is a bit enum, conservatively print the string form if the value
+    // is not a power of two (i.e. not a single bit case) and not a known case.
+  } else if (enumAttr.isBitEnum()) {
+    // Process the known multi-bit cases that use valid keywords.
+    llvm::SmallVector<EnumAttrCase *> validMultiBitCases;
+    for (auto [index, caseVal] : llvm::enumerate(cases)) {
+      uint64_t value = caseVal.getValue();
+      if (value && !llvm::has_single_bit(value) && !nonKeywordCases.test(index))
+        validMultiBitCases.push_back(&caseVal);
+    }
+    if (!validMultiBitCases.empty()) {
+      os << "  switch (value) {\n";
+      for (EnumAttrCase *caseVal : validMultiBitCases) {
+        StringRef symbol = caseVal->getSymbol();
+        os << llvm::formatv("  case {0}::{1}:\n", qualName,
+                            llvm::isDigit(symbol.front()) ? ("_" + symbol)
+                                                          : symbol);
+      }
+      os << "    return p << valueStr;\n"
+            "  default:\n"
+            "    break;\n"
+            "  }\n";
+    }
+
+    // All other multi-bit cases should be printed as strings.
+    os << formatv("  auto underlyingValue = "
+                  "static_cast<std::make_unsigned_t<{0}>>(value);\n",
+                  qualName);
+    os << "  if (underlyingValue && !llvm::has_single_bit(underlyingValue))\n"
+          "    return p << '\"' << valueStr << '\"';\n";
   }
   os << "  return p << valueStr;\n"
         "}\n"