[mlir] EnumsGen: dissociate string form of integer enum from C++ symbol name
authorAlex Zinenko <zinenko@google.com>
Fri, 24 Jan 2020 16:51:10 +0000 (17:51 +0100)
committerAlex Zinenko <zinenko@google.com>
Thu, 30 Jan 2020 16:04:00 +0000 (17:04 +0100)
Summary:
In some cases, one may want to use different names for C++ symbol of an
enumerand from its string representation. In particular, in the LLVM dialect
for, e.g., Linkage, we would like to preserve the same enumerand names as LLVM
API and the same textual IR form as LLVM IR, yet the two are different
(CamelCase vs snake_case with additional limitations on not being a C++
keyword).

Modify EnumAttrCaseInfo in OpBase.td to include both the integer value and its
string representation. By default, this representation is the same as C++
symbol name. Introduce new IntStrAttrCaseBase that allows one to use different
names. Exercise it for LLVM Dialect Linkage attribute. Other attributes will
follow as separate changes.

Differential Revision: https://reviews.llvm.org/D73362

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/Attribute.h
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/lib/TableGen/Attribute.cpp
mlir/tools/mlir-tblgen/EnumsGen.cpp
mlir/unittests/TableGen/EnumsGenTest.cpp
mlir/unittests/TableGen/enums.td

index 987a77e..8982588 100644 (file)
@@ -62,4 +62,15 @@ class LLVM_Op<string mnemonic, list<OpTrait> traits = []> :
 class LLVM_IntrOp<string mnemonic, list<OpTrait> traits = []> :
     LLVM_Op<"intr."#mnemonic, traits>;
 
+// Case of the LLVM enum attribute backed by I64Attr with customized string
+// representation that corresponds to what is visible in the textual IR form.
+class LLVM_EnumAttrCase<string cppSym, string irSym, int val> :
+    I64EnumAttrCase<cppSym, val, irSym>;
+
+// LLVM enum attribute backed by I64Attr with string representation
+// corresponding to what is visible in the textual IR form.
+class LLVM_EnumAttr<string name, string description,
+                    list<LLVM_EnumAttrCase> cases> :
+    I64EnumAttr<name, description, cases>;
+
 #endif  // LLVMIR_OP_BASE
index 6b74ae3..42635e4 100644 (file)
@@ -494,18 +494,21 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
 // https://llvm.org/docs/LangRef.html#linkage-types. The names are equivalent to
 // visible names in the IR rather than to enum values names in llvm::GlobalValue
 // since the latter is easier to change.
-def LinkagePrivate             : I64EnumAttrCase<"Private", 0>;
-def LinkageInternal            : I64EnumAttrCase<"Internal", 1>;
-def LinkageAvailableExternally : I64EnumAttrCase<"AvailableExternally", 2>;
-def LinkageLinkonce            : I64EnumAttrCase<"Linkonce", 3>;
-def LinkageWeak                : I64EnumAttrCase<"Weak", 4>;
-def LinkageCommon              : I64EnumAttrCase<"Common", 5>;
-def LinkageAppending           : I64EnumAttrCase<"Appending", 6>;
-def LinkageExternWeak          : I64EnumAttrCase<"ExternWeak", 7>;
-def LinkageLinkonceODR         : I64EnumAttrCase<"LinkonceODR", 8>;
-def LinkageWeakODR             : I64EnumAttrCase<"WeakODR", 9>;
-def LinkageExternal            : I64EnumAttrCase<"External", 10>;
-def Linkage : I64EnumAttr<
+def LinkagePrivate             : LLVM_EnumAttrCase<"Private", "private", 0>;
+def LinkageInternal            : LLVM_EnumAttrCase<"Internal", "internal", 1>;
+def LinkageAvailableExternally : LLVM_EnumAttrCase<"AvailableExternally",
+                                                   "available_externally", 2>;
+def LinkageLinkonce            : LLVM_EnumAttrCase<"Linkonce", "linkonce", 3>;
+def LinkageWeak                : LLVM_EnumAttrCase<"Weak", "weak", 4>;
+def LinkageCommon              : LLVM_EnumAttrCase<"Common", "common", 5>;
+def LinkageAppending           : LLVM_EnumAttrCase<"Appending", "appending", 6>;
+def LinkageExternWeak          : LLVM_EnumAttrCase<"ExternWeak",
+                                                   "extern_weak", 7>;
+def LinkageLinkonceODR         : LLVM_EnumAttrCase<"LinkonceODR",
+                                                   "linkonce_odr", 8>;
+def LinkageWeakODR             : LLVM_EnumAttrCase<"WeakODR", "weak_odr", 9>;
+def LinkageExternal            : LLVM_EnumAttrCase<"External", "external", 10>;
+def Linkage : LLVM_EnumAttr<
     "Linkage",
     "LLVM linkage types",
     [LinkagePrivate, LinkageInternal, LinkageAvailableExternally,
index ecf1df7..f5855d3 100644 (file)
@@ -808,39 +808,47 @@ def UnitAttr : Attr<CPred<"$_self.isa<UnitAttr>()">, "unit attribute"> {
 // Enum attribute kinds
 
 // Additional information for an enum attribute case.
-class EnumAttrCaseInfo<string sym, int val> {
-  // The C++ enumerant symbol
+class EnumAttrCaseInfo<string sym, int intVal, string strVal> {
+  // The C++ enumerant symbol.
   string symbol = sym;
 
-  // The C++ enumerant value
+  // The C++ enumerant value.
   // If less than zero, there will be no explicit discriminator values assigned
   // to enumerators in the generated enum class.
-  int value = val;
+  int value = intVal;
+
+  // The string representation of the enumerant. May be the same as symbol.
+  string str = strVal;
 }
 
 // An enum attribute case stored with StringAttr.
 class StrEnumAttrCase<string sym, int val = -1> :
-    EnumAttrCaseInfo<sym, val>,
+    EnumAttrCaseInfo<sym, val, sym>,
     StringBasedAttr<
       CPred<"$_self.cast<StringAttr>().getValue() == \"" # sym # "\"">,
       "case " # sym>;
 
-// An enum attribute case stored with IntegerAttr.
-class IntEnumAttrCaseBase<I intType, string sym, int val> :
-    EnumAttrCaseInfo<sym, val>,
-    IntegerAttrBase<intType, "case " # sym> {
+// An enum attribute case stored with IntegerAttr, which has an integer value,
+// its representation as a string and a C++ symbol name which may be different.
+class IntEnumAttrCaseBase<I intType, string sym, string strVal, int intVal> :
+    EnumAttrCaseInfo<sym, intVal, strVal>,
+    IntegerAttrBase<intType, "case " # strVal> {
   let predicate =
-    CPred<"$_self.cast<IntegerAttr>().getInt() == " # val>;
+    CPred<"$_self.cast<IntegerAttr>().getInt() == " # intVal>;
 }
 
-class I32EnumAttrCase<string sym, int val> : IntEnumAttrCaseBase<I32, sym, val>;
-class I64EnumAttrCase<string sym, int val> : IntEnumAttrCaseBase<I64, sym, val>;
+// Cases of integer enum attributes with a specific type. By default, the string
+// representation is the same as the C++ symbol name.
+class I32EnumAttrCase<string sym, int val, string str = sym>
+    : IntEnumAttrCaseBase<I32, sym, str, val>;
+class I64EnumAttrCase<string sym, int val, string str = sym>
+    : IntEnumAttrCaseBase<I64, sym, str, val>;
 
 // A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the
 // ordinal number of the bit that is set. It is the 32-bit integer with only
 // one bit set.
 class BitEnumAttrCase<string sym, int val> :
-    EnumAttrCaseInfo<sym, val>,
+    EnumAttrCaseInfo<sym, val, sym>,
     IntegerAttrBase<I32, "case " # sym> {
   let predicate = CPred<
     "$_self.cast<IntegerAttr>().getValue().getZExtValue() & " # val # "u">;
index 06ee892..cd41109 100644 (file)
@@ -134,6 +134,9 @@ public:
   // Returns the symbol of this enum attribute case.
   StringRef getSymbol() const;
 
+  // Returns the textual representation of this enum attribute case.
+  StringRef getStr() const;
+
   // Returns the value of this enum attribute case.
   int64_t getValue() const;
 
index b0bfbfc..c92c81e 100644 (file)
@@ -1079,44 +1079,8 @@ void GlobalOp::build(Builder *builder, OperationState &result, LLVMType type,
   result.addRegion();
 }
 
-// Returns the textual representation of the given linkage.
-static StringRef linkageToStr(LLVM::Linkage linkage) {
-  switch (linkage) {
-  case LLVM::Linkage::Private:
-    return "private";
-  case LLVM::Linkage::Internal:
-    return "internal";
-  case LLVM::Linkage::AvailableExternally:
-    return "available_externally";
-  case LLVM::Linkage::Linkonce:
-    return "linkonce";
-  case LLVM::Linkage::Weak:
-    return "weak";
-  case LLVM::Linkage::Common:
-    return "common";
-  case LLVM::Linkage::Appending:
-    return "appending";
-  case LLVM::Linkage::ExternWeak:
-    return "extern_weak";
-  case LLVM::Linkage::LinkonceODR:
-    return "linkonce_odr";
-  case LLVM::Linkage::WeakODR:
-    return "weak_odr";
-  case LLVM::Linkage::External:
-    return "external";
-  }
-  llvm_unreachable("unknown linkage type");
-}
-
-// Prints the keyword for the linkage type using the printer.
-static void printLinkage(OpAsmPrinter &p, LLVM::Linkage linkage) {
-  p << linkageToStr(linkage);
-}
-
 static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) {
-  p << op.getOperationName() << ' ';
-  printLinkage(p, op.linkage());
-  p << ' ';
+  p << op.getOperationName() << ' ' << stringifyLinkage(op.linkage()) << ' ';
   if (op.constant())
     p << "constant ";
   p.printSymbolName(op.sym_name());
@@ -1150,22 +1114,30 @@ static int parseOptionalKeywordAlternative(OpAsmParser &parser,
   return -1;
 }
 
-// Parses one of the linkage keywords and, if succeeded, appends the "linkage"
-// integer attribute with the corresponding value to `result`.
-//
-// linkage ::= `private` | `internal` | `available_externally` | `linkonce`
-//           | `weak` | `common` | `appending` | `extern_weak`
-//           | `linkonce_odr` | `weak_odr` | `external
-static ParseResult parseOptionalLinkageKeyword(OpAsmParser &parser,
-                                               OperationState &result) {
-  int index = parseOptionalKeywordAlternative(
-      parser, {"private", "internal", "available_externally", "linkonce",
-               "weak", "common", "appending", "extern_weak", "linkonce_odr",
-               "weak_odr", "external"});
+namespace {
+template <typename Ty> struct EnumTraits {};
+
+#define REGISTER_ENUM_TYPE(Ty)                                                 \
+  template <> struct EnumTraits<Ty> {                                          \
+    static StringRef stringify(Ty value) { return stringify##Ty(value); }      \
+    static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); }         \
+  }
+
+REGISTER_ENUM_TYPE(Linkage);
+} // end namespace
+
+template <typename EnumTy>
+static ParseResult parseOptionalLLVMKeyword(OpAsmParser &parser,
+                                            OperationState &result,
+                                            StringRef name) {
+  SmallVector<StringRef, 10> names;
+  for (unsigned i = 0, e = getMaxEnumValForLinkage(); i <= e; ++i)
+    names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
+
+  int index = parseOptionalKeywordAlternative(parser, names);
   if (index == -1)
     return failure();
-  result.addAttribute(getLinkageAttrName(),
-                      parser.getBuilder().getI64IntegerAttr(index));
+  result.addAttribute(name, parser.getBuilder().getI64IntegerAttr(index));
   return success();
 }
 
@@ -1175,7 +1147,8 @@ static ParseResult parseOptionalLinkageKeyword(OpAsmParser &parser,
 // The type can be omitted for string attributes, in which case it will be
 // inferred from the value of the string as [strlen(value) x i8].
 static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
-  if (failed(parseOptionalLinkageKeyword(parser, result)))
+  if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
+                                               getLinkageAttrName())))
     return parser.emitError(parser.getCurrentLocation(), "expected linkage");
 
   if (succeeded(parser.parseOptionalKeyword("constant")))
@@ -1398,7 +1371,8 @@ static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc,
 static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
                                    OperationState &result) {
   // Default to external linkage if no keyword is provided.
-  if (failed(parseOptionalLinkageKeyword(parser, result)))
+  if (failed(parseOptionalLLVMKeyword<Linkage>(parser, result,
+                                               getLinkageAttrName())))
     result.addAttribute(getLinkageAttrName(),
                         parser.getBuilder().getI64IntegerAttr(
                             static_cast<int64_t>(LLVM::Linkage::External)));
@@ -1441,10 +1415,8 @@ static ParseResult parseLLVMFuncOp(OpAsmParser &parser,
 // the external linkage since it is the default value.
 static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) {
   p << op.getOperationName() << ' ';
-  if (op.linkage() != LLVM::Linkage::External) {
-    printLinkage(p, op.linkage());
-    p << ' ';
-  }
+  if (op.linkage() != LLVM::Linkage::External)
+    p << stringifyLinkage(op.linkage()) << ' ';
   p.printSymbolName(op.getName());
 
   LLVMType fnType = op.getType();
@@ -1510,16 +1482,16 @@ unsigned LLVMFuncOp::getNumFuncResults() {
 static LogicalResult verify(LLVMFuncOp op) {
   if (op.linkage() == LLVM::Linkage::Common)
     return op.emitOpError()
-           << "functions cannot have '" << linkageToStr(LLVM::Linkage::Common)
-           << "' linkage";
+           << "functions cannot have '"
+           << stringifyLinkage(LLVM::Linkage::Common) << "' linkage";
 
   if (op.isExternal()) {
     if (op.linkage() != LLVM::Linkage::External &&
         op.linkage() != LLVM::Linkage::ExternWeak)
       return op.emitOpError()
              << "external functions must have '"
-             << linkageToStr(LLVM::Linkage::External) << "' or '"
-             << linkageToStr(LLVM::Linkage::ExternWeak) << "' linkage";
+             << stringifyLinkage(LLVM::Linkage::External) << "' or '"
+             << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage";
     return success();
   }
 
index 54ff6da..6f25c84 100644 (file)
@@ -154,6 +154,10 @@ StringRef tblgen::EnumAttrCase::getSymbol() const {
   return def->getValueAsString("symbol");
 }
 
+StringRef tblgen::EnumAttrCase::getStr() const {
+  return def->getValueAsString("str");
+}
+
 int64_t tblgen::EnumAttrCase::getValue() const {
   return def->getValueAsInt("value");
 }
index c25f5fe..a58ea24 100644 (file)
@@ -165,8 +165,9 @@ static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) {
   os << "  switch (val) {\n";
   for (const auto &enumerant : enumerants) {
     auto symbol = enumerant.getSymbol();
+    auto str = enumerant.getStr();
     os << formatv("    case {0}::{1}: return \"{2}\";\n", enumName,
-                  makeIdentifier(symbol), symbol);
+                  makeIdentifier(symbol), str);
   }
   os << "  }\n";
   os << "  return \"\";\n";
@@ -219,7 +220,8 @@ static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) {
                 enumName);
   for (const auto &enumerant : enumerants) {
     auto symbol = enumerant.getSymbol();
-    os << formatv("      .Case(\"{1}\", {0}::{2})\n", enumName, symbol,
+    auto str = enumerant.getStr();
+    os << formatv("      .Case(\"{1}\", {0}::{2})\n", enumName, str,
                   makeIdentifier(symbol));
   }
   os << "      .Default(llvm::None);\n";
index 1eed177..0c3db73 100644 (file)
@@ -94,3 +94,21 @@ TEST(EnumsGenTest, GeneratedOperator) {
   EXPECT_FALSE(bitEnumContains(BitEnumWithNone::Bit1 & BitEnumWithNone::Bit3,
                                BitEnumWithNone::Bit1));
 }
+
+TEST(EnumsGenTest, GeneratedSymbolToCustomStringFn) {
+  EXPECT_EQ(stringifyPrettyIntEnum(PrettyIntEnum::Case1), "case_one");
+  EXPECT_EQ(stringifyPrettyIntEnum(PrettyIntEnum::Case2), "case_two");
+}
+
+TEST(EnumsGenTest, GeneratedCustomStringToSymbolFn) {
+  auto one = symbolizePrettyIntEnum("case_one");
+  EXPECT_TRUE(one);
+  EXPECT_EQ(*one, PrettyIntEnum::Case1);
+
+  auto two = symbolizePrettyIntEnum("case_two");
+  EXPECT_TRUE(two);
+  EXPECT_EQ(*two, PrettyIntEnum::Case2);
+
+  auto none = symbolizePrettyIntEnum("Case1");
+  EXPECT_FALSE(none);
+}
index 3a9dcc5..b2c8f6f 100644 (file)
@@ -31,3 +31,9 @@ def BitEnumWithNone : BitEnumAttr<"BitEnumWithNone", "A test enum",
 
 def BitEnumWithoutNone : BitEnumAttr<"BitEnumWithoutNone", "A test enum",
                                      [Bit1, Bit3]>;
+
+def PrettyIntEnumCase1: I32EnumAttrCase<"Case1", 1, "case_one">;
+def PrettyIntEnumCase2: I32EnumAttrCase<"Case2", 2, "case_two">;
+
+def PrettyIntEnum: I32EnumAttr<"PrettyIntEnum", "A test enum",
+                               [PrettyIntEnumCase1, PrettyIntEnumCase2]>;