[mlir][ods] Remove StrEnumAttr
authorMogball <jeffniu22@gmail.com>
Wed, 2 Mar 2022 18:00:05 +0000 (18:00 +0000)
committerMogball <jeffniu22@gmail.com>
Wed, 13 Apr 2022 17:49:02 +0000 (17:49 +0000)
StrEnumAttr has been deprecated in favour of EnumAttr, a solution based on AttrDef (https://reviews.llvm.org/D115181). This patch removes StrEnumAttr, along with all the custom ODS logic required to handle it.

See https://discourse.llvm.org/t/psa-stop-using-strenumattr-do-use-enumattr/5710 on how to transition to EnumAttr. In short,

```
// Before
def MyEnumAttr : StrEnumAttr<"MyEnum", "", [
  StrEnumAttrCase<"A">,
  StrEnumAttrCase<"B">
]>;

// After (pick an integer enum of your choice)
def MyEnum : I32EnumAttr<"MyEnum", "", [
  I32EnumAttrCase<"A", 0>,
  I32EnumAttrCase<"B", 1>
]> {
  // Don't generate a C++ class! We want to use the AttrDef
  let genSpecializedAttr = 0;
}
// Define the AttrDef
def MyEnum : EnumAttr<MyDialect, MyEnum, "my_enum">;
```

Reviewed By: rriddle, jpienaar

Differential Revision: https://reviews.llvm.org/D120834

13 files changed:
mlir/docs/OpDefinitions.md
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/Attribute.h
mlir/lib/TableGen/Attribute.cpp
mlir/test/IR/attribute.mlir
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/mlir-tblgen/pattern.mlir
mlir/tools/mlir-tblgen/DirectiveCommonGen.cpp
mlir/tools/mlir-tblgen/EnumsGen.cpp
mlir/tools/mlir-tblgen/OpFormatGen.cpp
mlir/tools/mlir-tblgen/RewriterGen.cpp
mlir/unittests/TableGen/EnumsGenTest.cpp
mlir/unittests/TableGen/enums.td

index 1f05bf6..b3aadaa 100644 (file)
@@ -1283,10 +1283,8 @@ optionality, default values, etc.:
 
 Some attributes can only take values from a predefined enum, e.g., the
 comparison kind of a comparison op. To define such attributes, ODS provides
-several mechanisms: `StrEnumAttr`, `IntEnumAttr`, and `BitEnumAttr`.
+several mechanisms: `IntEnumAttr`, and `BitEnumAttr`.
 
-*   `StrEnumAttr`: each enum case is a string, the attribute is stored as a
-    [`StringAttr`][StringAttr] in the op.
 *   `IntEnumAttr`: each enum case is an integer, the attribute is stored as a
     [`IntegerAttr`][IntegerAttr] in the op.
 *   `BitEnumAttr`: each enum case is a either the empty case, a single bit,
index 3479803..d38d16a 100644 (file)
@@ -1230,13 +1230,6 @@ class EnumAttrCaseInfo<string sym, int intVal, string strVal> {
   string str = strVal;
 }
 
-// An enum attribute case stored with StringAttr.
-class StrEnumAttrCase<string sym, int val = -1, string str = sym> :
-    EnumAttrCaseInfo<sym, val, str>,
-    StringBasedAttr<
-      CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # str # "\"">,
-      "case " # str>;
-
 // An enum attribute case stored with IntegerAttr, which has an integer value,
 // its representation as a string and a C++ symbol name which may be different.
 class IntEnumAttrCaseBase<I intType, string sym, string strVal, int intVal> :
@@ -1393,22 +1386,6 @@ class EnumAttrInfo<
   let valueType = baseAttrClass.valueType;
 }
 
-// An enum attribute backed by StringAttr.
-//
-// Op attributes of this kind are stored as StringAttr. Extra verification will
-// be generated on the string though: only the symbols of the allowed cases are
-// permitted as the string value.
-class StrEnumAttr<string name, string summary, list<StrEnumAttrCase> cases> :
-  EnumAttrInfo<name, cases,
-    StringBasedAttr<
-      And<[StrAttr.predicate, Or<!foreach(case, cases, case.predicate)>]>,
-      !if(!empty(summary), "allowed string cases: " #
-          !interleave(!foreach(case, cases, "'" # case.symbol # "'"), ", "),
-          summary)>> {
-  // Disable specialized Attribute class for `StringAttr` backend by default.
-  let genSpecializedAttr = 0;
-}
-
 // An enum attribute backed by IntegerAttr.
 //
 // Op attributes of this kind are stored as IntegerAttr. Extra verification will
index 9e6165a..b74c664 100644 (file)
@@ -144,9 +144,6 @@ public:
   explicit EnumAttrCase(const llvm::Record *record);
   explicit EnumAttrCase(const llvm::DefInit *init);
 
-  // Returns true if this EnumAttrCase is backed by a StringAttr.
-  bool isStrCase() const;
-
   // Returns the symbol of this enum attribute case.
   StringRef getSymbol() const;
 
index ae9183f..1d2b8d3 100644 (file)
@@ -157,8 +157,6 @@ EnumAttrCase::EnumAttrCase(const llvm::Record *record) : Attribute(record) {
 EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
     : EnumAttrCase(init->getDef()) {}
 
-bool EnumAttrCase::isStrCase() const { return isSubClassOf("StrEnumAttrCase"); }
-
 StringRef EnumAttrCase::getSymbol() const {
   return def->getValueAsString("symbol");
 }
index 29235df..318168d 100644 (file)
@@ -346,29 +346,6 @@ func @string_attr_custom_type() {
 // -----
 
 //===----------------------------------------------------------------------===//
-// Test StrEnumAttr
-//===----------------------------------------------------------------------===//
-
-// CHECK-LABEL: func @allowed_cases_pass
-func @allowed_cases_pass() {
-  // CHECK: test.str_enum_attr
-  %0 = "test.str_enum_attr"() {attr = "A"} : () -> i32
-  // CHECK: test.str_enum_attr
-  %1 = "test.str_enum_attr"() {attr = "B"} : () -> i32
-  return
-}
-
-// -----
-
-func @disallowed_case_fail() {
-  // expected-error @+1 {{allowed string cases: 'A', 'B'}}
-  %0 = "test.str_enum_attr"() {attr = 7: i32} : () -> i32
-  return
-}
-
-// -----
-
-//===----------------------------------------------------------------------===//
 // Test I32EnumAttr
 //===----------------------------------------------------------------------===//
 
index bccca92..85dc1b2 100644 (file)
@@ -191,17 +191,6 @@ def TypeStringAttrWithTypeOp : TEST_Op<"string_attr_with_type"> {
   let assemblyFormat = "$attr attr-dict";
 }
 
-def StrCaseA: StrEnumAttrCase<"A">;
-def StrCaseB: StrEnumAttrCase<"B">;
-
-def SomeStrEnum: StrEnumAttr<
-  "SomeStrEnum", "", [StrCaseA, StrCaseB]>;
-
-def StrEnumAttrOp : TEST_Op<"str_enum_attr"> {
-  let arguments = (ins SomeStrEnum:$attr);
-  let results = (outs I32:$val);
-}
-
 def I32Case5:  I32EnumAttrCase<"case5", 5>;
 def I32Case10: I32EnumAttrCase<"case10", 10>;
 
@@ -1260,8 +1249,6 @@ def : Pat<(OpAttrMatch3 $attr), (OpAttrMatch4 ConstUnitAttr, $attr)>;
 def OpC : TEST_Op<"op_c">, Arguments<(ins I32)>, Results<(outs I32)>;
 def : Pat<(OpC $input), (OpB $input, ConstantAttr<I32Attr, "17">:$attr)>;
 
-// Test string enum attribute in rewrites.
-def : Pat<(StrEnumAttrOp StrCaseA), (StrEnumAttrOp StrCaseB)>;
 // Test integer enum attribute in rewrites.
 def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>;
 def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>;
@@ -1568,11 +1555,8 @@ def : Pat<(SourceOp $val, ConstantAttr<I32Attr, "66">:$attr),
 // Test Legalization
 //===----------------------------------------------------------------------===//
 
-def Test_LegalizerEnum_Success : StrEnumAttrCase<"Success">;
-def Test_LegalizerEnum_Failure : StrEnumAttrCase<"Failure">;
-
-def Test_LegalizerEnum : StrEnumAttr<"Success", "Failure",
-  [Test_LegalizerEnum_Success, Test_LegalizerEnum_Failure]>;
+def Test_LegalizerEnum_Success : ConstantStrAttr<StrAttr, "Success">;
+def Test_LegalizerEnum_Failure : ConstantStrAttr<StrAttr, "Failure">;
 
 def ILLegalOpA : TEST_Op<"illegal_op_a">, Results<(outs I32)>;
 def ILLegalOpB : TEST_Op<"illegal_op_b">, Results<(outs I32)>;
@@ -1582,7 +1566,7 @@ def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32)>;
 def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>;
 def ILLegalOpG : TEST_Op<"illegal_op_g">, Results<(outs I32)>;
 def LegalOpA : TEST_Op<"legal_op_a">,
-  Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>;
+  Arguments<(ins StrAttr:$status)>, Results<(outs I32)>;
 def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>;
 def LegalOpC : TEST_Op<"legal_op_c">,
   Arguments<(ins I32)>, Results<(outs I32)>;
index 9d64865..a6525ec 100644 (file)
@@ -356,13 +356,6 @@ func @testConstOpMatchNonConst(%arg0 : i32) -> (i32) {
 // Test Enum Attributes
 //===----------------------------------------------------------------------===//
 
-// CHECK-LABEL: verifyStrEnumAttr
-func @verifyStrEnumAttr() -> i32 {
-  // CHECK: "test.str_enum_attr"() {attr = "B"}
-  %0 = "test.str_enum_attr"() {attr = "A"} : () -> i32
-  return %0 : i32
-}
-
 // CHECK-LABEL: verifyI32EnumAttr
 func @verifyI32EnumAttr() -> i32 {
   // CHECK: "test.i32_enum_attr"() {attr = 10 : i32}
index a5c8934..337b6a5 100644 (file)
@@ -35,7 +35,7 @@ using llvm::RecordKeeper;
 // declarations, functions etc.
 //
 // Some OpenMP/OpenACC clauses accept only a fixed set of values as inputs.
-// These can be represented as a String Enum Attribute (StrEnumAttr) in MLIR
+// These can be represented as a Enum Attributes (EnumAttrDef) in MLIR
 // ODS. The emitDecls function below currently generates these enumerations. The
 // name of the enumeration is specified in the enumClauseValue field of
 // Clause record in OMP.td. This name can be used to specify the type of the
index 3365ff0..15b1719 100644 (file)
@@ -314,8 +314,6 @@ static void emitUnderlyingToSymFnForIntEnum(const Record &enumDef,
 static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
   EnumAttr enumAttr(enumDef);
   StringRef enumName = enumAttr.getEnumClassName();
-  StringRef symToStrFnName = enumAttr.getSymbolToStringFnName();
-  StringRef strToSymFnName = enumAttr.getStringToSymbolFnName();
   StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
   llvm::Record *baseAttrDef = enumAttr.getBaseAttrClass();
   Attribute baseAttr(baseAttrDef);
@@ -341,28 +339,22 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
   os << formatv("{0} {0}::get(::mlir::MLIRContext *context, {1} val) {{\n",
                 attrClassName, enumName);
 
-  if (enumAttr.isSubClassOf("StrEnumAttr")) {
-    os << formatv("  ::mlir::StringAttr baseAttr = "
-                  "::mlir::StringAttr::get(context, {0}(val));\n",
-                  symToStrFnName);
-  } else {
-    StringRef underlyingType = enumAttr.getUnderlyingType();
-
-    // Assuming that it is IntegerAttr constraint
-    int64_t bitwidth = 64;
-    if (baseAttrDef->getValue("valueType")) {
-      auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType");
-      if (valueTypeDef->getValue("bitwidth"))
-        bitwidth = valueTypeDef->getValueAsInt("bitwidth");
-    }
+  StringRef underlyingType = enumAttr.getUnderlyingType();
 
-    os << formatv("  ::mlir::IntegerType intType = "
-                  "::mlir::IntegerType::get(context, {0});\n",
-                  bitwidth);
-    os << formatv("  ::mlir::IntegerAttr baseAttr = "
-                  "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n",
-                  underlyingType);
+  // Assuming that it is IntegerAttr constraint
+  int64_t bitwidth = 64;
+  if (baseAttrDef->getValue("valueType")) {
+    auto *valueTypeDef = baseAttrDef->getValueAsDef("valueType");
+    if (valueTypeDef->getValue("bitwidth"))
+      bitwidth = valueTypeDef->getValueAsInt("bitwidth");
   }
+
+  os << formatv("  ::mlir::IntegerType intType = "
+                "::mlir::IntegerType::get(context, {0});\n",
+                bitwidth);
+  os << formatv("  ::mlir::IntegerAttr baseAttr = "
+                "::mlir::IntegerAttr::get(intType, static_cast<{0}>(val));\n",
+                underlyingType);
   os << formatv("  return baseAttr.cast<{0}>();\n", attrClassName);
 
   os << "}\n";
@@ -371,14 +363,8 @@ static void emitSpecializedAttrDef(const Record &enumDef, raw_ostream &os) {
 
   os << formatv("{0} {1}::getValue() const {{\n", enumName, attrClassName);
 
-  if (enumAttr.isSubClassOf("StrEnumAttr")) {
-    os << formatv("  const auto res = {0}(::mlir::StringAttr::getValue());\n",
-                  strToSymFnName);
-    os << "  return res.getValue();\n";
-  } else {
-    os << formatv("  return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n",
-                  enumName);
-  }
+  os << formatv("  return static_cast<{0}>(::mlir::IntegerAttr::getInt());\n",
+                enumName);
 
   os << "}\n";
 }
@@ -483,8 +469,7 @@ public:
 )";
   if (enumAttr.genSpecializedAttr()) {
     StringRef attrClassName = enumAttr.getSpecializedAttrClassName();
-    StringRef baseAttrClassName =
-        enumAttr.isSubClassOf("StrEnumAttr") ? "StringAttr" : "IntegerAttr";
+    StringRef baseAttrClassName = "IntegerAttr";
     os << formatv(attrClassDecl, enumName, attrClassName, baseAttrClassName);
   }
 
index fb54dcb..d6976e0 100644 (file)
@@ -1797,30 +1797,11 @@ static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op,
   // Get a string containing all of the cases that can't be represented with a
   // keyword.
   BitVector nonKeywordCases(cases.size());
-  bool hasStrCase = false;
   for (auto &it : llvm::enumerate(cases)) {
-    hasStrCase = it.value().isStrCase();
     if (!canFormatStringAsKeyword(it.value().getStr()))
       nonKeywordCases.set(it.index());
   }
 
-  // If this is a string enum, use the case string to determine which cases
-  // need to use the string form.
-  if (hasStrCase) {
-    if (nonKeywordCases.any()) {
-      body << "    if (llvm::is_contained(llvm::ArrayRef<llvm::StringRef>(";
-      llvm::interleaveComma(nonKeywordCases.set_bits(), body, [&](unsigned it) {
-        body << '"' << cases[it].getStr() << '"';
-      });
-      body << ")))\n"
-              "      _odsPrinter << '\"' << caseValueStr << '\"';\n"
-              "    else\n  ";
-    }
-    body << "    _odsPrinter << caseValueStr;\n"
-            "  }\n";
-    return;
-  }
-
   // Otherwise if this is a bit enum attribute, don't allow cases that may
   // overlap with other cases. For simplicity sake, only allow cases with a
   // single bit value.
index 0159328..d7dfa73 100644 (file)
@@ -1221,8 +1221,6 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
   }
   if (leaf.isEnumAttrCase()) {
     auto enumCase = leaf.getAsEnumAttrCase();
-    if (enumCase.isStrCase())
-      return handleConstantAttr(enumCase, "\"" + enumCase.getSymbol() + "\"");
     // This is an enum case backed by an IntegerAttr. We need to get its value
     // to build the constant.
     std::string val = std::to_string(enumCase.getValue());
index 82dbe11..a5819c5 100644 (file)
 /// Test namespaces and enum class/utility names.
 using Outer::Inner::ConvertToEnum;
 using Outer::Inner::ConvertToString;
-using Outer::Inner::StrEnum;
-using Outer::Inner::StrEnumAttr;
+using Outer::Inner::FooEnum;
+using Outer::Inner::FooEnumAttr;
 
 TEST(EnumsGenTest, GeneratedStrEnumDefinition) {
-  EXPECT_EQ(0u, static_cast<uint64_t>(StrEnum::CaseA));
-  EXPECT_EQ(10u, static_cast<uint64_t>(StrEnum::CaseB));
+  EXPECT_EQ(0u, static_cast<uint64_t>(FooEnum::CaseA));
+  EXPECT_EQ(1u, static_cast<uint64_t>(FooEnum::CaseB));
 }
 
 TEST(EnumsGenTest, GeneratedI32EnumDefinition) {
@@ -41,23 +41,23 @@ TEST(EnumsGenTest, GeneratedI32EnumDefinition) {
 }
 
 TEST(EnumsGenTest, GeneratedDenseMapInfo) {
-  llvm::DenseMap<StrEnum, std::string> myMap;
+  llvm::DenseMap<FooEnum, std::string> myMap;
 
-  myMap[StrEnum::CaseA] = "zero";
-  myMap[StrEnum::CaseB] = "one";
+  myMap[FooEnum::CaseA] = "zero";
+  myMap[FooEnum::CaseB] = "one";
 
-  EXPECT_EQ(myMap[StrEnum::CaseA], "zero");
-  EXPECT_EQ(myMap[StrEnum::CaseB], "one");
+  EXPECT_EQ(myMap[FooEnum::CaseA], "zero");
+  EXPECT_EQ(myMap[FooEnum::CaseB], "one");
 }
 
 TEST(EnumsGenTest, GeneratedSymbolToStringFn) {
-  EXPECT_EQ(ConvertToString(StrEnum::CaseA), "CaseA");
-  EXPECT_EQ(ConvertToString(StrEnum::CaseB), "CaseB");
+  EXPECT_EQ(ConvertToString(FooEnum::CaseA), "CaseA");
+  EXPECT_EQ(ConvertToString(FooEnum::CaseB), "CaseB");
 }
 
 TEST(EnumsGenTest, GeneratedStringToSymbolFn) {
-  EXPECT_EQ(llvm::Optional<StrEnum>(StrEnum::CaseA), ConvertToEnum("CaseA"));
-  EXPECT_EQ(llvm::Optional<StrEnum>(StrEnum::CaseB), ConvertToEnum("CaseB"));
+  EXPECT_EQ(llvm::Optional<FooEnum>(FooEnum::CaseA), ConvertToEnum("CaseA"));
+  EXPECT_EQ(llvm::Optional<FooEnum>(FooEnum::CaseB), ConvertToEnum("CaseB"));
   EXPECT_EQ(llvm::None, ConvertToEnum("X"));
 }
 
@@ -155,19 +155,6 @@ TEST(EnumsGenTest, GeneratedIntAttributeClass) {
   EXPECT_EQ(intAttr, enumAttr);
 }
 
-TEST(EnumsGenTest, GeneratedStringAttributeClass) {
-  mlir::MLIRContext ctx;
-  StrEnum rawVal = StrEnum::CaseA;
-
-  StrEnumAttr enumAttr = StrEnumAttr::get(&ctx, rawVal);
-  EXPECT_NE(enumAttr, nullptr);
-  EXPECT_EQ(enumAttr.getValue(), rawVal);
-
-  mlir::Attribute strAttr = mlir::StringAttr::get(&ctx, "CaseA");
-  EXPECT_TRUE(strAttr.isa<StrEnumAttr>());
-  EXPECT_EQ(strAttr, enumAttr);
-}
-
 TEST(EnumsGenTest, GeneratedBitAttributeClass) {
   mlir::MLIRContext ctx;
 
index 142f414..dcc2313 100644 (file)
@@ -8,10 +8,10 @@
 
 include "mlir/IR/OpBase.td"
 
-def CaseA: StrEnumAttrCase<"CaseA">;
-def CaseB: StrEnumAttrCase<"CaseB", 10>;
+def CaseA: I32EnumAttrCase<"CaseA", 0>;
+def CaseB: I32EnumAttrCase<"CaseB", 1>;
 
-def StrEnum: StrEnumAttr<"StrEnum", "A test enum", [CaseA, CaseB]> {
+def FooEnum: I32EnumAttr<"FooEnum", "A test enum", [CaseA, CaseB]> {
   let cppNamespace = "Outer::Inner";
   let stringToSymbolFnName = "ConvertToEnum";
   let symbolToStringFnName = "ConvertToString";