From 83f5669cee176a1426a2104401990b541bc4720b Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 24 May 2019 18:49:45 -0700 Subject: [PATCH] Update the m_Constant matcher to enable matching derived attribute types. -- PiperOrigin-RevId: 249933184 --- mlir/include/mlir/IR/Matchers.h | 19 +++++++++++-------- mlir/lib/StandardOps/Ops.cpp | 11 +++-------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index 3e70d7c..d1c2f94 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -60,12 +60,12 @@ struct attr_value_binder { /// The matcher that matches a constant foldable operation that has no side /// effect, no operands and produces a single result. -struct constant_op_binder { - Attribute *bind_value; +template struct constant_op_binder { + AttrT *bind_value; /// Creates a matcher instance that binds the constant attribute value to /// bind_value if match succeeds. - constant_op_binder(Attribute *bind_value) : bind_value(bind_value) {} + constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {} bool match(Operation *op) { if (op->getNumOperands() > 0 || op->getNumResults() != 1) @@ -75,8 +75,10 @@ struct constant_op_binder { SmallVector foldedOp; if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) { - *bind_value = foldedOp.front().dyn_cast(); - return true; + if (auto attr = foldedOp.front().dyn_cast()) { + if ((*bind_value = attr.dyn_cast())) + return true; + } } return false; } @@ -92,7 +94,7 @@ struct constant_int_op_binder { bool match(Operation *op) { Attribute attr; - if (!constant_op_binder(&attr).match(op)) + if (!constant_op_binder(&attr).match(op)) return false; auto type = op->getResult(0)->getType(); @@ -150,8 +152,9 @@ m_ConstantInt(IntegerAttr::ValueType *bind_value) { /// Matches a value from a constant foldable operation and writes the value to /// bind_value. -inline detail::constant_op_binder m_Constant(Attribute *bind_value) { - return detail::constant_op_binder(bind_value); +template +inline detail::constant_op_binder m_Constant(AttrT *bind_value) { + return detail::constant_op_binder(bind_value); } /// Matches a constant scalar / vector splat / tensor splat integer one. diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 29f9b19..508ebfe 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -492,14 +492,9 @@ struct SimplifyIndirectCallWithKnownCallee : public RewritePattern { PatternRewriter &rewriter) const override { auto indirectCall = cast(op); - // Check that the callee is a constant operation. - Attribute callee; - if (!matchPattern(indirectCall.getCallee(), m_Constant(&callee))) - return matchFailure(); - - // Check that the constant callee is a function. - FunctionAttr calledFn = callee.dyn_cast(); - if (!calledFn) + // Check that the callee is a constant callee. + FunctionAttr calledFn; + if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) return matchFailure(); // Replace with a direct call. -- 2.7.4