Add a layer of recursive matchers that compose.
authorNicolas Vasilache <ntv@google.com>
Mon, 9 Dec 2019 02:09:07 +0000 (18:09 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 9 Dec 2019 02:09:40 +0000 (18:09 -0800)
This CL adds support for building matchers recursively.
The following matchers are provided:

1. `m_any()` can match any value
2. `m_val(Value *)` binds to a value and must match it
3. `RecursivePatternMatcher<OpType, Matchers...>` n-arity pattern that matches `OpType` and whose operands must be matched exactly by `Matchers...`.

This allows building expression templates for patterns, declaratively, in a very natural fashion.
For example pattern `p9` defined as follows:
```
  auto mul_of_muladd = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<AddFOp>());
  auto mul_of_anyadd = m_Op<MulFOp>(m_any(), m_Op<AddFOp>());
  auto p9 = m_Op<MulFOp>(m_Op<MulFOp>(
                     mul_of_muladd, m_Op<MulFOp>()),
                   m_Op<MulFOp>(mul_of_anyadd, mul_of_anyadd));
```

Successfully matches `%6` in:
```
  %0 = addf %a, %b: f32
  %1 = addf %a, %c: f32 // matched
  %2 = addf %c, %b: f32
  %3 = mulf %a, %2: f32 // matched
  %4 = mulf %3, %1: f32 // matched
  %5 = mulf %4, %4: f32 // matched
  %6 = mulf %5, %5: f32 // matched
```

Note that 0-ary matchers can be used as leaves in place of n-ary matchers. This alleviates from passing explicit `m_any()` leaves.

In the future, we may add extra patterns to specify that operands may be matched in any order.

PiperOrigin-RevId: 284469446

mlir/include/mlir/IR/Matchers.h
mlir/test/IR/test-matchers.mlir [new file with mode: 0644]
mlir/test/lib/IR/CMakeLists.txt
mlir/test/lib/IR/TestMatchers.cpp [new file with mode: 0644]

index 1a1869b..99a33b6 100644 (file)
@@ -26,7 +26,6 @@
 
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/StandardTypes.h"
-#include <type_traits>
 
 namespace mlir {
 
@@ -134,30 +133,78 @@ template <typename OpClass> struct op_matcher {
   bool match(Operation *op) { return isa<OpClass>(op); }
 };
 
-} // end namespace detail
+/// Trait to check whether T provides a 'match' method with type
+/// `OperationOrValue`.
+template <typename T, typename OperationOrValue>
+using has_operation_or_value_matcher_t =
+    decltype(std::declval<T>().match(std::declval<OperationOrValue>()));
 
-/// Entry point for matching a pattern over a Value.
-template <typename Pattern>
-inline bool matchPattern(Value *value, const Pattern &pattern) {
-  // TODO: handle other cases
-  if (auto *op = value->getDefiningOp())
-    return const_cast<Pattern &>(pattern).match(op);
+/// Statically switch to a Value matcher.
+template <typename MatcherClass>
+typename std::enable_if_t<is_detected<detail::has_operation_or_value_matcher_t,
+                                      MatcherClass, Value *>::value,
+                          bool>
+matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
+  return matcher.match(op->getOperand(idx));
+}
+
+/// Statically switch to an Operation matcher.
+template <typename MatcherClass>
+typename std::enable_if_t<is_detected<detail::has_operation_or_value_matcher_t,
+                                      MatcherClass, Operation *>::value,
+                          bool>
+matchOperandOrValueAtIndex(Operation *op, unsigned idx, MatcherClass &matcher) {
+  if (auto defOp = op->getOperand(idx)->getDefiningOp())
+    return matcher.match(defOp);
   return false;
 }
 
-/// Entry point for matching a pattern over an Operation.
-template <typename Pattern>
-inline bool matchPattern(Operation *op, const Pattern &pattern) {
-  return const_cast<Pattern &>(pattern).match(op);
+/// Terminal matcher, always returns true.
+struct AnyValueMatcher {
+  bool match(Value *op) const { return true; }
+};
+
+/// Binds to a specific value and matches it.
+struct PatternMatcherValue {
+  PatternMatcherValue(Value *val) : value(val) {}
+  bool match(Value *val) const { return val == value; }
+  Value *value;
+};
+
+template <typename TupleT, class CallbackT, std::size_t... Is>
+constexpr void enumerateImpl(TupleT &&tuple, CallbackT &&callback,
+                             std::index_sequence<Is...>) {
+  (void)std::initializer_list<int>{
+      0,
+      (callback(std::integral_constant<std::size_t, Is>{}, std::get<Is>(tuple)),
+       0)...};
 }
 
-/// Matches a constant holding a scalar/vector/tensor integer (splat) and
-/// writes the integer value to bind_value.
-inline detail::constant_int_op_binder
-m_ConstantInt(IntegerAttr::ValueType *bind_value) {
-  return detail::constant_int_op_binder(bind_value);
+template <typename... Tys, typename CallbackT>
+constexpr void enumerate(std::tuple<Tys...> &tuple, CallbackT &&callback) {
+  detail::enumerateImpl(tuple, std::forward<CallbackT>(callback),
+                        std::make_index_sequence<sizeof...(Tys)>{});
 }
 
+/// RecursivePatternMatcher that composes.
+template <typename OpType, typename... OperandMatchers>
+struct RecursivePatternMatcher {
+  RecursivePatternMatcher(OperandMatchers... matchers)
+      : operandMatchers(matchers...) {}
+  bool match(Operation *op) {
+    if (!isa<OpType>(op) || op->getNumOperands() != sizeof...(OperandMatchers))
+      return false;
+    bool res = true;
+    enumerate(operandMatchers, [&](size_t index, auto &matcher) {
+      res &= matchOperandOrValueAtIndex(op, index, matcher);
+    });
+    return res;
+  }
+  std::tuple<OperandMatchers...> operandMatchers;
+};
+
+} // end namespace detail
+
 /// Matches a value from a constant foldable operation and writes the value to
 /// bind_value.
 template <typename AttrT>
@@ -186,6 +233,38 @@ inline detail::constant_int_not_value_matcher<0> m_NonZero() {
   return detail::constant_int_not_value_matcher<0>();
 }
 
+/// Entry point for matching a pattern over a Value.
+template <typename Pattern>
+inline bool matchPattern(Value *value, const Pattern &pattern) {
+  // TODO: handle other cases
+  if (auto *op = value->getDefiningOp())
+    return const_cast<Pattern &>(pattern).match(op);
+  return false;
+}
+
+/// Entry point for matching a pattern over an Operation.
+template <typename Pattern>
+inline bool matchPattern(Operation *op, const Pattern &pattern) {
+  return const_cast<Pattern &>(pattern).match(op);
+}
+
+/// Matches a constant holding a scalar/vector/tensor integer (splat) and
+/// writes the integer value to bind_value.
+inline detail::constant_int_op_binder
+m_ConstantInt(IntegerAttr::ValueType *bind_value) {
+  return detail::constant_int_op_binder(bind_value);
+}
+
+template <typename OpType, typename... Matchers>
+auto m_Op(Matchers... matchers) {
+  return detail::RecursivePatternMatcher<OpType, Matchers...>(matchers...);
+}
+
+namespace matchers {
+inline auto m_any() { return detail::AnyValueMatcher(); }
+inline auto m_val(Value *v) { return detail::PatternMatcherValue(v); }
+} // namespace matchers
+
 } // end namespace mlir
 
 #endif // MLIR_MATCHERS_H
diff --git a/mlir/test/IR/test-matchers.mlir b/mlir/test/IR/test-matchers.mlir
new file mode 100644 (file)
index 0000000..c428e78
--- /dev/null
@@ -0,0 +1,32 @@
+// RUN: mlir-opt %s -disable-pass-threading=true -test-matchers -o /dev/null 2>&1 | FileCheck %s
+
+func @test1(%a: f32, %b: f32, %c: f32) {
+  %0 = addf %a, %b: f32
+  %1 = addf %a, %c: f32
+  %2 = addf %c, %b: f32
+  %3 = mulf %a, %2: f32
+  %4 = mulf %3, %1: f32
+  %5 = mulf %4, %4: f32
+  %6 = mulf %5, %5: f32
+  return
+}
+
+// CHECK-LABEL: test1
+//       CHECK:   Pattern add(*) matched 3 times
+//       CHECK:   Pattern mul(*) matched 4 times
+//       CHECK:   Pattern add(add(*), *) matched 0 times
+//       CHECK:   Pattern add(*, add(*)) matched 0 times
+//       CHECK:   Pattern mul(add(*), *) matched 0 times
+//       CHECK:   Pattern mul(*, add(*)) matched 2 times
+//       CHECK:   Pattern mul(mul(*), *) matched 3 times
+//       CHECK:   Pattern mul(mul(*), mul(*)) matched 2 times
+//       CHECK:   Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched 1 times
+//       CHECK:   Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, add(*)), mul(*, add(*)))) matched 1 times
+//       CHECK:   Pattern add(a, b) matched 1 times
+//       CHECK:   Pattern add(a, c) matched 1 times
+//       CHECK:   Pattern add(b, a) matched 0 times
+//       CHECK:   Pattern add(c, a) matched 0 times
+//       CHECK:   Pattern mul(a, add(c, b)) matched 1 times
+//       CHECK:   Pattern mul(a, add(b, c)) matched 0 times
+//       CHECK:   Pattern mul(mul(a, *), add(a, c)) matched 1 times
+//       CHECK:   Pattern mul(mul(a, *), add(c, b)) matched 0 times
index 439d3a4..4ac6a91 100644 (file)
@@ -1,5 +1,6 @@
 add_llvm_library(MLIRTestIR
   TestFunc.cpp
+  TestMatchers.cpp
   TestSymbolUses.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/test/lib/IR/TestMatchers.cpp b/mlir/test/lib/IR/TestMatchers.cpp
new file mode 100644 (file)
index 0000000..c0b92a8
--- /dev/null
@@ -0,0 +1,150 @@
+//===- TestMatchers.cpp - Pass to test matchers ---------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Matchers.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// This is a test pass for verifying matchers.
+struct TestMatchers : public ModulePass<TestMatchers> {
+  void runOnModule() override;
+};
+} // end anonymous namespace
+
+// This could be done better but is not worth the variadic template trouble.
+template <typename Matcher> unsigned countMatches(FuncOp f, Matcher &matcher) {
+  unsigned count = 0;
+  f.walk([&count, &matcher](Operation *op) {
+    if (matcher.match(op)) {
+      // llvm::outs() << "matched " << *op << "\n";
+      ++count;
+    }
+  });
+  return count;
+}
+
+static void test1(FuncOp f) {
+  using mlir::matchers::m_any;
+  using mlir::matchers::m_val;
+
+  assert(f.getNumArguments() == 3 && "matcher test funcs must have 3 args");
+  auto a = m_val(f.getArgument(0));
+  auto b = m_val(f.getArgument(1));
+  auto c = m_val(f.getArgument(2));
+  (void)a;
+  (void)b;
+  (void)c;
+
+  llvm::outs() << f.getName();
+
+  auto p0 = m_Op<AddFOp>(); // using 0-arity matcher
+  llvm::outs() << "Pattern add(*) matched " << countMatches(f, p0)
+               << " times\n";
+
+  auto p1 = m_Op<MulFOp>(); // using 0-arity matcher
+  llvm::outs() << "Pattern mul(*) matched " << countMatches(f, p1)
+               << " times\n";
+
+  auto p2 = m_Op<AddFOp>(m_Op<AddFOp>(), m_any());
+  llvm::outs() << "Pattern add(add(*), *) matched " << countMatches(f, p2)
+               << " times\n";
+
+  auto p3 = m_Op<AddFOp>(m_any(), m_Op<AddFOp>());
+  llvm::outs() << "Pattern add(*, add(*)) matched " << countMatches(f, p3)
+               << " times\n";
+
+  auto p4 = m_Op<MulFOp>(m_Op<AddFOp>(), m_any());
+  llvm::outs() << "Pattern mul(add(*), *) matched " << countMatches(f, p4)
+               << " times\n";
+
+  auto p5 = m_Op<MulFOp>(m_any(), m_Op<AddFOp>());
+  llvm::outs() << "Pattern mul(*, add(*)) matched " << countMatches(f, p5)
+               << " times\n";
+
+  auto p6 = m_Op<MulFOp>(m_Op<MulFOp>(), m_any());
+  llvm::outs() << "Pattern mul(mul(*), *) matched " << countMatches(f, p6)
+               << " times\n";
+
+  auto p7 = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<MulFOp>());
+  llvm::outs() << "Pattern mul(mul(*), mul(*)) matched " << countMatches(f, p7)
+               << " times\n";
+
+  auto mul_of_mulmul = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<MulFOp>());
+  auto p8 = m_Op<MulFOp>(mul_of_mulmul, mul_of_mulmul);
+  llvm::outs()
+      << "Pattern mul(mul(mul(*), mul(*)), mul(mul(*), mul(*))) matched "
+      << countMatches(f, p8) << " times\n";
+
+  // clang-format off
+  auto mul_of_muladd = m_Op<MulFOp>(m_Op<MulFOp>(), m_Op<AddFOp>());
+  auto mul_of_anyadd = m_Op<MulFOp>(m_any(), m_Op<AddFOp>());
+  auto p9 = m_Op<MulFOp>(m_Op<MulFOp>(
+                     mul_of_muladd, m_Op<MulFOp>()),
+                   m_Op<MulFOp>(mul_of_anyadd, mul_of_anyadd));
+  // clang-format on
+  llvm::outs() << "Pattern mul(mul(mul(mul(*), add(*)), mul(*)), mul(mul(*, "
+                  "add(*)), mul(*, add(*)))) matched "
+               << countMatches(f, p9) << " times\n";
+
+  auto p10 = m_Op<AddFOp>(a, b);
+  llvm::outs() << "Pattern add(a, b) matched " << countMatches(f, p10)
+               << " times\n";
+
+  auto p11 = m_Op<AddFOp>(a, c);
+  llvm::outs() << "Pattern add(a, c) matched " << countMatches(f, p11)
+               << " times\n";
+
+  auto p12 = m_Op<AddFOp>(b, a);
+  llvm::outs() << "Pattern add(b, a) matched " << countMatches(f, p12)
+               << " times\n";
+
+  auto p13 = m_Op<AddFOp>(c, a);
+  llvm::outs() << "Pattern add(c, a) matched " << countMatches(f, p13)
+               << " times\n";
+
+  auto p14 = m_Op<MulFOp>(a, m_Op<AddFOp>(c, b));
+  llvm::outs() << "Pattern mul(a, add(c, b)) matched " << countMatches(f, p14)
+               << " times\n";
+
+  auto p15 = m_Op<MulFOp>(a, m_Op<AddFOp>(b, c));
+  llvm::outs() << "Pattern mul(a, add(b, c)) matched " << countMatches(f, p15)
+               << " times\n";
+
+  auto mul_of_aany = m_Op<MulFOp>(a, m_any());
+  auto p16 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(a, c));
+  llvm::outs() << "Pattern mul(mul(a, *), add(a, c)) matched "
+               << countMatches(f, p16) << " times\n";
+
+  auto p17 = m_Op<MulFOp>(mul_of_aany, m_Op<AddFOp>(c, b));
+  llvm::outs() << "Pattern mul(mul(a, *), add(c, b)) matched "
+               << countMatches(f, p17) << " times\n";
+}
+
+void TestMatchers::runOnModule() {
+  auto m = getModule();
+  for (auto f : m.getOps<FuncOp>()) {
+    if (f.getName() == "test1")
+      test1(f);
+  }
+}
+
+static PassRegistration<TestMatchers> pass("test-matchers",
+                                           "Test C++ pattern matchers.");