[AArch64][GlobalISel] Add support for sret demotion.
authorAmara Emerson <amara@apple.com>
Mon, 4 Jul 2022 23:43:56 +0000 (16:43 -0700)
committerAmara Emerson <amara@apple.com>
Tue, 5 Jul 2022 22:23:47 +0000 (15:23 -0700)
To do this, we need to implement a target hook and make a minor change to the
call lowering to demote arguments to sret if they can't be handled by the
calling conventions.

Fixes issue 56295

Differential Revision: https://reviews.llvm.org/D129098

llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp
llvm/lib/Target/AArch64/GISel/AArch64CallLowering.h
llvm/test/CodeGen/AArch64/GlobalISel/call-lowering-sret-demotion.ll [new file with mode: 0644]

index 89e1d85..aaef363 100644 (file)
@@ -21,6 +21,7 @@
 #include "llvm/Analysis/ObjCARCUtil.h"
 #include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/CallingConvLower.h"
+#include "llvm/CodeGen/FunctionLoweringInfo.h"
 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
 #include "llvm/CodeGen/GlobalISel/Utils.h"
 #include "llvm/CodeGen/LowLevelType.h"
@@ -354,7 +355,9 @@ bool AArch64CallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
          "Return value without a vreg");
 
   bool Success = true;
-  if (!VRegs.empty()) {
+  if (!FLI.CanLowerReturn) {
+    insertSRetStores(MIRBuilder, Val->getType(), VRegs, FLI.DemoteRegister);
+  } else if (!VRegs.empty()) {
     MachineFunction &MF = MIRBuilder.getMF();
     const Function &F = MF.getFunction();
     const AArch64Subtarget &Subtarget = MF.getSubtarget<AArch64Subtarget>();
@@ -464,6 +467,18 @@ bool AArch64CallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
   return Success;
 }
 
+bool AArch64CallLowering::canLowerReturn(MachineFunction &MF,
+                                         CallingConv::ID CallConv,
+                                         SmallVectorImpl<BaseArgInfo> &Outs,
+                                         bool IsVarArg) const {
+  SmallVector<CCValAssign, 16> ArgLocs;
+  const auto &TLI = *getTLI<AArch64TargetLowering>();
+  CCState CCInfo(CallConv, IsVarArg, MF, ArgLocs,
+                 MF.getFunction().getContext());
+
+  return checkReturn(CCInfo, Outs, TLI.CCAssignFnForReturn(CallConv));
+}
+
 /// Helper function to compute forwarded registers for musttail calls. Computes
 /// the forwarded registers, sets MBB liveness, and emits COPY instructions that
 /// can be used to save + restore registers later.
@@ -533,6 +548,12 @@ bool AArch64CallLowering::lowerFormalArguments(
 
   SmallVector<ArgInfo, 8> SplitArgs;
   SmallVector<std::pair<Register, Register>> BoolArgs;
+
+  // Insert the hidden sret parameter if the return value won't fit in the
+  // return registers.
+  if (!FLI.CanLowerReturn)
+    insertSRetIncomingArgument(F, SplitArgs, FLI.DemoteRegister, MRI, DL);
+
   unsigned i = 0;
   for (auto &Arg : F.args()) {
     if (DL.getTypeStoreSize(Arg.getType()).isZero())
@@ -1194,7 +1215,7 @@ bool AArch64CallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   // Finally we can copy the returned value back into its virtual-register. In
   // symmetry with the arguments, the physical register must be an
   // implicit-define of the call instruction.
-  if (!Info.OrigRet.Ty->isVoidTy()) {
+  if (Info.CanLowerReturn  && !Info.OrigRet.Ty->isVoidTy()) {
     CCAssignFn *RetAssignFn = TLI.CCAssignFnForReturn(Info.CallConv);
     CallReturnHandler Handler(MIRBuilder, MRI, MIB);
     bool UsingReturnedArg =
@@ -1226,6 +1247,10 @@ bool AArch64CallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
       .addImm(Assigner.StackOffset)
       .addImm(CalleePopBytes);
 
+  if (!Info.CanLowerReturn) {
+    insertSRetLoads(MIRBuilder, Info.OrigRet.Ty, Info.OrigRet.Regs,
+                    Info.DemoteRegister, Info.DemoteStackIndex);
+  }
   return true;
 }
 
index aafb1d1..cbdf77f 100644 (file)
@@ -35,6 +35,10 @@ public:
                    ArrayRef<Register> VRegs, FunctionLoweringInfo &FLI,
                    Register SwiftErrorVReg) const override;
 
+  bool canLowerReturn(MachineFunction &MF, CallingConv::ID CallConv,
+                      SmallVectorImpl<BaseArgInfo> &Outs,
+                      bool IsVarArg) const override;
+
   bool fallBackToDAGISel(const MachineFunction &MF) const override;
 
   bool lowerFormalArguments(MachineIRBuilder &MIRBuilder, const Function &F,
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/call-lowering-sret-demotion.ll b/llvm/test/CodeGen/AArch64/GlobalISel/call-lowering-sret-demotion.ll
new file mode 100644 (file)
index 0000000..a8520af
--- /dev/null
@@ -0,0 +1,119 @@
+; NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py
+; RUN: llc -mtriple=aarch64 -global-isel -stop-after=irtranslator -verify-machineinstrs -o - %s | FileCheck %s
+
+
+define [9 x i64] @callee_sret_demotion() {
+  ; CHECK-LABEL: name: callee_sret_demotion
+  ; CHECK: bb.1 (%ir-block.0):
+  ; CHECK-NEXT:   liveins: $x8
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT:   [[COPY:%[0-9]+]]:_(p0) = COPY $x8
+  ; CHECK-NEXT:   [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 0
+  ; CHECK-NEXT:   G_STORE [[C]](s64), [[COPY]](p0) :: (store (s64))
+  ; CHECK-NEXT:   [[C1:%[0-9]+]]:_(s64) = G_CONSTANT i64 8
+  ; CHECK-NEXT:   [[PTR_ADD:%[0-9]+]]:_(p0) = G_PTR_ADD [[COPY]], [[C1]](s64)
+  ; CHECK-NEXT:   G_STORE [[C]](s64), [[PTR_ADD]](p0) :: (store (s64))
+  ; CHECK-NEXT:   [[C2:%[0-9]+]]:_(s64) = G_CONSTANT i64 16
+  ; CHECK-NEXT:   [[PTR_ADD1:%[0-9]+]]:_(p0) = G_PTR_ADD [[COPY]], [[C2]](s64)
+  ; CHECK-NEXT:   G_STORE [[C]](s64), [[PTR_ADD1]](p0) :: (store (s64))
+  ; CHECK-NEXT:   [[C3:%[0-9]+]]:_(s64) = G_CONSTANT i64 24
+  ; CHECK-NEXT:   [[PTR_ADD2:%[0-9]+]]:_(p0) = G_PTR_ADD [[COPY]], [[C3]](s64)
+  ; CHECK-NEXT:   G_STORE [[C]](s64), [[PTR_ADD2]](p0) :: (store (s64))
+  ; CHECK-NEXT:   [[C4:%[0-9]+]]:_(s64) = G_CONSTANT i64 32
+  ; CHECK-NEXT:   [[PTR_ADD3:%[0-9]+]]:_(p0) = G_PTR_ADD [[COPY]], [[C4]](s64)
+  ; CHECK-NEXT:   G_STORE [[C]](s64), [[PTR_ADD3]](p0) :: (store (s64))
+  ; CHECK-NEXT:   [[C5:%[0-9]+]]:_(s64) = G_CONSTANT i64 40
+  ; CHECK-NEXT:   [[PTR_ADD4:%[0-9]+]]:_(p0) = G_PTR_ADD [[COPY]], [[C5]](s64)
+  ; CHECK-NEXT:   G_STORE [[C]](s64), [[PTR_ADD4]](p0) :: (store (s64))
+  ; CHECK-NEXT:   [[C6:%[0-9]+]]:_(s64) = G_CONSTANT i64 48
+  ; CHECK-NEXT:   [[PTR_ADD5:%[0-9]+]]:_(p0) = G_PTR_ADD [[COPY]], [[C6]](s64)
+  ; CHECK-NEXT:   G_STORE [[C]](s64), [[PTR_ADD5]](p0) :: (store (s64))
+  ; CHECK-NEXT:   [[C7:%[0-9]+]]:_(s64) = G_CONSTANT i64 56
+  ; CHECK-NEXT:   [[PTR_ADD6:%[0-9]+]]:_(p0) = G_PTR_ADD [[COPY]], [[C7]](s64)
+  ; CHECK-NEXT:   G_STORE [[C]](s64), [[PTR_ADD6]](p0) :: (store (s64))
+  ; CHECK-NEXT:   [[C8:%[0-9]+]]:_(s64) = G_CONSTANT i64 64
+  ; CHECK-NEXT:   [[PTR_ADD7:%[0-9]+]]:_(p0) = G_PTR_ADD [[COPY]], [[C8]](s64)
+  ; CHECK-NEXT:   G_STORE [[C]](s64), [[PTR_ADD7]](p0) :: (store (s64))
+  ; CHECK-NEXT:   RET_ReallyLR
+  ret [9 x i64] zeroinitializer
+}
+
+define i64 @caller() {
+  ; CHECK-LABEL: name: caller
+  ; CHECK: bb.1 (%ir-block.0):
+  ; CHECK-NEXT:   [[FRAME_INDEX:%[0-9]+]]:_(p0) = G_FRAME_INDEX %stack.0
+  ; CHECK-NEXT:   ADJCALLSTACKDOWN 0, 0, implicit-def $sp, implicit $sp
+  ; CHECK-NEXT:   $x8 = COPY [[FRAME_INDEX]](p0)
+  ; CHECK-NEXT:   BL @callee_sret_demotion, csr_aarch64_aapcs, implicit-def $lr, implicit $sp, implicit $x8
+  ; CHECK-NEXT:   ADJCALLSTACKUP 0, 0, implicit-def $sp, implicit $sp
+  ; CHECK-NEXT:   [[LOAD:%[0-9]+]]:_(s64) = G_LOAD [[FRAME_INDEX]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 8
+  ; CHECK-NEXT:   [[PTR_ADD:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[C]](s64)
+  ; CHECK-NEXT:   [[LOAD1:%[0-9]+]]:_(s64) = G_LOAD [[PTR_ADD]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   [[C1:%[0-9]+]]:_(s64) = G_CONSTANT i64 16
+  ; CHECK-NEXT:   [[PTR_ADD1:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[C1]](s64)
+  ; CHECK-NEXT:   [[LOAD2:%[0-9]+]]:_(s64) = G_LOAD [[PTR_ADD1]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   [[C2:%[0-9]+]]:_(s64) = G_CONSTANT i64 24
+  ; CHECK-NEXT:   [[PTR_ADD2:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[C2]](s64)
+  ; CHECK-NEXT:   [[LOAD3:%[0-9]+]]:_(s64) = G_LOAD [[PTR_ADD2]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   [[C3:%[0-9]+]]:_(s64) = G_CONSTANT i64 32
+  ; CHECK-NEXT:   [[PTR_ADD3:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[C3]](s64)
+  ; CHECK-NEXT:   [[LOAD4:%[0-9]+]]:_(s64) = G_LOAD [[PTR_ADD3]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   [[C4:%[0-9]+]]:_(s64) = G_CONSTANT i64 40
+  ; CHECK-NEXT:   [[PTR_ADD4:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[C4]](s64)
+  ; CHECK-NEXT:   [[LOAD5:%[0-9]+]]:_(s64) = G_LOAD [[PTR_ADD4]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   [[C5:%[0-9]+]]:_(s64) = G_CONSTANT i64 48
+  ; CHECK-NEXT:   [[PTR_ADD5:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[C5]](s64)
+  ; CHECK-NEXT:   [[LOAD6:%[0-9]+]]:_(s64) = G_LOAD [[PTR_ADD5]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   [[C6:%[0-9]+]]:_(s64) = G_CONSTANT i64 56
+  ; CHECK-NEXT:   [[PTR_ADD6:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[C6]](s64)
+  ; CHECK-NEXT:   [[LOAD7:%[0-9]+]]:_(s64) = G_LOAD [[PTR_ADD6]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   [[C7:%[0-9]+]]:_(s64) = G_CONSTANT i64 64
+  ; CHECK-NEXT:   [[PTR_ADD7:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[C7]](s64)
+  ; CHECK-NEXT:   [[LOAD8:%[0-9]+]]:_(s64) = G_LOAD [[PTR_ADD7]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   $x0 = COPY [[LOAD4]](s64)
+  ; CHECK-NEXT:   RET_ReallyLR implicit $x0
+  %res = call [9 x i64] @callee_sret_demotion()
+  %val = extractvalue [9 x i64] %res, 4
+  ret i64 %val
+}
+
+define i64 @caller_tail() {
+  ; CHECK-LABEL: name: caller_tail
+  ; CHECK: bb.1 (%ir-block.0):
+  ; CHECK-NEXT:   [[FRAME_INDEX:%[0-9]+]]:_(p0) = G_FRAME_INDEX %stack.0
+  ; CHECK-NEXT:   ADJCALLSTACKDOWN 0, 0, implicit-def $sp, implicit $sp
+  ; CHECK-NEXT:   $x8 = COPY [[FRAME_INDEX]](p0)
+  ; CHECK-NEXT:   BL @callee_sret_demotion, csr_aarch64_aapcs, implicit-def $lr, implicit $sp, implicit $x8
+  ; CHECK-NEXT:   ADJCALLSTACKUP 0, 0, implicit-def $sp, implicit $sp
+  ; CHECK-NEXT:   [[LOAD:%[0-9]+]]:_(s64) = G_LOAD [[FRAME_INDEX]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 8
+  ; CHECK-NEXT:   [[PTR_ADD:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[C]](s64)
+  ; CHECK-NEXT:   [[LOAD1:%[0-9]+]]:_(s64) = G_LOAD [[PTR_ADD]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   [[C1:%[0-9]+]]:_(s64) = G_CONSTANT i64 16
+  ; CHECK-NEXT:   [[PTR_ADD1:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[C1]](s64)
+  ; CHECK-NEXT:   [[LOAD2:%[0-9]+]]:_(s64) = G_LOAD [[PTR_ADD1]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   [[C2:%[0-9]+]]:_(s64) = G_CONSTANT i64 24
+  ; CHECK-NEXT:   [[PTR_ADD2:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[C2]](s64)
+  ; CHECK-NEXT:   [[LOAD3:%[0-9]+]]:_(s64) = G_LOAD [[PTR_ADD2]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   [[C3:%[0-9]+]]:_(s64) = G_CONSTANT i64 32
+  ; CHECK-NEXT:   [[PTR_ADD3:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[C3]](s64)
+  ; CHECK-NEXT:   [[LOAD4:%[0-9]+]]:_(s64) = G_LOAD [[PTR_ADD3]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   [[C4:%[0-9]+]]:_(s64) = G_CONSTANT i64 40
+  ; CHECK-NEXT:   [[PTR_ADD4:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[C4]](s64)
+  ; CHECK-NEXT:   [[LOAD5:%[0-9]+]]:_(s64) = G_LOAD [[PTR_ADD4]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   [[C5:%[0-9]+]]:_(s64) = G_CONSTANT i64 48
+  ; CHECK-NEXT:   [[PTR_ADD5:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[C5]](s64)
+  ; CHECK-NEXT:   [[LOAD6:%[0-9]+]]:_(s64) = G_LOAD [[PTR_ADD5]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   [[C6:%[0-9]+]]:_(s64) = G_CONSTANT i64 56
+  ; CHECK-NEXT:   [[PTR_ADD6:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[C6]](s64)
+  ; CHECK-NEXT:   [[LOAD7:%[0-9]+]]:_(s64) = G_LOAD [[PTR_ADD6]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   [[C7:%[0-9]+]]:_(s64) = G_CONSTANT i64 64
+  ; CHECK-NEXT:   [[PTR_ADD7:%[0-9]+]]:_(p0) = G_PTR_ADD [[FRAME_INDEX]], [[C7]](s64)
+  ; CHECK-NEXT:   [[LOAD8:%[0-9]+]]:_(s64) = G_LOAD [[PTR_ADD7]](p0) :: (load (s64) from %stack.0)
+  ; CHECK-NEXT:   $x0 = COPY [[LOAD4]](s64)
+  ; CHECK-NEXT:   RET_ReallyLR implicit $x0
+  %res = tail call [9 x i64] @callee_sret_demotion()
+  %val = extractvalue [9 x i64] %res, 4
+  ret i64 %val
+}