From 3f7439b28063c284975b49ebdc9c5645cedae7a0 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 7 Apr 2020 07:44:19 -0700 Subject: [PATCH] [mlir][DRR] Add location directive Summary: Add directive to indicate the location to give to op being created. This directive is optional and if unused the location will still be the fused location of all source operations. Currently this directive only works with other op locations, reusing an existing op location or a fusion of op locations. But doesn't yet support supplying metadata for the FusedLoc. Based off initial revision by antiagainst@ and effectively mirrors GlobalIsel debug_locations directive. Differential Revision: https://reviews.llvm.org/D77649 --- mlir/docs/DeclarativeRewrites.md | 48 +++++++++++++++++- mlir/include/mlir/IR/OpBase.td | 7 ++- mlir/include/mlir/TableGen/Pattern.h | 3 ++ mlir/lib/TableGen/Pattern.cpp | 16 +++++- mlir/test/lib/Dialect/Test/TestOps.td | 26 ++++++++++ mlir/test/mlir-tblgen/pattern.mlir | 17 ++++++- mlir/tools/mlir-tblgen/RewriterGen.cpp | 91 +++++++++++++++++++++++++--------- 7 files changed, 179 insertions(+), 29 deletions(-) diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md index 82b97e0..ca759f4 100644 --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -657,9 +657,53 @@ pattern. This is based on the heuristics and assumptions that: The fourth parameter to `Pattern` (and `Pat`) allows to manually tweak a pattern's benefit. Just supply `(addBenefit N)` to add `N` to the benefit value. -## Special directives +## Rewrite directives -[TODO] +### `location` + +By default the C++ pattern expanded from a DRR pattern uses the fused location +of all source ops as the location for all generated ops. This is not always the +best location mapping relationship. For such cases, DRR provides the `location` +directive to provide finer control. + +`location` is of the following syntax: + +```tablgen +(location $symbol0, $symbol1, ...) +``` + +where all `$symbol` should be bound previously in the pattern. + +`location` must be used as the last argument to an op creation. For example, + +```tablegen +def : Pat<(LocSrc1Op:$src1 (LocSrc2Op:$src2 ...), + (LocDst1Op (LocDst2Op ..., (location $src2)))>; +``` + +In the above pattern, the generated `LocDst2Op` will use the matched location +of `LocSrc2Op` while the root `LocDst1Op` node will still se the fused location +of all source Ops. + +### `replaceWithValue` + +The `replaceWithValue` directive is used to eliminate a matched op by replacing +all of it uses with a captured value. It is of the following syntax: + +```tablegen +(replaceWithValue $symbol) +``` + +where `$symbol` should be a symbol bound previously in the pattern. + +For example, + +```tablegen +def : Pat<(Foo $input), (replaceWithValue $input)>; +``` + +The above pattern removes the `Foo` and replaces all uses of `Foo` with +`$input`. ## Debugging Tips diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 25f062b..09cea1b 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2179,9 +2179,14 @@ class NativeCodeCall { } //===----------------------------------------------------------------------===// -// Common directives +// Rewrite directives //===----------------------------------------------------------------------===// +// Directive used in result pattern to specify the location of the generated +// op. This directive must be used as the last argument to the op creation +// DAG construct. The arguments to location must be previously captured symbol. +def location; + // Directive used in result pattern to indicate that no new op are generated, // so to replace the matched DAG with an existing SSA value. def replaceWithValue; diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index 0ed4133..e7fa48d 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -159,6 +159,9 @@ public: // value. bool isReplaceWithValue() const; + // Returns whether this DAG represents the location of an op creation. + bool isLocationDirective() const; + // Returns true if this DAG node is wrapping native code call. bool isNativeCodeCall() const; diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 6d4b03b..5b54708 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -103,7 +103,7 @@ bool tblgen::DagNode::isNativeCodeCall() const { } bool tblgen::DagNode::isOperation() const { - return !(isNativeCodeCall() || isReplaceWithValue()); + return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective(); } llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const { @@ -159,6 +159,11 @@ bool tblgen::DagNode::isReplaceWithValue() const { return dagOpDef->getName() == "replaceWithValue"; } +bool tblgen::DagNode::isLocationDirective() const { + auto *dagOpDef = cast(node->getOperator())->getDef(); + return dagOpDef->getName() == "location"; +} + void tblgen::DagNode::print(raw_ostream &os) const { if (node) node->print(os); @@ -533,7 +538,14 @@ void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, auto numOpArgs = op.getNumArgs(); auto numTreeArgs = tree.getNumArgs(); - if (numOpArgs != numTreeArgs) { + // The pattern might have the last argument specifying the location. + bool hasLocDirective = false; + if (numTreeArgs != 0) { + if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) + hasLocDirective = lastArg.isLocationDirective(); + } + + if (numOpArgs != numTreeArgs - hasLocDirective) { auto err = formatv("op '{0}' argument number mismatch: " "{1} in pattern vs. {2} in definition", op.getOperationName(), numTreeArgs, numOpArgs); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 0619609..8859d50 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -502,6 +502,20 @@ def StringAttrPrettyNameOp } //===----------------------------------------------------------------------===// +// Test Locations +//===----------------------------------------------------------------------===// + +def TestLocationSrcOp : TEST_Op<"loc_src"> { + let arguments = (ins I32:$input); + let results = (outs I32:$output); +} + +def TestLocationDstOp : TEST_Op<"loc_dst", [SameOperandsAndResultType]> { + let arguments = (ins I32:$input); + let results = (outs I32:$output); +} + +//===----------------------------------------------------------------------===// // Test Patterns //===----------------------------------------------------------------------===// @@ -996,6 +1010,18 @@ def : Pat<(OneI32ResultOp), ConstantAttr)>; //===----------------------------------------------------------------------===// +// Test Patterns (Location) + +// Test that we can specify locations for generated ops. +def : Pat<(TestLocationSrcOp:$res1 + (TestLocationSrcOp:$res2 + (TestLocationSrcOp:$res3 $input))), + (TestLocationDstOp + (TestLocationDstOp + (TestLocationDstOp $input, (location $res1))), + (location $res2, $res3))>; + +//===----------------------------------------------------------------------===// // Test Legalization //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir index 67ea2fd..a96d90f 100644 --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-patterns -mlir-print-debuginfo %s | FileCheck %s +// RUN: mlir-opt -test-patterns -mlir-print-debuginfo %s | FileCheck %s --dump-input-on-failure // CHECK-LABEL: verifyFusedLocs func @verifyFusedLocs(%arg0 : i32) -> i32 { @@ -10,6 +10,21 @@ func @verifyFusedLocs(%arg0 : i32) -> i32 { return %result : i32 } +// CHECK-LABEL: verifyDesignatedLoc +func @verifyDesignatedLoc(%arg0 : i32) -> i32 { + %0 = "test.loc_src"(%arg0) : (i32) -> i32 loc("loc3") + %1 = "test.loc_src"(%0) : (i32) -> i32 loc("loc2") + %2 = "test.loc_src"(%1) : (i32) -> i32 loc("loc1") + + // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc("loc1") + // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc(fused[ + // CHECK-SAME: "loc1" + // CHECK-SAME: "loc3" + // CHECK-SAME: "loc2" + // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc(fused["loc2", "loc3"]) + return %1 : i32 +} + // CHECK-LABEL: verifyZeroResult func @verifyZeroResult(%arg0 : i32) { // CHECK: "test.op_i"(%arg0) : (i32) -> () diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index a484316..73a8525 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -109,9 +109,11 @@ private: // calling native C++ code. std::string handleReplaceWithNativeCodeCall(DagNode resultTree); - // Returns the C++ expression referencing the old value serving as the - // replacement. - std::string handleReplaceWithValue(DagNode tree); + // Returns the symbol of the old value serving as the replacement. + StringRef handleReplaceWithValue(DagNode tree); + + // Returns the symbol of the value whose location to use. + std::string handleUseLocationOf(DagNode tree); // Emits the C++ statement to build a new op out of the given DAG `tree` and // returns the variable name that this op is assigned to. If the root op in @@ -580,11 +582,11 @@ void PatternEmitter::emitRewriteLogic() { PrintFatalError(loc, error); } - os.indent(4) << "auto loc = rewriter.getFusedLoc({"; + os.indent(4) << "auto odsLoc = rewriter.getFusedLoc({"; for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; } - os << "}); (void)loc;\n"; + os << "}); (void)odsLoc;\n"; // Process auxiliary result patterns. for (int i = 0; i < replStartIndex; ++i) { @@ -640,15 +642,19 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree, LLVM_DEBUG(resultTree.print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); + if (resultTree.isLocationDirective()) { + PrintFatalError(loc, + "location directive can only be used with op creation"); + } + if (resultTree.isNativeCodeCall()) { auto symbol = handleReplaceWithNativeCodeCall(resultTree); symbolInfoMap.bindValue(symbol); return symbol; } - if (resultTree.isReplaceWithValue()) { - return handleReplaceWithValue(resultTree); - } + if (resultTree.isReplaceWithValue()) + return handleReplaceWithValue(resultTree).str(); // Normal op creation. auto symbol = handleOpCreation(resultTree, resultIndex, depth); @@ -660,7 +666,7 @@ std::string PatternEmitter::handleResultPattern(DagNode resultTree, return symbol; } -std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { +StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) { assert(tree.isReplaceWithValue()); if (tree.getNumArgs() != 1) { @@ -672,7 +678,30 @@ std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); } - return std::string(tree.getArgName(0)); + return tree.getArgName(0); +} + +std::string PatternEmitter::handleUseLocationOf(DagNode tree) { + assert(tree.isLocationDirective()); + auto lookUpArgLoc = [this, &tree](int idx) { + const auto *const lookupFmt = "(*{0}.begin()).getLoc()"; + return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt); + }; + + if (tree.getNumArgs() != 1) { + std::string ret; + llvm::raw_string_ostream os(ret); + os << "rewriter.getFusedLoc({"; + for (int i = 0, e = tree.getNumArgs(); i != e; ++i) + os << (i ? ", " : "") << lookUpArgLoc(i); + os << "})"; + return os.str(); + } + + if (!tree.getSymbol().empty()) + PrintFatalError(loc, "cannot bind symbol to location"); + + return lookUpArgLoc(0); } std::string PatternEmitter::handleOpArgument(DagLeaf leaf, @@ -753,14 +782,28 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, Operator &resultOp = tree.getDialectOp(opMap); auto numOpArgs = resultOp.getNumArgs(); + auto numPatArgs = tree.getNumArgs(); + + // Get the location for this operation if explicitly provided. + std::string locToUse; + if (numPatArgs != 0) { + if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1)) + if (lastArg.isLocationDirective()) + locToUse = handleUseLocationOf(lastArg); + } - if (numOpArgs != tree.getNumArgs()) { - PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " - "{1} in pattern vs. {2} in definition", - resultOp.getOperationName(), tree.getNumArgs(), - numOpArgs)); + auto inPattern = numPatArgs - !locToUse.empty(); + if (numOpArgs != inPattern) { + PrintFatalError(loc, + formatv("resultant op '{0}' argument number mismatch: " + "{1} in pattern vs. {2} in definition", + resultOp.getOperationName(), inPattern, numOpArgs)); } + // If no explicit location is given, use the default, all fused, location. + if (locToUse.empty()) + locToUse = "odsLoc"; + // A map to collect all nested DAG child nodes' names, with operand index as // the key. This includes both bound and unbound child nodes. ChildNodeIndexNameMap childNodeNames; @@ -769,9 +812,8 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, // create ops for them and remember the symbol names for them, so that we can // use the results in the current node. This happens in a recursive manner. for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { - if (auto child = tree.getArgAsNestedDag(i)) { + if (auto child = tree.getArgAsNestedDag(i)) childNodeNames[i] = handleResultPattern(child, i, depth + 1); - } } // The name of the local variable holding this op. @@ -811,10 +853,11 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, // First prepare local variables for op arguments used in builder call. createAggregateLocalVarsForOpArgs(tree, childNodeNames); + // Then create the op. os.indent(6) << formatv( - "{0} = rewriter.create<{1}>(loc, tblgen_values, tblgen_attrs);\n", - valuePackName, resultOp.getQualCppClassName()); + "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);\n", + valuePackName, resultOp.getQualCppClassName(), locToUse); os.indent(4) << "}\n"; return resultValue; } @@ -831,8 +874,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, // here given that it's easier for developers to write compared to // aggregate-parameter builders. createSeparateLocalVarsForOpArgs(tree, childNodeNames); - os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName, - resultOp.getQualCppClassName()); + + os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, + resultOp.getQualCppClassName(), locToUse); supplyValuesForOpArgs(tree, childNodeNames); os << "\n );\n"; os.indent(4) << "}\n"; @@ -858,9 +902,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, "tblgen_types.push_back(v.getType()); }\n", resultIndex + i); } - os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc, tblgen_types, " + os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " "tblgen_values, tblgen_attrs);\n", - valuePackName, resultOp.getQualCppClassName()); + valuePackName, resultOp.getQualCppClassName(), + locToUse); os.indent(4) << "}\n"; return resultValue; } -- 2.7.4