From 5bc9cc1332aa042b68fb5efa9fb50eaaf2d54f79 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 28 Jun 2021 13:42:16 -0700 Subject: [PATCH] [drr] Enable specifying range in NativeCodeCall replacement. This enables creating a replacement rule where range of positional replacements need not be spelled out, or are not known (e.g., enable having a rewrite that forward all operands to a call generically). Differential Revision: https://reviews.llvm.org/D104955 --- mlir/docs/DeclarativeRewrites.md | 94 ++++++++++++++++-------------- mlir/include/mlir/TableGen/Format.h | 25 ++++++-- mlir/lib/TableGen/Format.cpp | 25 +++++++- mlir/test/mlir-tblgen/rewriter-indexing.td | 5 ++ 4 files changed, 98 insertions(+), 51 deletions(-) diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md index 27ae161..5815035 100644 --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -11,8 +11,8 @@ compiler build time. This manual explains in detail all of the available mechanisms for defining rewrite rules in such a declarative manner. It aims to be a specification instead of a tutorial. Please refer to -[Quickstart tutorial to adding MLIR graph -rewrite](Tutorials/QuickstartRewrites.md) for the latter. +[Quickstart tutorial to adding MLIR graph rewrite](Tutorials/QuickstartRewrites.md) +for the latter. Given that declarative rewrite rules depend on op definition specification, this manual assumes knowledge of the [ODS](OpDefinitions.md) doc. @@ -51,8 +51,8 @@ features: * Matching multi-result ops in nested patterns. * Matching and generating variadic operand/result ops in nested patterns. * Packing and unpacking variadic operands/results during generation. -* [`NativeCodeCall`](#nativecodecall-transforming-the-generated-op) - returning more than one results. +* [`NativeCodeCall`](#nativecodecall-transforming-the-generated-op) returning + more than one results. ## Rule Definition @@ -93,9 +93,9 @@ Each pattern is specified as a TableGen `dag` object with the syntax of [directives](#rewrite-directives). `argN` is for matching (if used in source pattern) or generating (if used in result pattern) the `N`-th argument for `operator`. If the `operator` is some MLIR operation, it means the `N`-th -argument as specified in the `arguments` list of the op's definition. -Therefore, we say op argument specification in pattern is **position-based**: -the position where they appear matters. +argument as specified in the `arguments` list of the op's definition. Therefore, +we say op argument specification in pattern is **position-based**: the position +where they appear matters. `argN` can be a `dag` object itself, thus we can have nested `dag` tree to model the def-use relationship between ops. @@ -245,15 +245,15 @@ the pattern by following the exact same order as the ODS `arguments` definition. Otherwise, a custom `build()` method that matches the argument list is required. Right now all ODS-generated `build()` methods require specifying the result -type(s), unless the op has known traits like `SameOperandsAndResultType` that -we can use to auto-generate a `build()` method with result type deduction. -When generating an op to replace the result of the matched root op, we can use -the matched root op's result type when calling the ODS-generated builder. -Otherwise (e.g., generating an [auxiliary op](#supporting-auxiliary-ops) or -generating an op with a nested result pattern), DRR will not be able to deduce -the result type(s). The pattern author will need to define a custom builder -that has result type deduction ability via `OpBuilder` in ODS. For example, -in the following pattern +type(s), unless the op has known traits like `SameOperandsAndResultType` that we +can use to auto-generate a `build()` method with result type deduction. When +generating an op to replace the result of the matched root op, we can use the +matched root op's result type when calling the ODS-generated builder. Otherwise +(e.g., generating an [auxiliary op](#supporting-auxiliary-ops) or generating an +op with a nested result pattern), DRR will not be able to deduce the result +type(s). The pattern author will need to define a custom builder that has result +type deduction ability via `OpBuilder` in ODS. For example, in the following +pattern ```tablegen def : Pat<(AOp $input, $attr), (COp (AOp $input, $attr) $attr)>; @@ -295,8 +295,8 @@ to replace the matched `AOp`. In the result pattern, we can bind to the result(s) of a newly built op by attaching symbols to the op. (But we **cannot** bind to op arguments given that -they are referencing previously bound symbols.) This is useful for reusing -newly created results where suitable. For example, +they are referencing previously bound symbols.) This is useful for reusing newly +created results where suitable. For example, ```tablegen def DOp : Op<"d_op"> { @@ -373,18 +373,18 @@ And make sure the generated C++ code from the above pattern has access to the definition of the C++ helper function. In the above example, we are using a string to specialize the `NativeCodeCall` -template. The string can be an arbitrary C++ expression that evaluates into -some C++ object expected at the `NativeCodeCall` site (here it would be -expecting an array attribute). Typically the string should be a function call. +template. The string can be an arbitrary C++ expression that evaluates into some +C++ object expected at the `NativeCodeCall` site (here it would be expecting an +array attribute). Typically the string should be a function call. Note that currently `NativeCodeCall` must return no more than one value or attribute. This might change in the future. ##### `NativeCodeCall` placeholders -In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N`. The former -is called _special placeholder_, while the latter is called _positional -placeholder_. +In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N` and `$N...`. +The former is called _special placeholder_, while the latter is called +_positional placeholder_ and _positional range placeholder_. `NativeCodeCall` right now only supports three special placeholders: `$_builder`, `$_loc`, and `$_self`: @@ -423,6 +423,11 @@ the `NativeCodeCall` use site. For example, if we define `SomeCall : NativeCodeCall<"someFn($1, $2, $0)">` and use it like `(SomeCall $in0, $in1, $in2)`, then this will be translated into C++ call `someFn($in1, $in2, $in0)`. +Positional range placeholders will be substituted by multiple `dag` object +parameters at the `NativeCodeCall` use site. For example, if we define +`SomeCall : NativeCodeCall<"someFn($1...)">` and use it like `(SomeCall $in0, +$in1, $in2)`, then this will be translated into C++ call `someFn($in1, $in2)`. + ##### Customizing entire op building `NativeCodeCall` is not only limited to transforming arguments for building an @@ -490,8 +495,8 @@ matched op. Multi-result ops bring extra complexity to declarative rewrite rules. We use TableGen `dag` objects to represent ops in patterns; there is no native way to -indicate that an op generates multiple results. The approach adopted is based -on **naming convention**: a `__N` suffix is added to a symbol to indicate the +indicate that an op generates multiple results. The approach adopted is based on +**naming convention**: a `__N` suffix is added to a symbol to indicate the `N`-th result. #### `__N` suffix @@ -541,12 +546,12 @@ The above example also shows how to replace a matched multi-result op. To replace an `N`-result op, the result patterns must generate at least `N` declared values (see [Declared vs. actual value](#declared-vs-actual-value) for -definition). If there are more than `N` declared values generated, only the -last `N` declared values will be used to replace the matched op. Note that -because of the existence of multi-result op, one result pattern **may** generate -multiple declared values. So it means we do not necessarily need `N` result -patterns to replace an `N`-result op. For example, to replace an op with three -results, you can have +definition). If there are more than `N` declared values generated, only the last +`N` declared values will be used to replace the matched op. Note that because of +the existence of multi-result op, one result pattern **may** generate multiple +declared values. So it means we do not necessarily need `N` result patterns to +replace an `N`-result op. For example, to replace an op with three results, you +can have ```tablegen // ThreeResultOp/TwoResultOp/OneResultOp generates three/two/one result(s), @@ -590,8 +595,8 @@ regarding an op's values. * _Actual operand/result/value_: an operand/result/value of an op instance at runtime -The above terms are needed because ops can have multiple results, and some of the -results can also be variadic. For example, +The above terms are needed because ops can have multiple results, and some of +the results can also be variadic. For example, ```tablegen def MultiVariadicOp : Op<"multi_variadic_op"> { @@ -611,8 +616,8 @@ def MultiVariadicOp : Op<"multi_variadic_op"> { We say the above op has 3 declared operands and 3 declared results. But at runtime, an instance can have 3 values corresponding to `$input2` and 2 values -correspond to `$output2`; we say it has 5 actual operands and 4 actual -results. A variadic operand/result is a considered as a declared value that can +correspond to `$output2`; we say it has 5 actual operands and 4 actual results. +A variadic operand/result is a considered as a declared value that can correspond to multiple actual values. [TODO] @@ -651,10 +656,10 @@ You can ### Adjusting benefits -The benefit of a `Pattern` is an integer value indicating the benefit of matching -the pattern. It determines the priorities of patterns inside the pattern rewrite -driver. A pattern with a higher benefit is applied before one with a lower -benefit. +The benefit of a `Pattern` is an integer value indicating the benefit of +matching the pattern. It determines the priorities of patterns inside the +pattern rewrite driver. A pattern with a higher benefit is applied before one +with a lower benefit. In DRR, a rule is set to have a benefit of the number of ops in the source pattern. This is based on the heuristics and assumptions that: @@ -662,7 +667,6 @@ pattern. This is based on the heuristics and assumptions that: * Larger matches are more beneficial than smaller ones. * If a smaller one is applied first the larger one may not apply anymore. - The fourth parameter to `Pattern` (and `Pat`) allows to manually tweak a pattern's benefit. Just supply `(addBenefit N)` to add `N` to the benefit value. @@ -696,8 +700,8 @@ def : Pat<(LocSrc1Op:$src1 (LocSrc2Op:$src2 ...), (LocDst1Op (LocDst2Op ..., (location $src2)), (location "outer"))>; ``` -In the above pattern, the generated `LocDst2Op` will use the matched location -of `LocSrc2Op` while the root `LocDst1Op` node will used the named location +In the above pattern, the generated `LocDst2Op` will use the matched location of +`LocSrc2Op` while the root `LocDst1Op` node will used the named location `outer`. ### `replaceWithValue` @@ -724,8 +728,8 @@ The above pattern removes the `Foo` and replaces all uses of `Foo` with ### Run `mlir-tblgen` to see the generated content -TableGen syntax sometimes can be obscure; reading the generated content can be -a very helpful way to understand and debug issues. To build `mlir-tblgen`, run +TableGen syntax sometimes can be obscure; reading the generated content can be a +very helpful way to understand and debug issues. To build `mlir-tblgen`, run `cmake --build . --target mlir-tblgen` in your build directory and find the `mlir-tblgen` binary in the `bin/` subdirectory. All the supported generators can be found via `mlir-tblgen --help`. diff --git a/mlir/include/mlir/TableGen/Format.h b/mlir/include/mlir/TableGen/Format.h index 441e05c..3120f6e 100644 --- a/mlir/include/mlir/TableGen/Format.h +++ b/mlir/include/mlir/TableGen/Format.h @@ -88,22 +88,33 @@ private: /// Struct representing a replacement segment for the formatted string. It can /// be a segment of the formatting template (for `Literal`) or a replacement -/// parameter (for `PositionalPH` and `SpecialPH`). +/// parameter (for `PositionalPH`, `PositionalRangePH` and `SpecialPH`). struct FmtReplacement { - enum class Type { Empty, Literal, PositionalPH, SpecialPH }; + enum class Type { + Empty, + Literal, + PositionalPH, + PositionalRangePH, + SpecialPH + }; FmtReplacement() = default; explicit FmtReplacement(StringRef literal) : type(Type::Literal), spec(literal) {} FmtReplacement(StringRef spec, size_t index) : type(Type::PositionalPH), spec(spec), index(index) {} + FmtReplacement(StringRef spec, size_t index, size_t end) + : type(Type::PositionalRangePH), spec(spec), index(index), end(end) {} FmtReplacement(StringRef spec, FmtContext::PHKind placeholder) : type(Type::SpecialPH), spec(spec), placeholder(placeholder) {} Type type = Type::Empty; StringRef spec; size_t index = 0; + size_t end = kUnset; FmtContext::PHKind placeholder = FmtContext::PHKind::None; + + static constexpr size_t kUnset = -1; }; class FmtObjectBase { @@ -121,7 +132,7 @@ protected: // std::vector. struct CreateAdapters { template - std::vector operator()(Ts &... items) { + std::vector operator()(Ts &...items) { return std::vector{&items...}; } }; @@ -205,7 +216,8 @@ private: /// /// There are two categories of placeholders accepted, both led by a '$' sign: /// -/// 1. Positional placeholder: $[0-9]+ +/// 1.a Positional placeholder: $[0-9]+ +/// 1.b Positional range placeholder: $[0-9]+... /// 2. Special placeholder: $[a-zA-Z_][a-zA-Z0-9_]* /// /// Replacement parameters for positional placeholders are supplied as the @@ -214,6 +226,9 @@ private: /// can use the positional placeholders in any order and repeat any times, for /// example, "$2 $1 $1 $0" is accepted. /// +/// Replace parameters for positional range placeholders are supplied as if +/// positional placeholders were specified with commas separating them. +/// /// Replacement parameters for special placeholders are supplied using the `ctx` /// format context. /// @@ -237,7 +252,7 @@ private: /// 2. This utility does not support format layout because it is rarely needed /// in C++ code generation. template -inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&... vals) +inline auto tgfmt(StringRef fmt, const FmtContext *ctx, Ts &&...vals) -> FmtObject(vals))...))> { using ParamTuple = decltype(std::make_tuple( diff --git a/mlir/lib/TableGen/Format.cpp b/mlir/lib/TableGen/Format.cpp index 1083451..4a0bbdf 100644 --- a/mlir/lib/TableGen/Format.cpp +++ b/mlir/lib/TableGen/Format.cpp @@ -97,7 +97,8 @@ FmtObjectBase::splitFmtSegment(StringRef fmt) { // First try to see if it's a positional placeholder, and then handle special // placeholders. - size_t end = fmt.find_if_not([](char c) { return std::isdigit(c); }, 1); + size_t end = + fmt.find_if_not([](char c) { return std::isdigit(c); }, /*From=*/1); if (end != 1) { // We have a positional placeholder. Parse the index. size_t index = 0; @@ -105,6 +106,14 @@ FmtObjectBase::splitFmtSegment(StringRef fmt) { llvm_unreachable("invalid replacement sequence index"); } + // Check if this is the part of a range specification. + if (fmt.substr(end, 3) == "...") { + // Currently only ranges without upper bound are supported. + return { + FmtReplacement{fmt.substr(0, end + 3), index, FmtReplacement::kUnset}, + fmt.substr(end + 3)}; + } + if (end == StringRef::npos) { // All the remaining characters are part of the positional placeholder. return {FmtReplacement{fmt, index}, StringRef()}; @@ -164,6 +173,20 @@ void FmtObjectBase::format(raw_ostream &s) const { continue; } + if (repl.type == FmtReplacement::Type::PositionalRangePH) { + if (repl.index >= adapters.size()) { + s << repl.spec << kMarkerForNoSubst; + continue; + } + auto range = llvm::makeArrayRef(adapters); + range = range.drop_front(repl.index); + if (repl.end != FmtReplacement::kUnset) + range = range.drop_back(adapters.size() - repl.end); + llvm::interleaveComma(range, s, + [&](auto &x) { x->format(s, /*Options=*/""); }); + continue; + } + assert(repl.type == FmtReplacement::Type::PositionalPH); if (repl.index >= adapters.size()) { diff --git a/mlir/test/mlir-tblgen/rewriter-indexing.td b/mlir/test/mlir-tblgen/rewriter-indexing.td index cbdeff9c..f4f055e 100644 --- a/mlir/test/mlir-tblgen/rewriter-indexing.td +++ b/mlir/test/mlir-tblgen/rewriter-indexing.td @@ -85,3 +85,8 @@ def NativeBuilder : // CHECK: nativeCall(rewriter, odsLoc, (*v1.begin()), (*v2.begin()), (*v3.begin()), (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin())) def test4 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10), (NativeBuilder $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>; + +// CHECK: struct test5 : public ::mlir::RewritePattern { +// CHECK: foo(rewriter, (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin())) +def test5 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10), + (NativeCodeCall<[{ foo($_builder, $3...) }]> $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>; -- 2.7.4