Use fused location for rewritten ops in generated rewrites.
authorJacques Pienaar <jpienaar@google.com>
Sat, 25 May 2019 18:03:51 +0000 (11:03 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:03:12 +0000 (20:03 -0700)
    This does tracks the location by recording all the ops in the source pattern and using the fused location for the transformed op. Track the locations via the rewrite state which is a bit heavy weight, in follow up to change to matchAndRewrite this will be addressed (and need for extra array go away).

--

PiperOrigin-RevId: 249986555

mlir/test/mlir-tblgen/pattern.td
mlir/tools/mlir-tblgen/RewriterGen.cpp

index 66ff381..d8e92cb 100644 (file)
@@ -19,6 +19,7 @@ def OpB : NS_Op<"op_b", []> {
 }
 
 def MyRule : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
+def MyRule2 : Pat<(OpA (OpA $input, $attr), $attr), (OpB $input, $attr)>;
 
 // Test rewrite rule naming
 // ---
@@ -27,6 +28,13 @@ def MyRule : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
 // CHECK-NEXT: {{.*pattern.td.*}}
 // CHECK: struct MyRule : public RewritePattern
 
+// CHECK-LABEL: struct MyRule2 : public RewritePattern
+// CHECK: s.autogeneratedRewritePatternOps[0] = op0;
+// CHECK: s.autogeneratedRewritePatternOps[1] = op1;
+// CHECK: rewriter.getFusedLoc({
+// CHECK-SAME: s.autogeneratedRewritePatternOps[0]->getLoc()
+// CHECK-SAME: s.autogeneratedRewritePatternOps[1]->getLoc()
+
 def : Pat<(OpA $input, $attr), (OpB $input, $attr)>;
 
 // Test basic structure generated from Pattern
index 9103cb0..cd93b98 100644 (file)
@@ -166,6 +166,7 @@ private:
   // Emits C++ statements for matching the `index`-th argument of the given DAG
   // `tree` as an operand.
   void emitOperandMatch(DagNode tree, int index, int depth, int indent);
+
   // Emits C++ statements for matching the `index`-th argument of the given DAG
   // `tree` as an attribute.
   void emitAttributeMatch(DagNode tree, int index, int depth, int indent);
@@ -231,6 +232,9 @@ private:
   FmtContext matchCtx;
   // Format contexts containing placeholder substitutations for rewrite().
   FmtContext rewriteCtx;
+
+  // Number of op processed.
+  int opCounter = 0;
 };
 } // end anonymous namespace
 
@@ -289,6 +293,9 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
           << formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n",
                      depth + 1, depth, i);
       emitOpMatch(argTree, depth + 1);
+      os.indent(indent + 2)
+          << formatv("s.autogeneratedRewritePatternOps[{0}] = op{1};\n",
+                     ++opCounter, depth + 1);
       os.indent(indent) << "}\n";
       continue;
     }
@@ -397,6 +404,7 @@ void PatternEmitter::emitMatchMethod(DagNode tree) {
     auto ctx = op0->getContext(); (void)ctx;
     auto state = llvm::make_unique<MatchedState>();
     auto &s = *state;
+    s.autogeneratedRewritePatternOps[0] = op0;
 )";
 
   // The rewrite pattern may specify that certain outputs should be unused in
@@ -500,6 +508,10 @@ void PatternEmitter::emit(StringRef rewriteName) {
   for (const auto &result : pattern.getSourcePatternBoundOps()) {
     os.indent(4) << "Operation *" << result.getKey() << ";\n";
   }
+  // TODO(jpienaar): Change to matchAndRewrite & capture ops with consistent
+  // numbering so that it can be reused for fused loc.
+  os.indent(4) << "Operation* autogeneratedRewritePatternOps["
+               << pattern.getSourcePattern().getNumOps() << "];\n";
   os << "  };\n";
 
   emitMatchMethod(tree);
@@ -521,8 +533,12 @@ void PatternEmitter::emitRewriteMethod() {
   void rewrite(Operation *op, std::unique_ptr<PatternState> state,
                PatternRewriter &rewriter) const override {
     auto& s = *static_cast<MatchedState *>(state.get());
-    auto loc = op->getLoc(); (void)loc;
-)";
+    auto loc = rewriter.getFusedLoc({)";
+  for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) {
+    os << (i ? ", " : "") << "s.autogeneratedRewritePatternOps[" << i
+       << "]->getLoc()";
+  }
+  os << "}); (void)loc;\n";
 
   // Collect the replacement value for each result
   llvm::SmallVector<std::string, 2> resultValues;