return getAttrOfType<FunctionAttr>("callee").getValue();
}
+ /// Get the argument operands to the called function.
+ llvm::iterator_range<const_operand_iterator> getArgOperands() const {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+ llvm::iterator_range<operand_iterator> getArgOperands() {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+
+ const_operand_iterator arg_operand_begin() const { return operand_begin(); }
+ const_operand_iterator arg_operand_end() const { return operand_end(); }
+
+ operand_iterator arg_operand_begin() { return operand_begin(); }
+ operand_iterator arg_operand_end() { return operand_end(); }
+
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
const Value *getCallee() const { return getOperand(0); }
Value *getCallee() { return getOperand(0); }
+ /// Get the argument operands to the called function.
+ llvm::iterator_range<const_operand_iterator> getArgOperands() const {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+ llvm::iterator_range<operand_iterator> getArgOperands() {
+ return {arg_operand_begin(), arg_operand_end()};
+ }
+
+ const_operand_iterator arg_operand_begin() const { return ++operand_begin(); }
+ const_operand_iterator arg_operand_end() const { return operand_end(); }
+
+ operand_iterator arg_operand_begin() { return ++operand_begin(); }
+ operand_iterator arg_operand_end() { return operand_end(); }
+
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;
+ static void getCanonicalizationPatterns(OwningRewritePatternList &results,
+ MLIRContext *context);
protected:
friend class OperationInst;
//===----------------------------------------------------------------------===//
// CallIndirectOp
//===----------------------------------------------------------------------===//
+namespace {
+/// Fold indirect calls that have a constant function as the callee operand.
+struct SimplifyIndirectCallWithKnownCallee : public RewritePattern {
+ SimplifyIndirectCallWithKnownCallee(MLIRContext *context)
+ : RewritePattern(CallIndirectOp::getOperationName(), 1, context) {}
+
+ PatternMatchResult match(OperationInst *op) const override {
+ auto indirectCall = op->cast<CallIndirectOp>();
+
+ // Check that the callee is a constant operation.
+ Value *callee = indirectCall->getCallee();
+ OperationInst *calleeInst = callee->getDefiningInst();
+ if (!calleeInst || !calleeInst->isa<ConstantOp>())
+ return matchFailure();
+
+ // Check that the constant callee is a function.
+ if (calleeInst->cast<ConstantOp>()->getValue().isa<FunctionAttr>())
+ return matchSuccess();
+ return matchFailure();
+ }
+ void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
+ auto indirectCall = op->cast<CallIndirectOp>();
+ auto calleeOp =
+ indirectCall->getCallee()->getDefiningInst()->cast<ConstantOp>();
+
+ // Replace with a direct call.
+ Function *calledFn = calleeOp->getValue().cast<FunctionAttr>().getValue();
+ SmallVector<Value *, 8> callOperands(indirectCall->getArgOperands());
+ rewriter.replaceOpWithNewOp<CallOp>(op, calledFn, callOperands);
+ }
+};
+} // end anonymous namespace.
void CallIndirectOp::build(Builder *builder, OperationState *result,
Value *callee, ArrayRef<Value *> operands) {
return false;
}
+void CallIndirectOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.push_back(
+ std::make_unique<SimplifyIndirectCallWithKnownCallee>(context));
+}
+
+//===----------------------------------------------------------------------===//
+// CmpIOp
+//===----------------------------------------------------------------------===//
+
// Return the type of the same shape (scalar, vector or tensor) containing i1.
static Type getCheckedI1SameShape(Builder *build, Type type) {
auto i1Type = build->getI1Type();