//
//===----------------------------------------------------------------------===//
-#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/GenInfo.h"
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an operand.
- void emitOperandMatch(DagNode tree, int argIndex, int depth);
+ void emitOperandMatch(DagNode tree, int argIndex, int depth, int indent);
// Emits C++ statements for matching the `argIndex`-th argument of the given
// DAG `tree` as an attribute.
- void emitAttributeMatch(DagNode tree, int argIndex, int depth);
+ void emitAttributeMatch(DagNode tree, int argIndex, int depth, int indent);
// Emits C++ for checking a match with a corresponding match failure
// diagnostic.
// The next unused ID for newly created values.
unsigned nextValueId;
- raw_indented_ostream os;
+ raw_ostream &os;
// Format contexts containing placeholder substitutions.
FmtContext fmtCtx;
// Skip the operand matching at depth 0 as the pattern rewriter already does.
if (depth != 0) {
// Skip if there is no defining operation (e.g., arguments to function).
- os << formatv("if (!castedOp{0})\n return failure();\n", depth);
+ os.indent(indent) << formatv("if (!castedOp{0}) return failure();\n",
+ depth);
}
if (tree.getNumArgs() != op.getNumArgs()) {
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
// If the operand's name is set, set to that variable.
auto name = tree.getSymbol();
if (!name.empty())
- os << formatv("{0} = castedOp{1};\n", name, depth);
+ os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto opArg = op.getArg(i);
PrintFatalError(loc, error);
}
}
- os << "{\n";
+ os.indent(indent) << "{\n";
- os.indent() << formatv(
+ os.indent(indent + 2) << formatv(
"auto *op{0} = "
"(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n",
depth + 1, depth, i);
emitOpMatch(argTree, depth + 1);
- os << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
- os.unindent() << "}\n";
+ os.indent(indent + 2)
+ << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1);
+ os.indent(indent) << "}\n";
continue;
}
// Next handle DAG leaf: operand or attribute
if (opArg.is<NamedTypeConstraint *>()) {
- emitOperandMatch(tree, i, depth);
+ emitOperandMatch(tree, i, depth, indent);
} else if (opArg.is<NamedAttribute *>()) {
- emitAttributeMatch(tree, i, depth);
+ emitAttributeMatch(tree, i, depth, indent);
} else {
PrintFatalError(loc, "unhandled case when matching op");
}
<< '\n');
}
-void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) {
+void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth,
+ int indent) {
Operator &op = tree.getDialectOp(opMap);
auto *operand = op.getArg(argIndex).get<NamedTypeConstraint *>();
auto matcher = tree.getArgAsLeaf(argIndex);
op.arg_begin(), op.arg_begin() + argIndex,
[](const Argument &arg) { return arg.is<NamedAttribute *>(); });
- os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", name, depth,
- argIndex - numPrevAttrs);
+ os.indent(indent) << formatv("{0} = castedOp{1}.getODSOperands({2});\n",
+ name, depth, argIndex - numPrevAttrs);
}
}
-void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) {
+void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth,
+ int indent) {
Operator &op = tree.getDialectOp(opMap);
auto *namedAttr = op.getArg(argIndex).get<NamedAttribute *>();
const auto &attr = namedAttr->attr;
- os << "{\n";
- os.indent() << formatv(
- "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\"); "
+ os.indent(indent) << "{\n";
+ indent += 2;
+ os.indent(indent) << formatv(
+ "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\");"
"(void)tblgen_attr;\n",
depth, attr.getStorageType(), namedAttr->name);
// TODO: This should use getter method to avoid duplication.
if (attr.hasDefaultValue()) {
- os << "if (!tblgen_attr) tblgen_attr = "
- << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx,
- attr.getDefaultValue()))
- << ";\n";
+ os.indent(indent) << "if (!tblgen_attr) tblgen_attr = "
+ << std::string(tgfmt(attr.getConstBuilderTemplate(),
+ &fmtCtx, attr.getDefaultValue()))
+ << ";\n";
} else if (attr.isOptional()) {
// For a missing attribute that is optional according to definition, we
// should just capture a mlir::Attribute() to signal the missing state.
auto name = tree.getArgName(argIndex);
// `$_` is a special symbol to ignore op argument matching.
if (!name.empty() && name != "_") {
- os << formatv("{0} = tblgen_attr;\n", name);
+ os.indent(indent) << formatv("{0} = tblgen_attr;\n", name);
}
- os.unindent() << "}\n";
+ indent -= 2;
+ os.indent(indent) << "}\n";
}
void PatternEmitter::emitMatchCheck(
int depth, const FmtObjectBase &matchFmt,
const llvm::formatv_object_base &failureFmt) {
- os << "if (!(" << matchFmt.str() << "))";
- os.scope("{\n", "\n}\n").os
- << "return rewriter.notifyMatchFailure(op" << depth
- << ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureFmt.str()
- << ";\n});";
+ // {0} The match depth (used to get the operation that failed to match).
+ // {1} The format for the match string.
+ // {2} The format for the failure string.
+ const char *matchStr = R"(
+ if (!({1})) {
+ return rewriter.notifyMatchFailure(op{0}, [&](::mlir::Diagnostic &diag) {
+ diag << {2};
+ });
+ })";
+ os << llvm::formatv(matchStr, depth, matchFmt.str(), failureFmt.str())
+ << "\n";
}
void PatternEmitter::emitMatchLogic(DagNode tree) {
// Emit RewritePattern for Pattern.
auto locs = pattern.getLocation();
- os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n",
+ os << formatv("/* Generated from:\n\t{0:$[ instantiating\n\t]}\n*/\n",
make_range(locs.rbegin(), locs.rend()));
os << formatv(R"(struct {0} : public ::mlir::RewritePattern {
{0}(::mlir::MLIRContext *context)
os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n";
// Emit matchAndRewrite() function.
- {
- auto classScope = os.scope();
- os.reindent(R"(
- ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0,
- ::mlir::PatternRewriter &rewriter) const override {)")
- << '\n';
- {
- auto functionScope = os.scope();
-
- // Register all symbols bound in the source pattern.
- pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
-
- LLVM_DEBUG(llvm::dbgs()
- << "start creating local variables for capturing matches\n");
- os << "// Variables for capturing values and attributes used while "
- "creating ops\n";
- // Create local variables for storing the arguments and results bound
- // to symbols.
- for (const auto &symbolInfoPair : symbolInfoMap) {
- StringRef symbol = symbolInfoPair.getKey();
- auto &info = symbolInfoPair.getValue();
- os << info.getVarDecl(symbol);
- }
- // TODO: capture ops with consistent numbering so that it can be
- // reused for fused loc.
- os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
- pattern.getSourcePattern().getNumOps());
- LLVM_DEBUG(llvm::dbgs()
- << "done creating local variables for capturing matches\n");
-
- os << "// Match\n";
- os << "tblgen_ops[0] = op0;\n";
- emitMatchLogic(sourceTree);
-
- os << "\n// Rewrite\n";
- emitRewriteLogic();
-
- os << "return success();\n";
- }
- os << "};\n";
+ os << R"(
+ ::mlir::LogicalResult
+ matchAndRewrite(::mlir::Operation *op0,
+ ::mlir::PatternRewriter &rewriter) const override {
+)";
+
+ // Register all symbols bound in the source pattern.
+ pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
+
+ LLVM_DEBUG(
+ llvm::dbgs() << "start creating local variables for capturing matches\n");
+ os.indent(4) << "// Variables for capturing values and attributes used for "
+ "creating ops\n";
+ // Create local variables for storing the arguments and results bound
+ // to symbols.
+ for (const auto &symbolInfoPair : symbolInfoMap) {
+ StringRef symbol = symbolInfoPair.getKey();
+ auto &info = symbolInfoPair.getValue();
+ os.indent(4) << info.getVarDecl(symbol);
}
- os << "};\n\n";
+ // TODO: capture ops with consistent numbering so that it can be
+ // reused for fused loc.
+ os.indent(4) << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n",
+ pattern.getSourcePattern().getNumOps());
+ LLVM_DEBUG(
+ llvm::dbgs() << "done creating local variables for capturing matches\n");
+
+ os.indent(4) << "// Match\n";
+ os.indent(4) << "tblgen_ops[0] = op0;\n";
+ emitMatchLogic(sourceTree);
+ os << "\n";
+
+ os.indent(4) << "// Rewrite\n";
+ emitRewriteLogic();
+
+ os.indent(4) << "return success();\n";
+ os << " };\n";
+ os << "};\n";
}
void PatternEmitter::emitRewriteLogic() {
PrintFatalError(loc, error);
}
- os << "auto odsLoc = rewriter.getFusedLoc({";
+ os.indent(4) << "auto odsLoc = rewriter.getFusedLoc({";
for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()";
}
// we are handling auxiliary patterns so we want the side effect even if
// NativeCodeCall is not replacing matched root op's results.
if (resultTree.isNativeCodeCall())
- os << val << ";\n";
+ os.indent(4) << val << ";\n";
}
if (numExpectedResults == 0) {
assert(replStartIndex >= numResultPatterns &&
"invalid auxiliary vs. replacement pattern division!");
// No result to replace. Just erase the op.
- os << "rewriter.eraseOp(op0);\n";
+ os.indent(4) << "rewriter.eraseOp(op0);\n";
} else {
// Process replacement result patterns.
- os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
+ os.indent(4)
+ << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n";
for (int i = replStartIndex; i < numResultPatterns; ++i) {
DagNode resultTree = pattern.getResultPattern(i);
auto val = handleResultPattern(resultTree, offsets[i], 0);
- os << "\n";
+ os.indent(4) << "\n";
// Resolve each symbol for all range use so that we can loop over them.
// We need an explicit cast to `SmallVector` to capture the cases where
// `{0}` resolves to an `Operation::result_range` as well as cases that
// TODO: Revisit the need for materializing a vector.
os << symbolInfoMap.getAllRangeUse(
val,
- "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
- " tblgen_repl_values.push_back(v);\n}\n",
+ " for (auto v : ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{ "
+ "tblgen_repl_values.push_back(v); }",
"\n");
}
- os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
+ os.indent(4) << "\n";
+ os.indent(4) << "rewriter.replaceOp(op0, tblgen_repl_values);\n";
}
LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
}
// Create the local variable for this op.
- os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(),
- valuePackName);
+ os.indent(4) << formatv("{0} {1};\n", resultOp.getQualCppClassName(),
+ valuePackName);
+ os.indent(4) << "{\n";
// Right now ODS don't have general type inference support. Except a few
// special cases listed below, DRR needs to supply types for all results
createAggregateLocalVarsForOpArgs(tree, childNodeNames);
// Then create the op.
- os.scope("", "\n}\n").os << formatv(
- "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);",
+ os.indent(6) << formatv(
+ "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);\n",
valuePackName, resultOp.getQualCppClassName(), locToUse);
+ os.indent(4) << "}\n";
return resultValue;
}
// aggregate-parameter builders.
createSeparateLocalVarsForOpArgs(tree, childNodeNames);
- os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
- resultOp.getQualCppClassName(), locToUse);
+ os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}", valuePackName,
+ resultOp.getQualCppClassName(), locToUse);
supplyValuesForOpArgs(tree, childNodeNames);
- os << "\n );\n}\n";
+ os << "\n );\n";
+ os.indent(4) << "}\n";
return resultValue;
}
// Then prepare the result types. We need to specify the types for all
// results.
- os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
- "(void)tblgen_types;\n");
+ os.indent(6) << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; "
+ "(void)tblgen_types;\n");
int numResults = resultOp.getNumResults();
if (numResults != 0) {
for (int i = 0; i < numResults; ++i)
- os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n"
- " tblgen_types.push_back(v.getType());\n}\n",
- resultIndex + i);
+ os.indent(6) << formatv("for (auto v : castedOp0.getODSResults({0})) {{"
+ "tblgen_types.push_back(v.getType()); }\n",
+ resultIndex + i);
}
- os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
- "tblgen_values, tblgen_attrs);\n",
- valuePackName, resultOp.getQualCppClassName(), locToUse);
- os.unindent() << "}\n";
+ os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, "
+ "tblgen_values, tblgen_attrs);\n",
+ valuePackName, resultOp.getQualCppClassName(),
+ locToUse);
+ os.indent(4) << "}\n";
return resultValue;
}
for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
const auto *operand =
resultOp.getArg(argIndex).dyn_cast<NamedTypeConstraint *>();
- // We do not need special handling for attributes.
- if (!operand)
+ if (!operand) {
+ // We do not need special handling for attributes.
continue;
+ }
- raw_indented_ostream::DelimitedScope scope(os);
std::string varName;
if (operand->isVariadic()) {
varName = std::string(formatv("tblgen_values_{0}", valueIndex++));
- os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName);
+ os.indent(6) << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n",
+ varName);
std::string range;
if (node.isNestedDagArg(argIndex)) {
range = childNodeNames[argIndex];
// Resolve the symbol for all range use so that we have a uniform way of
// capturing the values.
range = symbolInfoMap.getValueAndRangeUse(range);
- os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range,
- varName);
+ os.indent(6) << formatv("for (auto v : {0}) {1}.push_back(v);\n", range,
+ varName);
} else {
varName = std::string(formatv("tblgen_value_{0}", valueIndex++));
- os << formatv("::mlir::Value {0} = ", varName);
+ os.indent(6) << formatv("::mlir::Value {0} = ", varName);
if (node.isNestedDagArg(argIndex)) {
os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]);
} else {
for (int argIndex = 0, numOpArgs = resultOp.getNumArgs();
argIndex != numOpArgs; ++argIndex) {
// Start each argument on its own line.
- os << ",\n ";
+ (os << ",\n").indent(8);
Argument opArg = resultOp.getArg(argIndex);
// Handle the case of operand first.
DagNode node, const ChildNodeIndexNameMap &childNodeNames) {
Operator &resultOp = node.getDialectOp(opMap);
- auto scope = os.scope();
- os << formatv("::mlir::SmallVector<::mlir::Value, 4> "
- "tblgen_values; (void)tblgen_values;\n");
- os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
- "tblgen_attrs; (void)tblgen_attrs;\n");
+ os.indent(6) << formatv("::mlir::SmallVector<::mlir::Value, 4> "
+ "tblgen_values; (void)tblgen_values;\n");
+ os.indent(6) << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> "
+ "tblgen_attrs; (void)tblgen_attrs;\n");
const char *addAttrCmd =
- "if (auto tmpAttr = {1}) {\n"
- " tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), "
- "tmpAttr);\n}\n";
+ "if (auto tmpAttr = {1}) "
+ "tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), tmpAttr);\n";
for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) {
if (resultOp.getArg(argIndex).is<NamedAttribute *>()) {
// The argument in the op definition.
if (!subTree.isNativeCodeCall())
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
- os << formatv(addAttrCmd, opArgName,
- handleReplaceWithNativeCodeCall(subTree));
+ os.indent(6) << formatv(addAttrCmd, opArgName,
+ handleReplaceWithNativeCodeCall(subTree));
} else {
auto leaf = node.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern.
auto patArgName = node.getArgName(argIndex);
- os << formatv(addAttrCmd, opArgName,
- handleOpArgument(leaf, patArgName));
+ os.indent(6) << formatv(addAttrCmd, opArgName,
+ handleOpArgument(leaf, patArgName));
}
continue;
}
// Resolve the symbol for all range use so that we have a uniform way of
// capturing the values.
range = symbolInfoMap.getValueAndRangeUse(range);
- os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n",
- range);
+ os.indent(6) << formatv(
+ "for (auto v : {0}) tblgen_values.push_back(v);\n", range);
} else {
- os << formatv("tblgen_values.push_back(", varName);
+ os.indent(6) << formatv("tblgen_values.push_back(", varName);
if (node.isNestedDagArg(argIndex)) {
os << symbolInfoMap.getValueAndRangeUse(
childNodeNames.lookup(argIndex));