From e0774c008fdcee1d4007ede4fde4cf7ad83cfda8 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 1 Feb 2019 15:40:22 -0800 Subject: [PATCH] [TableGen] Use tblgen::DagLeaf to model DAG arguments This CL added a tblgen::DagLeaf wrapper class with several helper methods for handling DAG arguments. It helps to refactor the rewriter generation logic to be more higher level. This CL also added a tblgen::ConstantAttr wrapper class for constant attributes. PiperOrigin-RevId: 232050683 --- mlir/include/mlir/TableGen/Attribute.h | 18 ++ mlir/include/mlir/TableGen/Pattern.h | 82 ++++++-- mlir/include/mlir/TableGen/Type.h | 1 + mlir/lib/TableGen/Attribute.cpp | 14 ++ mlir/lib/TableGen/Pattern.cpp | 72 ++++++- mlir/test/mlir-tblgen/one-op-one-result.td | 17 +- mlir/tools/mlir-tblgen/RewriterGen.cpp | 232 +++++++++++---------- 7 files changed, 292 insertions(+), 144 deletions(-) diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index b126617f289f..e601fdf22ead 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -111,6 +111,24 @@ public: StringRef getDerivedCodeBody() const; }; +// Wrapper class providing helper methods for accessing MLIR constant attribute +// defined in TableGen. This class should closely reflect what is defined as +// class `ConstantAttr` in TableGen. +class ConstantAttr { +public: + explicit ConstantAttr(const llvm::DefInit *init); + + // Returns the attribute kind. + Attribute getAttribute() const; + + // Returns the constant value. + StringRef getConstantValue() const; + +private: + // The TableGen definition of this constant attribute. + const llvm::Record *def; +}; + } // end namespace tblgen } // end namespace mlir diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index 6544316313d7..80a38329a33b 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -23,6 +23,7 @@ #ifndef MLIR_TABLEGEN_PATTERN_H_ #define MLIR_TABLEGEN_PATTERN_H_ +#include "mlir/Support/LLVM.h" #include "mlir/TableGen/Argument.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/DenseMap.h" @@ -30,9 +31,9 @@ #include "llvm/TableGen/Error.h" namespace llvm { -class Record; -class Init; class DagInit; +class Init; +class Record; class StringRef; } // end namespace llvm @@ -42,19 +43,61 @@ namespace tblgen { // Mapping from TableGen Record to Operator wrapper object using RecordOperatorMap = llvm::DenseMap; -// Wrapper around DAG argument. -struct DagArg { - DagArg(Argument arg, llvm::Init *constraint) - : arg(arg), constraint(constraint) {} +class Pattern; - // Returns true if this DAG argument concerns an operation attribute. - bool isAttr() const; +// Wrapper class providing helper methods for accessing TableGen DAG leaves +// used inside Patterns. This class is lightweight and designed to be used like +// values. +// +// A TableGen DAG construct is of the syntax +// `(operator, arg0, arg1, ...)`. +// +// This class provides getters to retrieve `arg*` as tblgen:: wrapper objects +// for handy helper methods. It only works on `arg*`s that are not nested DAG +// constructs. +class DagLeaf { +public: + explicit DagLeaf(const llvm::Init *def) : def(def) {} - Argument arg; - llvm::Init *constraint; -}; + // Returns true if this DAG leaf is not specified in the pattern. That is, it + // places no further constraints/transforms and just carries over the original + // value. + bool isUnspecified() const; -class Pattern; + // Returns true if this DAG leaf is matching an operand. That is, it specifies + // a type constraint. + bool isOperandMatcher() const; + + // Returns true if this DAG leaf is matching an attribute. That is, it + // specifies an attribute constraint. + bool isAttrMatcher() const; + + // Returns true if this DAG leaf is transforming an attribute. + bool isAttrTransformer() const; + + // Returns true if this DAG leaf is specifying a constant attribute. + bool isConstantAttr() const; + + // Returns this DAG leaf as a type constraint. Asserts if fails. + TypeConstraint getAsTypeConstraint() const; + + // Returns this DAG leaf as an attribute constraint. Asserts if fails. + AttrConstraint getAsAttrConstraint() const; + + // Returns this DAG leaf as an constant attribute. Asserts if fails. + ConstantAttr getAsConstantAttr() const; + + // Returns the matching condition template inside this DAG leaf. Assumes the + // leaf is an operand/attribute matcher and asserts otherwise. + std::string getConditionTemplate() const; + + // Returns the transformation template inside this DAG leaf. Assumes the + // leaf is an attribute matcher and asserts otherwise. + std::string getTransformationTemplate() const; + +private: + const llvm::Init *def; +}; // Wrapper class providing helper methods for accessing TableGen DAG constructs // used inside Patterns. This class is lightweight and designed to be used like @@ -96,10 +139,9 @@ public: // Gets the `index`-th argument as a nested DAG construct if possible. Returns // null DagNode otherwise. DagNode getArgAsNestedDag(unsigned index) const; - // Gets the `index`-th argument as a TableGen DefInit* if possible. Returns - // nullptr otherwise. - // TODO: This method is exposing raw TableGen object and should be changed. - llvm::DefInit *getArgAsDefInit(unsigned index) const; + + // Gets the `index`-th argument as a DAG leaf. + DagLeaf getArgAsLeaf(unsigned index) const; // Returns the specified name of the `index`-th argument. llvm::StringRef getArgName(unsigned index) const; @@ -146,7 +188,7 @@ public: void ensureArgBoundInSourcePattern(llvm::StringRef name) const; // Returns a reference to all the bound arguments in the source pattern. - llvm::StringMap &getSourcePatternBoundArgs(); + llvm::StringMap &getSourcePatternBoundArgs(); // Returns the op that the root node of the source pattern matches. const Operator &getSourceRootOp(); @@ -159,8 +201,10 @@ private: // The TableGen definition of this pattern. const llvm::Record &def; - RecordOperatorMap *recordOpMap; // All operators - llvm::StringMap boundArguments; // All bound arguments + // All operators + RecordOperatorMap *recordOpMap; + // All bound arguments + llvm::StringMap boundArguments; }; } // end namespace tblgen diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h index fd91ff4dc2c8..247e0fc8e4b8 100644 --- a/mlir/include/mlir/TableGen/Type.h +++ b/mlir/include/mlir/TableGen/Type.h @@ -42,6 +42,7 @@ public: explicit TypeConstraint(const llvm::DefInit &init); bool operator==(const TypeConstraint &that) { return def == that.def; } + bool operator!=(const TypeConstraint &that) { return def != that.def; } // Returns the predicate that can be used to check if a type satisfies this // type constraint. diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index 1e9c37cac9a4..2b8cda031ef6 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -133,3 +133,17 @@ StringRef tblgen::Attribute::getDerivedCodeBody() const { assert(isDerivedAttr() && "only derived attribute has 'body' field"); return def->getValueAsString("body"); } + +tblgen::ConstantAttr::ConstantAttr(const llvm::DefInit *init) + : def(init->getDef()) { + assert(def->isSubClassOf("ConstantAttr") && + "must be subclass of TableGen 'ConstantAttr' class"); +} + +tblgen::Attribute tblgen::ConstantAttr::getAttribute() const { + return Attribute(def->getValueAsDef("attr")); +} + +StringRef tblgen::ConstantAttr::getConstantValue() const { + return def->getValueAsString("value"); +} diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index f9bb4a4b08c6..5262141e753b 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -28,8 +28,66 @@ using namespace mlir; using mlir::tblgen::Operator; -bool tblgen::DagArg::isAttr() const { - return arg.is(); +bool tblgen::DagLeaf::isUnspecified() const { + return !def || isa(def); +} + +bool tblgen::DagLeaf::isOperandMatcher() const { + if (!def || !isa(def)) + return false; + // Operand matchers specify a type constraint. + return cast(def)->getDef()->isSubClassOf("TypeConstraint"); +} + +bool tblgen::DagLeaf::isAttrMatcher() const { + if (!def || !isa(def)) + return false; + // Attribute matchers specify a type constraint. + return cast(def)->getDef()->isSubClassOf("AttrConstraint"); +} + +bool tblgen::DagLeaf::isAttrTransformer() const { + if (!def || !isa(def)) + return false; + return cast(def)->getDef()->isSubClassOf("tAttr"); +} + +bool tblgen::DagLeaf::isConstantAttr() const { + if (!def || !isa(def)) + return false; + return cast(def)->getDef()->isSubClassOf("ConstantAttr"); +} + +tblgen::TypeConstraint tblgen::DagLeaf::getAsTypeConstraint() const { + assert(isOperandMatcher() && "the DAG leaf must be operand"); + return TypeConstraint(*cast(def)->getDef()); +} + +tblgen::AttrConstraint tblgen::DagLeaf::getAsAttrConstraint() const { + assert(isAttrMatcher() && "the DAG leaf must be attribute"); + return AttrConstraint(cast(def)->getDef()); +} + +tblgen::ConstantAttr tblgen::DagLeaf::getAsConstantAttr() const { + assert(isConstantAttr() && "the DAG leaf must be constant attribute"); + return ConstantAttr(cast(def)); +} + +std::string tblgen::DagLeaf::getConditionTemplate() const { + assert((isOperandMatcher() || isAttrMatcher()) && + "the DAG leaf must be operand/attribute matcher"); + if (isOperandMatcher()) { + return getAsTypeConstraint().getConditionTemplate(); + } + return getAsAttrConstraint().getConditionTemplate(); +} + +std::string tblgen::DagLeaf::getTransformationTemplate() const { + assert(isAttrTransformer() && "the DAG leaf must be attribute transformer"); + return cast(def) + ->getDef() + ->getValueAsString("attrTransform") + .str(); } Operator &tblgen::DagNode::getDialectOp(RecordOperatorMap *mapper) const { @@ -56,8 +114,9 @@ tblgen::DagNode tblgen::DagNode::getArgAsNestedDag(unsigned index) const { return DagNode(dyn_cast_or_null(node->getArg(index))); } -llvm::DefInit *tblgen::DagNode::getArgAsDefInit(unsigned index) const { - return dyn_cast(node->getArg(index)); +tblgen::DagLeaf tblgen::DagNode::getArgAsLeaf(unsigned index) const { + assert(!isNestedDagArg(index)); + return DagLeaf(node->getArg(index)); } StringRef tblgen::DagNode::getArgName(unsigned index) const { @@ -81,7 +140,7 @@ static void collectBoundArguments(const llvm::DagInit *tree, if (name.empty()) continue; - pattern->getSourcePatternBoundArgs().try_emplace(name, op.getArg(i), arg); + pattern->getSourcePatternBoundArgs().try_emplace(name, op.getArg(i)); } } @@ -131,7 +190,8 @@ void tblgen::Pattern::ensureArgBoundInSourcePattern( Twine("referencing unbound variable '") + name + "'"); } -llvm::StringMap &tblgen::Pattern::getSourcePatternBoundArgs() { +llvm::StringMap & +tblgen::Pattern::getSourcePatternBoundArgs() { return boundArguments; } diff --git a/mlir/test/mlir-tblgen/one-op-one-result.td b/mlir/test/mlir-tblgen/one-op-one-result.td index 45056b154edc..3bdd9aa1b96b 100644 --- a/mlir/test/mlir-tblgen/one-op-one-result.td +++ b/mlir/test/mlir-tblgen/one-op-one-result.td @@ -3,24 +3,21 @@ include "mlir/IR/op_base.td" // Create a Type and Attribute. -def YT : BuildableType<"buildYT">; -def Y_Attr : TypeBasedAttr; -def Y_Const_Attr { - Attr attr = Y_Attr; - string value = "attrValue"; -} +def T : BuildableType<"buildT">; +def T_Attr : TypeBasedAttr; +def T_Const_Attr : ConstantAttr; // Define ops to rewrite. -def T1: Type, "T1">; +def U: Type, "U">; def X_AddOp : Op<"x.add"> { - let arguments = (ins T1, T1); + let arguments = (ins U, U); } def Y_AddOp : Op<"y.add"> { - let arguments = (ins T1, T1, Y_Attr:$attrName); + let arguments = (ins U, U, T_Attr:$attrName); } // Define rewrite pattern. -def : Pat<(X_AddOp $lhs, $rhs), (Y_AddOp $lhs, T1:$rhs, Y_Const_Attr:$x)>; +def : Pat<(X_AddOp $lhs, $rhs), (Y_AddOp $lhs, U:$rhs, T_Const_Attr:$x)>; // CHECK: struct GeneratedConvert0 : public RewritePattern // CHECK: RewritePattern("x.add", 1, context) diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 6b62da7eb04b..7ca663071d8a 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -39,31 +39,11 @@ using namespace llvm; using namespace mlir; -using mlir::tblgen::Argument; -using mlir::tblgen::Attribute; using mlir::tblgen::DagNode; using mlir::tblgen::NamedAttribute; using mlir::tblgen::Operand; using mlir::tblgen::Operator; -using mlir::tblgen::Pattern; using mlir::tblgen::RecordOperatorMap; -using mlir::tblgen::Type; - -namespace { - -// Wrapper around DAG argument. -struct DagArg { - DagArg(Argument arg, Init *constraintInit) - : arg(arg), constraintInit(constraintInit) {} - bool isAttr(); - - Argument arg; - Init *constraintInit; -}; - -} // end namespace - -bool DagArg::isAttr() { return arg.is(); } namespace { class PatternEmitter { @@ -93,12 +73,19 @@ private: void emitReplaceWithNativeBuilder(DagNode resultTree); // Emits the value of constant attribute to `os`. - void emitAttributeValue(Record *constAttr); + void emitConstantAttr(tblgen::ConstantAttr constAttr); // Emits C++ statements for matching the op constrained by the given DAG // `tree`. void emitOpMatch(DagNode tree, int depth); + // Emits C++ statements for matching the `index`-th argument of the given DAG + // `tree` as an operand. + void emitOperandMatch(DagNode tree, int index, int depth, int indent); + // Emits C++ statements for matching the `index`-th argument of the given DAG + // `tree` as an attribute. + void emitAttributeMatch(DagNode tree, int index, int depth, int indent); + private: // Pattern instantiation location followed by the location of multiclass // prototypes used. This is intended to be used as a whole to @@ -107,14 +94,13 @@ private: // Op's TableGen Record to wrapper object RecordOperatorMap *opMap; // Handy wrapper for pattern being emitted - Pattern pattern; + tblgen::Pattern pattern; raw_ostream &os; }; } // end namespace -void PatternEmitter::emitAttributeValue(Record *constAttr) { - Attribute attr(constAttr->getValueAsDef("attr")); - auto value = constAttr->getValue("value"); +void PatternEmitter::emitConstantAttr(tblgen::ConstantAttr constAttr) { + auto attr = constAttr.getAttribute(); if (!attr.isConstBuildable()) PrintFatalError(loc, "Attribute " + attr.getTableGenDefName() + @@ -122,7 +108,7 @@ void PatternEmitter::emitAttributeValue(Record *constAttr) { // TODO(jpienaar): Verify the constants here os << formatv(attr.getConstBuilderTemplate().str().c_str(), "rewriter", - value->getValue()->getAsUnquotedString()); + constAttr.getConstantValue()); } // Helper function to match patterns. @@ -137,13 +123,17 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) { "if (!op{0}->isa<{1}>()) return matchFailure();\n", depth, op.getQualCppClassName()); } - if (tree.getNumArgs() != op.getNumArgs()) - PrintFatalError(loc, Twine("mismatch in number of arguments to op '") + - op.getOperationName() + - "' in pattern and op's definition"); + if (tree.getNumArgs() != op.getNumArgs()) { + PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " + "pattern vs. {2} in definition", + op.getOperationName(), tree.getNumArgs(), + op.getNumArgs())); + } + for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { auto opArg = op.getArg(i); + // Handle nested DAG construct first if (DagNode argTree = tree.getArgAsNestedDag(i)) { os.indent(indent) << "{\n"; os.indent(indent + 2) << formatv( @@ -154,50 +144,78 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) { continue; } - // Verify arguments. - if (auto defInit = tree.getArgAsDefInit(i)) { - // Verify operands. - if (auto *operand = opArg.dyn_cast()) { - // Skip verification where not needed due to definition of op. - if (operand->type == Type(defInit)) - goto StateCapture; - - if (!defInit->getDef()->isSubClassOf("Type")) - PrintFatalError(loc, "type argument required for operand"); - - auto constraint = tblgen::TypeConstraint(*defInit); - os.indent(indent) - << "if (!(" - << formatv(constraint.getConditionTemplate().c_str(), - formatv("op{0}->getOperand({1})->getType()", depth, i)) - << ")) return matchFailure();\n"; - } - - // TODO(jpienaar): Verify attributes. - if (auto *namedAttr = opArg.dyn_cast()) { - auto constraint = tblgen::AttrConstraint(defInit); - std::string condition = formatv( - constraint.getConditionTemplate().c_str(), - formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth, - namedAttr->attr.getStorageType(), namedAttr->getName())); - os.indent(indent) << "if (!(" << condition - << ")) return matchFailure();\n"; - } + // Next handle DAG leaf: operand or attribute + if (auto *operand = opArg.dyn_cast()) { + emitOperandMatch(tree, i, depth, indent); + } else if (auto *namedAttr = opArg.dyn_cast()) { + emitAttributeMatch(tree, i, depth, indent); + } else { + PrintFatalError(loc, "unhandled case when matching op"); } + } +} - StateCapture: - auto name = tree.getArgName(i); - if (name.empty()) - continue; - if (opArg.is()) - os.indent(indent) << "state->" << name << " = op" << depth - << "->getOperand(" << i << ");\n"; - if (auto namedAttr = opArg.dyn_cast()) { - os.indent(indent) << "state->" << name << " = op" << depth - << "->getAttrOfType<" - << namedAttr->attr.getStorageType() << ">(\"" - << namedAttr->getName() << "\");\n"; +void PatternEmitter::emitOperandMatch(DagNode tree, int index, int depth, + int indent) { + Operator &op = tree.getDialectOp(opMap); + auto *operand = op.getArg(index).get(); + auto matcher = tree.getArgAsLeaf(index); + + // If a constraint is specified, we need to generate C++ statements to + // check the constraint. + if (!matcher.isUnspecified()) { + if (!matcher.isOperandMatcher()) { + PrintFatalError( + loc, formatv("the {1}-th argument of op '{0}' should be an operand", + op.getOperationName(), index + 1)); } + + // Only need to verify if the matcher's type is different from the one + // of op definition. + if (static_cast(operand->type) != + matcher.getAsTypeConstraint()) { + os.indent(indent) << "if (!(" + << formatv(matcher.getConditionTemplate().c_str(), + formatv("op{0}->getOperand({1})->getType()", + depth, index)) + << ")) return matchFailure();\n"; + } + } + + // Capture the value + auto name = tree.getArgName(index); + if (!name.empty()) { + os.indent(indent) << "state->" << name << " = op" << depth + << "->getOperand(" << index << ");\n"; + } +} + +void PatternEmitter::emitAttributeMatch(DagNode tree, int index, int depth, + int indent) { + Operator &op = tree.getDialectOp(opMap); + auto *namedAttr = op.getArg(index).get(); + auto matcher = tree.getArgAsLeaf(index); + + if (!matcher.isUnspecified() && !matcher.isAttrMatcher()) { + PrintFatalError( + loc, formatv("the {1}-th argument of op '{0}' should be an attribute", + op.getOperationName(), index + 1)); + } + + // If a constraint is specified, we need to generate C++ statements to + // check the constraint. + std::string condition = + formatv(matcher.getConditionTemplate().c_str(), + formatv("op{0}->getAttrOfType<{1}>(\"{2}\")", depth, + namedAttr->attr.getStorageType(), namedAttr->getName())); + os.indent(indent) << "if (!(" << condition << ")) return matchFailure();\n"; + + // Capture the value + auto name = tree.getArgName(index); + if (!name.empty()) { + os.indent(indent) << "state->" << name << " = op" << depth + << "->getAttrOfType<" << namedAttr->attr.getStorageType() + << ">(\"" << namedAttr->getName() << "\");\n"; } } @@ -234,11 +252,12 @@ void PatternEmitter::emit(StringRef rewriteName) { // Emit matched state. os << " struct MatchedState : public PatternState {\n"; for (const auto &arg : pattern.getSourcePatternBoundArgs()) { - if (auto namedAttr = arg.second.arg.dyn_cast()) { - os.indent(4) << namedAttr->attr.getStorageType() << " " << arg.first() + auto fieldName = arg.first(); + if (auto namedAttr = arg.second.dyn_cast()) { + os.indent(4) << namedAttr->attr.getStorageType() << " " << fieldName << ";\n"; } else { - os.indent(4) << "Value* " << arg.first() << ";\n"; + os.indent(4) << "Value* " << fieldName << ";\n"; } } os << " };\n"; @@ -285,10 +304,10 @@ void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) { rewriter.replaceOpWithNewOp<{0}>(op, op->getResult(0)->getType())", resultOp.getCppClassName()); if (numOpArgs != resultTree.getNumArgs()) { - PrintFatalError(loc, Twine("mismatch between arguments of resultant op (") + - Twine(numOpArgs) + - ") and arguments provided for rewrite (" + - Twine(resultTree.getNumArgs()) + Twine(')')); + PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " + "{1} in pattern vs. {2} in definition", + resultOp.getOperationName(), + resultTree.getNumArgs(), numOpArgs)); } // Create the builder call for the result. @@ -312,38 +331,33 @@ void PatternEmitter::emitReplaceOpWithNewOp(DagNode resultTree) { // Start each attribute on its own line. (os << ",\n").indent(6); + auto leaf = resultTree.getArgAsLeaf(i); // The argument in the result DAG pattern. - auto argName = resultTree.getArgName(i); - auto opName = resultOp.getArgName(i); - auto *defInit = resultTree.getArgAsDefInit(i); - auto *value = defInit ? defInit->getDef()->getValue("value") : nullptr; - if (!value) { - pattern.ensureArgBoundInSourcePattern(argName); - auto result = "s." + argName; - os << "/*" << opName << "=*/"; - if (defInit) { - auto transform = defInit->getDef(); - if (transform->isSubClassOf("tAttr")) { - // TODO(jpienaar): move to helper class. - os << formatv( - transform->getValueAsString("attrTransform").str().c_str(), - result); - continue; - } - } - os << result; - continue; + auto patArgName = resultTree.getArgName(i); + // The argument in the op definition. + auto opArgName = resultOp.getArgName(i); + + if (leaf.isUnspecified() || leaf.isOperandMatcher()) { + pattern.ensureArgBoundInSourcePattern(patArgName); + os << formatv("/*{0}=*/s.{1}", opArgName, patArgName); + } else if (leaf.isAttrTransformer()) { + pattern.ensureArgBoundInSourcePattern(patArgName); + std::string result = std::string("s.") + patArgName.str(); + result = formatv(leaf.getTransformationTemplate().c_str(), result); + os << formatv("/*{0}=*/{1}", opArgName, result); + } else if (leaf.isConstantAttr()) { + // TODO(jpienaar): Refactor out into map to avoid recomputing these. + auto argument = resultOp.getArg(i); + if (!argument.is()) + PrintFatalError(loc, Twine("expected attribute ") + Twine(i)); + + if (!patArgName.empty()) + os << "/*" << patArgName << "=*/"; + emitConstantAttr(leaf.getAsConstantAttr()); + // TODO(jpienaar): verify types + } else { + PrintFatalError(loc, "unhandled case when rewriting op"); } - - // TODO(jpienaar): Refactor out into map to avoid recomputing these. - auto argument = resultOp.getArg(i); - if (!argument.is()) - PrintFatalError(loc, Twine("expected attribute ") + Twine(i)); - - if (!argName.empty()) - os << "/*" << argName << "=*/"; - emitAttributeValue(defInit->getDef()); - // TODO(jpienaar): verify types } os << "\n );\n"; } @@ -367,7 +381,7 @@ void PatternEmitter::emitReplaceWithNativeBuilder(DagNode resultTree) { auto name = resultTree.getArgName(i); pattern.ensureArgBoundInSourcePattern(name); const auto &val = boundedValues.find(name); - if (val->second.isAttr() && !printingAttr) { + if (val->second.dyn_cast() && !printingAttr) { os << "}, {"; first = true; printingAttr = true; -- 2.34.1