[COFF, ARM64] Make sure to forward arguments from vararg to musttail vararg
authorMandeep Singh Grang <mgrang@codeaurora.org>
Tue, 30 Oct 2018 20:46:10 +0000 (20:46 +0000)
committerMandeep Singh Grang <mgrang@codeaurora.org>
Tue, 30 Oct 2018 20:46:10 +0000 (20:46 +0000)
Summary:
    Thunk functions in Windows are varag functions that call a musttail function
    to pass the arguments after the fixup is done.  We need to make sure that we
    forward the arguments from the caller vararg to the callee vararg function.
    This is the same mechanism that is used for Windows on X86.

Reviewers: ssijaric, eli.friedman, TomTan, mgrang, mstorsjo, rnk, compnerd, efriedma

Reviewed By: efriedma

Subscribers: efriedma, kristof.beyls, chrib, javed.absar, llvm-commits

Differential Revision: https://reviews.llvm.org/D53843

llvm-svn: 345641

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
llvm/test/CodeGen/AArch64/vararg-tallcall.ll [new file with mode: 0644]

index 2a42d2d..3c10701 100644 (file)
@@ -3148,6 +3148,17 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
     // We currently pass all varargs at 8-byte alignment.
     StackOffset = ((StackOffset + 7) & ~7);
     FuncInfo->setVarArgsStackIndex(MFI.CreateFixedObject(4, StackOffset, true));
+
+    if (MFI.hasMustTailInVarArgFunc()) {
+      SmallVector<MVT, 2> RegParmTypes;
+      RegParmTypes.push_back(MVT::i64);
+      RegParmTypes.push_back(MVT::f128);
+      // Compute the set of forwarded registers. The rest are scratch.
+      SmallVectorImpl<ForwardedRegister> &Forwards =
+                                       FuncInfo->getForwardedMustTailRegParms();
+      CCInfo.analyzeMustTailForwardedRegisters(Forwards, RegParmTypes,
+                                               CC_AArch64_AAPCS);
+    }
   }
 
   unsigned StackArgSize = CCInfo.getNextStackOffset();
@@ -3608,6 +3619,14 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   SmallVector<SDValue, 8> MemOpChains;
   auto PtrVT = getPointerTy(DAG.getDataLayout());
 
+  if (IsVarArg && CLI.CS && CLI.CS.isMustTailCall()) {
+    const auto &Forwards = FuncInfo->getForwardedMustTailRegParms();
+    for (const auto &F : Forwards) {
+      SDValue Val = DAG.getCopyFromReg(Chain, DL, F.VReg, F.VT);
+       RegsToPass.push_back(std::make_pair(unsigned(F.PReg), Val));
+    }
+  }
+
   // Walk the register/memloc assignments, inserting copies/loads.
   for (unsigned i = 0, realArgIdx = 0, e = ArgLocs.size(); i != e;
        ++i, ++realArgIdx) {
index 63c0ba2..5183e7d 100644 (file)
@@ -18,6 +18,7 @@
 #include "llvm/ADT/Optional.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/CodeGen/CallingConvLower.h"
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/MC/MCLinkerOptimizationHint.h"
 #include <cassert>
@@ -97,6 +98,9 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
   /// attribute, in which case it is set to false at construction.
   Optional<bool> HasRedZone;
 
+  /// ForwardedMustTailRegParms - A list of virtual and physical registers
+  /// that must be forwarded to every musttail call.
+  SmallVector<ForwardedRegister, 1> ForwardedMustTailRegParms;
 public:
   AArch64FunctionInfo() = default;
 
@@ -209,6 +213,10 @@ public:
     LOHRelated.insert(Args.begin(), Args.end());
   }
 
+  SmallVectorImpl<ForwardedRegister> &getForwardedMustTailRegParms() {
+    return ForwardedMustTailRegParms;
+  }
+
 private:
   // Hold the lists of LOHs.
   MILOHContainer LOHContainerSet;
diff --git a/llvm/test/CodeGen/AArch64/vararg-tallcall.ll b/llvm/test/CodeGen/AArch64/vararg-tallcall.ll
new file mode 100644 (file)
index 0000000..2818222
--- /dev/null
@@ -0,0 +1,34 @@
+; RUN: llc -mtriple=aarch64-windows-msvc %s -o - | FileCheck %s
+; RUN: llc -mtriple=aarch64-linux-gnu %s -o - | FileCheck %s
+
+target datalayout = "e-m:w-p:64:64-i32:32-i64:64-i128:128-n32:64-S128"
+
+%class.X = type { i8 }
+%struct.B = type { i32 (...)** }
+
+$"??_9B@@$BA@AA" = comdat any
+
+; Function Attrs: noinline optnone
+define linkonce_odr void @"??_9B@@$BA@AA"(%struct.B* %this, ...) #1 comdat align 2  {
+entry:
+  %this.addr = alloca %struct.B*, align 8
+  store %struct.B* %this, %struct.B** %this.addr, align 8
+  %this1 = load %struct.B*, %struct.B** %this.addr, align 8
+  call void asm sideeffect "", "~{d0}"()
+  %0 = bitcast %struct.B* %this1 to void (%struct.B*, ...)***
+  %vtable = load void (%struct.B*, ...)**, void (%struct.B*, ...)*** %0, align 8
+  %vfn = getelementptr inbounds void (%struct.B*, ...)*, void (%struct.B*, ...)** %vtable, i64 0
+  %1 = load void (%struct.B*, ...)*, void (%struct.B*, ...)** %vfn, align 8
+  musttail call void (%struct.B*, ...) %1(%struct.B* %this1, ...)
+  ret void
+                                                  ; No predecessors!
+  ret void
+}
+
+attributes #1 = { noinline optnone "thunk" }
+
+; CHECK: mov     v16.16b, v0.16b
+; CHECK: ldr     x8, [x0]
+; CHECK: ldr     x8, [x8]
+; CHECK: mov     v0.16b, v16.16b
+; CHECK: br      x8