[DAE] Don't DAE if we musttail call a "live" (non-DAE-able) function
authorMircea Trofin <mtrofin@google.com>
Fri, 3 Mar 2023 02:45:34 +0000 (18:45 -0800)
committerMircea Trofin <mtrofin@google.com>
Thu, 16 Mar 2023 20:36:11 +0000 (13:36 -0700)
There are 2 such base cases: indirect calls and calls to functions external
to the module; and then any musttail calls to live functions (because of
the first 2 reasons or otherwise).

The IR validator reports, in these cases, that it "cannot guarantee tail
call due to mismatched parameter counts".

The fix is two-fold: first, we mark as "live" (i.e. non-DAE-able)
functions that make an indirect musttail call.

Then, we propagate live-ness to musttail callers of live functions.

Declared functions are already marked "live".

Differential Revision: https://reviews.llvm.org/D145209

llvm/include/llvm/Transforms/IPO/DeadArgumentElimination.h
llvm/lib/Transforms/IPO/DeadArgumentElimination.cpp
llvm/test/Transforms/DeadArgElim/musttail-indirect.ll [new file with mode: 0644]

index a71fa3b..63e1ad0 100644 (file)
@@ -136,6 +136,7 @@ private:
   bool removeDeadStuffFromFunction(Function *F);
   bool deleteDeadVarargs(Function &F);
   bool removeDeadArgumentsFromCallers(Function &F);
+  void propagateVirtMustcallLiveness(const Module &M);
 };
 
 } // end namespace llvm
index bf2c65a..d6dc0f9 100644 (file)
@@ -85,6 +85,11 @@ public:
   virtual bool shouldHackArguments() const { return false; }
 };
 
+bool isMustTailCalleeAnalyzable(const CallBase &CB) {
+  assert(CB.isMustTailCall());
+  return CB.getCalledFunction() && !CB.getCalledFunction()->isDeclaration();
+}
+
 } // end anonymous namespace
 
 char DAE::ID = 0;
@@ -520,8 +525,16 @@ void DeadArgumentEliminationPass::surveyFunction(const Function &F) {
   for (const BasicBlock &BB : F) {
     // If we have any returns of `musttail` results - the signature can't
     // change
-    if (BB.getTerminatingMustTailCall() != nullptr)
+    if (const auto *TC = BB.getTerminatingMustTailCall()) {
       HasMustTailCalls = true;
+      // In addition, if the called function is not locally defined (or unknown,
+      // if this is an indirect call), we can't change the callsite and thus
+      // can't change this function's signature either.
+      if (!isMustTailCalleeAnalyzable(*TC)) {
+        markLive(F);
+        return;
+      }
+    }
   }
 
   if (HasMustTailCalls) {
@@ -1081,6 +1094,26 @@ bool DeadArgumentEliminationPass::removeDeadStuffFromFunction(Function *F) {
   return true;
 }
 
+void DeadArgumentEliminationPass::propagateVirtMustcallLiveness(
+    const Module &M) {
+  // If a function was marked "live", and it has musttail callers, they in turn
+  // can't change either.
+  LiveFuncSet NewLiveFuncs(LiveFunctions);
+  while (!NewLiveFuncs.empty()) {
+    LiveFuncSet Temp;
+    for (const auto *F : NewLiveFuncs)
+      for (const auto *U : F->users())
+        if (const auto *CB = dyn_cast<CallBase>(U))
+          if (CB->isMustTailCall())
+            if (!LiveFunctions.count(CB->getParent()->getParent()))
+              Temp.insert(CB->getParent()->getParent());
+    NewLiveFuncs.clear();
+    NewLiveFuncs.insert(Temp.begin(), Temp.end());
+    for (const auto *F : Temp)
+      markLive(*F);
+  }
+}
+
 PreservedAnalyses DeadArgumentEliminationPass::run(Module &M,
                                                    ModuleAnalysisManager &) {
   bool Changed = false;
@@ -1101,6 +1134,8 @@ PreservedAnalyses DeadArgumentEliminationPass::run(Module &M,
   for (auto &F : M)
     surveyFunction(F);
 
+  propagateVirtMustcallLiveness(M);
+
   // Now, remove all dead arguments and return values from each function in
   // turn.  We use make_early_inc_range here because functions will probably get
   // removed (i.e. replaced by new ones).
diff --git a/llvm/test/Transforms/DeadArgElim/musttail-indirect.ll b/llvm/test/Transforms/DeadArgElim/musttail-indirect.ll
new file mode 100644 (file)
index 0000000..cdcccb4
--- /dev/null
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: -p --function-signature
+; RUN: opt -passes=deadargelim -S < %s | FileCheck %s
+
+define internal i32 @test_caller(ptr %fptr, i32 %a, i32 %b) {
+; CHECK-LABEL: define {{[^@]+}}@test_caller(ptr %fptr, i32 %a, i32 %b) {
+; CHECK-NEXT:    %r = musttail call i32 @test(ptr %fptr, i32 %a, i32 poison)
+; CHECK-NEXT:    ret i32 %r
+;
+  %r = musttail call i32 @test(ptr %fptr, i32 %a, i32 %b)
+  ret i32 %r
+}
+
+define internal i32 @test(ptr %fptr, i32 %a, i32 %b) {
+; CHECK-LABEL: define {{[^@]+}}@test(ptr %fptr, i32 %a, i32 %b) {
+; CHECK-NEXT:    %r = musttail call i32 %fptr(ptr %fptr, i32 %a, i32 0)
+; CHECK-NEXT:    ret i32 %r
+;
+  %r = musttail call i32 %fptr(ptr %fptr, i32 %a, i32 0)
+  ret i32 %r
+}
+
+define internal i32 @direct_test() {
+; CHECK-LABEL: define {{[^@]+}}@direct_test() {
+; CHECK-NEXT:    %r = musttail call i32 @foo()
+; CHECK-NEXT:    ret i32 %r
+;
+  %r = musttail call i32 @foo()
+  ret i32 %r
+}
+
+declare i32 @foo()
+
+define internal i32 @ping(i32 %x) {
+; CHECK-LABEL: define {{[^@]+}}@ping(i32 %x) {
+; CHECK-NEXT:    %r = musttail call i32 @pong(i32 %x)
+; CHECK-NEXT:    ret i32 %r
+;
+  %r = musttail call i32 @pong(i32 %x)
+  ret i32 %r
+}
+
+define internal i32 @pong(i32 %x) {
+; CHECK-LABEL: define {{[^@]+}}@pong(i32 %x) {
+; CHECK-NEXT:    %cond = icmp eq i32 %x, 2
+; CHECK-NEXT:    br i1 %cond, label %yes, label %no
+; CHECK:       yes:
+; CHECK-NEXT:    %r1 = musttail call i32 @ping(i32 %x)
+; CHECK-NEXT:    ret i32 %r1
+; CHECK:       no:
+; CHECK-NEXT:    %r2 = musttail call i32 @bar(i32 %x)
+; CHECK-NEXT:    ret i32 %r2
+;
+  %cond = icmp eq i32 %x, 2
+  br i1 %cond, label %yes, label %no
+
+yes:
+  %r1 = musttail call i32 @ping(i32 %x)
+  ret i32 %r1
+no:
+  %r2 = musttail call i32 @bar(i32 %x)
+  ret i32 %r2
+}
+
+declare i32 @bar(i32 %x)
+