/// The matcher that matches a constant foldable operation that has no side
/// effect, no operands and produces a single result.
-struct constant_op_binder {
- Attribute *bind_value;
+template <typename AttrT> struct constant_op_binder {
+ AttrT *bind_value;
/// Creates a matcher instance that binds the constant attribute value to
/// bind_value if match succeeds.
- constant_op_binder(Attribute *bind_value) : bind_value(bind_value) {}
+ constant_op_binder(AttrT *bind_value) : bind_value(bind_value) {}
bool match(Operation *op) {
if (op->getNumOperands() > 0 || op->getNumResults() != 1)
SmallVector<OpFoldResult, 1> foldedOp;
if (succeeded(op->fold(/*operands=*/llvm::None, foldedOp))) {
- *bind_value = foldedOp.front().dyn_cast<Attribute>();
- return true;
+ if (auto attr = foldedOp.front().dyn_cast<Attribute>()) {
+ if ((*bind_value = attr.dyn_cast<AttrT>()))
+ return true;
+ }
}
return false;
}
bool match(Operation *op) {
Attribute attr;
- if (!constant_op_binder(&attr).match(op))
+ if (!constant_op_binder<Attribute>(&attr).match(op))
return false;
auto type = op->getResult(0)->getType();
/// Matches a value from a constant foldable operation and writes the value to
/// bind_value.
-inline detail::constant_op_binder m_Constant(Attribute *bind_value) {
- return detail::constant_op_binder(bind_value);
+template <typename AttrT>
+inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) {
+ return detail::constant_op_binder<AttrT>(bind_value);
}
/// Matches a constant scalar / vector splat / tensor splat integer one.
PatternRewriter &rewriter) const override {
auto indirectCall = cast<CallIndirectOp>(op);
- // Check that the callee is a constant operation.
- Attribute callee;
- if (!matchPattern(indirectCall.getCallee(), m_Constant(&callee)))
- return matchFailure();
-
- // Check that the constant callee is a function.
- FunctionAttr calledFn = callee.dyn_cast<FunctionAttr>();
- if (!calledFn)
+ // Check that the callee is a constant callee.
+ FunctionAttr calledFn;
+ if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
return matchFailure();
// Replace with a direct call.