Fold CallIndirectOp to CallOp when the callee operand is a known constant function.
authorRiver Riddle <riverriddle@google.com>
Wed, 30 Jan 2019 02:08:28 +0000 (18:08 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 23:01:23 +0000 (16:01 -0700)
PiperOrigin-RevId: 231511697

mlir/include/mlir/StandardOps/StandardOps.h
mlir/lib/StandardOps/StandardOps.cpp
mlir/test/Transforms/canonicalize.mlir

index 4ef0081fda2c3fb61d5636aaa36a31fb252c2b3a..5dbf90e60e9be08aeffd4449ae9838fbf62d35cf 100644 (file)
@@ -101,6 +101,20 @@ public:
     return getAttrOfType<FunctionAttr>("callee").getValue();
   }
 
+  /// Get the argument operands to the called function.
+  llvm::iterator_range<const_operand_iterator> getArgOperands() const {
+    return {arg_operand_begin(), arg_operand_end()};
+  }
+  llvm::iterator_range<operand_iterator> getArgOperands() {
+    return {arg_operand_begin(), arg_operand_end()};
+  }
+
+  const_operand_iterator arg_operand_begin() const { return operand_begin(); }
+  const_operand_iterator arg_operand_end() const { return operand_end(); }
+
+  operand_iterator arg_operand_begin() { return operand_begin(); }
+  operand_iterator arg_operand_end() { return operand_end(); }
+
   // Hooks to customize behavior of this op.
   static bool parse(OpAsmParser *parser, OperationState *result);
   void print(OpAsmPrinter *p) const;
@@ -130,10 +144,26 @@ public:
   const Value *getCallee() const { return getOperand(0); }
   Value *getCallee() { return getOperand(0); }
 
+  /// Get the argument operands to the called function.
+  llvm::iterator_range<const_operand_iterator> getArgOperands() const {
+    return {arg_operand_begin(), arg_operand_end()};
+  }
+  llvm::iterator_range<operand_iterator> getArgOperands() {
+    return {arg_operand_begin(), arg_operand_end()};
+  }
+
+  const_operand_iterator arg_operand_begin() const { return ++operand_begin(); }
+  const_operand_iterator arg_operand_end() const { return operand_end(); }
+
+  operand_iterator arg_operand_begin() { return ++operand_begin(); }
+  operand_iterator arg_operand_end() { return operand_end(); }
+
   // Hooks to customize behavior of this op.
   static bool parse(OpAsmParser *parser, OperationState *result);
   void print(OpAsmPrinter *p) const;
   bool verify() const;
+  static void getCanonicalizationPatterns(OwningRewritePatternList &results,
+                                          MLIRContext *context);
 
 protected:
   friend class OperationInst;
index b1490b59a39634716824cf7f0ffe4265823c49c6..5bad8e9ec90fa4d424e68a9ad6be04152e80b542 100644 (file)
@@ -382,6 +382,38 @@ bool CallOp::verify() const {
 //===----------------------------------------------------------------------===//
 // CallIndirectOp
 //===----------------------------------------------------------------------===//
+namespace {
+/// Fold indirect calls that have a constant function as the callee operand.
+struct SimplifyIndirectCallWithKnownCallee : public RewritePattern {
+  SimplifyIndirectCallWithKnownCallee(MLIRContext *context)
+      : RewritePattern(CallIndirectOp::getOperationName(), 1, context) {}
+
+  PatternMatchResult match(OperationInst *op) const override {
+    auto indirectCall = op->cast<CallIndirectOp>();
+
+    // Check that the callee is a constant operation.
+    Value *callee = indirectCall->getCallee();
+    OperationInst *calleeInst = callee->getDefiningInst();
+    if (!calleeInst || !calleeInst->isa<ConstantOp>())
+      return matchFailure();
+
+    // Check that the constant callee is a function.
+    if (calleeInst->cast<ConstantOp>()->getValue().isa<FunctionAttr>())
+      return matchSuccess();
+    return matchFailure();
+  }
+  void rewrite(OperationInst *op, PatternRewriter &rewriter) const override {
+    auto indirectCall = op->cast<CallIndirectOp>();
+    auto calleeOp =
+        indirectCall->getCallee()->getDefiningInst()->cast<ConstantOp>();
+
+    // Replace with a direct call.
+    Function *calledFn = calleeOp->getValue().cast<FunctionAttr>().getValue();
+    SmallVector<Value *, 8> callOperands(indirectCall->getArgOperands());
+    rewriter.replaceOpWithNewOp<CallOp>(op, calledFn, callOperands);
+  }
+};
+} // end anonymous namespace.
 
 void CallIndirectOp::build(Builder *builder, OperationState *result,
                            Value *callee, ArrayRef<Value *> operands) {
@@ -445,6 +477,16 @@ bool CallIndirectOp::verify() const {
   return false;
 }
 
+void CallIndirectOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.push_back(
+      std::make_unique<SimplifyIndirectCallWithKnownCallee>(context));
+}
+
+//===----------------------------------------------------------------------===//
+// CmpIOp
+//===----------------------------------------------------------------------===//
+
 // Return the type of the same shape (scalar, vector or tensor) containing i1.
 static Type getCheckedI1SameShape(Builder *build, Type type) {
   auto i1Type = build->getI1Type();
index 0c3a52e86082283c47a31b514e61ea846901b865..7df61cda44d3b9bb309bb2a44071e346aded3e9f 100644 (file)
@@ -323,3 +323,16 @@ func @cond_br_folding(%a : i32) {
 ^bb2:
   return
 }
+
+// CHECK-LABEL: func @indirect_call_folding
+func @indirect_target() {
+  return
+}
+
+func @indirect_call_folding() {
+  // CHECK-NEXT: call @indirect_target() : () -> ()
+  // CHECK-NEXT: return
+  %indirect_fn = constant @indirect_target : () -> ()
+  call_indirect %indirect_fn() : () -> ()
+  return
+}