// specifying rewrite rules.
//
// A rewrite rule contains two components: a source pattern and one or more
-// result rules. Each pattern is specified as a (recursive) DAG node (tree)
+// result patterns. Each pattern is specified as a (recursive) DAG node (tree)
// in the form of `(node arg0, arg1, ...)`.
// The `node` are normally MLIR ops, but it can also be one of the directives
// listed later in this section.
// In the source pattern, `arg*` can be used to specify matchers (e.g., using
-// type/attribute types, mAttr, etc.) and bound to a name for later use. In
+// type/attribute types, etc.) and bound to a name for later use. In
// the result pattern, `arg*` can be used to refer to a previously bound name,
// with potential transformations (e.g., using tAttr, etc.). `arg*` can itself
// be nested DAG node.
class Pattern<dag source, list<dag> results, list<dag> preds = [],
dag benefitAdded = (addBenefit 0)> {
dag sourcePattern = source;
+ // Result patterns. Each result pattern is expected to replace one result
+ // of the root op in the source pattern. In the case of more result patterns
+ // than needed to replace the source op, only the last N results generated
+ // by the last N result pattern is used to replace a N-result source op.
+ // So that the beginning result patterns can be used to generate additional
+ // ops to aid building the results used for replacement.
list<dag> resultPatterns = results;
// Multi-entity constraints. Each constraint here involves multiple entities
// matched in source pattern and places further constraints on them as a
let results = (outs I32:$r1, I32:$r2, I32:$r3);
}
+def TwoResultOp : Op<"two_result_op", []> {
+ let arguments = (ins I32:$input);
+ let results = (outs I32:$r1, I32:$r2);
+}
+
def OneResultOp : Op<"one_result_op", []> {
let arguments = (ins I32:$input);
let results = (outs I32:$r1);
// CHECK-NEXT: /*input=*/vOneResultOp2
// CHECK: rewriter.replaceOp(op, {vOneResultOp0, vOneResultOp1, vOneResultOp3});
+// Test more result patterns than needed for replacement
+// ---
+def AdditionalOp : Op<"additional_one_result_op", []> {
+ let arguments = (ins I32:$input);
+ let results = (outs I32:$r1);
+}
+def : Pattern<(TwoResultOp $input), [
+ // Additional op generated to help build the final result but not
+ // directly used to replace the source op
+ (AdditionalOp:$interm $input),
+
+ (OneResultOp $interm),
+ (OneResultOp $input)
+ ]>;
+
+// CHECK-LABEL: struct GeneratedConvert2
+
+// CHECK: auto interm = rewriter.create<AdditionalOp>(
+// CHECK: auto vOneResultOp0 = rewriter.create<OneResultOp>(
+// CHECK-NEXT: /*input=*/interm
+// CHECK: auto vOneResultOp1 = rewriter.create<OneResultOp>(
+// CHECK: rewriter.replaceOp(op, {vOneResultOp0, vOneResultOp1});
}
def Y_AddOp : Op<"y.add"> {
let arguments = (ins U, U, T_Attr:$attrName);
+ let results = (outs U);
}
def Z_AddOp : Op<"z.add"> {
let arguments = (ins U, U, T_Attr:$attrName1, T_Attr:$attrName2);
+ let results = (outs U);
}
// Define rewrite pattern.
const Operator &rootOp = pattern.getSourceRootOp();
auto rootName = rootOp.getOperationName();
+ if (rootOp.hasVariadicResult())
+ PrintFatalError(
+ loc, "replacing op with variadic results not supported right now");
+
// Emit RewritePattern for Pattern.
os << formatv(R"(struct {0} : public RewritePattern {
{0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
}
void PatternEmitter::emitRewriteMethod() {
- unsigned numResults = pattern.getNumResults();
- if (numResults == 0)
- PrintFatalError(loc, "must provide at least one result pattern");
+ const Operator &rootOp = pattern.getSourceRootOp();
+ int numExpectedResults = rootOp.getNumResults();
+ unsigned numProvidedResults = pattern.getNumResults();
+
+ if (numProvidedResults < numExpectedResults)
+ PrintFatalError(
+ loc, "no enough result patterns to replace root op in source pattern");
os << R"(
void rewrite(Operation *op, std::unique_ptr<PatternState> state,
// Collect the replacement value for each result
llvm::SmallVector<std::string, 2> resultValues;
- for (unsigned i = 0; i < numResults; ++i) {
+ for (unsigned i = 0; i < numProvidedResults; ++i) {
DagNode resultTree = pattern.getResultPattern(i);
resultValues.push_back(handleRewritePattern(resultTree, i, 0));
}
// Emit the final replaceOp() statement
os.indent(4) << "rewriter.replaceOp(op, {";
interleave(
- resultValues, [&](const std::string &name) { os << name; },
- [&]() { os << ", "; });
+ // We only use the last numExpectedResults ones to replace the root op.
+ ArrayRef<std::string>(resultValues).take_back(numExpectedResults),
+ [&](const std::string &name) { os << name; }, [&]() { os << ", "; });
os << "});\n }\n";
}