// Generates the build() method that takes each operand/attribute as a
// stand-alone parameter. This build() method uses first operand's type
- // as all result's types.
- void genUseOperandAsResultTypeBuilder();
+ // as all results' types.
+ void genUseOperandAsResultTypeSeparateParamBuilder();
+
+ // Generates the build() method that takes all operands/attributes
+ // collectively as one parameter. This build() method uses first operand's
+ // type as all results' types.
+ void genUseOperandAsResultTypeCollectiveParamBuilder();
// Generates the build() method that takes each operand/attribute as a
// stand-alone parameter. This build() method uses first attribute's type
m.body() << formatv(" {0}.addTypes(resultTypes);\n", builderOpState);
}
-void OpEmitter::genUseOperandAsResultTypeBuilder() {
+void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
+ // If this op has a variadic result, we cannot generate this builder because
+ // we don't know how many results to create.
+ if (op.getNumVariadicResults() != 0)
+ return;
+
+ int numResults = op.getNumResults();
+
+ // Signature
+ std::string params =
+ std::string("Builder *, OperationState &") + builderOpState +
+ ", ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes";
+ auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
+ auto &body = m.body();
+
+ // Result types
+ SmallVector<std::string, 2> resultTypes(numResults, "operands[0]->getType()");
+ body << " " << builderOpState << ".addTypes({"
+ << llvm::join(resultTypes, ", ") << "});\n\n";
+
+ // Operands
+ body << " " << builderOpState << ".addOperands(operands);\n\n";
+
+ // Attributes
+ body << " " << builderOpState << ".addAttributes(attributes);\n";
+
+ // Create the correct number of regions
+ if (int numRegions = op.getNumRegions()) {
+ for (int i = 0; i < numRegions; ++i)
+ m.body() << " (void)" << builderOpState << ".addRegion();\n";
+ }
+}
+
+void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
std::string paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, TypeParamKind::None);
}
void OpEmitter::genUseAttrAsResultTypeBuilder() {
- std::string paramList;
- llvm::SmallVector<std::string, 4> resultNames;
- buildParamList(paramList, resultNames, TypeParamKind::None);
-
- auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
- genCodeForAddingArgAndRegionForBuilder(m.body());
-
- auto numResults = op.getNumResults();
- if (numResults == 0)
- return;
+ std::string params =
+ std::string("Builder *, OperationState &") + builderOpState +
+ ", ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes";
+ auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
+ auto &body = m.body();
// Push all result types to the operation state
std::string resultType;
const auto &namedAttr = op.getAttribute(0);
+
+ body << " for (auto attr : attributes) {\n";
+ body << " if (attr.first != \"" << namedAttr.name << "\") continue;\n";
if (namedAttr.attr.isTypeAttr()) {
- resultType = formatv("{0}.getValue()", namedAttr.name);
+ resultType = "attr.second.cast<TypeAttr>().getValue()";
} else {
- resultType = formatv("{0}.getType()", namedAttr.name);
+ resultType = "attr.second.getType()";
}
- m.body() << " " << builderOpState << ".addTypes({" << resultType;
- for (int i = 1; i != numResults; ++i)
- m.body() << ", " << resultType;
- m.body() << "});\n\n";
+ SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
+ body << " " << builderOpState << ".addTypes({"
+ << llvm::join(resultTypes, ", ") << "});\n";
+ body << " }\n";
+
+ // Operands
+ body << " " << builderOpState << ".addOperands(operands);\n\n";
+ // Attributes
+ body << " " << builderOpState << ".addAttributes(attributes);\n";
}
void OpEmitter::genBuilder() {
// use the first operand or attribute's type as all result types
// to facilitate different call patterns.
if (op.getNumVariadicResults() == 0) {
- if (op.hasTrait("OpTrait::SameOperandsAndResultType"))
- genUseOperandAsResultTypeBuilder();
+ if (op.hasTrait("OpTrait::SameOperandsAndResultType")) {
+ genUseOperandAsResultTypeSeparateParamBuilder();
+ genUseOperandAsResultTypeCollectiveParamBuilder();
+ }
if (op.hasTrait("OpTrait::FirstAttrDerivedResultType"))
genUseAttrAsResultTypeBuilder();
}
body << " " << builderOpState << ".addOperands(operands);\n\n";
// Attributes
- body << " for (const auto& pair : attributes)\n"
- << " " << builderOpState
- << ".addAttribute(pair.first, pair.second);\n";
+ body << " " << builderOpState << ".addAttributes(attributes);\n";
// Create the correct number of regions
if (int numRegions = op.getNumRegions()) {
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
-using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
+using llvm::formatv;
+using llvm::Record;
+using llvm::RecordKeeper;
+
#define DEBUG_TYPE "mlir-tblgen-rewritergen"
namespace llvm {
// result value name.
std::string handleOpCreation(DagNode tree, int resultIndex, int depth);
+ using ChildNodeIndexNameMap = DenseMap<unsigned, std::string>;
+
+ // Emits a local variable for each value and attribute to be used for creating
+ // an op.
+ void createSeparateLocalVarsForOpArgs(DagNode node,
+ ChildNodeIndexNameMap &childNodeNames);
+
+ // Emits the concrete arguments used to call a op's builder.
+ void supplyValuesForOpArgs(DagNode node,
+ const ChildNodeIndexNameMap &childNodeNames);
+
+ // Emits the local variables for holding all values as a whole and all named
+ // attributes as a whole to be used for creating an op.
+ void createAggregateLocalVarsForOpArgs(
+ DagNode node, const ChildNodeIndexNameMap &childNodeNames);
+
// Returns the C++ expression to construct a constant attribute of the given
// `value` for the given attribute kind `attr`.
std::string handleConstantAttr(Attribute attr, StringRef value);
PrintFatalError(loc, error);
}
- os.indent(4) << "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n";
os.indent(4) << "auto loc = rewriter.getFusedLoc({";
for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
os.indent(4) << "rewriter.eraseOp(op0);\n";
} else {
// Process replacement result patterns.
- os.indent(4) << "SmallVector<Value *, 4> tblgen_values;";
+ os.indent(4) << "SmallVector<Value *, 4> tblgen_repl_values;\n";
for (int i = replStartIndex; i < numResultPatterns; ++i) {
DagNode resultTree = pattern.getResultPattern(i);
auto val = handleResultPattern(resultTree, offsets[i], 0);
os.indent(4) << "\n";
// Resolve each symbol for all range use so that we can loop over them.
os << symbolInfoMap.getAllRangeUse(
- val, " for (auto *v : {0}) {{ tblgen_values.push_back(v); }",
+ val, " for (auto *v : {0}) {{ tblgen_repl_values.push_back(v); }",
"\n");
}
os.indent(4) << "\n";
- os.indent(4) << "rewriter.replaceOp(op0, tblgen_values);\n";
+ os.indent(4) << "rewriter.replaceOp(op0, tblgen_repl_values);\n";
}
LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
// A map to collect all nested DAG child nodes' names, with operand index as
// the key. This includes both bound and unbound child nodes.
- llvm::DenseMap<unsigned, std::string> childNodeNames;
+ ChildNodeIndexNameMap childNodeNames;
// First go through all the child nodes who are nested DAG constructs to
// create ops for them and remember the symbol names for them, so that we can
valuePackName);
os.indent(4) << "{\n";
+ // Right now ODS don't have general type inference support. Except a few
+ // special cases listed below, DRR needs to supply types for all results
+ // when building an op.
+ bool isSameOperandsAndResultType =
+ resultOp.hasTrait("OpTrait::SameOperandsAndResultType");
+ bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType");
+
+ if (isSameOperandsAndResultType || useFirstAttr) {
+ // We know how to deduce the result type for ops with these traits and we've
+ // generated builders taking aggregrate parameters. Use those builders to
+ // create the ops.
+
+ // First prepare local variables for op arguments used in builder call.
+ createAggregateLocalVarsForOpArgs(tree, childNodeNames);
+ // Then create the op.
+ os.indent(6) << formatv(
+ "{0} = rewriter.create<{1}>(loc, tblgen_values, tblgen_attrs);\n",
+ valuePackName, resultOp.getQualCppClassName());
+ os.indent(4) << "}\n";
+ return resultValue;
+ }
+
+ bool isBroadcastable =
+ resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult");
+ bool usePartialResults = valuePackName != resultValue;
+
+ if (isBroadcastable || usePartialResults || depth > 0 || resultIndex < 0) {
+ // For these cases (broadcastable ops, op results used both as auxiliary
+ // values and replacement values, ops in nested patterns, auxiliary ops), we
+ // still need to supply the result types when building the op. But because
+ // we don't generate a builder automatically with ODS for them, it's the
+ // developer's responsiblity to make sure such a builder (with result type
+ // deduction ability) exists. We go through the separate-parameter builder
+ // here given that it's easier for developers to write compared to
+ // aggregate-parameter builders.
+ createSeparateLocalVarsForOpArgs(tree, childNodeNames);
+ os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
+ resultOp.getQualCppClassName());
+ supplyValuesForOpArgs(tree, childNodeNames);
+ os << "\n );\n";
+ os.indent(4) << "}\n";
+ return resultValue;
+ }
+
+ // If depth == 0 and resultIndex >= 0, it means we are replacing the values
+ // generated from the source pattern root op. Then we can use the source
+ // pattern's value types to determine the value type of the generated op
+ // here.
+
+ // First prepare local variables for op arguments used in builder call.
+ createAggregateLocalVarsForOpArgs(tree, childNodeNames);
+
+ // Then prepare the result types. We need to specify the types for all
+ // results.
+ os.indent(6) << formatv(
+ "SmallVector<Type, 4> tblgen_types; (void)tblgen_types;\n");
+ int numResults = resultOp.getNumResults();
+ if (numResults != 0) {
+ for (int i = 0; i < numResults; ++i)
+ os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) {{"
+ "tblgen_types.push_back(v->getType()); }\n",
+ resultIndex + i);
+ }
+ os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc, tblgen_types, "
+ "tblgen_values, tblgen_attrs);\n",
+ valuePackName, resultOp.getQualCppClassName());
+ os.indent(4) << "}\n";
+ return resultValue;
+}
+
+void PatternEmitter::createSeparateLocalVarsForOpArgs(
+ DagNode node, ChildNodeIndexNameMap &childNodeNames) {
+ Operator &resultOp = node.getDialectOp(opMap);
+
// Now prepare operands used for building this op:
// * If the operand is non-variadic, we create a `Value*` local variable.
// * If the operand is variadic, we create a `SmallVector<Value*>` local
// We do not need special handling for attributes.
continue;
}
+
std::string varName;
if (operand->isVariadic()) {
varName = formatv("tblgen_values_{0}", valueIndex++);
os.indent(6) << formatv("SmallVector<Value *, 4> {0};\n", varName);
std::string range;
- if (tree.isNestedDagArg(argIndex)) {
+ if (node.isNestedDagArg(argIndex)) {
range = childNodeNames[argIndex];
} else {
- range = tree.getArgName(argIndex);
+ range = node.getArgName(argIndex);
}
// Resolve the symbol for all range use so that we have a uniform way of
// capturing the values.
} else {
varName = formatv("tblgen_value_{0}", valueIndex++);
os.indent(6) << formatv("Value *{0} = ", varName);
- if (tree.isNestedDagArg(argIndex)) {
+ if (node.isNestedDagArg(argIndex)) {
os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
} else {
- DagLeaf leaf = tree.getArgAsLeaf(argIndex);
+ DagLeaf leaf = node.getArgAsLeaf(argIndex);
auto symbol =
- symbolInfoMap.getValueAndRangeUse(tree.getArgName(argIndex));
+ symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
if (leaf.isNativeCodeCall()) {
os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
} else {
// Update to use the newly created local variable for building the op later.
childNodeNames[argIndex] = varName;
}
+}
- // Then we create the builder call.
-
- // Right now we don't have general type inference in MLIR. Except a few
- // special cases listed below, we need to supply types for all results
- // when building an op.
- bool isSameOperandsAndResultType =
- resultOp.hasTrait("OpTrait::SameOperandsAndResultType");
- bool isBroadcastable =
- resultOp.hasTrait("OpTrait::BroadcastableTwoOperandsOneResult");
- bool useFirstAttr = resultOp.hasTrait("OpTrait::FirstAttrDerivedResultType");
- bool usePartialResults = valuePackName != resultValue;
-
- if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
- usePartialResults || depth > 0 || resultIndex < 0) {
- os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
- resultOp.getQualCppClassName());
- } else {
- // If depth == 0 and resultIndex >= 0, it means we are replacing the values
- // generated from the source pattern root op. Then we can use the source
- // pattern's value types to determine the value type of the generated op
- // here.
-
- // We need to specify the types for all results.
- int numResults = resultOp.getNumResults();
- if (numResults != 0) {
- os.indent(6) << "tblgen_types.clear();\n";
- for (int i = 0; i < numResults; ++i) {
- os.indent(6) << formatv("for (auto *v : castedOp0.getODSResults({0})) "
- "tblgen_types.push_back(v->getType());\n",
- resultIndex + i);
- }
- }
-
- os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName,
- resultOp.getQualCppClassName());
- if (numResults != 0)
- os.indent(6) << ", tblgen_types";
- }
-
- // Add arguments for the builder call.
- for (int argIndex = 0; argIndex != numOpArgs; ++argIndex) {
+void PatternEmitter::supplyValuesForOpArgs(
+ DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
+ Operator &resultOp = node.getDialectOp(opMap);
+ for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
+ argIndex != numOpArgs; ++argIndex) {
// Start each argment on its own line.
(os << ",\n").indent(8);
Argument opArg = resultOp.getArg(argIndex);
// Handle the case of operand first.
if (auto *operand = opArg.dyn_cast<NamedTypeConstraint *>()) {
- if (!operand->name.empty()) {
+ if (!operand->name.empty())
os << "/*" << operand->name << "=*/";
- }
- os << childNodeNames[argIndex];
- // TODO(jpienaar): verify types
+ os << childNodeNames.lookup(argIndex);
continue;
}
// The argument in the op definition.
auto opArgName = resultOp.getArgName(argIndex);
- if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
+ if (auto subTree = node.getArgAsNestedDag(argIndex)) {
if (!subTree.isNativeCodeCall())
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
os << formatv("/*{0}=*/{1}", opArgName,
handleReplaceWithNativeCodeCall(subTree));
} else {
- auto leaf = tree.getArgAsLeaf(argIndex);
+ auto leaf = node.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern.
- auto patArgName = tree.getArgName(argIndex);
+ auto patArgName = node.getArgName(argIndex);
if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
if (!opArg.is<NamedAttribute *>())
os << handleOpArgument(leaf, patArgName);
}
}
- os << "\n );\n";
- os.indent(4) << "}\n";
+}
- return resultValue;
+void PatternEmitter::createAggregateLocalVarsForOpArgs(
+ DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
+ Operator &resultOp = node.getDialectOp(opMap);
+
+ os.indent(6) << formatv(
+ "SmallVector<Value *, 4> tblgen_values; (void)tblgen_values;\n");
+ os.indent(6) << formatv(
+ "SmallVector<NamedAttribute, 4> tblgen_attrs; (void)tblgen_attrs;\n");
+
+ for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
+ if (const auto *attr =
+ resultOp.getArg(argIndex).dyn_cast<NamedAttribute *>()) {
+ const char *addAttrCmd = "if ({1}) {{"
+ " tblgen_attrs.emplace_back(rewriter."
+ "getIdentifier(\"{0}\"), {1}); }\n";
+ // The argument in the op definition.
+ auto opArgName = resultOp.getArgName(argIndex);
+ if (auto subTree = node.getArgAsNestedDag(argIndex)) {
+ if (!subTree.isNativeCodeCall())
+ PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
+ "for creating attribute");
+ os.indent(6) << formatv(addAttrCmd, opArgName,
+ handleReplaceWithNativeCodeCall(subTree));
+ } else {
+ auto leaf = node.getArgAsLeaf(argIndex);
+ // The argument in the result DAG pattern.
+ auto patArgName = node.getArgName(argIndex);
+ os.indent(6) << formatv(addAttrCmd, opArgName,
+ handleOpArgument(leaf, patArgName));
+ }
+ continue;
+ }
+
+ const auto *operand =
+ resultOp.getArg(argIndex).get<NamedTypeConstraint *>();
+ std::string varName;
+ if (operand->isVariadic()) {
+ std::string range;
+ if (node.isNestedDagArg(argIndex)) {
+ range = childNodeNames.lookup(argIndex);
+ } else {
+ range = node.getArgName(argIndex);
+ }
+ // Resolve the symbol for all range use so that we have a uniform way of
+ // capturing the values.
+ range = symbolInfoMap.getValueAndRangeUse(range);
+ os.indent(6) << formatv(
+ "for (auto *v : {0}) tblgen_values.push_back(v);\n", range);
+ } else {
+ os.indent(6) << formatv("tblgen_values.push_back(", varName);
+ if (node.isNestedDagArg(argIndex)) {
+ os << symbolInfoMap.getValueAndRangeUse(
+ childNodeNames.lookup(argIndex));
+ } else {
+ DagLeaf leaf = node.getArgAsLeaf(argIndex);
+ auto symbol =
+ symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex));
+ if (leaf.isNativeCodeCall()) {
+ os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
+ } else {
+ os << symbol;
+ }
+ }
+ os << ");\n";
+ }
+ }
}
static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) {