[TableGen] Allow additional result patterns not directly used for replacement
authorLei Zhang <antiagainst@google.com>
Wed, 3 Apr 2019 19:29:14 +0000 (12:29 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 4 Apr 2019 02:20:22 +0000 (19:20 -0700)
    This CL looses the requirement that all result patterns in a rewrite rule must
    replace a result of the root op in the source pattern. Now only the last N
    result pattern-generated ops are used to replace a N-result source op.

    This allows to generate additional ops to aid building up final ops used to
    replace the source op.

--

PiperOrigin-RevId: 241783192

mlir/include/mlir/IR/OpBase.td
mlir/test/mlir-tblgen/pattern-multi-result-op.td
mlir/test/mlir-tblgen/pattern-tAttr.td
mlir/tools/mlir-tblgen/RewriterGen.cpp

index 76b5a75..3c2ad54 100644 (file)
@@ -779,18 +779,24 @@ def addBenefit;
 // specifying rewrite rules.
 //
 // A rewrite rule contains two components: a source pattern and one or more
-// result rules. Each pattern is specified as a (recursive) DAG node (tree)
+// result patterns. Each pattern is specified as a (recursive) DAG node (tree)
 // in the form of `(node arg0, arg1, ...)`.
 // The `node` are normally MLIR ops, but it can also be one of the directives
 // listed later in this section.
 // In the source pattern, `arg*` can be used to specify matchers (e.g., using
-// type/attribute types, mAttr, etc.) and bound to a name for later use. In
+// type/attribute types, etc.) and bound to a name for later use. In
 // the result pattern, `arg*` can be used to refer to a previously bound name,
 // with potential transformations (e.g., using tAttr, etc.). `arg*` can itself
 // be nested DAG node.
 class Pattern<dag source, list<dag> results, list<dag> preds = [],
   dag benefitAdded = (addBenefit 0)> {
   dag sourcePattern = source;
+  // Result patterns. Each result pattern is expected to replace one result
+  // of the root op in the source pattern. In the case of more result patterns
+  // than needed to replace the source op, only the last N results generated
+  // by the last N result pattern is used to replace a N-result source op.
+  // So that the beginning result patterns can be used to generate additional
+  // ops to aid building the results used for replacement.
   list<dag> resultPatterns = results;
   // Multi-entity constraints. Each constraint here involves multiple entities
   // matched in source pattern and places further constraints on them as a
index bcbdb85..c57cd9a 100644 (file)
@@ -7,6 +7,11 @@ def ThreeResultOp : Op<"three_result_op", []> {
   let results = (outs I32:$r1, I32:$r2, I32:$r3);
 }
 
+def TwoResultOp : Op<"two_result_op", []> {
+  let arguments = (ins I32:$input);
+  let results = (outs I32:$r1, I32:$r2);
+}
+
 def OneResultOp : Op<"one_result_op", []> {
   let arguments = (ins I32:$input);
   let results = (outs I32:$r1);
@@ -47,3 +52,25 @@ def : Pattern<(ThreeResultOp $input), [
 // CHECK-NEXT:     /*input=*/vOneResultOp2
 // CHECK:        rewriter.replaceOp(op, {vOneResultOp0, vOneResultOp1, vOneResultOp3});
 
+// Test more result patterns than needed for replacement
+// ---
+def AdditionalOp : Op<"additional_one_result_op", []> {
+  let arguments = (ins I32:$input);
+  let results = (outs I32:$r1);
+}
+def : Pattern<(TwoResultOp $input), [
+        // Additional op generated to help build the final result but not
+        // directly used to replace the source op
+        (AdditionalOp:$interm $input),
+
+        (OneResultOp $interm),
+        (OneResultOp $input)
+      ]>;
+
+// CHECK-LABEL: struct GeneratedConvert2
+
+// CHECK:      auto interm = rewriter.create<AdditionalOp>(
+// CHECK:      auto vOneResultOp0 = rewriter.create<OneResultOp>(
+// CHECK-NEXT:   /*input=*/interm
+// CHECK:      auto vOneResultOp1 = rewriter.create<OneResultOp>(
+// CHECK:      rewriter.replaceOp(op, {vOneResultOp0, vOneResultOp1});
index 93aaae2..39fa479 100644 (file)
@@ -15,9 +15,11 @@ def X_AddOp : Op<"x.add"> {
 }
 def Y_AddOp : Op<"y.add"> {
   let arguments = (ins U, U, T_Attr:$attrName);
+  let results = (outs U);
 }
 def Z_AddOp : Op<"z.add"> {
   let arguments = (ins U, U, T_Attr:$attrName1, T_Attr:$attrName2);
+  let results = (outs U);
 }
 
 // Define rewrite pattern.
index d9d5f74..d0e4008 100644 (file)
@@ -343,6 +343,10 @@ void PatternEmitter::emit(StringRef rewriteName) {
   const Operator &rootOp = pattern.getSourceRootOp();
   auto rootName = rootOp.getOperationName();
 
+  if (rootOp.hasVariadicResult())
+    PrintFatalError(
+        loc, "replacing op with variadic results not supported right now");
+
   // Emit RewritePattern for Pattern.
   os << formatv(R"(struct {0} : public RewritePattern {
   {0}(MLIRContext *context) : RewritePattern("{1}", {2}, context) {{})",
@@ -369,9 +373,13 @@ void PatternEmitter::emit(StringRef rewriteName) {
 }
 
 void PatternEmitter::emitRewriteMethod() {
-  unsigned numResults = pattern.getNumResults();
-  if (numResults == 0)
-    PrintFatalError(loc, "must provide at least one result pattern");
+  const Operator &rootOp = pattern.getSourceRootOp();
+  int numExpectedResults = rootOp.getNumResults();
+  unsigned numProvidedResults = pattern.getNumResults();
+
+  if (numProvidedResults < numExpectedResults)
+    PrintFatalError(
+        loc, "no enough result patterns to replace root op in source pattern");
 
   os << R"(
   void rewrite(Operation *op, std::unique_ptr<PatternState> state,
@@ -382,7 +390,7 @@ void PatternEmitter::emitRewriteMethod() {
 
   // Collect the replacement value for each result
   llvm::SmallVector<std::string, 2> resultValues;
-  for (unsigned i = 0; i < numResults; ++i) {
+  for (unsigned i = 0; i < numProvidedResults; ++i) {
     DagNode resultTree = pattern.getResultPattern(i);
     resultValues.push_back(handleRewritePattern(resultTree, i, 0));
   }
@@ -390,8 +398,9 @@ void PatternEmitter::emitRewriteMethod() {
   // Emit the final replaceOp() statement
   os.indent(4) << "rewriter.replaceOp(op, {";
   interleave(
-      resultValues, [&](const std::string &name) { os << name; },
-      [&]() { os << ", "; });
+      // We only use the last numExpectedResults ones to replace the root op.
+      ArrayRef<std::string>(resultValues).take_back(numExpectedResults),
+      [&](const std::string &name) { os << name; }, [&]() { os << ", "; });
   os << "});\n  }\n";
 }