From 2f50b6c401fd4d6ff63718ef3b889a79ba32a640 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Sat, 25 May 2019 11:03:51 -0700 Subject: [PATCH] Use fused location for rewritten ops in generated rewrites. This does tracks the location by recording all the ops in the source pattern and using the fused location for the transformed op. Track the locations via the rewrite state which is a bit heavy weight, in follow up to change to matchAndRewrite this will be addressed (and need for extra array go away). -- PiperOrigin-RevId: 249986555 --- mlir/test/mlir-tblgen/pattern.td | 8 ++++++++ mlir/tools/mlir-tblgen/RewriterGen.cpp | 20 ++++++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/mlir/test/mlir-tblgen/pattern.td b/mlir/test/mlir-tblgen/pattern.td index 66ff381..d8e92cb 100644 --- a/mlir/test/mlir-tblgen/pattern.td +++ b/mlir/test/mlir-tblgen/pattern.td @@ -19,6 +19,7 @@ def OpB : NS_Op<"op_b", []> { } def MyRule : Pat<(OpA $input, $attr), (OpB $input, $attr)>; +def MyRule2 : Pat<(OpA (OpA $input, $attr), $attr), (OpB $input, $attr)>; // Test rewrite rule naming // --- @@ -27,6 +28,13 @@ def MyRule : Pat<(OpA $input, $attr), (OpB $input, $attr)>; // CHECK-NEXT: {{.*pattern.td.*}} // CHECK: struct MyRule : public RewritePattern +// CHECK-LABEL: struct MyRule2 : public RewritePattern +// CHECK: s.autogeneratedRewritePatternOps[0] = op0; +// CHECK: s.autogeneratedRewritePatternOps[1] = op1; +// CHECK: rewriter.getFusedLoc({ +// CHECK-SAME: s.autogeneratedRewritePatternOps[0]->getLoc() +// CHECK-SAME: s.autogeneratedRewritePatternOps[1]->getLoc() + def : Pat<(OpA $input, $attr), (OpB $input, $attr)>; // Test basic structure generated from Pattern diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 9103cb0..cd93b98 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -166,6 +166,7 @@ private: // Emits C++ statements for matching the `index`-th argument of the given DAG // `tree` as an operand. void emitOperandMatch(DagNode tree, int index, int depth, int indent); + // Emits C++ statements for matching the `index`-th argument of the given DAG // `tree` as an attribute. void emitAttributeMatch(DagNode tree, int index, int depth, int indent); @@ -231,6 +232,9 @@ private: FmtContext matchCtx; // Format contexts containing placeholder substitutations for rewrite(). FmtContext rewriteCtx; + + // Number of op processed. + int opCounter = 0; }; } // end anonymous namespace @@ -289,6 +293,9 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) { << formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n", depth + 1, depth, i); emitOpMatch(argTree, depth + 1); + os.indent(indent + 2) + << formatv("s.autogeneratedRewritePatternOps[{0}] = op{1};\n", + ++opCounter, depth + 1); os.indent(indent) << "}\n"; continue; } @@ -397,6 +404,7 @@ void PatternEmitter::emitMatchMethod(DagNode tree) { auto ctx = op0->getContext(); (void)ctx; auto state = llvm::make_unique(); auto &s = *state; + s.autogeneratedRewritePatternOps[0] = op0; )"; // The rewrite pattern may specify that certain outputs should be unused in @@ -500,6 +508,10 @@ void PatternEmitter::emit(StringRef rewriteName) { for (const auto &result : pattern.getSourcePatternBoundOps()) { os.indent(4) << "Operation *" << result.getKey() << ";\n"; } + // TODO(jpienaar): Change to matchAndRewrite & capture ops with consistent + // numbering so that it can be reused for fused loc. + os.indent(4) << "Operation* autogeneratedRewritePatternOps[" + << pattern.getSourcePattern().getNumOps() << "];\n"; os << " };\n"; emitMatchMethod(tree); @@ -521,8 +533,12 @@ void PatternEmitter::emitRewriteMethod() { void rewrite(Operation *op, std::unique_ptr state, PatternRewriter &rewriter) const override { auto& s = *static_cast(state.get()); - auto loc = op->getLoc(); (void)loc; -)"; + auto loc = rewriter.getFusedLoc({)"; + for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { + os << (i ? ", " : "") << "s.autogeneratedRewritePatternOps[" << i + << "]->getLoc()"; + } + os << "}); (void)loc;\n"; // Collect the replacement value for each result llvm::SmallVector resultValues; -- 2.7.4