[MLIR][LLVM] Support inlining LLVM::CallOp to LLVM::FuncOp.
authorJohannes de Fine Licht <johannes.definelicht@nextsilicon.com>
Fri, 20 Jan 2023 15:13:48 +0000 (16:13 +0100)
committerChristian Ulmann <christian.ulmann@nextsilicon.com>
Fri, 20 Jan 2023 15:26:33 +0000 (16:26 +0100)
Extend `LLVMInlinerInterface` to handle calls from an `LLVM::CallOp` to
an `LLVM::FuncOp` when there are no attributes present that require
special handling.

Depends on D141676

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D141682

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Dialect/LLVMIR/inlining.mlir

index 955eb91..4f301ff 100644 (file)
@@ -2852,6 +2852,22 @@ namespace {
 struct LLVMInlinerInterface : public DialectInlinerInterface {
   using DialectInlinerInterface::DialectInlinerInterface;
 
+  bool isLegalToInline(Operation *call, Operation *callable,
+                       bool wouldBeCloned) const final {
+    if (!wouldBeCloned)
+      return false;
+    auto callOp = dyn_cast<LLVM::CallOp>(call);
+    auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable);
+    if (!callOp || !funcOp)
+      return false;
+    return isLegalToInlineCallAttributes(callOp) &&
+           isLegalToInlineFuncAttributes(funcOp);
+  }
+
+  bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
+    return true;
+  }
+
   /// Conservative allowlist-based inlining of operations supported so far.
   bool isLegalToInline(Operation *op, Region *, bool, IRMapping &) const final {
     if (isPure(op))
@@ -2869,22 +2885,84 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
             return false;
           return true;
         })
+        .Case<LLVM::CallOp>([](auto) { return true; })
         .Default([](auto) { return false; });
   }
-  /// Handle the given inlined terminator by replacing it with a new operation
-  /// as necessary. Required when the region has only one block.
-  void handleTerminator(Operation *op,
-                        ArrayRef<Value> valuesToRepl) const final {
 
-    // Only handle "llvm.return" here.
-    auto returnOp = dyn_cast<ReturnOp>(op);
+  /// Handle the given inlined return by replacing it with a branch. This
+  /// overload is called when the inlined region has more than one block.
+  void handleTerminator(Operation *op, Block *newDest) const final {
+    // Only return needs to be handled here.
+    auto returnOp = dyn_cast<LLVM::ReturnOp>(op);
     if (!returnOp)
       return;
 
+    // Replace the return with a branch to the dest.
+    OpBuilder builder(op);
+    builder.create<LLVM::BrOp>(op->getLoc(), returnOp.getOperands(), newDest);
+    op->erase();
+  }
+
+  /// Handle the given inlined return by replacing the uses of the call with the
+  /// operands of the return. This overload is called when the inlined region
+  /// only contains one block.
+  void handleTerminator(Operation *op,
+                        ArrayRef<Value> valuesToRepl) const final {
+    // Return will be the only terminator present.
+    auto returnOp = cast<LLVM::ReturnOp>(op);
+
     // Replace the values directly with the return operands.
     assert(returnOp.getNumOperands() == valuesToRepl.size());
-    for (const auto &it : llvm::enumerate(returnOp.getOperands()))
-      valuesToRepl[it.index()].replaceAllUsesWith(it.value());
+    for (const auto &[dst, src] :
+         llvm::zip(valuesToRepl, returnOp.getOperands()))
+      dst.replaceAllUsesWith(src);
+  }
+
+private:
+  /// Returns true if all attributes of `callOp` are handled during inlining.
+  [[nodiscard]] static bool isLegalToInlineCallAttributes(LLVM::CallOp callOp) {
+    return all_of(callOp.getAttributeNames(), [&](StringRef attrName) {
+      return llvm::StringSwitch<bool>(attrName)
+          // TODO: Propagate and update branch weights.
+          .Case("branch_weights", !callOp.getBranchWeights())
+          .Case("callee", true)
+          .Case("fastmathFlags", true)
+          .Default(false);
+    });
+  }
+
+  /// Returns true if all attributes of `funcOp` are handled during inlining.
+  [[nodiscard]] static bool
+  isLegalToInlineFuncAttributes(LLVM::LLVMFuncOp funcOp) {
+    return all_of(funcOp.getAttributeNames(), [&](StringRef attrName) {
+      return llvm::StringSwitch<bool>(attrName)
+          .Case("CConv", true)
+          .Case("arg_attrs", ([&]() {
+                  if (!funcOp.getArgAttrs())
+                    return true;
+                  return llvm::all_of(funcOp.getArgAttrs().value(),
+                                      [&](Attribute) {
+                                        // TODO: Handle argument attributes.
+                                        return false;
+                                      });
+                })())
+          .Case("dso_local", true)
+          .Case("function_entry_count", true)
+          .Case("function_type", true)
+          // TODO: Once the garbage collector attribute is supported on
+          // LLVM::CallOp, make sure that the garbage collector matches.
+          .Case("garbageCollector", !funcOp.getGarbageCollector())
+          .Case("linkage", true)
+          .Case("memory", true)
+          .Case("passthrough", !funcOp.getPassthrough())
+          // Exception handling is not yet supported, so bail out if the
+          // personality is set.
+          .Case("personality", !funcOp.getPersonality())
+          // TODO: Handle result attributes.
+          .Case("res_attrs", !funcOp.getResAttrs())
+          .Case("sym_name", true)
+          .Default(false);
+    });
   }
 };
 } // end anonymous namespace
index ce2bf69..64d1e55 100644 (file)
@@ -41,7 +41,7 @@ llvm.metadata @metadata {
   llvm.return
 }
 
-func.func private @with_mem_attr(%ptr : !llvm.ptr) -> () {
+func.func private @with_mem_attr(%ptr : !llvm.ptr) {
   %0 = llvm.mlir.constant(42 : i32) : i32
   // Do not inline load/store operations that carry attributes requiring
   // handling while inlining, until this is supported by the inliner.
@@ -52,7 +52,7 @@ func.func private @with_mem_attr(%ptr : !llvm.ptr) -> () {
 // CHECK-LABEL: func.func @test_not_inline
 // CHECK-NEXT: call @with_mem_attr
 // CHECK-NEXT: return
-func.func @test_not_inline(%ptr : !llvm.ptr) -> () {
+func.func @test_not_inline(%ptr : !llvm.ptr) {
   call @with_mem_attr(%ptr) : (!llvm.ptr) -> ()
   return
 }
@@ -70,3 +70,136 @@ func.func @llvm_ret(%arg0 : i32) -> i32 {
   %res = call @func(%arg0) : (i32) -> (i32)
   return %res : i32
 }
+
+// -----
+
+// Include all function attributes that don't prevent inlining
+llvm.func internal fastcc @callee() -> (i32) attributes { function_entry_count = 42 : i64, dso_local } {
+  %0 = llvm.mlir.constant(42 : i32) : i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @caller
+// CHECK-NEXT: %[[CST:.+]] = llvm.mlir.constant
+// CHECK-NEXT: llvm.return %[[CST]]
+llvm.func @caller() -> (i32) {
+  // Include all call attributes that don't prevent inlining.
+  %0 = llvm.call @callee() { fastmathFlags = #llvm.fastmath<nnan, ninf> } : () -> (i32)
+  llvm.return %0 : i32
+}
+
+// -----
+
+llvm.func @foo() -> (i32) attributes { passthrough = ["noinline"] } {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  llvm.return %0 : i32
+}
+
+llvm.func @bar() -> (i32) attributes { passthrough = ["noinline"] } {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  llvm.return %0 : i32
+}
+
+llvm.func @callee_with_multiple_blocks(%cond: i1) -> (i32) {
+  llvm.cond_br %cond, ^bb1, ^bb2
+^bb1:
+  %0 = llvm.call @foo() : () -> (i32)
+  llvm.br ^bb3(%0: i32)
+^bb2:
+  %1 = llvm.call @bar() : () -> (i32)
+  llvm.br ^bb3(%1: i32)
+^bb3(%arg: i32):
+  llvm.return %arg : i32
+}
+
+// CHECK-LABEL: llvm.func @caller
+// CHECK-NEXT: llvm.cond_br {{.+}}, ^[[BB1:.+]], ^[[BB2:.+]]
+// CHECK-NEXT: ^[[BB1]]:
+// CHECK-NEXT: llvm.call @foo
+// CHECK-NEXT: llvm.br ^[[BB3:[a-zA-Z0-9_]+]]
+// CHECK-NEXT: ^[[BB2]]:
+// CHECK-NEXT: llvm.call @bar
+// CHECK-NEXT: llvm.br ^[[BB3]]
+// CHECK-NEXT: ^[[BB3]]
+// CHECK-NEXT: llvm.br ^[[BB4:[a-zA-Z0-9_]+]]
+// CHECK-NEXT: ^[[BB4]]
+// CHECK-NEXT: llvm.return
+llvm.func @caller(%cond: i1) -> (i32) {
+  %0 = llvm.call @callee_with_multiple_blocks(%cond) : (i1) -> (i32)
+  llvm.return %0 : i32
+}
+
+// -----
+
+llvm.func @personality() -> i32
+
+llvm.func @callee() -> (i32) attributes { personality = @personality } {
+  %0 = llvm.mlir.constant(42 : i32) : i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @caller
+// CHECK-NEXT: llvm.call @callee
+// CHECK-NEXT: return
+llvm.func @caller() -> (i32) {
+  %0 = llvm.call @callee() : () -> (i32)
+  llvm.return %0 : i32
+}
+
+// -----
+
+llvm.func @callee() -> (i32) attributes { passthrough = ["foo"] } {
+  %0 = llvm.mlir.constant(42 : i32) : i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @caller
+// CHECK-NEXT: llvm.call @callee
+// CHECK-NEXT: return
+llvm.func @caller() -> (i32) {
+  %0 = llvm.call @callee() : () -> (i32)
+  llvm.return %0 : i32
+}
+
+// -----
+
+llvm.func @callee() -> (i32) attributes { garbageCollector = "foo" } {
+  %0 = llvm.mlir.constant(42 : i32) : i32
+  llvm.return %0 : i32
+}
+
+// CHECK-LABEL: llvm.func @caller
+// CHECK-NEXT: llvm.call @callee
+// CHECK-NEXT: return
+llvm.func @caller() -> (i32) {
+  %0 = llvm.call @callee() : () -> (i32)
+  llvm.return %0 : i32
+}
+
+// -----
+
+llvm.func @callee(%ptr : !llvm.ptr {llvm.byval = !llvm.ptr}) -> (!llvm.ptr) {
+  llvm.return %ptr : !llvm.ptr
+}
+
+// CHECK-LABEL: llvm.func @caller
+// CHECK-NEXT: llvm.call @callee
+// CHECK-NEXT: return
+llvm.func @caller(%ptr : !llvm.ptr) -> (!llvm.ptr) {
+  %0 = llvm.call @callee(%ptr) : (!llvm.ptr) -> (!llvm.ptr)
+  llvm.return %0 : !llvm.ptr
+}
+
+// -----
+
+llvm.func @callee() {
+  llvm.return
+}
+
+// CHECK-LABEL: llvm.func @caller
+// CHECK-NEXT: llvm.call @callee
+// CHECK-NEXT: llvm.return
+llvm.func @caller() {
+  llvm.call @callee() { branch_weights = dense<42> : vector<1xi32> } : () -> ()
+  llvm.return
+}