def StrAttr : StringBasedAttr<CPred<"true">, "string">;
+// An enum attribute case.
+class EnumAttrCase<string sym> : StringBasedAttr<
+ CPred<"{0}.cast<StringAttr>().getValue() == \"" # sym # "\"">,
+ "case " # sym> {
+ // The C++ enumerant symbol
+ string symbol = sym;
+}
+
+// An enum attribute. Its value can only be one from the given list of `cases`.
+// Enum attributes are emulated via mlir::StringAttr, plus extra verification
+// on the string: only the symbols of the allowed cases are permitted as the
+// string value.
+class EnumAttr<string name, string description, list<EnumAttrCase> cases> :
+ StringBasedAttr<AnyOf<!foreach(case, cases, case.predicate)>, description> {
+ // The C++ enum class name
+ string className = name;
+ // List of all accepted cases
+ list<EnumAttrCase> enumerants = cases;
+}
+
class ElementsAttrBase<Pred condition, string description> :
Attr<condition, description> {
let storageType = [{ ElementsAttr }];
const llvm::Record *def;
};
+// Wrapper class providing helper methods for accessing enum attribute cases
+// defined in TableGen. This class should closely reflect what is defined as
+// class `EnumAttrCase` in TableGen.
+class EnumAttrCase : public Attribute {
+public:
+ explicit EnumAttrCase(const llvm::DefInit *init);
+
+ // Returns the symbol of this enum attribute case.
+ StringRef getSymbol() const;
+};
+
+// Wrapper class providing helper methods for accessing enum attributes defined
+// in TableGen. This class should closely reflect what is defined as class
+// `EnumAttr` in TableGen.
+class EnumAttr : public Attribute {
+public:
+ explicit EnumAttr(const llvm::Record *record);
+ explicit EnumAttr(const llvm::DefInit *init);
+
+ // Returns the enum class name.
+ StringRef getEnumClassName() const;
+
+ // Returns all allowed cases for this enum attribute.
+ std::vector<EnumAttrCase> getAllCases() const;
+};
+
} // end namespace tblgen
} // end namespace mlir
// Returns true if this DAG leaf is specifying a constant attribute.
bool isConstantAttr() const;
+ // Returns true if this DAG leaf is specifying an enum attribute case.
+ bool isEnumAttrCase() const;
+
// Returns this DAG leaf as a constraint. Asserts if fails.
Constraint getAsConstraint() const;
// Returns this DAG leaf as an constant attribute. Asserts if fails.
ConstantAttr getAsConstantAttr() const;
+ // Returns this DAG leaf as an enum attribute case.
+ // Precondition: isEnumAttrCase()
+ EnumAttrCase getAsEnumAttrCase() const;
+
// Returns the matching condition template inside this DAG leaf. Assumes the
// leaf is an operand/attribute matcher and asserts otherwise.
std::string getConditionTemplate() const;
std::string getTransformationTemplate() const;
private:
+ // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
+ // also a subclass of the given `superclass`.
+ bool isSubClassOf(StringRef superclass) const;
+
const llvm::Init *def;
};
StringRef tblgen::ConstantAttr::getConstantValue() const {
return def->getValueAsString("value");
}
+
+tblgen::EnumAttrCase::EnumAttrCase(const llvm::DefInit *init)
+ : Attribute(init) {
+ assert(def->isSubClassOf("EnumAttrCase") &&
+ "must be subclass of TableGen 'EnumAttrCase' class");
+}
+
+StringRef tblgen::EnumAttrCase::getSymbol() const {
+ return def->getValueAsString("symbol");
+}
+
+tblgen::EnumAttr::EnumAttr(const llvm::Record *record) : Attribute(record) {
+ assert(def->isSubClassOf("EnumAttr") &&
+ "must be subclass of TableGen 'EnumAttr' class");
+}
+
+tblgen::EnumAttr::EnumAttr(const llvm::DefInit *init)
+ : EnumAttr(init->getDef()) {}
+
+StringRef tblgen::EnumAttr::getEnumClassName() const {
+ return def->getValueAsString("className");
+}
+
+std::vector<tblgen::EnumAttrCase> tblgen::EnumAttr::getAllCases() const {
+ const auto *inits = def->getValueAsListInit("enumerants");
+
+ std::vector<tblgen::EnumAttrCase> cases;
+ cases.reserve(inits->size());
+
+ for (const llvm::Init *init : *inits) {
+ cases.push_back(tblgen::EnumAttrCase(cast<llvm::DefInit>(init)));
+ }
+
+ return cases;
+}
using mlir::tblgen::Operator;
bool tblgen::DagLeaf::isUnspecified() const {
- return !def || isa<llvm::UnsetInit>(def);
+ return dyn_cast_or_null<llvm::UnsetInit>(def);
}
bool tblgen::DagLeaf::isOperandMatcher() const {
- if (!def || !isa<llvm::DefInit>(def))
- return false;
// Operand matchers specify a type constraint.
- return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("TypeConstraint");
+ return isSubClassOf("TypeConstraint");
}
bool tblgen::DagLeaf::isAttrMatcher() const {
- if (!def || !isa<llvm::DefInit>(def))
- return false;
// Attribute matchers specify an attribute constraint.
- return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("AttrConstraint");
+ return isSubClassOf("AttrConstraint");
}
bool tblgen::DagLeaf::isAttrTransformer() const {
- if (!def || !isa<llvm::DefInit>(def))
- return false;
- return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("tAttr");
+ return isSubClassOf("tAttr");
}
bool tblgen::DagLeaf::isConstantAttr() const {
- if (!def || !isa<llvm::DefInit>(def))
- return false;
- return cast<llvm::DefInit>(def)->getDef()->isSubClassOf("ConstantAttr");
+ return isSubClassOf("ConstantAttr");
+}
+
+bool tblgen::DagLeaf::isEnumAttrCase() const {
+ return isSubClassOf("EnumAttrCase");
}
tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const {
return ConstantAttr(cast<llvm::DefInit>(def));
}
+tblgen::EnumAttrCase tblgen::DagLeaf::getAsEnumAttrCase() const {
+ assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case");
+ return EnumAttrCase(cast<llvm::DefInit>(def));
+}
+
std::string tblgen::DagLeaf::getConditionTemplate() const {
return getAsConstraint().getConditionTemplate();
}
.str();
}
+bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
+ if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(def))
+ return defInit->getDef()->isSubClassOf(superclass);
+ return false;
+}
+
bool tblgen::DagNode::isAttrTransformer() const {
auto op = node->getOperator();
if (!op || !isa<llvm::DefInit>(op))
--- /dev/null
+// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
+// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s --check-prefix=PAT
+
+include "mlir/IR/OpBase.td"
+
+def NS_SomeEnum_A : EnumAttrCase<"A">;
+def NS_SomeEnum_B : EnumAttrCase<"B">;
+def NS_SomeEnum_C : EnumAttrCase<"C">;
+
+def NS_SomeEnum : EnumAttr<
+ "SomeEnum", "some enum",
+ [NS_SomeEnum_A, NS_SomeEnum_B, NS_SomeEnum_C]>;
+
+def NS_OpA : Op<"op_a_with_enum_attr", []> {
+ let arguments = (ins NS_SomeEnum:$attr);
+}
+
+// DEF-LABEL: StringRef OpA::attr()
+// DEF-NEXT: auto attr = this->getAttr("attr").dyn_cast_or_null<StringAttr>();
+// DEF-NEXT: return attr.getValue();
+
+// DEF-LABEL: OpA::verify()
+// DEF: if (!(((this->getAttr("attr").cast<StringAttr>().getValue() == "A")) || ((this->getAttr("attr").cast<StringAttr>().getValue() == "B")) || ((this->getAttr("attr").cast<StringAttr>().getValue() == "C"))))
+// DEF-SAME: return emitOpError("attribute 'attr' failed to satisfy some enum attribute constraints");
+
+def NS_OpB : Op<"op_b_with_enum_attr", []> {
+ let arguments = (ins NS_SomeEnum:$attr);
+}
+
+def : Pat<(NS_OpA NS_SomeEnum_A:$attr), (NS_OpB NS_SomeEnum_B)>;
+
+// PAT-LABEL: struct GeneratedConvert0
+// PAT: PatternMatchResult match
+// PAT: if (!((op0->getAttrOfType<StringAttr>("attr").cast<StringAttr>().getValue() == "A"))) return matchFailure();
+// PAT: void rewrite
+// PAT: auto vOpB0 = rewriter.create<NS::OpB>(loc,
+// PAT-NEXT: rewriter.getStringAttr("B")
+// PAT-NEXT: );
// result value name.
std::string emitOpCreate(DagNode tree, int resultIndex, int depth);
- // Returns the string value of constant attribute as an argument.
- std::string handleConstantAttr(ConstantAttr constAttr);
+ // Returns the C++ expression to construct a constant attribute of the given
+ // `value` for the given attribute kind `attr`.
+ std::string handleConstantAttr(Attribute attr, StringRef value);
// Returns the C++ expression to build an argument from the given DAG `leaf`.
// `patArgName` is used to bound the argument to the source pattern.
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), nextValueId(0),
os(os) {}
-std::string PatternEmitter::handleConstantAttr(ConstantAttr constAttr) {
- auto attr = constAttr.getAttribute();
-
+std::string PatternEmitter::handleConstantAttr(Attribute attr,
+ StringRef value) {
if (!attr.isConstBuildable())
PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() +
" does not have the 'constBuilderCall' field");
// TODO(jpienaar): Verify the constants here
return formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter",
- constAttr.getConstantValue());
+ value);
}
static Twine resultName(const StringRef &name) { return Twine("res_") + name; }
std::string PatternEmitter::handleOpArgument(DagLeaf leaf,
llvm::StringRef argName) {
if (leaf.isConstantAttr()) {
- return handleConstantAttr(leaf.getAsConstantAttr());
+ auto constAttr = leaf.getAsConstantAttr();
+ return handleConstantAttr(constAttr.getAttribute(),
+ constAttr.getConstantValue());
+ }
+ if (leaf.isEnumAttrCase()) {
+ auto enumCase = leaf.getAsEnumAttrCase();
+ return handleConstantAttr(enumCase, enumCase.getSymbol());
}
pattern.ensureArgBoundInSourcePattern(argName);
std::string result = boundArgNameInRewrite(argName).str();
auto leaf = tree.getArgAsLeaf(i);
// The argument in the result DAG pattern.
auto patArgName = tree.getArgName(i);
- if (leaf.isConstantAttr()) {
+ if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
auto argument = resultOp.getArg(i);
if (!argument.is<NamedAttribute *>())