[mlir] m_Constant()
authorLorenzo Chelini <l.chelini@icloud.com>
Mon, 13 Jan 2020 16:21:04 +0000 (17:21 +0100)
committerAlex Zinenko <zinenko@google.com>
Mon, 13 Jan 2020 16:22:01 +0000 (17:22 +0100)
Summary: Introduce m_Constant() which allows matching a constant operation without forcing the user also to capture the attribute value.

Differential Revision: https://reviews.llvm.org/D72397

mlir/include/mlir/IR/Matchers.h
mlir/lib/IR/Builders.cpp
mlir/test/IR/test-matchers.mlir
mlir/test/lib/IR/TestMatchers.cpp

index d8d3308..170984b 100644 (file)
@@ -56,6 +56,8 @@ template <typename AttrT> struct constant_op_binder {
   /// Creates a matcher instance that binds the constant attribute value to
   /// bind_value if match succeeds.
   constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {}
+  /// Creates a matcher instance that doesn't bind if match succeeds.
+  constant_op_binder() : bind_value(nullptr) {}
 
   bool match(Operation *op) {
     if (op->getNumOperands() > 0 || op->getNumResults() != 1)
@@ -66,8 +68,11 @@ template <typename AttrT> struct constant_op_binder {
     SmallVector<OpFoldResult, 1> foldedOp;
     if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
       if (auto attr = foldedOp.front().dyn_cast<Attribute>()) {
-        if ((*bind_value = attr.dyn_cast<AttrT>()))
+        if (auto attrT = attr.dyn_cast<AttrT>()) {
+          if (bind_value)
+            *bind_value = attrT;
           return true;
+        }
       }
     }
     return false;
@@ -196,6 +201,11 @@ struct RecursivePatternMatcher {
 
 } // end namespace detail
 
+/// Matches a constant foldable operation.
+inline detail::constant_op_binder<Attribute> m_Constant() {
+  return detail::constant_op_binder<Attribute>();
+}
+
 /// Matches a value from a constant foldable operation and writes the value to
 /// bind_value.
 template <typename AttrT>
index 0c72abf..5066366 100644 (file)
@@ -342,8 +342,7 @@ LogicalResult OpBuilder::tryFold(Operation *op,
   };
 
   // If this operation is already a constant, there is nothing to do.
-  Attribute unused;
-  if (matchPattern(op, m_Constant(&unused)))
+  if (matchPattern(op, m_Constant()))
     return cleanupFailure();
 
   // Check to see if any operands to the operation is constant and whether
index 7808f25..60d5bcf 100644 (file)
@@ -40,3 +40,4 @@ func @test2(%a: f32) -> f32 {
 
 // CHECK-LABEL: test2
 //       CHECK:   Pattern add(add(a, constant), a) matched and bound constant to: 1.000000e+00
+//       CHECK:   Pattern add(add(a, constant), a) matched
index b62daa8..6061b25 100644 (file)
@@ -126,12 +126,15 @@ void test2(FuncOp f) {
   auto a = m_Val(f.getArgument(0));
   FloatAttr floatAttr;
   auto p = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant(&floatAttr)));
+  auto p1 = m_Op<MulFOp>(a, m_Op<AddFOp>(a, m_Constant()));
   // Last operation that is not the terminator.
   Operation *lastOp = f.getBody().front().back().getPrevNode();
   if (p.match(lastOp))
     llvm::outs()
         << "Pattern add(add(a, constant), a) matched and bound constant to: "
         << floatAttr.getValueAsDouble() << "\n";
+  if (p1.match(lastOp))
+    llvm::outs() << "Pattern add(add(a, constant), a) matched\n";
 }
 
 void TestMatchers::runOnFunction() {