[TableGen] Add EnumAttrCase and EnumAttr
authorLei Zhang <antiagainst@google.com>
Mon, 1 Apr 2019 15:58:53 +0000 (08:58 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 1 Apr 2019 17:59:31 +0000 (10:59 -0700)
    This CL adds EnumAttr as a general mechanism for modelling enum attributes. Right now
    it is using StringAttr under the hood since MLIR does not have native support for enum
    attributes.

--

PiperOrigin-RevId: 241334043

mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/TableGen/Attribute.h
mlir/include/mlir/TableGen/Pattern.h
mlir/lib/TableGen/Attribute.cpp
mlir/lib/TableGen/Pattern.cpp
mlir/test/mlir-tblgen/attr-enum.td [new file with mode: 0644]
mlir/tools/mlir-tblgen/RewriterGen.cpp

index 590823f..3259346 100644 (file)
@@ -445,6 +445,26 @@ class StringBasedAttr<Pred condition, string descr> :
 
 def StrAttr : StringBasedAttr<CPred<"true">, "string">;
 
+// An enum attribute case.
+class EnumAttrCase<string sym> : StringBasedAttr<
+    CPred<"{0}.cast<StringAttr>().getValue() == \"" # sym # "\"">,
+    "case " # sym> {
+  // The C++ enumerant symbol
+  string symbol = sym;
+}
+
+// An enum attribute. Its value can only be one from the given list of `cases`.
+// Enum attributes are emulated via mlir::StringAttr, plus extra verification
+// on the string: only the symbols of the allowed cases are permitted as the
+// string value.
+class EnumAttr<string name, string description, list<EnumAttrCase> cases> :
+    StringBasedAttr<AnyOf<!foreach(case, cases, case.predicate)>, description> {
+  // The C++ enum class name
+  string className = name;
+  // List of all accepted cases
+  list<EnumAttrCase> enumerants = cases;
+}
+
 class ElementsAttrBase<Pred condition, string description> :
     Attr<condition, description> {
   let storageType = [{ ElementsAttr }];
index 7cd5189..635d714 100644 (file)
@@ -120,6 +120,32 @@ private:
   const llvm::Record *def;
 };
 
+// Wrapper class providing helper methods for accessing enum attribute cases
+// defined in TableGen. This class should closely reflect what is defined as
+// class `EnumAttrCase` in TableGen.
+class EnumAttrCase : public Attribute {
+public:
+  explicit EnumAttrCase(const llvm::DefInit *init);
+
+  // Returns the symbol of this enum attribute case.
+  StringRef getSymbol() const;
+};
+
+// Wrapper class providing helper methods for accessing enum attributes defined
+// in TableGen. This class should closely reflect what is defined as class
+// `EnumAttr` in TableGen.
+class EnumAttr : public Attribute {
+public:
+  explicit EnumAttr(const llvm::Record *record);
+  explicit EnumAttr(const llvm::DefInit *init);
+
+  // Returns the enum class name.
+  StringRef getEnumClassName() const;
+
+  // Returns all allowed cases for this enum attribute.
+  std::vector<EnumAttrCase> getAllCases() const;
+};
+
 } // end namespace tblgen
 } // end namespace mlir
 
index 28800d3..e7856e6 100644 (file)
@@ -77,12 +77,19 @@ public:
   // Returns true if this DAG leaf is specifying a constant attribute.
   bool isConstantAttr() const;
 
+  // Returns true if this DAG leaf is specifying an enum attribute case.
+  bool isEnumAttrCase() const;
+
   // Returns this DAG leaf as a constraint. Asserts if fails.
   Constraint getAsConstraint() const;
 
   // Returns this DAG leaf as an constant attribute. Asserts if fails.
   ConstantAttr getAsConstantAttr() const;
 
+  // Returns this DAG leaf as an enum attribute case.
+  // Precondition: isEnumAttrCase()
+  EnumAttrCase getAsEnumAttrCase() const;
+
   // Returns the matching condition template inside this DAG leaf. Assumes the
   // leaf is an operand/attribute matcher and asserts otherwise.
   std::string getConditionTemplate() const;
@@ -92,6 +99,10 @@ public:
   std::string getTransformationTemplate() const;
 
 private:
+  // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
+  // also a subclass of the given `superclass`.
+  bool isSubClassOf(StringRef superclass) const;
+
   const llvm::Init *def;
 };
 
index 3e94528..26c2204 100644 (file)
@@ -130,3 +130,38 @@ tblgen::Attribute tblgen::ConstantAttr::getAttribute() const {
 StringRef tblgen::ConstantAttr::getConstantValue() const {
   return def->getValueAsString("value");
 }
+
+tblgen::EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
+    : Attribute(init) {
+  assert(def->isSubClassOf("EnumAttrCase") &&
+         "must be subclass of TableGen 'EnumAttrCase' class");
+}
+
+StringRef tblgen::EnumAttrCase::getSymbol() const {
+  return def->getValueAsString("symbol");
+}
+
+tblgen::EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
+  assert(def->isSubClassOf("EnumAttr") &&
+         "must be subclass of TableGen 'EnumAttr' class");
+}
+
+tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init)
+    : EnumAttr(init->getDef()) {}
+
+StringRef tblgen::EnumAttr::getEnumClassName() const {
+  return def->getValueAsString("className");
+}
+
+std::vector<tblgen::EnumAttrCase> tblgen::EnumAttr::getAllCases() const {
+  const auto *inits = def->getValueAsListInit("enumerants");
+
+  std::vector<tblgen::EnumAttrCase> cases;
+  cases.reserve(inits->size());
+
+  for (const llvm::Init *init : *inits) {
+    cases.push_back(tblgen::EnumAttrCase(cast<llvm::DefInit>(init)));
+  }
+
+  return cases;
+}
index a6c4095..de620e1 100644 (file)
@@ -30,33 +30,29 @@ using namespace mlir;
 using mlir::tblgen::Operator;
 
 bool tblgen::DagLeaf::isUnspecified() const {
-  return !def || isa<llvm::UnsetInit>(def);
+  return dyn_cast_or_null<llvm::UnsetInit>(def);
 }
 
 bool tblgen::DagLeaf::isOperandMatcher() const {
-  if (!def || !isa<llvm::DefInit>(def))
-    return false;
   // Operand matchers specify a type constraint.
-  return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("TypeConstraint");
+  return isSubClassOf("TypeConstraint");
 }
 
 bool tblgen::DagLeaf::isAttrMatcher() const {
-  if (!def || !isa<llvm::DefInit>(def))
-    return false;
   // Attribute matchers specify an attribute constraint.
-  return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("AttrConstraint");
+  return isSubClassOf("AttrConstraint");
 }
 
 bool tblgen::DagLeaf::isAttrTransformer() const {
-  if (!def || !isa<llvm::DefInit>(def))
-    return false;
-  return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("tAttr");
+  return isSubClassOf("tAttr");
 }
 
 bool tblgen::DagLeaf::isConstantAttr() const {
-  if (!def || !isa<llvm::DefInit>(def))
-    return false;
-  return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("ConstantAttr");
+  return isSubClassOf("ConstantAttr");
+}
+
+bool tblgen::DagLeaf::isEnumAttrCase() const {
+  return isSubClassOf("EnumAttrCase");
 }
 
 tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const {
@@ -70,6 +66,11 @@ tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const {
   return ConstantAttr(cast<llvm::DefInit>(def));
 }
 
+tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const {
+  assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
+  return EnumAttrCase(cast<llvm::DefInit>(def));
+}
+
 std::string tblgen::DagLeaf::getConditionTemplate() const {
   return getAsConstraint().getConditionTemplate();
 }
@@ -82,6 +83,12 @@ std::string tblgen::DagLeaf::getTransformationTemplate() const {
       .str();
 }
 
+bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
+  if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
+    return defInit->getDef()->isSubClassOf(superclass);
+  return false;
+}
+
 bool tblgen::DagNode::isAttrTransformer() const {
   auto op = node->getOperator();
   if (!op || !isa<llvm::DefInit>(op))
diff --git a/mlir/test/mlir-tblgen/attr-enum.td b/mlir/test/mlir-tblgen/attr-enum.td
new file mode 100644 (file)
index 0000000..52b88cf
--- /dev/null
@@ -0,0 +1,38 @@
+// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
+// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s --check-prefix=PAT
+
+include "mlir/IR/OpBase.td"
+
+def NS_SomeEnum_A : EnumAttrCase<"A">;
+def NS_SomeEnum_B : EnumAttrCase<"B">;
+def NS_SomeEnum_C : EnumAttrCase<"C">;
+
+def NS_SomeEnum : EnumAttr<
+  "SomeEnum", "some enum",
+  [NS_SomeEnum_A, NS_SomeEnum_B, NS_SomeEnum_C]>;
+
+def NS_OpA : Op<"op_a_with_enum_attr", []> {
+  let arguments = (ins NS_SomeEnum:$attr);
+}
+
+// DEF-LABEL: StringRef OpA::attr()
+// DEF-NEXT:    auto attr = this->getAttr("attr").dyn_cast_or_null<StringAttr>();
+// DEF-NEXT:    return attr.getValue();
+
+// DEF-LABEL: OpA::verify()
+// DEF:       if (!(((this->getAttr("attr").cast<StringAttr>().getValue() == "A")) || ((this->getAttr("attr").cast<StringAttr>().getValue() == "B")) || ((this->getAttr("attr").cast<StringAttr>().getValue() == "C"))))
+// DEF-SAME:    return emitOpError("attribute 'attr' failed to satisfy some enum attribute constraints");
+
+def NS_OpB : Op<"op_b_with_enum_attr", []> {
+  let arguments = (ins NS_SomeEnum:$attr);
+}
+
+def : Pat<(NS_OpA NS_SomeEnum_A:$attr), (NS_OpB NS_SomeEnum_B)>;
+
+// PAT-LABEL: struct GeneratedConvert0
+// PAT:       PatternMatchResult match
+// PAT:         if (!((op0->getAttrOfType<StringAttr>("attr").cast<StringAttr>().getValue() == "A"))) return matchFailure();
+// PAT:       void rewrite
+// PAT:         auto vOpB0 = rewriter.create<NS::OpB>(loc,
+// PAT-NEXT:        rewriter.getStringAttr("B")
+// PAT-NEXT:    );
index 025ec91..d9d5f74 100644 (file)
@@ -98,8 +98,9 @@ private:
   // result value name.
   std::string emitOpCreate(DagNode tree, int resultIndex, int depth);
 
-  // Returns the string value of constant attribute as an argument.
-  std::string handleConstantAttr(ConstantAttr constAttr);
+  // Returns the C++ expression to construct a constant attribute of the given
+  // `value` for the given attribute kind `attr`.
+  std::string handleConstantAttr(Attribute attr, StringRef value);
 
   // Returns the C++ expression to build an argument from the given DAG `leaf`.
   // `patArgName` is used to bound the argument to the source pattern.
@@ -128,16 +129,15 @@ PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
     : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), nextValueId(0),
       os(os) {}
 
-std::string PatternEmitter::handleConstantAttr(ConstantAttr constAttr) {
-  auto attr = constAttr.getAttribute();
-
+std::string PatternEmitter::handleConstantAttr(Attribute attr,
+                                               StringRef value) {
   if (!attr.isConstBuildable())
     PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() +
                              " does not have the 'constBuilderCall' field");
 
   // TODO(jpienaar): Verify the constants here
   return formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
-                 constAttr.getConstantValue());
+                 value);
 }
 
 static Twine resultName(const StringRef &name) { return Twine("res_") + name; }
@@ -448,7 +448,13 @@ void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) {
 std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
                                              llvm::StringRef argName) {
   if (leaf.isConstantAttr()) {
-    return handleConstantAttr(leaf.getAsConstantAttr());
+    auto constAttr = leaf.getAsConstantAttr();
+    return handleConstantAttr(constAttr.getAttribute(),
+                              constAttr.getConstantValue());
+  }
+  if (leaf.isEnumAttrCase()) {
+    auto enumCase = leaf.getAsEnumAttrCase();
+    return handleConstantAttr(enumCase, enumCase.getSymbol());
   }
   pattern.ensureArgBoundInSourcePattern(argName);
   std::string result = boundArgNameInRewrite(argName).str();
@@ -587,7 +593,7 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
       auto leaf = tree.getArgAsLeaf(i);
       // The argument in the result DAG pattern.
       auto patArgName = tree.getArgName(i);
-      if (leaf.isConstantAttr()) {
+      if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
         // TODO(jpienaar): Refactor out into map to avoid recomputing these.
         auto argument = resultOp.getArg(i);
         if (!argument.is<NamedAttribute *>())