This CL adds support for building matchers recursively.
The following matchers are provided:
1. `m_any()` can match any value
2. `m_val(Value *)` binds to a value and must match it
3. `RecursivePatternMatcher<OpType, Matchers...>` n-arity pattern that matches `OpType` and whose operands must be matched exactly by `Matchers...`.
This allows building expression templates for patterns, declaratively, in a very natural fashion.
For example pattern `p9` defined as follows:
```
auto mul_of_muladd = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<AddFOp>());
auto mul_of_anyadd = m_Op<MulFOp>(m_any(), m_Op<AddFOp>());
auto p9 = m_Op<MulFOp>(m_Op<MulFOp>(
mul_of_muladd, m_Op<MulFOp>()),
m_Op<MulFOp>(mul_of_anyadd, mul_of_anyadd));
```
Successfully matches `%6` in:
```
%0 = addf %a, %b: f32
%1 = addf %a, %c: f32 // matched
%2 = addf %c, %b: f32
%3 = mulf %a, %2: f32 // matched
%4 = mulf %3, %1: f32 // matched
%5 = mulf %4, %4: f32 // matched
%6 = mulf %5, %5: f32 // matched
```
Note that 0-ary matchers can be used as leaves in place of n-ary matchers. This alleviates from passing explicit `m_any()` leaves.
In the future, we may add extra patterns to specify that operands may be matched in any order.
PiperOrigin-RevId:
284469446
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
-#include <type_traits>
namespace mlir {
bool match(Operation *op) { return isa<OpClass>(op); }
};
-} // end namespace detail
+/// Trait to check whether T provides a 'match' method with type
+/// `OperationOrValue`.
+template <typename T, typename OperationOrValue>
+using has_operation_or_value_matcher_t =
+ decltype(std::declval<T>().match(std::declval<OperationOrValue>()));
-/// Entry point for matching a pattern over a Value.
-template <typename Pattern>
-inline bool matchPattern(Value *value, const Pattern &pattern) {
- // TODO: handle other cases
- if (auto *op = value->getDefiningOp())
- return const_cast<Pattern &>(pattern).match(op);
+/// Statically switch to a Value matcher.
+template <typename MatcherClass>
+typename std::enable_if_t<is_detected<detail::has_operation_or_value_matcher_t,
+ MatcherClass, Value *>::value,
+ bool>
+matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
+ return matcher.match(op->getOperand(idx));
+}
+
+/// Statically switch to an Operation matcher.
+template <typename MatcherClass>
+typename std::enable_if_t<is_detected<detail::has_operation_or_value_matcher_t,
+ MatcherClass, Operation *>::value,
+ bool>
+matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
+ if (auto defOp = op->getOperand(idx)->getDefiningOp())
+ return matcher.match(defOp);
return false;
}
-/// Entry point for matching a pattern over an Operation.
-template <typename Pattern>
-inline bool matchPattern(Operation *op, const Pattern &pattern) {
- return const_cast<Pattern &>(pattern).match(op);
+/// Terminal matcher, always returns true.
+struct AnyValueMatcher {
+ bool match(Value *op) const { return true; }
+};
+
+/// Binds to a specific value and matches it.
+struct PatternMatcherValue {
+ PatternMatcherValue(Value *val) : value(val) {}
+ bool match(Value *val) const { return val == value; }
+ Value *value;
+};
+
+template <typename TupleT, class CallbackT, std::size_t... Is>
+constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback,
+ std::index_sequence<Is...>) {
+ (void)std::initializer_list<int>{
+ 0,
+ (callback(std::integral_constant<std::size_t, Is>{}, std::get<Is>(tuple)),
+ 0)...};
}
-/// Matches a constant holding a scalar/vector/tensor integer (splat) and
-/// writes the integer value to bind_value.
-inline detail::constant_int_op_binder
-m_ConstantInt(IntegerAttr::ValueType *bind_value) {
- return detail::constant_int_op_binder(bind_value);
+template <typename... Tys, typename CallbackT>
+constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) {
+ detail::enumerateImpl(tuple, std::forward<CallbackT>(callback),
+ std::make_index_sequence<sizeof...(Tys)>{});
}
+/// RecursivePatternMatcher that composes.
+template <typename OpType, typename... OperandMatchers>
+struct RecursivePatternMatcher {
+ RecursivePatternMatcher(OperandMatchers... matchers)
+ : operandMatchers(matchers...) {}
+ bool match(Operation *op) {
+ if (!isa<OpType>(op) || op->getNumOperands() != sizeof...(OperandMatchers))
+ return false;
+ bool res = true;
+ enumerate(operandMatchers, [&](size_t index, auto &matcher) {
+ res &= matchOperandOrValueAtIndex(op, index, matcher);
+ });
+ return res;
+ }
+ std::tuple<OperandMatchers...> operandMatchers;
+};
+
+} // end namespace detail
+
/// Matches a value from a constant foldable operation and writes the value to
/// bind_value.
template <typename AttrT>
return detail::constant_int_not_value_matcher<0>();
}
+/// Entry point for matching a pattern over a Value.
+template <typename Pattern>
+inline bool matchPattern(Value *value, const Pattern &pattern) {
+ // TODO: handle other cases
+ if (auto *op = value->getDefiningOp())
+ return const_cast<Pattern &>(pattern).match(op);
+ return false;
+}
+
+/// Entry point for matching a pattern over an Operation.
+template <typename Pattern>
+inline bool matchPattern(Operation *op, const Pattern &pattern) {
+ return const_cast<Pattern &>(pattern).match(op);
+}
+
+/// Matches a constant holding a scalar/vector/tensor integer (splat) and
+/// writes the integer value to bind_value.
+inline detail::constant_int_op_binder
+m_ConstantInt(IntegerAttr::ValueType *bind_value) {
+ return detail::constant_int_op_binder(bind_value);
+}
+
+template <typename OpType, typename... Matchers>
+auto m_Op(Matchers... matchers) {
+ return detail::RecursivePatternMatcher<OpType, Matchers...>(matchers...);
+}
+
+namespace matchers {
+inline auto m_any() { return detail::AnyValueMatcher(); }
+inline auto m_val(Value *v) { return detail::PatternMatcherValue(v); }
+} // namespace matchers
+
} // end namespace mlir
#endif // MLIR_MATCHERS_H
--- /dev/null
+// RUN: mlir-opt %s -disable-pass-threading=true -test-matchers -o /dev/null 2>&1 | FileCheck %s
+
+func @test1(%a: f32, %b: f32, %c: f32) {
+ %0 = addf %a, %b: f32
+ %1 = addf %a, %c: f32
+ %2 = addf %c, %b: f32
+ %3 = mulf %a, %2: f32
+ %4 = mulf %3, %1: f32
+ %5 = mulf %4, %4: f32
+ %6 = mulf %5, %5: f32
+ return
+}
+
+// CHECK-LABEL: test1
+// CHECK: Pattern add(*) matched 3 times
+// CHECK: Pattern mul(*) matched 4 times
+// CHECK: Pattern add(add(*), *) matched 0 times
+// CHECK: Pattern add(*, add(*)) matched 0 times
+// CHECK: Pattern mul(add(*), *) matched 0 times
+// CHECK: Pattern mul(*, add(*)) matched 2 times
+// CHECK: Pattern mul(mul(*), *) matched 3 times
+// CHECK: Pattern mul(mul(*), mul(*)) matched 2 times
+// CHECK: Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched 1 times
+// CHECK: Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, add(*)), mul(*, add(*)))) matched 1 times
+// CHECK: Pattern add(a, b) matched 1 times
+// CHECK: Pattern add(a, c) matched 1 times
+// CHECK: Pattern add(b, a) matched 0 times
+// CHECK: Pattern add(c, a) matched 0 times
+// CHECK: Pattern mul(a, add(c, b)) matched 1 times
+// CHECK: Pattern mul(a, add(b, c)) matched 0 times
+// CHECK: Pattern mul(mul(a, *), add(a, c)) matched 1 times
+// CHECK: Pattern mul(mul(a, *), add(c, b)) matched 0 times
add_llvm_library(MLIRTestIR
TestFunc.cpp
+ TestMatchers.cpp
TestSymbolUses.cpp
ADDITIONAL_HEADER_DIRS
--- /dev/null
+//===- TestMatchers.cpp - Pass to test matchers ---------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// This is a test pass for verifying matchers.
+struct TestMatchers : public ModulePass<TestMatchers> {
+ void runOnModule() override;
+};
+} // end anonymous namespace
+
+// This could be done better but is not worth the variadic template trouble.
+template <typename Matcher> unsigned countMatches(FuncOp f, Matcher &matcher) {
+ unsigned count = 0;
+ f.walk([&count, &matcher](Operation *op) {
+ if (matcher.match(op)) {
+ // llvm::outs() << "matched " << *op << "\n";
+ ++count;
+ }
+ });
+ return count;
+}
+
+static void test1(FuncOp f) {
+ using mlir::matchers::m_any;
+ using mlir::matchers::m_val;
+
+ assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args");
+ auto a = m_val(f.getArgument(0));
+ auto b = m_val(f.getArgument(1));
+ auto c = m_val(f.getArgument(2));
+ (void)a;
+ (void)b;
+ (void)c;
+
+ llvm::outs() << f.getName();
+
+ auto p0 = m_Op<AddFOp>(); // using 0-arity matcher
+ llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0)
+ << " times\n";
+
+ auto p1 = m_Op<MulFOp>(); // using 0-arity matcher
+ llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1)
+ << " times\n";
+
+ auto p2 = m_Op<AddFOp>(m_Op<AddFOp>(), m_any());
+ llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2)
+ << " times\n";
+
+ auto p3 = m_Op<AddFOp>(m_any(), m_Op<AddFOp>());
+ llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3)
+ << " times\n";
+
+ auto p4 = m_Op<MulFOp>(m_Op<AddFOp>(), m_any());
+ llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4)
+ << " times\n";
+
+ auto p5 = m_Op<MulFOp>(m_any(), m_Op<AddFOp>());
+ llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5)
+ << " times\n";
+
+ auto p6 = m_Op<MulFOp>(m_Op<MulFOp>(), m_any());
+ llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6)
+ << " times\n";
+
+ auto p7 = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<MulFOp>());
+ llvm::outs() << "Pattern mul(mul(*), mul(*)) matched " << countMatches(f, p7)
+ << " times\n";
+
+ auto mul_of_mulmul = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<MulFOp>());
+ auto p8 = m_Op<MulFOp>(mul_of_mulmul, mul_of_mulmul);
+ llvm::outs()
+ << "Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched "
+ << countMatches(f, p8) << " times\n";
+
+ // clang-format off
+ auto mul_of_muladd = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<AddFOp>());
+ auto mul_of_anyadd = m_Op<MulFOp>(m_any(), m_Op<AddFOp>());
+ auto p9 = m_Op<MulFOp>(m_Op<MulFOp>(
+ mul_of_muladd, m_Op<MulFOp>()),
+ m_Op<MulFOp>(mul_of_anyadd, mul_of_anyadd));
+ // clang-format on
+ llvm::outs() << "Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, "
+ "add(*)), mul(*, add(*)))) matched "
+ << countMatches(f, p9) << " times\n";
+
+ auto p10 = m_Op<AddFOp>(a, b);
+ llvm::outs() << "Pattern add(a, b) matched " << countMatches(f, p10)
+ << " times\n";
+
+ auto p11 = m_Op<AddFOp>(a, c);
+ llvm::outs() << "Pattern add(a, c) matched " << countMatches(f, p11)
+ << " times\n";
+
+ auto p12 = m_Op<AddFOp>(b, a);
+ llvm::outs() << "Pattern add(b, a) matched " << countMatches(f, p12)
+ << " times\n";
+
+ auto p13 = m_Op<AddFOp>(c, a);
+ llvm::outs() << "Pattern add(c, a) matched " << countMatches(f, p13)
+ << " times\n";
+
+ auto p14 = m_Op<MulFOp>(a, m_Op<AddFOp>(c, b));
+ llvm::outs() << "Pattern mul(a, add(c, b)) matched " << countMatches(f, p14)
+ << " times\n";
+
+ auto p15 = m_Op<MulFOp>(a, m_Op<AddFOp>(b, c));
+ llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15)
+ << " times\n";
+
+ auto mul_of_aany = m_Op<MulFOp>(a, m_any());
+ auto p16 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(a, c));
+ llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched "
+ << countMatches(f, p16) << " times\n";
+
+ auto p17 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(c, b));
+ llvm::outs() << "Pattern mul(mul(a, *), add(c, b)) matched "
+ << countMatches(f, p17) << " times\n";
+}
+
+void TestMatchers::runOnModule() {
+ auto m = getModule();
+ for (auto f : m.getOps<FuncOp>()) {
+ if (f.getName() == "test1")
+ test1(f);
+ }
+}
+
+static PassRegistration<TestMatchers> pass("test-matchers",
+ "Test C++ pattern matchers.");