[AArch64] Use correct calling convention for each vararg
authorPhilippe Valembois <lephilousophe@users.noreply.github.com>
Thu, 10 Mar 2022 23:05:29 +0000 (15:05 -0800)
committerEli Friedman <efriedma@quicinc.com>
Thu, 10 Mar 2022 23:07:25 +0000 (15:07 -0800)
While checking is tail call optimization is possible, the calling
convention applied to fixed arguments is not the correct one.
This implies for DarwinPCS that all arguments of a vararg function will
go to the stack although fixed ones can go in registers.

This prevents non-virtual thunks to be tail optimized although they are
marked as musttail.

Differential Revision: https://reviews.llvm.org/D120622

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/test/CodeGen/AArch64/darwinpcs-tail.ll [new file with mode: 0644]

index 0489207..37537c9 100644 (file)
@@ -5906,14 +5906,62 @@ static bool mayTailCallThisCC(CallingConv::ID CC) {
   }
 }
 
+static void analyzeCallOperands(const AArch64TargetLowering &TLI,
+                                const AArch64Subtarget *Subtarget,
+                                const TargetLowering::CallLoweringInfo &CLI,
+                                CCState &CCInfo) {
+  const SelectionDAG &DAG = CLI.DAG;
+  CallingConv::ID CalleeCC = CLI.CallConv;
+  bool IsVarArg = CLI.IsVarArg;
+  const SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
+  bool IsCalleeWin64 = Subtarget->isCallingConvWin64(CalleeCC);
+
+  unsigned NumArgs = Outs.size();
+  for (unsigned i = 0; i != NumArgs; ++i) {
+    MVT ArgVT = Outs[i].VT;
+    ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
+
+    bool UseVarArgCC = false;
+    if (IsVarArg) {
+      // On Windows, the fixed arguments in a vararg call are passed in GPRs
+      // too, so use the vararg CC to force them to integer registers.
+      if (IsCalleeWin64) {
+        UseVarArgCC = true;
+      } else {
+        UseVarArgCC = !Outs[i].IsFixed;
+      }
+    } else {
+      // Get type of the original argument.
+      EVT ActualVT =
+          TLI.getValueType(DAG.getDataLayout(), CLI.Args[Outs[i].OrigArgIndex].Ty,
+                       /*AllowUnknown*/ true);
+      MVT ActualMVT = ActualVT.isSimple() ? ActualVT.getSimpleVT() : ArgVT;
+      // If ActualMVT is i1/i8/i16, we should set LocVT to i8/i8/i16.
+      if (ActualMVT == MVT::i1 || ActualMVT == MVT::i8)
+        ArgVT = MVT::i8;
+      else if (ActualMVT == MVT::i16)
+        ArgVT = MVT::i16;
+    }
+
+    CCAssignFn *AssignFn = TLI.CCAssignFnForCall(CalleeCC, UseVarArgCC);
+    bool Res = AssignFn(i, ArgVT, ArgVT, CCValAssign::Full, ArgFlags, CCInfo);
+    assert(!Res && "Call operand has unhandled type");
+    (void)Res;
+  }
+}
+
 bool AArch64TargetLowering::isEligibleForTailCallOptimization(
-    SDValue Callee, CallingConv::ID CalleeCC, bool isVarArg,
-    const SmallVectorImpl<ISD::OutputArg> &Outs,
-    const SmallVectorImpl<SDValue> &OutVals,
-    const SmallVectorImpl<ISD::InputArg> &Ins, SelectionDAG &DAG) const {
+    const CallLoweringInfo &CLI) const {
+  CallingConv::ID CalleeCC = CLI.CallConv;
   if (!mayTailCallThisCC(CalleeCC))
     return false;
 
+  SDValue Callee = CLI.Callee;
+  bool IsVarArg = CLI.IsVarArg;
+  const SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
+  const SmallVector<SDValue, 32> &OutVals = CLI.OutVals;
+  const SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins;
+  const SelectionDAG &DAG = CLI.DAG;
   MachineFunction &MF = DAG.getMachineFunction();
   const Function &CallerF = MF.getFunction();
   CallingConv::ID CallerCC = CallerF.getCallingConv();
@@ -5978,30 +6026,14 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
 
   // I want anyone implementing a new calling convention to think long and hard
   // about this assert.
-  assert((!isVarArg || CalleeCC == CallingConv::C) &&
+  assert((!IsVarArg || CalleeCC == CallingConv::C) &&
          "Unexpected variadic calling convention");
 
   LLVMContext &C = *DAG.getContext();
-  if (isVarArg && !Outs.empty()) {
-    // At least two cases here: if caller is fastcc then we can't have any
-    // memory arguments (we'd be expected to clean up the stack afterwards). If
-    // caller is C then we could potentially use its argument area.
-
-    // FIXME: for now we take the most conservative of these in both cases:
-    // disallow all variadic memory operands.
-    SmallVector<CCValAssign, 16> ArgLocs;
-    CCState CCInfo(CalleeCC, isVarArg, MF, ArgLocs, C);
-
-    CCInfo.AnalyzeCallOperands(Outs, CCAssignFnForCall(CalleeCC, true));
-    for (const CCValAssign &ArgLoc : ArgLocs)
-      if (!ArgLoc.isRegLoc())
-        return false;
-  }
-
   // Check that the call results are passed in the same way.
   if (!CCState::resultsCompatible(CalleeCC, CallerCC, MF, C, Ins,
-                                  CCAssignFnForCall(CalleeCC, isVarArg),
-                                  CCAssignFnForCall(CallerCC, isVarArg)))
+                                  CCAssignFnForCall(CalleeCC, IsVarArg),
+                                  CCAssignFnForCall(CallerCC, IsVarArg)))
     return false;
   // The callee has to preserve all registers the caller needs to preserve.
   const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
@@ -6021,9 +6053,22 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
     return true;
 
   SmallVector<CCValAssign, 16> ArgLocs;
-  CCState CCInfo(CalleeCC, isVarArg, MF, ArgLocs, C);
+  CCState CCInfo(CalleeCC, IsVarArg, MF, ArgLocs, C);
+
+  analyzeCallOperands(*this, Subtarget, CLI, CCInfo);
+
+  if (IsVarArg && !(CLI.CB && CLI.CB->isMustTailCall())) {
+    // When we are musttail, additional checks have been done and we can safely ignore this check
+    // At least two cases here: if caller is fastcc then we can't have any
+    // memory arguments (we'd be expected to clean up the stack afterwards). If
+    // caller is C then we could potentially use its argument area.
 
-  CCInfo.AnalyzeCallOperands(Outs, CCAssignFnForCall(CalleeCC, isVarArg));
+    // FIXME: for now we take the most conservative of these in both cases:
+    // disallow all variadic memory operands.
+    for (const CCValAssign &ArgLoc : ArgLocs)
+      if (!ArgLoc.isRegLoc())
+        return false;
+  }
 
   const AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
 
@@ -6114,7 +6159,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   SDValue Chain = CLI.Chain;
   SDValue Callee = CLI.Callee;
   bool &IsTailCall = CLI.IsTailCall;
-  CallingConv::ID CallConv = CLI.CallConv;
+  CallingConv::ID &CallConv = CLI.CallConv;
   bool IsVarArg = CLI.IsVarArg;
 
   MachineFunction &MF = DAG.getMachineFunction();
@@ -6124,7 +6169,6 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
   bool TailCallOpt = MF.getTarget().Options.GuaranteedTailCallOpt;
   bool IsSibCall = false;
-  bool IsCalleeWin64 = Subtarget->isCallingConvWin64(CallConv);
 
   // Check callee args/returns for SVE registers and set calling convention
   // accordingly.
@@ -6142,8 +6186,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
 
   if (IsTailCall) {
     // Check if it's really possible to do a tail call.
-    IsTailCall = isEligibleForTailCallOptimization(
-        Callee, CallConv, IsVarArg, Outs, OutVals, Ins, DAG);
+    IsTailCall = isEligibleForTailCallOptimization(CLI);
 
     // A sibling call is one where we're under the usual C ABI and not planning
     // to change that but can still do a tail call:
@@ -6164,56 +6207,17 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   CCState CCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext());
 
   if (IsVarArg) {
-    // Handle fixed and variable vector arguments differently.
-    // Variable vector arguments always go into memory.
     unsigned NumArgs = Outs.size();
 
     for (unsigned i = 0; i != NumArgs; ++i) {
-      MVT ArgVT = Outs[i].VT;
-      if (!Outs[i].IsFixed && ArgVT.isScalableVector())
+      if (!Outs[i].IsFixed && Outs[i].VT.isScalableVector())
         report_fatal_error("Passing SVE types to variadic functions is "
                            "currently not supported");
-
-      ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
-      bool UseVarArgCC = !Outs[i].IsFixed;
-      // On Windows, the fixed arguments in a vararg call are passed in GPRs
-      // too, so use the vararg CC to force them to integer registers.
-      if (IsCalleeWin64)
-        UseVarArgCC = true;
-      CCAssignFn *AssignFn = CCAssignFnForCall(CallConv, UseVarArgCC);
-      bool Res = AssignFn(i, ArgVT, ArgVT, CCValAssign::Full, ArgFlags, CCInfo);
-      assert(!Res && "Call operand has unhandled type");
-      (void)Res;
-    }
-  } else {
-    // At this point, Outs[].VT may already be promoted to i32. To correctly
-    // handle passing i8 as i8 instead of i32 on stack, we pass in both i32 and
-    // i8 to CC_AArch64_AAPCS with i32 being ValVT and i8 being LocVT.
-    // Since AnalyzeCallOperands uses Ins[].VT for both ValVT and LocVT, here
-    // we use a special version of AnalyzeCallOperands to pass in ValVT and
-    // LocVT.
-    unsigned NumArgs = Outs.size();
-    for (unsigned i = 0; i != NumArgs; ++i) {
-      MVT ValVT = Outs[i].VT;
-      // Get type of the original argument.
-      EVT ActualVT = getValueType(DAG.getDataLayout(),
-                                  CLI.getArgs()[Outs[i].OrigArgIndex].Ty,
-                                  /*AllowUnknown*/ true);
-      MVT ActualMVT = ActualVT.isSimple() ? ActualVT.getSimpleVT() : ValVT;
-      ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
-      // If ActualMVT is i1/i8/i16, we should set LocVT to i8/i8/i16.
-      if (ActualMVT == MVT::i1 || ActualMVT == MVT::i8)
-        ValVT = MVT::i8;
-      else if (ActualMVT == MVT::i16)
-        ValVT = MVT::i16;
-
-      CCAssignFn *AssignFn = CCAssignFnForCall(CallConv, /*IsVarArg=*/false);
-      bool Res = AssignFn(i, ValVT, ValVT, CCValAssign::Full, ArgFlags, CCInfo);
-      assert(!Res && "Call operand has unhandled type");
-      (void)Res;
     }
   }
 
+  analyzeCallOperands(*this, Subtarget, CLI, CCInfo);
+
   // Get a count of how many bytes are to be pushed on the stack.
   unsigned NumBytes = CCInfo.getNextStackOffset();
 
index 0d2df10..eb88572 100644 (file)
@@ -889,11 +889,8 @@ private:
   SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const;
 
-  bool isEligibleForTailCallOptimization(
-      SDValue Callee, CallingConv::ID CalleeCC, bool isVarArg,
-      const SmallVectorImpl<ISD::OutputArg> &Outs,
-      const SmallVectorImpl<SDValue> &OutVals,
-      const SmallVectorImpl<ISD::InputArg> &Ins, SelectionDAG &DAG) const;
+  bool
+  isEligibleForTailCallOptimization(const CallLoweringInfo &CLI) const;
 
   /// Finds the incoming stack arguments which overlap the given fixed stack
   /// object and incorporates their load into the current chain. This prevents
diff --git a/llvm/test/CodeGen/AArch64/darwinpcs-tail.ll b/llvm/test/CodeGen/AArch64/darwinpcs-tail.ll
new file mode 100644 (file)
index 0000000..9d13ed6
--- /dev/null
@@ -0,0 +1,36 @@
+; With Darwin PCS, non-virtual thunks generated are generated with musttail
+; and are expected to build
+; In general Darwin PCS should be tail optimized
+; RUN: llc -mtriple=arm64-apple-ios5.0.0 < %s | FileCheck %s
+
+; CHECK-LABEL: __ZThn16_N1C3addEPKcz:
+; CHECK:       b __ZN1C3addEPKcz
+; CHECK-LABEL: _tailTest:
+; CHECK:       b __ZN1C3addEPKcz
+; CHECK-LABEL: __ZThn8_N1C1fEiiiiiiiiiz:
+; CHECK:       ldr     w9, [sp, #4]
+; CHECK:       str     w9, [sp, #4]
+; CHECK:       b __ZN1C1fEiiiiiiiiiz
+
+%class.C = type { %class.A.base, [4 x i8], %class.B.base, [4 x i8] }
+%class.A.base = type <{ i32 (...)**, i32 }>
+%class.B.base = type <{ i32 (...)**, i32 }>
+
+declare void @_ZN1C3addEPKcz(%class.C*, i8*, ...) unnamed_addr #0 align 2
+
+define void @_ZThn16_N1C3addEPKcz(%class.C* %0, i8* %1, ...) unnamed_addr #0 align 2 {
+  musttail call void (%class.C*, i8*, ...) @_ZN1C3addEPKcz(%class.C* noundef nonnull align 8 dereferenceable(28) undef, i8* noundef %1, ...)
+  ret void
+}
+
+define void @tailTest(%class.C* %0, i8* %1, ...) unnamed_addr #0 align 2 {
+  tail call void (%class.C*, i8*, ...) @_ZN1C3addEPKcz(%class.C* noundef nonnull align 8 dereferenceable(28) undef, i8* noundef %1)
+  ret void
+}
+
+declare void @_ZN1C1fEiiiiiiiiiz(%class.C* %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8, i32 noundef %9, ...) unnamed_addr #1 align 2
+
+define void @_ZThn8_N1C1fEiiiiiiiiiz(%class.C* %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8, i32 noundef %9, ...) unnamed_addr #1 align 2 {
+  musttail call void (%class.C*, i32, i32, i32, i32, i32, i32, i32, i32, i32, ...) @_ZN1C1fEiiiiiiiiiz(%class.C* nonnull align 8 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 undef, i32 noundef %9, ...)
+  ret void
+}