From fdc496a3d30d2d82814965a6aa987b7ef0b136ef Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 24 Jan 2020 17:51:10 +0100 Subject: [PATCH] [mlir] EnumsGen: dissociate string form of integer enum from C++ symbol name Summary: In some cases, one may want to use different names for C++ symbol of an enumerand from its string representation. In particular, in the LLVM dialect for, e.g., Linkage, we would like to preserve the same enumerand names as LLVM API and the same textual IR form as LLVM IR, yet the two are different (CamelCase vs snake_case with additional limitations on not being a C++ keyword). Modify EnumAttrCaseInfo in OpBase.td to include both the integer value and its string representation. By default, this representation is the same as C++ symbol name. Introduce new IntStrAttrCaseBase that allows one to use different names. Exercise it for LLVM Dialect Linkage attribute. Other attributes will follow as separate changes. Differential Revision: https://reviews.llvm.org/D73362 --- mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td | 11 +++ mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 27 ++++---- mlir/include/mlir/IR/OpBase.td | 34 ++++++---- mlir/include/mlir/TableGen/Attribute.h | 3 + mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 94 +++++++++----------------- mlir/lib/TableGen/Attribute.cpp | 4 ++ mlir/tools/mlir-tblgen/EnumsGen.cpp | 6 +- mlir/unittests/TableGen/EnumsGenTest.cpp | 18 +++++ mlir/unittests/TableGen/enums.td | 6 ++ 9 files changed, 115 insertions(+), 88 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index 987a77e..8982588 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -62,4 +62,15 @@ class LLVM_Op traits = []> : class LLVM_IntrOp traits = []> : LLVM_Op<"intr."#mnemonic, traits>; +// Case of the LLVM enum attribute backed by I64Attr with customized string +// representation that corresponds to what is visible in the textual IR form. +class LLVM_EnumAttrCase : + I64EnumAttrCase; + +// LLVM enum attribute backed by I64Attr with string representation +// corresponding to what is visible in the textual IR form. +class LLVM_EnumAttr cases> : + I64EnumAttr; + #endif // LLVMIR_OP_BASE diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 6b74ae3..42635e4 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -494,18 +494,21 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> { // https://llvm.org/docs/LangRef.html#linkage-types. The names are equivalent to // visible names in the IR rather than to enum values names in llvm::GlobalValue // since the latter is easier to change. -def LinkagePrivate : I64EnumAttrCase<"Private", 0>; -def LinkageInternal : I64EnumAttrCase<"Internal", 1>; -def LinkageAvailableExternally : I64EnumAttrCase<"AvailableExternally", 2>; -def LinkageLinkonce : I64EnumAttrCase<"Linkonce", 3>; -def LinkageWeak : I64EnumAttrCase<"Weak", 4>; -def LinkageCommon : I64EnumAttrCase<"Common", 5>; -def LinkageAppending : I64EnumAttrCase<"Appending", 6>; -def LinkageExternWeak : I64EnumAttrCase<"ExternWeak", 7>; -def LinkageLinkonceODR : I64EnumAttrCase<"LinkonceODR", 8>; -def LinkageWeakODR : I64EnumAttrCase<"WeakODR", 9>; -def LinkageExternal : I64EnumAttrCase<"External", 10>; -def Linkage : I64EnumAttr< +def LinkagePrivate : LLVM_EnumAttrCase<"Private", "private", 0>; +def LinkageInternal : LLVM_EnumAttrCase<"Internal", "internal", 1>; +def LinkageAvailableExternally : LLVM_EnumAttrCase<"AvailableExternally", + "available_externally", 2>; +def LinkageLinkonce : LLVM_EnumAttrCase<"Linkonce", "linkonce", 3>; +def LinkageWeak : LLVM_EnumAttrCase<"Weak", "weak", 4>; +def LinkageCommon : LLVM_EnumAttrCase<"Common", "common", 5>; +def LinkageAppending : LLVM_EnumAttrCase<"Appending", "appending", 6>; +def LinkageExternWeak : LLVM_EnumAttrCase<"ExternWeak", + "extern_weak", 7>; +def LinkageLinkonceODR : LLVM_EnumAttrCase<"LinkonceODR", + "linkonce_odr", 8>; +def LinkageWeakODR : LLVM_EnumAttrCase<"WeakODR", "weak_odr", 9>; +def LinkageExternal : LLVM_EnumAttrCase<"External", "external", 10>; +def Linkage : LLVM_EnumAttr< "Linkage", "LLVM linkage types", [LinkagePrivate, LinkageInternal, LinkageAvailableExternally, diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index ecf1df7..f5855d3 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -808,39 +808,47 @@ def UnitAttr : Attr()">, "unit attribute"> { // Enum attribute kinds // Additional information for an enum attribute case. -class EnumAttrCaseInfo { - // The C++ enumerant symbol +class EnumAttrCaseInfo { + // The C++ enumerant symbol. string symbol = sym; - // The C++ enumerant value + // The C++ enumerant value. // If less than zero, there will be no explicit discriminator values assigned // to enumerators in the generated enum class. - int value = val; + int value = intVal; + + // The string representation of the enumerant. May be the same as symbol. + string str = strVal; } // An enum attribute case stored with StringAttr. class StrEnumAttrCase : - EnumAttrCaseInfo, + EnumAttrCaseInfo, StringBasedAttr< CPred<"$_self.cast().getValue() == \"" # sym # "\"">, "case " # sym>; -// An enum attribute case stored with IntegerAttr. -class IntEnumAttrCaseBase : - EnumAttrCaseInfo, - IntegerAttrBase { +// An enum attribute case stored with IntegerAttr, which has an integer value, +// its representation as a string and a C++ symbol name which may be different. +class IntEnumAttrCaseBase : + EnumAttrCaseInfo, + IntegerAttrBase { let predicate = - CPred<"$_self.cast().getInt() == " # val>; + CPred<"$_self.cast().getInt() == " # intVal>; } -class I32EnumAttrCase : IntEnumAttrCaseBase; -class I64EnumAttrCase : IntEnumAttrCaseBase; +// Cases of integer enum attributes with a specific type. By default, the string +// representation is the same as the C++ symbol name. +class I32EnumAttrCase + : IntEnumAttrCaseBase; +class I64EnumAttrCase + : IntEnumAttrCaseBase; // A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the // ordinal number of the bit that is set. It is the 32-bit integer with only // one bit set. class BitEnumAttrCase : - EnumAttrCaseInfo, + EnumAttrCaseInfo, IntegerAttrBase { let predicate = CPred< "$_self.cast().getValue().getZExtValue() & " # val # "u">; diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index 06ee892..cd41109 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -134,6 +134,9 @@ public: // Returns the symbol of this enum attribute case. StringRef getSymbol() const; + // Returns the textual representation of this enum attribute case. + StringRef getStr() const; + // Returns the value of this enum attribute case. int64_t getValue() const; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index b0bfbfc..c92c81e 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1079,44 +1079,8 @@ void GlobalOp::build(Builder *builder, OperationState &result, LLVMType type, result.addRegion(); } -// Returns the textual representation of the given linkage. -static StringRef linkageToStr(LLVM::Linkage linkage) { - switch (linkage) { - case LLVM::Linkage::Private: - return "private"; - case LLVM::Linkage::Internal: - return "internal"; - case LLVM::Linkage::AvailableExternally: - return "available_externally"; - case LLVM::Linkage::Linkonce: - return "linkonce"; - case LLVM::Linkage::Weak: - return "weak"; - case LLVM::Linkage::Common: - return "common"; - case LLVM::Linkage::Appending: - return "appending"; - case LLVM::Linkage::ExternWeak: - return "extern_weak"; - case LLVM::Linkage::LinkonceODR: - return "linkonce_odr"; - case LLVM::Linkage::WeakODR: - return "weak_odr"; - case LLVM::Linkage::External: - return "external"; - } - llvm_unreachable("unknown linkage type"); -} - -// Prints the keyword for the linkage type using the printer. -static void printLinkage(OpAsmPrinter &p, LLVM::Linkage linkage) { - p << linkageToStr(linkage); -} - static void printGlobalOp(OpAsmPrinter &p, GlobalOp op) { - p << op.getOperationName() << ' '; - printLinkage(p, op.linkage()); - p << ' '; + p << op.getOperationName() << ' ' << stringifyLinkage(op.linkage()) << ' '; if (op.constant()) p << "constant "; p.printSymbolName(op.sym_name()); @@ -1150,22 +1114,30 @@ static int parseOptionalKeywordAlternative(OpAsmParser &parser, return -1; } -// Parses one of the linkage keywords and, if succeeded, appends the "linkage" -// integer attribute with the corresponding value to `result`. -// -// linkage ::= `private` | `internal` | `available_externally` | `linkonce` -// | `weak` | `common` | `appending` | `extern_weak` -// | `linkonce_odr` | `weak_odr` | `external -static ParseResult parseOptionalLinkageKeyword(OpAsmParser &parser, - OperationState &result) { - int index = parseOptionalKeywordAlternative( - parser, {"private", "internal", "available_externally", "linkonce", - "weak", "common", "appending", "extern_weak", "linkonce_odr", - "weak_odr", "external"}); +namespace { +template struct EnumTraits {}; + +#define REGISTER_ENUM_TYPE(Ty) \ + template <> struct EnumTraits { \ + static StringRef stringify(Ty value) { return stringify##Ty(value); } \ + static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \ + } + +REGISTER_ENUM_TYPE(Linkage); +} // end namespace + +template +static ParseResult parseOptionalLLVMKeyword(OpAsmParser &parser, + OperationState &result, + StringRef name) { + SmallVector names; + for (unsigned i = 0, e = getMaxEnumValForLinkage(); i <= e; ++i) + names.push_back(EnumTraits::stringify(static_cast(i))); + + int index = parseOptionalKeywordAlternative(parser, names); if (index == -1) return failure(); - result.addAttribute(getLinkageAttrName(), - parser.getBuilder().getI64IntegerAttr(index)); + result.addAttribute(name, parser.getBuilder().getI64IntegerAttr(index)); return success(); } @@ -1175,7 +1147,8 @@ static ParseResult parseOptionalLinkageKeyword(OpAsmParser &parser, // The type can be omitted for string attributes, in which case it will be // inferred from the value of the string as [strlen(value) x i8]. static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) { - if (failed(parseOptionalLinkageKeyword(parser, result))) + if (failed(parseOptionalLLVMKeyword(parser, result, + getLinkageAttrName()))) return parser.emitError(parser.getCurrentLocation(), "expected linkage"); if (succeeded(parser.parseOptionalKeyword("constant"))) @@ -1398,7 +1371,8 @@ static Type buildLLVMFunctionType(OpAsmParser &parser, llvm::SMLoc loc, static ParseResult parseLLVMFuncOp(OpAsmParser &parser, OperationState &result) { // Default to external linkage if no keyword is provided. - if (failed(parseOptionalLinkageKeyword(parser, result))) + if (failed(parseOptionalLLVMKeyword(parser, result, + getLinkageAttrName()))) result.addAttribute(getLinkageAttrName(), parser.getBuilder().getI64IntegerAttr( static_cast(LLVM::Linkage::External))); @@ -1441,10 +1415,8 @@ static ParseResult parseLLVMFuncOp(OpAsmParser &parser, // the external linkage since it is the default value. static void printLLVMFuncOp(OpAsmPrinter &p, LLVMFuncOp op) { p << op.getOperationName() << ' '; - if (op.linkage() != LLVM::Linkage::External) { - printLinkage(p, op.linkage()); - p << ' '; - } + if (op.linkage() != LLVM::Linkage::External) + p << stringifyLinkage(op.linkage()) << ' '; p.printSymbolName(op.getName()); LLVMType fnType = op.getType(); @@ -1510,16 +1482,16 @@ unsigned LLVMFuncOp::getNumFuncResults() { static LogicalResult verify(LLVMFuncOp op) { if (op.linkage() == LLVM::Linkage::Common) return op.emitOpError() - << "functions cannot have '" << linkageToStr(LLVM::Linkage::Common) - << "' linkage"; + << "functions cannot have '" + << stringifyLinkage(LLVM::Linkage::Common) << "' linkage"; if (op.isExternal()) { if (op.linkage() != LLVM::Linkage::External && op.linkage() != LLVM::Linkage::ExternWeak) return op.emitOpError() << "external functions must have '" - << linkageToStr(LLVM::Linkage::External) << "' or '" - << linkageToStr(LLVM::Linkage::ExternWeak) << "' linkage"; + << stringifyLinkage(LLVM::Linkage::External) << "' or '" + << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage"; return success(); } diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index 54ff6da..6f25c84 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -154,6 +154,10 @@ StringRef tblgen::EnumAttrCase::getSymbol() const { return def->getValueAsString("symbol"); } +StringRef tblgen::EnumAttrCase::getStr() const { + return def->getValueAsString("str"); +} + int64_t tblgen::EnumAttrCase::getValue() const { return def->getValueAsInt("value"); } diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index c25f5fe..a58ea24 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -165,8 +165,9 @@ static void emitSymToStrFnForIntEnum(const Record &enumDef, raw_ostream &os) { os << " switch (val) {\n"; for (const auto &enumerant : enumerants) { auto symbol = enumerant.getSymbol(); + auto str = enumerant.getStr(); os << formatv(" case {0}::{1}: return \"{2}\";\n", enumName, - makeIdentifier(symbol), symbol); + makeIdentifier(symbol), str); } os << " }\n"; os << " return \"\";\n"; @@ -219,7 +220,8 @@ static void emitStrToSymFnForIntEnum(const Record &enumDef, raw_ostream &os) { enumName); for (const auto &enumerant : enumerants) { auto symbol = enumerant.getSymbol(); - os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, symbol, + auto str = enumerant.getStr(); + os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, str, makeIdentifier(symbol)); } os << " .Default(llvm::None);\n"; diff --git a/mlir/unittests/TableGen/EnumsGenTest.cpp b/mlir/unittests/TableGen/EnumsGenTest.cpp index 1eed177..0c3db73 100644 --- a/mlir/unittests/TableGen/EnumsGenTest.cpp +++ b/mlir/unittests/TableGen/EnumsGenTest.cpp @@ -94,3 +94,21 @@ TEST(EnumsGenTest, GeneratedOperator) { EXPECT_FALSE(bitEnumContains(BitEnumWithNone::Bit1 & BitEnumWithNone::Bit3, BitEnumWithNone::Bit1)); } + +TEST(EnumsGenTest, GeneratedSymbolToCustomStringFn) { + EXPECT_EQ(stringifyPrettyIntEnum(PrettyIntEnum::Case1), "case_one"); + EXPECT_EQ(stringifyPrettyIntEnum(PrettyIntEnum::Case2), "case_two"); +} + +TEST(EnumsGenTest, GeneratedCustomStringToSymbolFn) { + auto one = symbolizePrettyIntEnum("case_one"); + EXPECT_TRUE(one); + EXPECT_EQ(*one, PrettyIntEnum::Case1); + + auto two = symbolizePrettyIntEnum("case_two"); + EXPECT_TRUE(two); + EXPECT_EQ(*two, PrettyIntEnum::Case2); + + auto none = symbolizePrettyIntEnum("Case1"); + EXPECT_FALSE(none); +} diff --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td index 3a9dcc5..b2c8f6f 100644 --- a/mlir/unittests/TableGen/enums.td +++ b/mlir/unittests/TableGen/enums.td @@ -31,3 +31,9 @@ def BitEnumWithNone : BitEnumAttr<"BitEnumWithNone", "A test enum", def BitEnumWithoutNone : BitEnumAttr<"BitEnumWithoutNone", "A test enum", [Bit1, Bit3]>; + +def PrettyIntEnumCase1: I32EnumAttrCase<"Case1", 1, "case_one">; +def PrettyIntEnumCase2: I32EnumAttrCase<"Case2", 2, "case_two">; + +def PrettyIntEnum: I32EnumAttr<"PrettyIntEnum", "A test enum", + [PrettyIntEnumCase1, PrettyIntEnumCase2]>; -- 2.7.4