Update the m_Constant matcher to enable matching derived attribute types.
authorRiver Riddle <riverriddle@google.com>
Sat, 25 May 2019 01:49:45 +0000 (18:49 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 03:02:12 +0000 (20:02 -0700)
--

PiperOrigin-RevId: 249933184

mlir/include/mlir/IR/Matchers.h
mlir/lib/StandardOps/Ops.cpp

index 3e70d7c..d1c2f94 100644 (file)
@@ -60,12 +60,12 @@ struct attr_value_binder {
 
 /// The matcher that matches a constant foldable operation that has no side
 /// effect, no operands and produces a single result.
-struct constant_op_binder {
-  Attribute *bind_value;
+template <typename AttrT> struct constant_op_binder {
+  AttrT *bind_value;
 
   /// Creates a matcher instance that binds the constant attribute value to
   /// bind_value if match succeeds.
-  constant_op_binder(Attribute *bind_value) : bind_value(bind_value) {}
+  constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {}
 
   bool match(Operation *op) {
     if (op->getNumOperands() > 0 || op->getNumResults() != 1)
@@ -75,8 +75,10 @@ struct constant_op_binder {
 
     SmallVector<OpFoldResult, 1> foldedOp;
     if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
-      *bind_value = foldedOp.front().dyn_cast<Attribute>();
-      return true;
+      if (auto attr = foldedOp.front().dyn_cast<Attribute>()) {
+        if ((*bind_value = attr.dyn_cast<AttrT>()))
+          return true;
+      }
     }
     return false;
   }
@@ -92,7 +94,7 @@ struct constant_int_op_binder {
 
   bool match(Operation *op) {
     Attribute attr;
-    if (!constant_op_binder(&attr).match(op))
+    if (!constant_op_binder<Attribute>(&attr).match(op))
       return false;
     auto type = op->getResult(0)->getType();
 
@@ -150,8 +152,9 @@ m_ConstantInt(IntegerAttr::ValueType *bind_value) {
 
 /// Matches a value from a constant foldable operation and writes the value to
 /// bind_value.
-inline detail::constant_op_binder m_Constant(Attribute *bind_value) {
-  return detail::constant_op_binder(bind_value);
+template <typename AttrT>
+inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) {
+  return detail::constant_op_binder<AttrT>(bind_value);
 }
 
 /// Matches a constant scalar / vector splat / tensor splat integer one.
index 29f9b19..508ebfe 100644 (file)
@@ -492,14 +492,9 @@ struct SimplifyIndirectCallWithKnownCallee : public RewritePattern {
                                      PatternRewriter &rewriter) const override {
     auto indirectCall = cast<CallIndirectOp>(op);
 
-    // Check that the callee is a constant operation.
-    Attribute callee;
-    if (!matchPattern(indirectCall.getCallee(), m_Constant(&callee)))
-      return matchFailure();
-
-    // Check that the constant callee is a function.
-    FunctionAttr calledFn = callee.dyn_cast<FunctionAttr>();
-    if (!calledFn)
+    // Check that the callee is a constant callee.
+    FunctionAttr calledFn;
+    if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
       return matchFailure();
 
     // Replace with a direct call.