[mlir] Generate parser/printers for enums
authorRiver Riddle <riddleriver@gmail.com>
Thu, 20 Oct 2022 23:31:01 +0000 (16:31 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Fri, 21 Oct 2022 22:32:36 +0000 (15:32 -0700)
This greatly simplifies composing enums in attribute/type printers,
which currently reimplement these functions as needed.

Differential Revision: https://reviews.llvm.org/D136407

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Dialect/LLVMIR/func.mlir
mlir/test/mlir-tblgen/enums-gen.td
mlir/tools/mlir-tblgen/EnumsGen.cpp
mlir/tools/mlir-tblgen/FormatGen.cpp

index 95d4a90..4c40060 100644 (file)
@@ -21,7 +21,7 @@ def LinkageAttr : LLVM_Attr<"Linkage"> {
   let parameters = (ins
     "linkage::Linkage":$linkage
   );
-  let hasCustomAssemblyFormat = 1;
+  let assemblyFormat = "`<` $linkage `>`";
 }
 
 // Attribute definition for the LLVM Linkage enum.
@@ -30,7 +30,7 @@ def CConvAttr : LLVM_Attr<"CConv"> {
   let parameters = (ins
     "CConv":$CallingConv
   );
-  let hasCustomAssemblyFormat = 1;
+  let assemblyFormat = "`<` $CallingConv `>`";
 }
 
 def LoopOptionsAttr : LLVM_Attr<"LoopOptions"> {
index 3ec212a..980fc19 100644 (file)
@@ -2797,54 +2797,6 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
          op->hasTrait<OpTrait::IsIsolatedFromAbove>();
 }
 
-void LinkageAttr::print(AsmPrinter &printer) const {
-  printer << "<";
-  if (static_cast<uint64_t>(getLinkage()) <= getMaxEnumValForLinkage())
-    printer << stringifyEnum(getLinkage());
-  else
-    printer << static_cast<uint64_t>(getLinkage());
-  printer << ">";
-}
-
-Attribute LinkageAttr::parse(AsmParser &parser, Type type) {
-  StringRef elemName;
-  if (parser.parseLess() || parser.parseKeyword(&elemName) ||
-      parser.parseGreater())
-    return {};
-  auto elem = linkage::symbolizeLinkage(elemName);
-  if (!elem) {
-    parser.emitError(parser.getNameLoc(), "Unknown linkage: ") << elemName;
-    return {};
-  }
-  Linkage linkage = *elem;
-  return LinkageAttr::get(parser.getContext(), linkage);
-}
-
-void CConvAttr::print(AsmPrinter &printer) const {
-  printer << "<";
-  if (static_cast<uint64_t>(getCallingConv()) <= cconv::getMaxEnumValForCConv())
-    printer << stringifyEnum(getCallingConv());
-  else
-    printer << "INVALID_cc_" << static_cast<uint64_t>(getCallingConv());
-  printer << ">";
-}
-
-Attribute CConvAttr::parse(AsmParser &parser, Type type) {
-  StringRef convName;
-
-  if (parser.parseLess() || parser.parseKeyword(&convName) ||
-      parser.parseGreater())
-    return {};
-  auto cconv = cconv::symbolizeCConv(convName);
-  if (!cconv) {
-    parser.emitError(parser.getNameLoc(), "unknown calling convention: ")
-        << convName;
-    return {};
-  }
-  CConv cconvVal = *cconv;
-  return CConvAttr::get(parser.getContext(), cconvVal);
-}
-
 LoopOptionsAttrBuilder::LoopOptionsAttrBuilder(LoopOptionsAttr attr)
     : options(attr.getOptions().begin(), attr.getOptions().end()) {}
 
index da46908..17cc6bf 100644 (file)
@@ -273,8 +273,9 @@ module {
 // -----
 
 module {
-  // expected-error@+2 {{unknown calling convention: cc_12}}
   "llvm.func"() ({
+  // expected-error @below {{invalid Calling Conventions specification: cc_12}}
+  // expected-error @below {{failed to parse CConvAttr parameter 'CallingConv' which is to be a `CConv`}}
   }) {sym_name = "generic_unknown_calling_convention", CConv = #llvm.cconv<cc_12>, function_type = !llvm.func<i64 (i64, i64)>} : () -> ()
 }
 
index ebe1264..ed1b8f5 100644 (file)
@@ -28,6 +28,24 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
 // DECL: std::string stringifyMyBitEnum(MyBitEnum);
 // DECL: ::llvm::Optional<MyBitEnum> symbolizeMyBitEnum(::llvm::StringRef);
 
+// DECL: struct FieldParser<::MyBitEnum, ::MyBitEnum> {
+// DECL:   template <typename ParserT>
+// DECL:   static FailureOr<::MyBitEnum> parse(ParserT &parser) {
+// DECL:     // Parse the keyword/string containing the enum.
+// DECL:     std::string enumKeyword;
+// DECL:     auto loc = parser.getCurrentLocation();
+// DECL:     if (failed(parser.parseOptionalKeywordOrString(&enumKeyword)))
+// DECL:       return parser.emitError(loc, "expected keyword for An example bit enum");
+// DECL:     // Symbolize the keyword.
+// DECL:     if (::llvm::Optional<::MyBitEnum> attr = ::symbolizeEnum<::MyBitEnum>(enumKeyword))
+// DECL:       return *attr;
+// DECL:     return parser.emitError(loc, "invalid An example bit enum specification: ") << enumKeyword;
+// DECL:   }
+
+// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyBitEnum value) {
+// DECL:   auto valueStr = stringifyEnum(value);
+// DECL:   return p << valueStr;
+
 // DEF-LABEL: std::string stringifyMyBitEnum
 // DEF: auto val = static_cast<uint32_t>
 // DEF: if (val == 0) return "None";
@@ -40,3 +58,34 @@ def MyBitEnum: I32BitEnumAttr<"MyBitEnum", "An example bit enum",
 // DEF: if (str == "None") return MyBitEnum::None;
 // DEF: .Case("tagged", 1)
 // DEF: .Case("Bit1", 2)
+
+// Test enum printer generation for non non-keyword enums.
+
+def NonKeywordBit: I32BitEnumAttrCaseBit<"Bit0", 0, "tag-ged">;
+def MyMixedNonKeywordBitEnum: I32BitEnumAttr<"MyMixedNonKeywordBitEnum", "An example bit enum", [
+    NonKeywordBit,
+    Bit1
+  ]> {
+  let genSpecializedAttr = 0;
+}
+
+def MyNonKeywordBitEnum: I32BitEnumAttr<"MyNonKeywordBitEnum", "An example bit enum", [
+    NonKeywordBit
+  ]> {
+  let genSpecializedAttr = 0;
+}
+
+// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyMixedNonKeywordBitEnum value) {
+// DECL: auto valueStr = stringifyEnum(value);
+// DECL:   switch (value) {
+// DECL:   case ::MyMixedNonKeywordBitEnum::Bit1:
+// DECL:     break;
+// DECL:   default:
+// DECL:     return p << '"' << valueStr << '"';
+// DECL:   }
+// DECL:   return p << valueStr;
+// DECL: }
+
+// DECL: inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, ::MyNonKeywordBitEnum value) {
+// DECL:   auto valueStr = stringifyEnum(value);
+// DECL:   return p << '"' << valueStr << '"';
index 60dde06..c84995e 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
+#include "FormatGen.h"
 #include "mlir/TableGen/Attribute.h"
 #include "mlir/TableGen/Format.h"
 #include "mlir/TableGen/GenInfo.h"
+#include "llvm/ADT/BitVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -65,10 +67,92 @@ static void emitEnumClass(const Record &enumDef, StringRef enumName,
   os << "};\n\n";
 }
 
-static void emitDenseMapInfo(StringRef enumName, std::string underlyingType,
+static void emitParserPrinter(const EnumAttr &enumAttr, StringRef qualName,
+                              StringRef cppNamespace, raw_ostream &os) {
+  if (enumAttr.getUnderlyingType().empty() ||
+      enumAttr.getConstBuilderTemplate().empty())
+    return;
+  auto cases = enumAttr.getAllCases();
+
+  // Check which cases shouldn't be printed using a keyword.
+  llvm::BitVector nonKeywordCases(cases.size());
+  for (auto [index, caseVal] : llvm::enumerate(cases))
+    if (!mlir::tblgen::canFormatStringAsKeyword(caseVal.getStr()))
+      nonKeywordCases.set(index);
+
+  // If this is a bit enum attribute, don't allow cases that may overlap with
+  // other cases. For simplicity sake, only allow cases with a single bit value.
+  if (enumAttr.isBitEnum()) {
+    for (auto [index, caseVal] : llvm::enumerate(cases)) {
+      int64_t value = caseVal.getValue();
+      if (value < 0 || (value != 0 && !llvm::isPowerOf2_64(value)))
+        nonKeywordCases.set(index);
+    }
+  }
+
+  // Generate the parser and the start of the printer for the enum.
+  const char *parsedAndPrinterStart = R"(
+namespace mlir {
+template <typename T, typename>
+struct FieldParser;
+
+template<>
+struct FieldParser<{0}, {0}> {{
+  template <typename ParserT>
+  static FailureOr<{0}> parse(ParserT &parser) {{
+    // Parse the keyword/string containing the enum.
+    std::string enumKeyword;
+    auto loc = parser.getCurrentLocation();
+    if (failed(parser.parseOptionalKeywordOrString(&enumKeyword)))
+      return parser.emitError(loc, "expected keyword for {2}");
+
+    // Symbolize the keyword.
+    if (::llvm::Optional<{0}> attr = {1}::symbolizeEnum<{0}>(enumKeyword))
+      return *attr;
+    return parser.emitError(loc, "invalid {2} specification: ") << enumKeyword;
+  }
+};
+} // namespace mlir
+
+namespace llvm {
+inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{
+  auto valueStr = stringifyEnum(value);
+)";
+  os << formatv(parsedAndPrinterStart, qualName, cppNamespace,
+                enumAttr.getSummary());
+
+  // If all cases require a string, always wrap.
+  if (nonKeywordCases.all()) {
+    os << "  return p << '\"' << valueStr << '\"';\n"
+          "}\n"
+          "} // namespace llvm\n";
+    return;
+  }
+
+  // If there are any cases that can't be used with a keyword, switch on the
+  // case value to determine when to print in the string form.
+  if (nonKeywordCases.any()) {
+    os << "  switch (value) {\n";
+    for (auto &it : llvm::enumerate(cases)) {
+      if (nonKeywordCases.test(it.index()))
+        continue;
+      StringRef symbol = it.value().getSymbol();
+      os << llvm::formatv("    case {0}::{1}:\n", qualName,
+                          llvm::isDigit(symbol.front()) ? ("_" + symbol)
+                                                        : symbol);
+    }
+    os << "    break;\n"
+          "  default:\n"
+          "    return p << '\"' << valueStr << '\"';\n"
+          "  }\n";
+  }
+  os << "  return p << valueStr;\n"
+        "}\n"
+        "} // namespace llvm\n";
+}
+
+static void emitDenseMapInfo(StringRef qualName, std::string underlyingType,
                              StringRef cppNamespace, raw_ostream &os) {
-  std::string qualName =
-      std::string(formatv("{0}::{1}", cppNamespace, enumName));
   if (underlyingType.empty())
     underlyingType =
         std::string(formatv("std::underlying_type_t<{0}>", qualName));
@@ -529,8 +613,13 @@ public:
   for (auto ns : llvm::reverse(namespaces))
     os << "} // namespace " << ns << "\n";
 
+  // Generate a generic parser and printer for the enum.
+  std::string qualName =
+      std::string(formatv("{0}::{1}", cppNamespace, enumName));
+  emitParserPrinter(enumAttr, qualName, cppNamespace, os);
+
   // Emit DenseMapInfo for this enum class
-  emitDenseMapInfo(enumName, underlyingType, cppNamespace, os);
+  emitDenseMapInfo(qualName, underlyingType, cppNamespace, os);
 }
 
 static bool emitEnumDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
index a756587..7d2e03e 100644 (file)
@@ -444,6 +444,11 @@ bool mlir::tblgen::shouldEmitSpaceBefore(StringRef value,
 
 bool mlir::tblgen::canFormatStringAsKeyword(
     StringRef value, function_ref<void(Twine)> emitError) {
+  if (value.empty()) {
+    if (emitError)
+      emitError("keywords cannot be empty");
+    return false;
+  }
   if (!isalpha(value.front()) && value.front() != '_') {
     if (emitError)
       emitError("valid keyword starts with a letter or '_'");