[TableGen] Support creating multi-result ops in result patterns
authorLei Zhang <antiagainst@google.com>
Thu, 4 Jul 2019 00:17:33 +0000 (17:17 -0700)
committerjpienaar <jpienaar@google.com>
Thu, 4 Jul 2019 01:17:49 +0000 (18:17 -0700)
This CL introduces a new syntax for creating multi-result ops and access their
results in result patterns. Specifically, if a multi-result op is unbound or
bound to a name without a trailing `__N` suffix, it will act as a value pack
and expand to all its values. If a multi-result op is bound to a symbol with
`__N` suffix, only the N-th result will be extracted and used.

PiperOrigin-RevId: 256465208

mlir/include/mlir/IR/OpBase.td
mlir/test/lib/TestDialect/TestOps.td
mlir/test/mlir-tblgen/pattern-bound-symbol.td
mlir/test/mlir-tblgen/pattern-multi-result-op.td
mlir/test/mlir-tblgen/pattern.mlir
mlir/tools/mlir-tblgen/RewriterGen.cpp

index 40ef9e4915bf093a2e9d6caf66b07c9a59ac8ce6..26a21bd18dbc462ab5da043f78e7916b1584fb68 100644 (file)
@@ -1257,6 +1257,8 @@ def addBenefit;
 // The `node` are normally MLIR ops, but it can also be one of the directives
 // listed later in this section.
 //
+// ## Symbol binding
+//
 // In the source pattern, `argN` can be used to specify matchers (e.g., using
 // type/attribute type constraints, etc.) and bound to a name for later use.
 // We can also bound names to op instances to reference them later in
@@ -1280,6 +1282,33 @@ def addBenefit;
 // build `OneResultOp2`. `$op1` is bound to `OneResultOp1` and used to
 // check whether the result's shape is static. `$op2` is bound to
 // `OneResultOp2` and used to build `OneResultOp3`.
+//
+// ## Multi-result op
+//
+// To create multi-result ops in result pattern, you can use a syntax similar
+// to uni-result op, and it will act as a value pack for all results:
+//
+// ```
+// def : Pattern<(ThreeResultOp ...),
+//               [(TwoResultOp ...), (OneResultOp ...)]>;
+// ```
+//
+// Then `TwoResultOp` will replace the first two values of `ThreeResultOp`.
+//
+// You can also use `$<name>__N` to explicitly access the N-th reusult.
+// ```
+// def : Pattern<(FiveResultOp ...),
+//               [(TwoResultOp1:$res1__1 ...), (replaceWithValue $res1__0),
+//                (TwoResultOp2:$res2 ...), (replaceWithValue $res2__1)]>;
+// ```
+//
+// Then the values generated by `FiveResultOp` will be replaced by
+//
+// * `FiveResultOp`#0: `TwoResultOp1`#1
+// * `FiveResultOp`#1: `TwoResultOp1`#0
+// * `FiveResultOp`#2: `TwoResultOp2`#0
+// * `FiveResultOp`#3: `TwoResultOp2`#1
+// * `FiveResultOp`#4: `TwoResultOp2`#1
 class Pattern<dag source, list<dag> results, list<dag> preds = [],
   dag benefitAdded = (addBenefit 0)> {
   dag sourcePattern = source;
index ba9b6429ec2d6fab6f5861cadaf20b232cdbf717..f5bba8c9cb8a5071da08837e999f20514689b10d 100644 (file)
@@ -231,6 +231,80 @@ def : Pat<(StrEnumAttrOp StrCaseA), (StrEnumAttrOp StrCaseB)>;
 def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>;
 def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>;
 
+//===----------------------------------------------------------------------===//
+// Test Patterns (Multi-result Ops)
+//===----------------------------------------------------------------------===//
+
+def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>;
+def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>;
+def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>;
+def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>;
+
+def MultiResultOpEnum: I64EnumAttr<
+  "Multi-result op kinds", "", [
+    MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3,
+    MultiResultOpKind4
+  ]>;
+
+def ThreeResultOp : TEST_Op<"three_result"> {
+  let arguments = (ins MultiResultOpEnum:$kind);
+  let results = (outs I32:$result1, F32:$result2, F32:$result3);
+}
+
+def AnotherThreeResultOp : TEST_Op<"another_three_result"> {
+  let arguments = (ins MultiResultOpEnum:$kind);
+  let results = (outs I32:$result1, F32:$result2, F32:$result3);
+}
+
+def TwoResultOp : TEST_Op<"two_result"> {
+  let arguments = (ins MultiResultOpEnum:$kind);
+  let results = (outs I32:$result1, F32:$result2);
+
+  let builders = [
+    OpBuilder<
+      "Builder *builder, OperationState *state, IntegerAttr kind",
+      [{
+        auto i32 = builder->getIntegerType(32);
+        auto f32 = builder->getF32Type();
+        state->types.assign({i32, f32});
+        state->addAttribute("kind", kind);
+      }]>
+  ];
+}
+
+def AnotherTwoResultOp : TEST_Op<"another_two_result"> {
+  let arguments = (ins MultiResultOpEnum:$kind);
+  let results = (outs F32:$result1, F32:$result2);
+}
+
+def OneResultOp : TEST_Op<"one_result"> {
+  let arguments = (ins MultiResultOpEnum:$kind);
+  let results = (outs F32:$result1);
+}
+
+def AnotherOneResultOp : TEST_Op<"another_one_result"> {
+  let arguments = (ins MultiResultOpEnum:$kind);
+  let results = (outs I32:$result1);
+}
+
+// Test using multi-result op as a whole
+def : Pat<(ThreeResultOp MultiResultOpKind1),
+          (AnotherThreeResultOp MultiResultOpKind1)>;
+
+// Test using multi-result op as a whole for partial replacement
+def : Pattern<(ThreeResultOp MultiResultOpKind2),
+              [(TwoResultOp MultiResultOpKind2),
+               (OneResultOp MultiResultOpKind2)]>;
+def : Pattern<(ThreeResultOp MultiResultOpKind3),
+              [(AnotherOneResultOp MultiResultOpKind3),
+               (AnotherTwoResultOp MultiResultOpKind3)]>;
+
+// Test using results separately in a multi-result op
+def : Pattern<(ThreeResultOp MultiResultOpKind4),
+              [(TwoResultOp:$res1__0 MultiResultOpKind4),
+               (OneResultOp MultiResultOpKind4),
+               (TwoResultOp:$res2__1 MultiResultOpKind4)]>;
+
 //===----------------------------------------------------------------------===//
 // Test op regions
 //===----------------------------------------------------------------------===//
index 61add86bcbfc98ecde0df052bc178056293a1882..9e22fff92910c455131b5d03cd5677109ea3034f 100644 (file)
@@ -65,7 +65,7 @@ def : Pattern<(OpA:$res_a $operand, $attr),
 // CHECK:   auto res_c = rewriter.create<OpC>(
 // CHECK:     /*operand=*/res_b
 // CHECK:   auto vOpD0 = rewriter.create<OpD>(
-// CHECK:     /*input1=*/res_b,
-// CHECK:     /*input2=*/res_c,
+// CHECK:     /*input1=*/res_b.getOperation()->getResult(0),
+// CHECK:     /*input2=*/res_c.getOperation()->getResult(0),
 // CHECK:     /*input3=*/s.res_a->getResult(0),
 // CHECK:     /*attr=*/s.attr
index ffae5915ef46a91df5975d694e1819bcf5136b99..063d48026f95c840af6ed926d6405a33d1ed1cfe 100644 (file)
@@ -24,20 +24,6 @@ def OneResultOp : NS_Op<"one_result_op", []> {
   let results = (outs I32:$r1);
 }
 
-def a : Pattern<(ThreeResultOp $input), [
-        (OneResultOp $input),
-        (OneResultOp $input),
-        (OneResultOp $input)
-      ]>;
-
-// CHECK-LABEL: struct a
-
-// CHECK: void rewrite(
-// CHECK:      auto vOneResultOp0 = rewriter.create<OneResultOp>(
-// CHECK:      auto vOneResultOp1 = rewriter.create<OneResultOp>(
-// CHECK:      auto vOneResultOp2 = rewriter.create<OneResultOp>(
-// CHECK:      rewriter.replaceOp(op, {vOneResultOp0, vOneResultOp1, vOneResultOp2});
-
 def b : Pattern<(ThreeResultOp $input), [
         (OneResultOp (OneResultOp:$interm $input)),
         (OneResultOp $interm),
@@ -57,7 +43,7 @@ def b : Pattern<(ThreeResultOp $input), [
 // CHECK-NEXT:     /*input=*/interm
 // CHECK:        auto vOneResultOp3 = rewriter.create<OneResultOp>(
 // CHECK-NEXT:     /*input=*/vOneResultOp2
-// CHECK:        rewriter.replaceOp(op, {vOneResultOp0, vOneResultOp1, vOneResultOp3});
+// CHECK:        rewriter.replaceOp(op, {vOneResultOp0.getOperation()->getResult(0), vOneResultOp1.getOperation()->getResult(0), vOneResultOp3.getOperation()->getResult(0)});
 
 // Test more result patterns than needed for replacement
 // ---
@@ -80,4 +66,4 @@ def c : Pattern<(TwoResultOp $input), [
 // CHECK:      auto vOneResultOp0 = rewriter.create<OneResultOp>(
 // CHECK-NEXT:   /*input=*/interm
 // CHECK:      auto vOneResultOp1 = rewriter.create<OneResultOp>(
-// CHECK:      rewriter.replaceOp(op, {vOneResultOp0, vOneResultOp1});
+// CHECK:      rewriter.replaceOp(op, {vOneResultOp0.getOperation()->getResult(0), vOneResultOp1.getOperation()->getResult(0)});
index 557c340fd63c2de131b5c164072ff168084de665..15c0b9af023c69c1983d92d05d80687bd63f51cc 100644 (file)
@@ -49,3 +49,43 @@ func @verifyI64EnumAttr() -> i32 {
   %0 = "test.i64_enum_attr"() {attr = 5: i64} : () -> i32
   return %0 : i32
 }
+
+//===----------------------------------------------------------------------===//
+// Test Multi-result Ops
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: @useMultiResultOpToReplaceWhole
+func @useMultiResultOpToReplaceWhole() -> (i32, f32, f32) {
+  // CHECK: %0:3 = "test.another_three_result"()
+  // CHECK: return %0#0, %0#1, %0#2
+  %0:3 = "test.three_result"() {kind = 1} : () -> (i32, f32, f32)
+  return %0#0, %0#1, %0#2 : i32, f32, f32
+}
+
+// CHECK-LABEL: @useMultiResultOpToReplacePartial1
+func @useMultiResultOpToReplacePartial1() -> (i32, f32, f32) {
+  // CHECK: %0:2 = "test.two_result"()
+  // CHECK: %1 = "test.one_result"()
+  // CHECK: return %0#0, %0#1, %1
+  %0:3 = "test.three_result"() {kind = 2} : () -> (i32, f32, f32)
+  return %0#0, %0#1, %0#2 : i32, f32, f32
+}
+
+// CHECK-LABEL: @useMultiResultOpToReplacePartial2
+func @useMultiResultOpToReplacePartial2() -> (i32, f32, f32) {
+  // CHECK: %0 = "test.another_one_result"()
+  // CHECK: %1:2 = "test.another_two_result"()
+  // CHECK: return %0, %1#0, %1#1
+  %0:3 = "test.three_result"() {kind = 3} : () -> (i32, f32, f32)
+  return %0#0, %0#1, %0#2 : i32, f32, f32
+}
+
+// CHECK-LABEL: @useMultiResultOpResultsSeparately
+func @useMultiResultOpResultsSeparately() -> (i32, f32, f32) {
+  // CHECK: %0:2 = "test.two_result"()
+  // CHECK: %1 = "test.one_result"()
+  // CHECK: %2:2 = "test.two_result"()
+  // CHECK: return %0#0, %1, %2#1
+  %0:3 = "test.three_result"() {kind = 4} : () -> (i32, f32, f32)
+  return %0#0, %0#1, %0#2 : i32, f32, f32
+}
index 2419c31b1701766a4079bd8089571e4ea61d9825..a89e903039f750fcd63bc010a9b52ceb6dcd7a12 100644 (file)
@@ -60,6 +60,55 @@ static Twine getBoundSymbol(const StringRef &symbol) {
   return Twine("s.") + symbol;
 }
 
+// Gets the dynamic value pack's name by removing the index suffix from
+// `symbol`. Returns `symbol` itself if it does not contain an index.
+//
+// We can use `name__<index>` to access the `<index>`-th value in the dynamic
+// value pack bound to `name`. `name` is typically the results of an
+// multi-result op.
+static StringRef getValuePackName(StringRef symbol, unsigned *index = nullptr) {
+  StringRef name, indexStr;
+  unsigned idx = 0;
+  std::tie(name, indexStr) = symbol.rsplit("__");
+  if (indexStr.consumeInteger(10, idx)) {
+    // The second part is not an index.
+    return symbol;
+  }
+  if (index)
+    *index = idx;
+  return name;
+}
+
+// Formats all values from a dynamic value pack `symbol` according to the given
+// `fmt` string. The `fmt` string should use `{0}` as a placeholder for `symbol`
+// and `{1}` as a placeholder for the value index, which will be offsetted by
+// `offset`. The `symbol` value pack has a total of `count` values.
+//
+// This extracts one value from the pack if `symbol` contains an index,
+// otherwise it extracts all values sequentially and returns them as a
+// comma-separated list.
+static std::string formtValuePack(const char *fmt, StringRef symbol,
+                                  unsigned count, unsigned offset) {
+  auto getNthValue = [fmt, offset](StringRef results,
+                                   unsigned index) -> std::string {
+    return formatv(fmt, results, index + offset);
+  };
+
+  unsigned index = 0;
+  StringRef name = getValuePackName(symbol, &index);
+  if (name != symbol) {
+    // The symbol contains an index.
+    return getNthValue(name, index);
+  }
+
+  // The symbol does not contain an index. Treat the symbol as a whole.
+  SmallVector<std::string, 4> values;
+  values.reserve(count);
+  for (unsigned i = 0; i < count; ++i)
+    values.emplace_back(getNthValue(symbol, i));
+  return llvm::join(values, ", ");
+}
+
 //===----------------------------------------------------------------------===//
 // PatternSymbolResolver
 //===----------------------------------------------------------------------===//
@@ -78,6 +127,11 @@ namespace {
 // `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`.
 // `$op2` is bound to `ResOp1`.
 //
+// If a symbol binds to a multi-result op and it does not have the `__N`
+// suffix, the symbol is expanded to the whole value pack generated by the
+// multi-result op. If the symbol has a `__N` suffix, then it will expand to
+// only the N-th result.
+//
 // This class keeps track of such symbols and translates them into their bound
 // values.
 //
@@ -90,9 +144,9 @@ public:
   PatternSymbolResolver(const StringMap<Argument> &srcArgs,
                         const StringSet<> &srcOperations);
 
-  // Marks the given `symbol` as bound.  Returns false if the `symbol` is
-  // already bound.
-  bool add(StringRef symbol);
+  // Marks the given `symbol` as bound to a value pack with `numValues` and
+  // returns true on success. Returns false if the `symbol` is already bound.
+  bool add(StringRef symbol, int numValues);
 
   // Queries the substitution for the given `symbol`.
   std::string query(StringRef symbol) const;
@@ -103,7 +157,8 @@ private:
   // Symbols bound to ops (for their results) in source pattern.
   const StringSet<> &sourceOps;
   // Symbols bound to ops (for their results) in result patterns.
-  StringSet<> resultOps;
+  // Key: symbol; value: number of values inside the pack
+  StringMap<int> resultOps;
 };
 } // end anonymous namespace
 
@@ -111,15 +166,18 @@ PatternSymbolResolver::PatternSymbolResolver(const StringMap<Argument> &srcArgs,
                                              const StringSet<> &srcOperations)
     : sourceArguments(srcArgs), sourceOps(srcOperations) {}
 
-bool PatternSymbolResolver::add(StringRef symbol) {
-  return resultOps.insert(symbol).second;
+bool PatternSymbolResolver::add(StringRef symbol, int numValues) {
+  StringRef name = getValuePackName(symbol);
+  return resultOps.try_emplace(name, numValues).second;
 }
 
 std::string PatternSymbolResolver::query(StringRef symbol) const {
   {
-    auto it = resultOps.find(symbol);
+    StringRef name = getValuePackName(symbol);
+    auto it = resultOps.find(name);
     if (it != resultOps.end())
-      return it->getKey();
+      return formtValuePack("{0}.getOperation()->getResult({1})", symbol,
+                            it->second, 0);
   }
   {
     auto it = sourceArguments.find(symbol);
@@ -527,10 +585,6 @@ void PatternEmitter::emitRewriteMethod() {
   int numExpectedResults = rootOp.getNumResults();
   int numProvidedResults = pattern.getNumResults();
 
-  if (numProvidedResults < numExpectedResults)
-    PrintFatalError(
-        loc, "no enough result patterns to replace root op in source pattern");
-
   os << R"(
   void rewrite(Operation *op, std::unique_ptr<PatternState> state,
                PatternRewriter &rewriter) const override {
@@ -591,7 +645,13 @@ std::string PatternEmitter::handleRewritePattern(DagNode resultTree,
   if (resultTree.isReplaceWithValue())
     return handleReplaceWithValue(resultTree);
 
-  return emitOpCreate(resultTree, resultIndex, depth);
+  // Create the op and get the local variable for it.
+  auto results = emitOpCreate(resultTree, resultIndex, depth);
+  // We need to get all the values out of this local variable if we've created a
+  // multi-result op.
+  const auto &numResults = pattern.getDialectOp(resultTree).getNumResults();
+  return formtValuePack("{0}.getOperation()->getResult({1})", results,
+                        numResults, 0);
 }
 
 std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
@@ -606,10 +666,7 @@ std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
     PrintFatalError(loc, "cannot bind symbol to verifyUnusedValue");
   }
 
-  auto name = tree.getArgName(0);
-  pattern.ensureBoundInSourcePattern(name);
-
-  return getBoundSymbol(name).str();
+  return resolveSymbol(tree.getArgName(0));
 }
 
 void PatternEmitter::handleVerifyUnusedValue(DagNode tree, int index) {
@@ -666,7 +723,7 @@ void PatternEmitter::addSymbol(DagNode node) {
   // Skip empty-named symbols, which happen for unbound ops in result patterns.
   if (symbol.empty())
     return;
-  if (!symbolResolver.add(symbol))
+  if (!symbolResolver.add(symbol, pattern.getDialectOp(node).getNumResults()))
     PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol));
 }
 
@@ -682,11 +739,6 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
   Operator &resultOp = tree.getDialectOp(opMap);
   auto numOpArgs = resultOp.getNumArgs();
 
-  if (resultOp.getNumResults() > 1) {
-    PrintFatalError(
-        loc, formatv("generating multiple-result op '{0}' is unsupported now",
-                     resultOp.getOperationName()));
-  }
   if (resultOp.isVariadic()) {
     PrintFatalError(loc, formatv("generating op '{0}' with variadic "
                                  "operands/results is unsupported now",
@@ -722,26 +774,37 @@ std::string PatternEmitter::emitOpCreate(DagNode tree, int resultIndex,
   std::string resultValue = tree.getOpName();
   if (resultValue.empty())
     resultValue = getUniqueValueName(&resultOp);
+  // Strip the index to get the name for the value pack. This will be used to
+  // name the local variable for the op.
+  StringRef valuePackName = getValuePackName(resultValue);
 
   // Then we build the new op corresponding to this DAG node.
 
+  // Right now we don't have general type inference in MLIR. Except a few
+  // special cases listed below, we need to supply types for all results
+  // when building an op.
   bool isSameOperandsAndResultType =
       resultOp.hasTrait("SameOperandsAndResultType");
   bool isBroadcastable = resultOp.hasTrait("BroadcastableTwoOperandsOneResult");
   bool useFirstAttr = resultOp.hasTrait("FirstAttrDerivedResultType");
+  bool usePartialResults = valuePackName != resultValue;
 
   if (isSameOperandsAndResultType || isBroadcastable || useFirstAttr ||
-      depth > 0) {
-    os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", resultValue,
-                            resultOp.getQualCppClassName());
+      usePartialResults || depth > 0) {
+    os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
+                            valuePackName, resultOp.getQualCppClassName());
   } else {
     // If depth == 0 we can use the equivalence of the source and target root
     // ops in the pattern to determine the return type.
-    std::string resultType = formatv("op->getResult({0})", resultIndex).str();
 
-    os.indent(4) << formatv(
-        "auto {0} = rewriter.create<{1}>(loc, {2}->getType()", resultValue,
-        resultOp.getQualCppClassName(), resultType);
+    // We need to specify the types for all results.
+    auto resultTypes =
+        formtValuePack("op->getResult({1})->getType()", valuePackName,
+                       resultOp.getNumResults(), resultIndex);
+
+    os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc, {2}",
+                            valuePackName, resultOp.getQualCppClassName(),
+                            resultTypes);
   }
 
   // Create the builder call for the result.