pattern. This will signal to the pattern driver that recursive application of
this pattern may happen, and the pattern is equipped to safely handle it.
+### Debug Names and Labels
+
+To aid in debugging, patterns may specify: a debug name (via `setDebugName`),
+which should correspond to an identifier that uniquely identifies the specific
+pattern; and a set of debug labels (via `addDebugLabels`), which correspond to
+identifiers that uniquely identify groups of patterns. This information is used
+by various utilities to aid in the debugging of pattern rewrites, e.g. in debug
+logs, to provide pattern filtering, etc. A simple code example is shown below:
+
+```c++
+class MyPattern : public RewritePattern {
+public:
+ /// Inherit constructors from RewritePattern.
+ using RewritePattern::RewritePattern;
+
+ void initialize() {
+ setDebugName("MyPattern");
+ addDebugLabels("MyRewritePass");
+ }
+
+ // ...
+};
+
+void populateMyPatterns(RewritePatternSet &patterns, MLIRContext *ctx) {
+ // Debug labels may also be attached to patterns during insertion. This allows
+ // for easily attaching common labels to groups of patterns.
+ patterns.addWithLabel<MyPattern, ...>("MyRewritePatterns", ctx);
+}
+```
+
### Initialization
Several pieces of pattern state require explicit initialization by the pattern,
Note: This driver is the one used by the [canonicalization](Canonicalization.md)
[pass](Passes.md/#-canonicalize-canonicalize-operations) in MLIR.
+
+## Debugging
+
+### Pattern Filtering
+
+To simplify test case definition and reduction, the `FrozenRewritePatternSet`
+class provides built-in support for filtering which patterns should be provided
+to the pattern driver for application. Filtering behavior is specified by
+providing a `disabledPatterns` and `enabledPatterns` list when constructing the
+`FrozenRewritePatternSet`. The `disabledPatterns` list should contain a set of
+debug names or labels for patterns that are disabled during pattern application,
+i.e. which patterns should be filtered out. The `enabledPatterns` list should
+contain a set of debug names or labels for patterns that are enabled during
+pattern application, patterns that do not satisfy this constraint are filtered
+out. Note that patterns specified by the `disabledPatterns` list will be
+filtered out even if they match criteria in the `enabledPatterns` list. An
+example is shown below:
+
+```c++
+void MyPass::initialize(MLIRContext *context) {
+ // No patterns are explicitly disabled.
+ SmallVector<std::string> disabledPatterns;
+ // Enable only patterns with a debug name or label of `MyRewritePatterns`.
+ SmallVector<std::string> enabledPatterns(1, "MyRewritePatterns");
+
+ RewritePatternSet rewritePatterns(context);
+ // ...
+ frozenPatterns = FrozenRewritePatternSet(rewritePatterns, disabledPatterns,
+ enabledPatterns);
+}
+```
+
+### Common Pass Utilities
+
+Passes that utilize rewrite patterns should aim to provide a common set of
+options and toggles to simplify the debugging experience when switching between
+different passes/projects/etc. To aid in this endeavor, MLIR provides a common
+set of utilities that can be easily included when defining a custom pass. These
+are defined in `mlir/RewritePassUtil.td`; an example usage is shown below:
+
+```tablegen
+def MyRewritePass : Pass<"..."> {
+ let summary = "...";
+ let constructor = "createMyRewritePass()";
+
+ // Inherit the common pattern rewrite options from `RewritePassUtils`.
+ let options = RewritePassUtils.options;
+}
+```
+
+#### Rewrite Pass Options
+
+This section documents common pass options that are useful for controlling the
+behavior of rewrite pattern application.
+
+##### Pattern Filtering
+
+Two common pattern filtering options are exposed, `disable-patterns` and
+`enable-patterns`, matching the behavior of the `disabledPatterns` and
+`enabledPatterns` lists described in the [Pattern Filtering](#pattern-filtering)
+section above. A snippet of the tablegen definition of these options is shown
+below:
+
+```tablegen
+ListOption<"disabledPatterns", "disable-patterns", "std::string",
+ "Labels of patterns that should be filtered out during application",
+ "llvm::cl::MiscFlags::CommaSeparated">,
+ListOption<"enabledPatterns", "enable-patterns", "std::string",
+ "Labels of patterns that should be used during application, all "
+ "other patterns are filtered out",
+ "llvm::cl::MiscFlags::CommaSeparated">,
+```
+
+These options may be used to provide filtering behavior when constructing any
+`FrozenRewritePatternSet`s within the pass:
+
+```c++
+void MyRewritePass::initialize(MLIRContext *context) {
+ RewritePatternSet rewritePatterns(context);
+ // ...
+
+ // When constructing the `FrozenRewritePatternSet`, we provide the filter
+ // list options.
+ frozenPatterns = FrozenRewritePatternSet(rewritePatterns, disabledPatterns,
+ enabledPatterns);
+}
+```
return contextAndHasBoundedRecursion.getPointer();
}
- /// Return readable pattern name. Should only be used for debugging purposes.
- /// Can be empty.
+ /// Return a readable name for this pattern. This name should only be used for
+ /// debugging purposes, and may be empty.
StringRef getDebugName() const { return debugName; }
- /// Set readable pattern name. Should only be used for debugging purposes.
+ /// Set the human readable debug name used for this pattern. This name will
+ /// only be used for debugging purposes.
void setDebugName(StringRef name) { debugName = name; }
+ /// Return the set of debug labels attached to this pattern.
+ ArrayRef<StringRef> getDebugLabels() const { return debugLabels; }
+
+ /// Add the provided debug labels to this pattern.
+ void addDebugLabels(ArrayRef<StringRef> labels) {
+ debugLabels.append(labels.begin(), labels.end());
+ }
+ void addDebugLabels(StringRef label) { debugLabels.push_back(label); }
+
protected:
/// This class acts as a special tag that makes the desire to match "any"
/// operation type explicit. This helps to avoid unnecessary usages of this
/// an op with this pattern.
SmallVector<OperationName, 2> generatedOps;
- /// Readable pattern name. Can be empty.
+ /// A readable name for this pattern. May be empty.
StringRef debugName;
+
+ /// The set of debug labels attached to this pattern.
+ SmallVector<StringRef, 0> debugLabels;
};
//===----------------------------------------------------------------------===//
// types 'Ts'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
- (void)std::initializer_list<int>{0, (addImpl<Ts>(arg, args...), 0)...};
+ (void)std::initializer_list<int>{
+ 0, (addImpl<Ts>(/*debugLabels=*/llvm::None, arg, args...), 0)...};
+ return *this;
+ }
+ /// An overload of the above `add` method that allows for attaching a set
+ /// of debug labels to the attached patterns. This is useful for labeling
+ /// groups of patterns that may be shared between multiple different
+ /// passes/users.
+ template <typename... Ts, typename ConstructorArg,
+ typename... ConstructorArgs,
+ typename = std::enable_if_t<sizeof...(Ts) != 0>>
+ RewritePatternSet &addWithLabel(ArrayRef<StringRef> debugLabels,
+ ConstructorArg &&arg,
+ ConstructorArgs &&... args) {
+ // The following expands a call to emplace_back for each of the pattern
+ // types 'Ts'. This magic is necessary due to a limitation in the places
+ // that a parameter pack can be expanded in c++11.
+ // FIXME: In c++17 this can be simplified by using 'fold expressions'.
+ (void)std::initializer_list<int>{
+ 0, (addImpl<Ts>(debugLabels, arg, args...), 0)...};
return *this;
}
// types 'Ts'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
- (void)std::initializer_list<int>{0, (addImpl<Ts>(arg, args...), 0)...};
+ (void)std::initializer_list<int>{
+ 0, (addImpl<Ts>(/*debugLabels=*/llvm::None, arg, args...), 0)...};
return *this;
}
/// chaining insertions.
template <typename T, typename... Args>
std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
- addImpl(Args &&... args) {
- nativePatterns.emplace_back(
- RewritePattern::create<T>(std::forward<Args>(args)...));
+ addImpl(ArrayRef<StringRef> debugLabels, Args &&... args) {
+ std::unique_ptr<T> pattern =
+ RewritePattern::create<T>(std::forward<Args>(args)...);
+ pattern->addDebugLabels(debugLabels);
+ nativePatterns.emplace_back(std::move(pattern));
}
template <typename T, typename... Args>
std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
- addImpl(Args &&... args) {
+ addImpl(ArrayRef<StringRef> debugLabels, Args &&... args) {
+ // TODO: Add the provided labels to the PDL pattern when PDL supports
+ // labels.
pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
}
return *this;
}
- MutableArrayRef<DataType> operator->() const { return &*this; }
+ /// Allow accessing the data held by this option.
+ MutableArrayRef<DataType> operator*() {
+ return static_cast<std::vector<DataType> &>(*this);
+ }
+ ArrayRef<DataType> operator*() const {
+ return static_cast<const std::vector<DataType> &>(*this);
+ }
private:
/// Return the main option instance.
/// Print the name and value of this option to the given stream.
void print(raw_ostream &os) final {
+ // Don't print the list if empty. An empty option value can be treated as
+ // an element of the list in certain cases (e.g. ListOption<std::string>).
+ if ((**this).empty())
+ return;
+
os << this->ArgStr << '=';
auto printElementFn = [&](const DataType &value) {
printValue(os, this->getParser(), value);
using OpSpecificNativePatternListT =
DenseMap<OperationName, std::vector<RewritePattern *>>;
- /// Freeze the patterns held in `patterns`, and take ownership.
FrozenRewritePatternSet();
- FrozenRewritePatternSet(RewritePatternSet &&patterns);
FrozenRewritePatternSet(FrozenRewritePatternSet &&patterns) = default;
FrozenRewritePatternSet(const FrozenRewritePatternSet &patterns) = default;
FrozenRewritePatternSet &
operator=(FrozenRewritePatternSet &&patterns) = default;
~FrozenRewritePatternSet();
+ /// Freeze the patterns held in `patterns`, and take ownership.
+ /// `disabledPatternLabels` is a set of labels used to filter out input
+ /// patterns with a label in this set. `enabledPatternLabels` is a set of
+ /// labels used to filter out input patterns that do not have one of the
+ /// lables in this set.
+ FrozenRewritePatternSet(
+ RewritePatternSet &&patterns,
+ ArrayRef<std::string> disabledPatternLabels = llvm::None,
+ ArrayRef<std::string> enabledPatternLabels = llvm::None);
+
/// Return the op specific native patterns held by this list.
const OpSpecificNativePatternListT &getOpSpecificNativePatterns() const {
return impl->nativeOpSpecificPatternMap;
--- /dev/null
+//===-- PassUtil.td - Utilities for rewrite passes ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains several utilities for passes that utilize rewrite
+// patterns.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_REWRITE_PASSUTIL_TD_
+#define MLIR_REWRITE_PASSUTIL_TD_
+
+include "mlir/Pass/PassBase.td"
+
+def RewritePassUtils {
+ // A set of options commonly options used for pattern rewrites.
+ list<Option> options = [
+ // These two options provide filtering for which patterns are applied. These
+ // should be passed directly to the FrozenRewritePatternSet when it is
+ // created.
+ ListOption<"disabledPatterns", "disable-patterns", "std::string",
+ "Labels of patterns that should be filtered out during"
+ " application",
+ "llvm::cl::MiscFlags::CommaSeparated">,
+ ListOption<"enabledPatterns", "enable-patterns", "std::string",
+ "Labels of patterns that should be used during"
+ " application, all other patterns are filtered out",
+ "llvm::cl::MiscFlags::CommaSeparated">,
+ ];
+}
+
+#endif // MLIR_REWRITE_PASSUTIL_TD_
#define MLIR_TRANSFORMS_PASSES
include "mlir/Pass/PassBase.td"
+include "mlir/Rewrite/PassUtil.td"
def AffineLoopFusion : FunctionPass<"affine-loop-fusion"> {
let summary = "Fuse affine loop nests";
Option<"maxIterations", "max-iterations", "unsigned",
/*default=*/"10",
"Seed the worklist in general top-down order">
- ];
+ ] # RewritePassUtils.options;
}
def CSE : Pass<"cse"> {
FrozenRewritePatternSet::FrozenRewritePatternSet()
: impl(std::make_shared<Impl>()) {}
-FrozenRewritePatternSet::FrozenRewritePatternSet(RewritePatternSet &&patterns)
+FrozenRewritePatternSet::FrozenRewritePatternSet(
+ RewritePatternSet &&patterns, ArrayRef<std::string> disabledPatternLabels,
+ ArrayRef<std::string> enabledPatternLabels)
: impl(std::make_shared<Impl>()) {
+ DenseSet<StringRef> disabledPatterns, enabledPatterns;
+ disabledPatterns.insert(disabledPatternLabels.begin(),
+ disabledPatternLabels.end());
+ enabledPatterns.insert(enabledPatternLabels.begin(),
+ enabledPatternLabels.end());
+
// Functor used to walk all of the operations registered in the context. This
// is useful for patterns that get applied to multiple operations, such as
// interface and trait based patterns.
};
for (std::unique_ptr<RewritePattern> &pat : patterns.getNativePatterns()) {
+ // Don't add patterns that haven't been enabled by the user.
+ if (!enabledPatterns.empty()) {
+ auto isEnabledFn = [&](StringRef label) {
+ return enabledPatterns.count(label);
+ };
+ if (!isEnabledFn(pat->getDebugName()) &&
+ llvm::none_of(pat->getDebugLabels(), isEnabledFn))
+ continue;
+ }
+ // Don't add patterns that have been disabled by the user.
+ if (!disabledPatterns.empty()) {
+ auto isDisabledFn = [&](StringRef label) {
+ return disabledPatterns.count(label);
+ };
+ if (isDisabledFn(pat->getDebugName()) ||
+ llvm::any_of(pat->getDebugLabels(), isDisabledFn))
+ continue;
+ }
+
if (Optional<OperationName> rootName = pat->getRootKind()) {
impl->nativeOpSpecificPatternMap[*rootName].push_back(pat.get());
impl->nativeOpSpecificPatternList.push_back(std::move(pat));
dialect->getCanonicalizationPatterns(owningPatterns);
for (auto *op : context->getRegisteredOperations())
op->getCanonicalizationPatterns(owningPatterns, context);
- patterns = std::move(owningPatterns);
+
+ patterns = FrozenRewritePatternSet(std::move(owningPatterns),
+ disabledPatterns, enabledPatterns);
return success();
}
void runOnOperation() override {
// CHECK_1: test-options-pass{list=1,2,3,4,5 string=some_value string-list=a,b,c,d}
// CHECK_2: test-options-pass{list=1 string= string-list=a,b}
-// CHECK_3: module(func(test-options-pass{list=3 string= string-list=}), func(test-options-pass{list=1,2,3,4 string= string-list=}))
+// CHECK_3: module(func(test-options-pass{list=3 string= }), func(test-options-pass{list=1,2,3,4 string= }))
--- /dev/null
+// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s --check-prefix=NO_FILTER
+// RUN: mlir-opt %s -pass-pipeline='func(canonicalize{enable-patterns=TestRemoveOpWithInnerOps})' | FileCheck %s --check-prefix=FILTER_ENABLE
+// RUN: mlir-opt %s -pass-pipeline='func(canonicalize{disable-patterns=TestRemoveOpWithInnerOps})' | FileCheck %s --check-prefix=FILTER_DISABLE
+
+// NO_FILTER-LABEL: func @remove_op_with_inner_ops_pattern
+// NO_FILTER-NEXT: return
+// FILTER_ENABLE-LABEL: func @remove_op_with_inner_ops_pattern
+// FILTER_ENABLE-NEXT: return
+// FILTER_DISABLE-LABEL: func @remove_op_with_inner_ops_pattern
+// FILTER_DISABLE-NEXT: "test.op_with_region_pattern"()
+func @remove_op_with_inner_ops_pattern() {
+ "test.op_with_region_pattern"() ({
+ "test.op_with_region_terminator"() : () -> ()
+ }) : () -> ()
+ return
+}
: public OpRewritePattern<TestOpWithRegionPattern> {
using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
+ void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
+
LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
PatternRewriter &rewriter) const override {
rewriter.eraseOp(op);