StrEnumAttr has been deprecated in favour of EnumAttr, a solution based on AttrDef (https://reviews.llvm.org/D115181). This patch removes StrEnumAttr, along with all the custom ODS logic required to handle it.
See https://discourse.llvm.org/t/psa-stop-using-strenumattr-do-use-enumattr/5710 on how to transition to EnumAttr. In short,
```
// Before
def MyEnumAttr : StrEnumAttr<"MyEnum", "", [
StrEnumAttrCase<"A">,
StrEnumAttrCase<"B">
]>;
// After (pick an integer enum of your choice)
def MyEnum : I32EnumAttr<"MyEnum", "", [
I32EnumAttrCase<"A", 0>,
I32EnumAttrCase<"B", 1>
]> {
// Don't generate a C++ class! We want to use the AttrDef
let genSpecializedAttr = 0;
}
// Define the AttrDef
def MyEnum : EnumAttr<MyDialect, MyEnum, "my_enum">;
```
Reviewed By: rriddle, jpienaar
Differential Revision: https://reviews.llvm.org/D120834
Some attributes can only take values from a predefined enum, e.g., the
comparison kind of a comparison op. To define such attributes, ODS provides
-several mechanisms: `StrEnumAttr`, `IntEnumAttr`, and `BitEnumAttr`.
+several mechanisms: `IntEnumAttr`, and `BitEnumAttr`.
-* `StrEnumAttr`: each enum case is a string, the attribute is stored as a
- [`StringAttr`][StringAttr] in the op.
* `IntEnumAttr`: each enum case is an integer, the attribute is stored as a
[`IntegerAttr`][IntegerAttr] in the op.
* `BitEnumAttr`: each enum case is a either the empty case, a single bit,
string str = strVal;
}
-// An enum attribute case stored with StringAttr.
-class StrEnumAttrCase<string sym, int val = -1, string str = sym> :
- EnumAttrCaseInfo<sym, val, str>,
- StringBasedAttr<
- CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # str # "\"">,
- "case " # str>;
-
// An enum attribute case stored with IntegerAttr, which has an integer value,
// its representation as a string and a C++ symbol name which may be different.
class IntEnumAttrCaseBase<I intType, string sym, string strVal, int intVal> :
let valueType = baseAttrClass.valueType;
}
-// An enum attribute backed by StringAttr.
-//
-// Op attributes of this kind are stored as StringAttr. Extra verification will
-// be generated on the string though: only the symbols of the allowed cases are
-// permitted as the string value.
-class StrEnumAttr<string name, string summary, list<StrEnumAttrCase> cases> :
- EnumAttrInfo<name, cases,
- StringBasedAttr<
- And<[StrAttr.predicate, Or<!foreach(case, cases, case.predicate)>]>,
- !if(!empty(summary), "allowed string cases: " #
- !interleave(!foreach(case, cases, "'" # case.symbol # "'"), ", "),
- summary)>> {
- // Disable specialized Attribute class for `StringAttr` backend by default.
- let genSpecializedAttr = 0;
-}
-
// An enum attribute backed by IntegerAttr.
//
// Op attributes of this kind are stored as IntegerAttr. Extra verification will
explicit EnumAttrCase(const llvm::Record *record);
explicit EnumAttrCase(const llvm::DefInit *init);
- // Returns true if this EnumAttrCase is backed by a StringAttr.
- bool isStrCase() const;
-
// Returns the symbol of this enum attribute case.
StringRef getSymbol() const;
EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
: EnumAttrCase(init->getDef()) {}
-bool EnumAttrCase::isStrCase() const { return isSubClassOf("StrEnumAttrCase"); }
-
StringRef EnumAttrCase::getSymbol() const {
return def->getValueAsString("symbol");
}
// -----
//===----------------------------------------------------------------------===//
-// Test StrEnumAttr
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func @allowed_cases_pass
-func @allowed_cases_pass() {
- // CHECK: test.str_enum_attr
- %0 = "test.str_enum_attr"() {attr = "A"} : () -> i32
- // CHECK: test.str_enum_attr
- %1 = "test.str_enum_attr"() {attr = "B"} : () -> i32
- return
-}
-
-// -----
-
-func @disallowed_case_fail() {
- // expected-error @+1 {{allowed string cases: 'A', 'B'}}
- %0 = "test.str_enum_attr"() {attr = 7: i32} : () -> i32
- return
-}
-
-// -----
-
-//===----------------------------------------------------------------------===//
// Test I32EnumAttr
//===----------------------------------------------------------------------===//
let assemblyFormat = "$attr attr-dict";
}
-def StrCaseA: StrEnumAttrCase<"A">;
-def StrCaseB: StrEnumAttrCase<"B">;
-
-def SomeStrEnum: StrEnumAttr<
- "SomeStrEnum", "", [StrCaseA, StrCaseB]>;
-
-def StrEnumAttrOp : TEST_Op<"str_enum_attr"> {
- let arguments = (ins SomeStrEnum:$attr);
- let results = (outs I32:$val);
-}
-
def I32Case5: I32EnumAttrCase<"case5", 5>;
def I32Case10: I32EnumAttrCase<"case10", 10>;
def OpC : TEST_Op<"op_c">, Arguments<(ins I32)>, Results<(outs I32)>;
def : Pat<(OpC $input), (OpB $input, ConstantAttr<I32Attr, "17">:$attr)>;
-// Test string enum attribute in rewrites.
-def : Pat<(StrEnumAttrOp StrCaseA), (StrEnumAttrOp StrCaseB)>;
// Test integer enum attribute in rewrites.
def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>;
def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>;
// Test Legalization
//===----------------------------------------------------------------------===//
-def Test_LegalizerEnum_Success : StrEnumAttrCase<"Success">;
-def Test_LegalizerEnum_Failure : StrEnumAttrCase<"Failure">;
-
-def Test_LegalizerEnum : StrEnumAttr<"Success", "Failure",
- [Test_LegalizerEnum_Success, Test_LegalizerEnum_Failure]>;
+def Test_LegalizerEnum_Success : ConstantStrAttr<StrAttr, "Success">;
+def Test_LegalizerEnum_Failure : ConstantStrAttr<StrAttr, "Failure">;
def ILLegalOpA : TEST_Op<"illegal_op_a">, Results<(outs I32)>;
def ILLegalOpB : TEST_Op<"illegal_op_b">, Results<(outs I32)>;
def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>;
def ILLegalOpG : TEST_Op<"illegal_op_g">, Results<(outs I32)>;
def LegalOpA : TEST_Op<"legal_op_a">,
- Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>;
+ Arguments<(ins StrAttr:$status)>, Results<(outs I32)>;
def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>;
def LegalOpC : TEST_Op<"legal_op_c">,
Arguments<(ins I32)>, Results<(outs I32)>;
// Test Enum Attributes
//===----------------------------------------------------------------------===//
-// CHECK-LABEL: verifyStrEnumAttr
-func @verifyStrEnumAttr() -> i32 {
- // CHECK: "test.str_enum_attr"() {attr = "B"}
- %0 = "test.str_enum_attr"() {attr = "A"} : () -> i32
- return %0 : i32
-}
-
// CHECK-LABEL: verifyI32EnumAttr
func @verifyI32EnumAttr() -> i32 {
// CHECK: "test.i32_enum_attr"() {attr = 10 : i32}
// declarations, functions etc.
//
// Some OpenMP/OpenACC clauses accept only a fixed set of values as inputs.
-// These can be represented as a String Enum Attribute (StrEnumAttr) in MLIR
+// These can be represented as a Enum Attributes (EnumAttrDef) in MLIR
// ODS. The emitDecls function below currently generates these enumerations. The
// name of the enumeration is specified in the enumClauseValue field of
// Clause record in OMP.td. This name can be used to specify the type of the
static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
EnumAttr enumAttr(enumDef);
StringRef enumName = enumAttr.getEnumClassName();
- StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
- StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
llvm::Record *baseAttrDef = enumAttr.getBaseAttrClass();
Attribute baseAttr(baseAttrDef);
os << formatv("{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n",
attrClassName, enumName);
- if (enumAttr.isSubClassOf("StrEnumAttr")) {
- os << formatv(" ::mlir::StringAttr baseAttr = "
- "::mlir::StringAttr::get(context, {0}(val));\n",
- symToStrFnName);
- } else {
- StringRef underlyingType = enumAttr.getUnderlyingType();
-
- // Assuming that it is IntegerAttr constraint
- int64_t bitwidth = 64;
- if (baseAttrDef->getValue("valueType")) {
- auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType");
- if (valueTypeDef->getValue("bitwidth"))
- bitwidth = valueTypeDef->getValueAsInt("bitwidth");
- }
+ StringRef underlyingType = enumAttr.getUnderlyingType();
- os << formatv(" ::mlir::IntegerType intType = "
- "::mlir::IntegerType::get(context, {0});\n",
- bitwidth);
- os << formatv(" ::mlir::IntegerAttr baseAttr = "
- "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n",
- underlyingType);
+ // Assuming that it is IntegerAttr constraint
+ int64_t bitwidth = 64;
+ if (baseAttrDef->getValue("valueType")) {
+ auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType");
+ if (valueTypeDef->getValue("bitwidth"))
+ bitwidth = valueTypeDef->getValueAsInt("bitwidth");
}
+
+ os << formatv(" ::mlir::IntegerType intType = "
+ "::mlir::IntegerType::get(context, {0});\n",
+ bitwidth);
+ os << formatv(" ::mlir::IntegerAttr baseAttr = "
+ "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n",
+ underlyingType);
os << formatv(" return baseAttr.cast<{0}>();\n", attrClassName);
os << "}\n";
os << formatv("{0} {1}::getValue() const {{\n", enumName, attrClassName);
- if (enumAttr.isSubClassOf("StrEnumAttr")) {
- os << formatv(" const auto res = {0}(::mlir::StringAttr::getValue());\n",
- strToSymFnName);
- os << " return res.getValue();\n";
- } else {
- os << formatv(" return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n",
- enumName);
- }
+ os << formatv(" return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n",
+ enumName);
os << "}\n";
}
)";
if (enumAttr.genSpecializedAttr()) {
StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
- StringRef baseAttrClassName =
- enumAttr.isSubClassOf("StrEnumAttr") ? "StringAttr" : "IntegerAttr";
+ StringRef baseAttrClassName = "IntegerAttr";
os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName);
}
// Get a string containing all of the cases that can't be represented with a
// keyword.
BitVector nonKeywordCases(cases.size());
- bool hasStrCase = false;
for (auto &it : llvm::enumerate(cases)) {
- hasStrCase = it.value().isStrCase();
if (!canFormatStringAsKeyword(it.value().getStr()))
nonKeywordCases.set(it.index());
}
- // If this is a string enum, use the case string to determine which cases
- // need to use the string form.
- if (hasStrCase) {
- if (nonKeywordCases.any()) {
- body << " if (llvm::is_contained(llvm::ArrayRef<llvm::StringRef>(";
- llvm::interleaveComma(nonKeywordCases.set_bits(), body, [&](unsigned it) {
- body << '"' << cases[it].getStr() << '"';
- });
- body << ")))\n"
- " _odsPrinter << '\"' << caseValueStr << '\"';\n"
- " else\n ";
- }
- body << " _odsPrinter << caseValueStr;\n"
- " }\n";
- return;
- }
-
// Otherwise if this is a bit enum attribute, don't allow cases that may
// overlap with other cases. For simplicity sake, only allow cases with a
// single bit value.
}
if (leaf.isEnumAttrCase()) {
auto enumCase = leaf.getAsEnumAttrCase();
- if (enumCase.isStrCase())
- return handleConstantAttr(enumCase, "\"" + enumCase.getSymbol() + "\"");
// This is an enum case backed by an IntegerAttr. We need to get its value
// to build the constant.
std::string val = std::to_string(enumCase.getValue());
/// Test namespaces and enum class/utility names.
using Outer::Inner::ConvertToEnum;
using Outer::Inner::ConvertToString;
-using Outer::Inner::StrEnum;
-using Outer::Inner::StrEnumAttr;
+using Outer::Inner::FooEnum;
+using Outer::Inner::FooEnumAttr;
TEST(EnumsGenTest, GeneratedStrEnumDefinition) {
- EXPECT_EQ(0u, static_cast<uint64_t>(StrEnum::CaseA));
- EXPECT_EQ(10u, static_cast<uint64_t>(StrEnum::CaseB));
+ EXPECT_EQ(0u, static_cast<uint64_t>(FooEnum::CaseA));
+ EXPECT_EQ(1u, static_cast<uint64_t>(FooEnum::CaseB));
}
TEST(EnumsGenTest, GeneratedI32EnumDefinition) {
}
TEST(EnumsGenTest, GeneratedDenseMapInfo) {
- llvm::DenseMap<StrEnum, std::string> myMap;
+ llvm::DenseMap<FooEnum, std::string> myMap;
- myMap[StrEnum::CaseA] = "zero";
- myMap[StrEnum::CaseB] = "one";
+ myMap[FooEnum::CaseA] = "zero";
+ myMap[FooEnum::CaseB] = "one";
- EXPECT_EQ(myMap[StrEnum::CaseA], "zero");
- EXPECT_EQ(myMap[StrEnum::CaseB], "one");
+ EXPECT_EQ(myMap[FooEnum::CaseA], "zero");
+ EXPECT_EQ(myMap[FooEnum::CaseB], "one");
}
TEST(EnumsGenTest, GeneratedSymbolToStringFn) {
- EXPECT_EQ(ConvertToString(StrEnum::CaseA), "CaseA");
- EXPECT_EQ(ConvertToString(StrEnum::CaseB), "CaseB");
+ EXPECT_EQ(ConvertToString(FooEnum::CaseA), "CaseA");
+ EXPECT_EQ(ConvertToString(FooEnum::CaseB), "CaseB");
}
TEST(EnumsGenTest, GeneratedStringToSymbolFn) {
- EXPECT_EQ(llvm::Optional<StrEnum>(StrEnum::CaseA), ConvertToEnum("CaseA"));
- EXPECT_EQ(llvm::Optional<StrEnum>(StrEnum::CaseB), ConvertToEnum("CaseB"));
+ EXPECT_EQ(llvm::Optional<FooEnum>(FooEnum::CaseA), ConvertToEnum("CaseA"));
+ EXPECT_EQ(llvm::Optional<FooEnum>(FooEnum::CaseB), ConvertToEnum("CaseB"));
EXPECT_EQ(llvm::None, ConvertToEnum("X"));
}
EXPECT_EQ(intAttr, enumAttr);
}
-TEST(EnumsGenTest, GeneratedStringAttributeClass) {
- mlir::MLIRContext ctx;
- StrEnum rawVal = StrEnum::CaseA;
-
- StrEnumAttr enumAttr = StrEnumAttr::get(&ctx, rawVal);
- EXPECT_NE(enumAttr, nullptr);
- EXPECT_EQ(enumAttr.getValue(), rawVal);
-
- mlir::Attribute strAttr = mlir::StringAttr::get(&ctx, "CaseA");
- EXPECT_TRUE(strAttr.isa<StrEnumAttr>());
- EXPECT_EQ(strAttr, enumAttr);
-}
-
TEST(EnumsGenTest, GeneratedBitAttributeClass) {
mlir::MLIRContext ctx;
include "mlir/IR/OpBase.td"
-def CaseA: StrEnumAttrCase<"CaseA">;
-def CaseB: StrEnumAttrCase<"CaseB", 10>;
+def CaseA: I32EnumAttrCase<"CaseA", 0>;
+def CaseB: I32EnumAttrCase<"CaseB", 1>;
-def StrEnum: StrEnumAttr<"StrEnum", "A test enum", [CaseA, CaseB]> {
+def FooEnum: I32EnumAttr<"FooEnum", "A test enum", [CaseA, CaseB]> {
let cppNamespace = "Outer::Inner";
let stringToSymbolFnName = "ConvertToEnum";
let symbolToStringFnName = "ConvertToString";