[mlir][matchers] Add m_Op(StringRef) and m_Attr matchers
authorDevajith V S <devajithvs@gmail.com>
Tue, 11 Apr 2023 20:55:59 +0000 (13:55 -0700)
committerJacques Pienaar <jpienaar@google.com>
Tue, 11 Apr 2023 21:16:14 +0000 (14:16 -0700)
This patch introduces support for m_Op with a StringRef argument and m_Attr matchers. These matchers will be very useful for mlir-query that is being developed currently.

Submitting this patch separately to reduce the final patch size and make it easier to upstream mlir-query.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D147262

mlir/include/mlir/IR/Matchers.h
mlir/test/IR/test-matchers.mlir
mlir/test/lib/IR/TestMatchers.cpp

index 374f05a..4dbc623 100644 (file)
@@ -52,6 +52,22 @@ struct constant_op_matcher {
   bool match(Operation *op) { return op->hasTrait<OpTrait::ConstantLike>(); }
 };
 
+/// The matcher that matches operations that have the specified op name.
+struct NameOpMatcher {
+  NameOpMatcher(StringRef name) : name(name) {}
+  bool match(Operation *op) { return op->getName().getStringRef() == name; }
+
+  StringRef name;
+};
+
+/// The matcher that matches operations that have the specified attribute name.
+struct AttrOpMatcher {
+  AttrOpMatcher(StringRef attrName) : attrName(attrName) {}
+  bool match(Operation *op) { return op->hasAttr(attrName); }
+
+  StringRef attrName;
+};
+
 /// The matcher that matches operations that have the `ConstantLike` trait, and
 /// binds the folded attribute value.
 template <typename AttrT>
@@ -83,6 +99,29 @@ struct constant_op_binder {
   }
 };
 
+/// The matcher that matches operations that have the specified attribute
+/// name, and binds the attribute value.
+template <typename AttrT>
+struct AttrOpBinder {
+  /// Creates a matcher instance that binds the attribute value to
+  /// bind_value if match succeeds.
+  AttrOpBinder(StringRef attrName, AttrT *bindValue)
+      : attrName(attrName), bindValue(bindValue) {}
+  /// Creates a matcher instance that doesn't bind if match succeeds.
+  AttrOpBinder(StringRef attrName) : attrName(attrName), bindValue(nullptr) {}
+
+  bool match(Operation *op) {
+    if (auto attr = op->getAttrOfType<AttrT>(attrName)) {
+      if (bindValue)
+        *bindValue = attr;
+      return true;
+    }
+    return false;
+  }
+  StringRef attrName;
+  AttrT *bindValue;
+};
+
 /// The matcher that matches a constant scalar / vector splat / tensor splat
 /// float operation and binds the constant float value.
 struct constant_float_op_binder {
@@ -249,6 +288,16 @@ inline detail::constant_op_matcher m_Constant() {
   return detail::constant_op_matcher();
 }
 
+/// Matches a named attribute operation.
+inline detail::AttrOpMatcher m_Attr(StringRef attrName) {
+  return detail::AttrOpMatcher(attrName);
+}
+
+/// Matches a named operation.
+inline detail::NameOpMatcher m_Op(StringRef opName) {
+  return detail::NameOpMatcher(opName);
+}
+
 /// Matches a value from a constant foldable operation and writes the value to
 /// bind_value.
 template <typename AttrT>
@@ -256,6 +305,13 @@ inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) {
   return detail::constant_op_binder<AttrT>(bind_value);
 }
 
+/// Matches a named attribute operation and writes the value to bind_value.
+template <typename AttrT>
+inline detail::AttrOpBinder<AttrT> m_Attr(StringRef attrName,
+                                          AttrT *bindValue) {
+  return detail::AttrOpBinder<AttrT>(attrName, bindValue);
+}
+
 /// Matches a constant scalar / vector splat / tensor splat float (both positive
 /// and negative) zero.
 inline detail::constant_float_predicate_matcher m_AnyZeroFloat() {
index 87c7bf9..31f1b6d 100644 (file)
@@ -41,3 +41,14 @@ func.func @test2(%a: f32) -> f32 {
 // CHECK-LABEL: test2
 //       CHECK:   Pattern add(add(a, constant), a) matched and bound constant to: 1.000000e+00
 //       CHECK:   Pattern add(add(a, constant), a) matched
+
+func.func @test3(%a: f32) -> f32 {
+  %0 = "test.name"() {value = 1.0 : f32} : () -> f32
+  %1 = arith.addf %a, %0: f32
+  %2 = arith.mulf %a, %1 fastmath<fast>: f32
+  return %2: f32
+}
+
+// CHECK-LABEL: test3
+//       CHECK:   Pattern mul(*, add(*, m_Op("test.name"))) matched
+//       CHECK:   Pattern m_Attr("fastmath") matched and bound value to: fast
index 4f87517..d075d7a 100644 (file)
@@ -148,6 +148,21 @@ void test2(FunctionOpInterface f) {
     llvm::outs() << "Pattern add(add(a, constant), a) matched\n";
 }
 
+void test3(FunctionOpInterface f) {
+  arith::FastMathFlagsAttr fastMathAttr;
+  auto p = m_Op<arith::MulFOp>(m_Any(),
+                               m_Op<arith::AddFOp>(m_Any(), m_Op("test.name")));
+  auto p1 = m_Attr("fastmath", &fastMathAttr);
+
+  // Last operation that is not the terminator.
+  Operation *lastOp = f.getFunctionBody().front().back().getPrevNode();
+  if (p.match(lastOp))
+    llvm::outs() << "Pattern mul(*, add(*, m_Op(\"test.name\"))) matched\n";
+  if (p1.match(lastOp))
+    llvm::outs() << "Pattern m_Attr(\"fastmath\") matched and bound value to: "
+                 << fastMathAttr.getValue() << "\n";
+}
+
 void TestMatchers::runOnOperation() {
   auto f = getOperation();
   llvm::outs() << f.getName() << "\n";
@@ -155,6 +170,8 @@ void TestMatchers::runOnOperation() {
     test1(f);
   if (f.getName() == "test2")
     test2(f);
+  if (f.getName() == "test3")
+    test3(f);
 }
 
 namespace mlir {