// The `node` are normally MLIR ops, but it can also be one of the directives
// listed later in this section.
//
+// ## Symbol binding
+//
// In the source pattern, `argN` can be used to specify matchers (e.g., using
// type/attribute type constraints, etc.) and bound to a name for later use.
// We can also bound names to op instances to reference them later in
// build `OneResultOp2`. `$op1` is bound to `OneResultOp1` and used to
// check whether the result's shape is static. `$op2` is bound to
// `OneResultOp2` and used to build `OneResultOp3`.
+//
+// ## Multi-result op
+//
+// To create multi-result ops in result pattern, you can use a syntax similar
+// to uni-result op, and it will act as a value pack for all results:
+//
+// ```
+// def : Pattern<(ThreeResultOp ...),
+// [(TwoResultOp ...), (OneResultOp ...)]>;
+// ```
+//
+// Then `TwoResultOp` will replace the first two values of `ThreeResultOp`.
+//
+// You can also use `$<name>__N` to explicitly access the N-th reusult.
+// ```
+// def : Pattern<(FiveResultOp ...),
+// [(TwoResultOp1:$res1__1 ...), (replaceWithValue $res1__0),
+// (TwoResultOp2:$res2 ...), (replaceWithValue $res2__1)]>;
+// ```
+//
+// Then the values generated by `FiveResultOp` will be replaced by
+//
+// * `FiveResultOp`#0: `TwoResultOp1`#1
+// * `FiveResultOp`#1: `TwoResultOp1`#0
+// * `FiveResultOp`#2: `TwoResultOp2`#0
+// * `FiveResultOp`#3: `TwoResultOp2`#1
+// * `FiveResultOp`#4: `TwoResultOp2`#1
class Pattern<dag source, list<dag> results, list<dag> preds = [],
dag benefitAdded = (addBenefit 0)> {
dag sourcePattern = source;
def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>;
def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>;
+//===----------------------------------------------------------------------===//
+// Test Patterns (Multi-result Ops)
+//===----------------------------------------------------------------------===//
+
+def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>;
+def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>;
+def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>;
+def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>;
+
+def MultiResultOpEnum: I64EnumAttr<
+ "Multi-result op kinds", "", [
+ MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3,
+ MultiResultOpKind4
+ ]>;
+
+def ThreeResultOp : TEST_Op<"three_result"> {
+ let arguments = (ins MultiResultOpEnum:$kind);
+ let results = (outs I32:$result1, F32:$result2, F32:$result3);
+}
+
+def AnotherThreeResultOp : TEST_Op<"another_three_result"> {
+ let arguments = (ins MultiResultOpEnum:$kind);
+ let results = (outs I32:$result1, F32:$result2, F32:$result3);
+}
+
+def TwoResultOp : TEST_Op<"two_result"> {
+ let arguments = (ins MultiResultOpEnum:$kind);
+ let results = (outs I32:$result1, F32:$result2);
+
+ let builders = [
+ OpBuilder<
+ "Builder *builder, OperationState *state, IntegerAttr kind",
+ [{
+ auto i32 = builder->getIntegerType(32);
+ auto f32 = builder->getF32Type();
+ state->types.assign({i32, f32});
+ state->addAttribute("kind", kind);
+ }]>
+ ];
+}
+
+def AnotherTwoResultOp : TEST_Op<"another_two_result"> {
+ let arguments = (ins MultiResultOpEnum:$kind);
+ let results = (outs F32:$result1, F32:$result2);
+}
+
+def OneResultOp : TEST_Op<"one_result"> {
+ let arguments = (ins MultiResultOpEnum:$kind);
+ let results = (outs F32:$result1);
+}
+
+def AnotherOneResultOp : TEST_Op<"another_one_result"> {
+ let arguments = (ins MultiResultOpEnum:$kind);
+ let results = (outs I32:$result1);
+}
+
+// Test using multi-result op as a whole
+def : Pat<(ThreeResultOp MultiResultOpKind1),
+ (AnotherThreeResultOp MultiResultOpKind1)>;
+
+// Test using multi-result op as a whole for partial replacement
+def : Pattern<(ThreeResultOp MultiResultOpKind2),
+ [(TwoResultOp MultiResultOpKind2),
+ (OneResultOp MultiResultOpKind2)]>;
+def : Pattern<(ThreeResultOp MultiResultOpKind3),
+ [(AnotherOneResultOp MultiResultOpKind3),
+ (AnotherTwoResultOp MultiResultOpKind3)]>;
+
+// Test using results separately in a multi-result op
+def : Pattern<(ThreeResultOp MultiResultOpKind4),
+ [(TwoResultOp:$res1__0 MultiResultOpKind4),
+ (OneResultOp MultiResultOpKind4),
+ (TwoResultOp:$res2__1 MultiResultOpKind4)]>;
+
//===----------------------------------------------------------------------===//
// Test op regions
//===----------------------------------------------------------------------===//
// CHECK: auto res_c = rewriter.create<OpC>(
// CHECK: /*operand=*/res_b
// CHECK: auto vOpD0 = rewriter.create<OpD>(
-// CHECK: /*input1=*/res_b,
-// CHECK: /*input2=*/res_c,
+// CHECK: /*input1=*/res_b.getOperation()->getResult(0),
+// CHECK: /*input2=*/res_c.getOperation()->getResult(0),
// CHECK: /*input3=*/s.res_a->getResult(0),
// CHECK: /*attr=*/s.attr
let results = (outs I32:$r1);
}
-def a : Pattern<(ThreeResultOp $input), [
- (OneResultOp $input),
- (OneResultOp $input),
- (OneResultOp $input)
- ]>;
-
-// CHECK-LABEL: struct a
-
-// CHECK: void rewrite(
-// CHECK: auto vOneResultOp0 = rewriter.create<OneResultOp>(
-// CHECK: auto vOneResultOp1 = rewriter.create<OneResultOp>(
-// CHECK: auto vOneResultOp2 = rewriter.create<OneResultOp>(
-// CHECK: rewriter.replaceOp(op, {vOneResultOp0, vOneResultOp1, vOneResultOp2});
-
def b : Pattern<(ThreeResultOp $input), [
(OneResultOp (OneResultOp:$interm $input)),
(OneResultOp $interm),
// CHECK-NEXT: /*input=*/interm
// CHECK: auto vOneResultOp3 = rewriter.create<OneResultOp>(
// CHECK-NEXT: /*input=*/vOneResultOp2
-// CHECK: rewriter.replaceOp(op, {vOneResultOp0, vOneResultOp1, vOneResultOp3});
+// CHECK: rewriter.replaceOp(op, {vOneResultOp0.getOperation()->getResult(0), vOneResultOp1.getOperation()->getResult(0), vOneResultOp3.getOperation()->getResult(0)});
// Test more result patterns than needed for replacement
// ---
// CHECK: auto vOneResultOp0 = rewriter.create<OneResultOp>(
// CHECK-NEXT: /*input=*/interm
// CHECK: auto vOneResultOp1 = rewriter.create<OneResultOp>(
-// CHECK: rewriter.replaceOp(op, {vOneResultOp0, vOneResultOp1});
+// CHECK: rewriter.replaceOp(op, {vOneResultOp0.getOperation()->getResult(0), vOneResultOp1.getOperation()->getResult(0)});
%0 = "test.i64_enum_attr"() {attr = 5: i64} : () -> i32
return %0 : i32
}
+
+//===----------------------------------------------------------------------===//
+// Test Multi-result Ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @useMultiResultOpToReplaceWhole
+func @useMultiResultOpToReplaceWhole() -> (i32, f32, f32) {
+ // CHECK: %0:3 = "test.another_three_result"()
+ // CHECK: return %0#0, %0#1, %0#2
+ %0:3 = "test.three_result"() {kind = 1} : () -> (i32, f32, f32)
+ return %0#0, %0#1, %0#2 : i32, f32, f32
+}
+
+// CHECK-LABEL: @useMultiResultOpToReplacePartial1
+func @useMultiResultOpToReplacePartial1() -> (i32, f32, f32) {
+ // CHECK: %0:2 = "test.two_result"()
+ // CHECK: %1 = "test.one_result"()
+ // CHECK: return %0#0, %0#1, %1
+ %0:3 = "test.three_result"() {kind = 2} : () -> (i32, f32, f32)
+ return %0#0, %0#1, %0#2 : i32, f32, f32
+}
+
+// CHECK-LABEL: @useMultiResultOpToReplacePartial2
+func @useMultiResultOpToReplacePartial2() -> (i32, f32, f32) {
+ // CHECK: %0 = "test.another_one_result"()
+ // CHECK: %1:2 = "test.another_two_result"()
+ // CHECK: return %0, %1#0, %1#1
+ %0:3 = "test.three_result"() {kind = 3} : () -> (i32, f32, f32)
+ return %0#0, %0#1, %0#2 : i32, f32, f32
+}
+
+// CHECK-LABEL: @useMultiResultOpResultsSeparately
+func @useMultiResultOpResultsSeparately() -> (i32, f32, f32) {
+ // CHECK: %0:2 = "test.two_result"()
+ // CHECK: %1 = "test.one_result"()
+ // CHECK: %2:2 = "test.two_result"()
+ // CHECK: return %0#0, %1, %2#1
+ %0:3 = "test.three_result"() {kind = 4} : () -> (i32, f32, f32)
+ return %0#0, %0#1, %0#2 : i32, f32, f32
+}
return Twine("s.") + symbol;
}
+// Gets the dynamic value pack's name by removing the index suffix from
+// `symbol`. Returns `symbol` itself if it does not contain an index.
+//
+// We can use `name__<index>` to access the `<index>`-th value in the dynamic
+// value pack bound to `name`. `name` is typically the results of an
+// multi-result op.
+static StringRef getValuePackName(StringRef symbol, unsigned *index = nullptr) {
+ StringRef name, indexStr;
+ unsigned idx = 0;
+ std::tie(name, indexStr) = symbol.rsplit("__");
+ if (indexStr.consumeInteger(10, idx)) {
+ // The second part is not an index.
+ return symbol;
+ }
+ if (index)
+ *index = idx;
+ return name;
+}
+
+// Formats all values from a dynamic value pack `symbol` according to the given
+// `fmt` string. The `fmt` string should use `{0}` as a placeholder for `symbol`
+// and `{1}` as a placeholder for the value index, which will be offsetted by
+// `offset`. The `symbol` value pack has a total of `count` values.
+//
+// This extracts one value from the pack if `symbol` contains an index,
+// otherwise it extracts all values sequentially and returns them as a
+// comma-separated list.
+static std::string formtValuePack(const char *fmt, StringRef symbol,
+ unsigned count, unsigned offset) {
+ auto getNthValue = [fmt, offset](StringRef results,
+ unsigned index) -> std::string {
+ return formatv(fmt, results, index + offset);
+ };
+
+ unsigned index = 0;
+ StringRef name = getValuePackName(symbol, &index);
+ if (name != symbol) {
+ // The symbol contains an index.
+ return getNthValue(name, index);
+ }
+
+ // The symbol does not contain an index. Treat the symbol as a whole.
+ SmallVector<std::string, 4> values;
+ values.reserve(count);
+ for (unsigned i = 0; i < count; ++i)
+ values.emplace_back(getNthValue(symbol, i));
+ return llvm::join(values, ", ");
+}
+
//===----------------------------------------------------------------------===//
// PatternSymbolResolver
//===----------------------------------------------------------------------===//
// `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`.
// `$op2` is bound to `ResOp1`.
//
+// If a symbol binds to a multi-result op and it does not have the `__N`
+// suffix, the symbol is expanded to the whole value pack generated by the
+// multi-result op. If the symbol has a `__N` suffix, then it will expand to
+// only the N-th result.
+//
// This class keeps track of such symbols and translates them into their bound
// values.
//
PatternSymbolResolver(const StringMap<Argument> &srcArgs,
const StringSet<> &srcOperations);
- // Marks the given `symbol` as bound. Returns false if the `symbol` is
- // already bound.
- bool add(StringRef symbol);
+ // Marks the given `symbol` as bound to a value pack with `numValues` and
+ // returns true on success. Returns false if the `symbol` is already bound.
+ bool add(StringRef symbol, int numValues);
// Queries the substitution for the given `symbol`.
std::string query(StringRef symbol) const;
// Symbols bound to ops (for their results) in source pattern.
const StringSet<> &sourceOps;
// Symbols bound to ops (for their results) in result patterns.
- StringSet<> resultOps;
+ // Key: symbol; value: number of values inside the pack
+ StringMap<int> resultOps;
};
} // end anonymous namespace
const StringSet<> &srcOperations)
: sourceArguments(srcArgs), sourceOps(srcOperations) {}
-bool PatternSymbolResolver::add(StringRef symbol) {
- return resultOps.insert(symbol).second;
+bool PatternSymbolResolver::add(StringRef symbol, int numValues) {
+ StringRef name = getValuePackName(symbol);
+ return resultOps.try_emplace(name, numValues).second;
}
std::string PatternSymbolResolver::query(StringRef symbol) const {
{
- auto it = resultOps.find(symbol);
+ StringRef name = getValuePackName(symbol);
+ auto it = resultOps.find(name);
if (it != resultOps.end())
- return it->getKey();
+ return formtValuePack("{0}.getOperation()->getResult({1})", symbol,
+ it->second, 0);
}
{
auto it = sourceArguments.find(symbol);
int numExpectedResults = rootOp.getNumResults();
int numProvidedResults = pattern.getNumResults();
- if (numProvidedResults < numExpectedResults)
- PrintFatalError(
- loc, "no enough result patterns to replace root op in source pattern");
-
os << R"(
void rewrite(Operation *op, std::unique_ptr<PatternState> state,
PatternRewriter &rewriter) const override {
if (resultTree.isReplaceWithValue())
return handleReplaceWithValue(resultTree);
- return emitOpCreate(resultTree, resultIndex, depth);
+ // Create the op and get the local variable for it.
+ auto results = emitOpCreate(resultTree, resultIndex, depth);
+ // We need to get all the values out of this local variable if we've created a
+ // multi-result op.
+ const auto &numResults = pattern.getDialectOp(resultTree).getNumResults();
+ return formtValuePack("{0}.getOperation()->getResult({1})", results,
+ numResults, 0);
}
std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue");
}
- auto name = tree.getArgName(0);
- pattern.ensureBoundInSourcePattern(name);
-
- return getBoundSymbol(name).str();
+ return resolveSymbol(tree.getArgName(0));
}
void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) {
// Skip empty-named symbols, which happen for unbound ops in result patterns.
if (symbol.empty())
return;
- if (!symbolResolver.add(symbol))
+ if (!symbolResolver.add(symbol, pattern.getDialectOp(node).getNumResults()))
PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol));
}
Operator &resultOp = tree.getDialectOp(opMap);
auto numOpArgs = resultOp.getNumArgs();
- if (resultOp.getNumResults() > 1) {
- PrintFatalError(
- loc, formatv("generating multiple-result op '{0}' is unsupported now",
- resultOp.getOperationName()));
- }
if (resultOp.isVariadic()) {
PrintFatalError(loc, formatv("generating op '{0}' with variadic "
"operands/results is unsupported now",
std::string resultValue = tree.getOpName();
if (resultValue.empty())
resultValue = getUniqueValueName(&resultOp);
+ // Strip the index to get the name for the value pack. This will be used to
+ // name the local variable for the op.
+ StringRef valuePackName = getValuePackName(resultValue);
// Then we build the new op corresponding to this DAG node.
+ // Right now we don't have general type inference in MLIR. Except a few
+ // special cases listed below, we need to supply types for all results
+ // when building an op.
bool isSameOperandsAndResultType =
resultOp.hasTrait("SameOperandsAndResultType");
bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult");
bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType");
+ bool usePartialResults = valuePackName != resultValue;
if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
- depth > 0) {
- os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", resultValue,
- resultOp.getQualCppClassName());
+ usePartialResults || depth > 0) {
+ os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
+ valuePackName, resultOp.getQualCppClassName());
} else {
// If depth == 0 we can use the equivalence of the source and target root
// ops in the pattern to determine the return type.
- std::string resultType = formatv("op->getResult({0})", resultIndex).str();
- os.indent(4) << formatv(
- "auto {0} = rewriter.create<{1}>(loc, {2}->getType()", resultValue,
- resultOp.getQualCppClassName(), resultType);
+ // We need to specify the types for all results.
+ auto resultTypes =
+ formtValuePack("op->getResult({1})->getType()", valuePackName,
+ resultOp.getNumResults(), resultIndex);
+
+ os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc, {2}",
+ valuePackName, resultOp.getQualCppClassName(),
+ resultTypes);
}
// Create the builder call for the result.