[GlobalISel][AArch64] Allow CallLowering to handle types which are normally
authorAmara Emerson <aemerson@apple.com>
Tue, 9 Apr 2019 21:22:33 +0000 (21:22 +0000)
committerAmara Emerson <aemerson@apple.com>
Tue, 9 Apr 2019 21:22:33 +0000 (21:22 +0000)
required to be passed as different register types. E.g. <2 x i16> may need to
be passed as a larger <2 x i32> type, so formal arg lowering needs to be able
truncate it back. Likewise, when dealing with returns of these types, they need
to be widened in the appropriate way back.

Differential Revision: https://reviews.llvm.org/D60425

llvm-svn: 358032

llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h
llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
llvm/lib/Target/AArch64/AArch64CallLowering.cpp
llvm/lib/Target/ARM/ARMCallLowering.cpp
llvm/lib/Target/X86/X86CallLowering.cpp
llvm/test/CodeGen/AArch64/GlobalISel/ret-vec-promote.ll [new file with mode: 0644]
llvm/test/CodeGen/AArch64/GlobalISel/vec-s16-param.ll [new file with mode: 0644]
llvm/test/CodeGen/ARM/GlobalISel/arm-unsupported.ll

index 9b72b70..af40d4b 100644 (file)
@@ -65,6 +65,10 @@ public:
 
     virtual ~ValueHandler() = default;
 
+    /// Returns true if the handler is dealing with formal arguments,
+    /// not with return values etc.
+    virtual bool isArgumentHandler() const { return false; }
+
     /// Materialize a VReg containing the address of the specified
     /// stack-based object. This is either based on a FrameIndex or
     /// direct SP manipulation, depending on the context. \p MPO
index a7c2a2e..47fdeed 100644 (file)
@@ -20,6 +20,8 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
 
+#define DEBUG_TYPE "call-lowering"
+
 using namespace llvm;
 
 void CallLowering::anchor() {}
@@ -121,8 +123,15 @@ bool CallLowering::handleAssignments(MachineIRBuilder &MIRBuilder,
   unsigned NumArgs = Args.size();
   for (unsigned i = 0; i != NumArgs; ++i) {
     MVT CurVT = MVT::getVT(Args[i].Ty);
-    if (Handler.assignArg(i, CurVT, CurVT, CCValAssign::Full, Args[i], CCInfo))
-      return false;
+    if (Handler.assignArg(i, CurVT, CurVT, CCValAssign::Full, Args[i], CCInfo)) {
+      // Try to use the register type if we couldn't assign the VT.
+      if (!Handler.isArgumentHandler())
+        return false; 
+      CurVT = TLI->getRegisterTypeForCallingConv(
+          F.getContext(), F.getCallingConv(), EVT(CurVT));
+      if (Handler.assignArg(i, CurVT, CurVT, CCValAssign::Full, Args[i], CCInfo))
+        return false;
+    }
   }
 
   for (unsigned i = 0, e = Args.size(), j = 0; i != e; ++i, ++j) {
@@ -136,12 +145,39 @@ bool CallLowering::handleAssignments(MachineIRBuilder &MIRBuilder,
       continue;
     }
 
-    if (VA.isRegLoc())
-      Handler.assignValueToReg(Args[i].Reg, VA.getLocReg(), VA);
-    else if (VA.isMemLoc()) {
-      unsigned Size = VA.getValVT() == MVT::iPTR
-                          ? DL.getPointerSize()
-                          : alignTo(VA.getValVT().getSizeInBits(), 8) / 8;
+    if (VA.isRegLoc()) {
+      MVT OrigVT = MVT::getVT(Args[i].Ty);
+      MVT VAVT = VA.getValVT();
+      if (Handler.isArgumentHandler() && VAVT != OrigVT) {
+        if (VAVT.getSizeInBits() < OrigVT.getSizeInBits())
+          return false; // Can't handle this type of arg yet.
+        const LLT VATy(VAVT);
+        unsigned NewReg =
+            MIRBuilder.getMRI()->createGenericVirtualRegister(VATy);
+        Handler.assignValueToReg(NewReg, VA.getLocReg(), VA);
+        // If it's a vector type, we either need to truncate the elements
+        // or do an unmerge to get the lower block of elements.
+        if (VATy.isVector() &&
+            VATy.getNumElements() > OrigVT.getVectorNumElements()) {
+          const LLT OrigTy(OrigVT);
+          // Just handle the case where the VA type is 2 * original type.
+          if (VATy.getNumElements() != OrigVT.getVectorNumElements() * 2) {
+            LLVM_DEBUG(dbgs()
+                       << "Incoming promoted vector arg has too many elts");
+            return false;
+          }
+          auto Unmerge = MIRBuilder.buildUnmerge({OrigTy, OrigTy}, {NewReg});
+          MIRBuilder.buildCopy(Args[i].Reg, Unmerge.getReg(0));
+        } else {
+          MIRBuilder.buildTrunc(Args[i].Reg, {NewReg}).getReg(0);
+        }
+      } else {
+        Handler.assignValueToReg(Args[i].Reg, VA.getLocReg(), VA);
+      }
+    } else if (VA.isMemLoc()) {
+      MVT VT = MVT::getVT(Args[i].Ty);
+      unsigned Size = VT == MVT::iPTR ? DL.getPointerSize()
+                                      : alignTo(VT.getSizeInBits(), 8) / 8;
       unsigned Offset = VA.getLocMemOffset();
       MachinePointerInfo MPO;
       unsigned StackAddr = Handler.getStackAddress(Size, Offset, MPO);
@@ -157,6 +193,8 @@ bool CallLowering::handleAssignments(MachineIRBuilder &MIRBuilder,
 unsigned CallLowering::ValueHandler::extendRegister(unsigned ValReg,
                                                     CCValAssign &VA) {
   LLT LocTy{VA.getLocVT()};
+  if (LocTy.getSizeInBits() == MRI.getType(ValReg).getSizeInBits())
+    return ValReg;
   switch (VA.getLocInfo()) {
   default: break;
   case CCValAssign::Full:
index 8a00a3f..83054ee 100644 (file)
@@ -44,6 +44,8 @@
 #include <cstdint>
 #include <iterator>
 
+#define DEBUG_TYPE "aarch64-call-lowering"
+
 using namespace llvm;
 
 AArch64CallLowering::AArch64CallLowering(const AArch64TargetLowering &TLI)
@@ -97,6 +99,8 @@ struct IncomingArgHandler : public CallLowering::ValueHandler {
   /// (it's an implicit-def of the BL).
   virtual void markPhysRegUsed(unsigned PhysReg) = 0;
 
+  bool isArgumentHandler() const override { return true; }
+
   uint64_t StackUsed;
 };
 
@@ -250,18 +254,63 @@ bool AArch64CallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
            "For each split Type there should be exactly one VReg.");
 
     SmallVector<ArgInfo, 8> SplitArgs;
+    CallingConv::ID CC = F.getCallingConv();
+
     for (unsigned i = 0; i < SplitEVTs.size(); ++i) {
-      // We zero-extend i1s to i8.
-      unsigned CurVReg = VRegs[i];
-      if (MRI.getType(VRegs[i]).getSizeInBits() == 1) {
-        CurVReg = MIRBuilder.buildZExt(LLT::scalar(8), CurVReg)
-                       ->getOperand(0)
-                       .getReg();
+      if (TLI.getNumRegistersForCallingConv(Ctx, CC, SplitEVTs[i]) > 1) {
+        LLVM_DEBUG(dbgs() << "Can't handle extended arg types which need split");
+        return false;
       }
 
+      unsigned CurVReg = VRegs[i];
       ArgInfo CurArgInfo = ArgInfo{CurVReg, SplitEVTs[i].getTypeForEVT(Ctx)};
       setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, F);
-      splitToValueTypes(CurArgInfo, SplitArgs, DL, MRI, F.getCallingConv(),
+
+      // i1 is a special case because SDAG i1 true is naturally zero extended
+      // when widened using ANYEXT. We need to do it explicitly here.
+      if (MRI.getType(CurVReg).getSizeInBits() == 1) {
+        CurVReg = MIRBuilder.buildZExt(LLT::scalar(8), CurVReg).getReg(0);
+      } else {
+        // Some types will need extending as specified by the CC.
+        MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CC, SplitEVTs[i]);
+        if (EVT(NewVT) != SplitEVTs[i]) {
+          unsigned ExtendOp = TargetOpcode::G_ANYEXT;
+          if (F.getAttributes().hasAttribute(AttributeList::ReturnIndex,
+                                             Attribute::SExt))
+            ExtendOp = TargetOpcode::G_SEXT;
+          else if (F.getAttributes().hasAttribute(AttributeList::ReturnIndex,
+                                                  Attribute::ZExt))
+            ExtendOp = TargetOpcode::G_ZEXT;
+
+          LLT NewLLT(NewVT);
+          LLT OldLLT(MVT::getVT(CurArgInfo.Ty));
+          CurArgInfo.Ty = EVT(NewVT).getTypeForEVT(Ctx);
+          // Instead of an extend, we might have a vector type which needs
+          // padding with more elements, e.g. <2 x half> -> <4 x half>
+          if (NewVT.isVector() &&
+              NewLLT.getNumElements() > OldLLT.getNumElements()) {
+            // We don't handle VA types which are not exactly twice the size,
+            // but can easily be done in future.
+            if (NewLLT.getNumElements() != OldLLT.getNumElements() * 2) {
+              LLVM_DEBUG(dbgs() << "Outgoing vector ret has too many elts");
+              return false;
+            }
+            auto Undef = MIRBuilder.buildUndef({OldLLT});
+            CurVReg =
+                MIRBuilder.buildMerge({NewLLT}, {CurVReg, Undef.getReg(0)})
+                    .getReg(0);
+          } else {
+            CurVReg =
+                MIRBuilder.buildInstr(ExtendOp, {NewLLT}, {CurVReg}).getReg(0);
+          }
+        }
+      }
+      if (CurVReg != CurArgInfo.Reg) {
+        CurArgInfo.Reg = CurVReg;
+        // Reset the arg flags after modifying CurVReg.
+        setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, F);
+      }
+     splitToValueTypes(CurArgInfo, SplitArgs, DL, MRI, CC,
                         [&](unsigned Reg, uint64_t Offset) {
                           MIRBuilder.buildExtract(Reg, CurVReg, Offset);
                         });
index def7c5c..b70d55f 100644 (file)
@@ -301,6 +301,8 @@ struct IncomingValueHandler : public CallLowering::ValueHandler {
                        CCAssignFn AssignFn)
       : ValueHandler(MIRBuilder, MRI, AssignFn) {}
 
+  bool isArgumentHandler() const override { return true; }
+
   unsigned getStackAddress(uint64_t Size, int64_t Offset,
                            MachinePointerInfo &MPO) override {
     assert((Size == 1 || Size == 2 || Size == 4 || Size == 8) &&
index 048e4ca..5a623db 100644 (file)
@@ -228,6 +228,8 @@ struct IncomingValueHandler : public CallLowering::ValueHandler {
       : ValueHandler(MIRBuilder, MRI, AssignFn),
         DL(MIRBuilder.getMF().getDataLayout()) {}
 
+  bool isArgumentHandler() const override { return true; }
+
   unsigned getStackAddress(uint64_t Size, int64_t Offset,
                            MachinePointerInfo &MPO) override {
     auto &MFI = MIRBuilder.getMF().getFrameInfo();
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/ret-vec-promote.ll b/llvm/test/CodeGen/AArch64/GlobalISel/ret-vec-promote.ll
new file mode 100644 (file)
index 0000000..2d39203
--- /dev/null
@@ -0,0 +1,16 @@
+; NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -O0 -global-isel -stop-after=irtranslator -o - %s | FileCheck %s
+
+; Tests vectors of i1 types can appropriately extended first before return handles it.
+define <4 x i1> @ret_v4i1(<4 x i1> *%v) {
+  ; CHECK-LABEL: name: ret_v4i1
+  ; CHECK: bb.1 (%ir-block.0):
+  ; CHECK:   liveins: $x0
+  ; CHECK:   [[COPY:%[0-9]+]]:_(p0) = COPY $x0
+  ; CHECK:   [[LOAD:%[0-9]+]]:_(<4 x s1>) = G_LOAD [[COPY]](p0) :: (load 1 from %ir.v, align 4)
+  ; CHECK:   [[ANYEXT:%[0-9]+]]:_(<4 x s16>) = G_ANYEXT [[LOAD]](<4 x s1>)
+  ; CHECK:   $d0 = COPY [[ANYEXT]](<4 x s16>)
+  ; CHECK:   RET_ReallyLR implicit $d0
+  %v2 = load <4 x i1>, <4 x i1> *%v
+  ret <4 x i1> %v2
+}
diff --git a/llvm/test/CodeGen/AArch64/GlobalISel/vec-s16-param.ll b/llvm/test/CodeGen/AArch64/GlobalISel/vec-s16-param.ll
new file mode 100644 (file)
index 0000000..f8319a9
--- /dev/null
@@ -0,0 +1,28 @@
+; NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -O0 -stop-after=irtranslator -verify-machineinstrs -o - %s | FileCheck %s
+
+define <2 x half> @f16_vec_param(<2 x half> %v) {
+  ; CHECK-LABEL: name: f16_vec_param
+  ; CHECK: bb.1 (%ir-block.0):
+  ; CHECK:   liveins: $d0
+  ; CHECK:   [[COPY:%[0-9]+]]:_(<4 x s16>) = COPY $d0
+  ; CHECK:   [[UV:%[0-9]+]]:_(<2 x s16>), [[UV1:%[0-9]+]]:_(<2 x s16>) = G_UNMERGE_VALUES [[COPY]](<4 x s16>)
+  ; CHECK:   [[COPY1:%[0-9]+]]:_(<2 x s16>) = COPY [[UV]](<2 x s16>)
+  ; CHECK:   [[DEF:%[0-9]+]]:_(<2 x s16>) = G_IMPLICIT_DEF
+  ; CHECK:   [[CONCAT_VECTORS:%[0-9]+]]:_(<4 x s16>) = G_CONCAT_VECTORS [[COPY1]](<2 x s16>), [[DEF]](<2 x s16>)
+  ; CHECK:   $d0 = COPY [[CONCAT_VECTORS]](<4 x s16>)
+  ; CHECK:   RET_ReallyLR implicit $d0
+  ret <2 x half> %v
+}
+
+define <2 x i16> @i16_vec_param(<2 x i16> %v) {
+  ; CHECK-LABEL: name: i16_vec_param
+  ; CHECK: bb.1 (%ir-block.0):
+  ; CHECK:   liveins: $d0
+  ; CHECK:   [[COPY:%[0-9]+]]:_(<2 x s32>) = COPY $d0
+  ; CHECK:   [[TRUNC:%[0-9]+]]:_(<2 x s16>) = G_TRUNC [[COPY]](<2 x s32>)
+  ; CHECK:   [[ANYEXT:%[0-9]+]]:_(<2 x s32>) = G_ANYEXT [[TRUNC]](<2 x s16>)
+  ; CHECK:   $d0 = COPY [[ANYEXT]](<2 x s32>)
+  ; CHECK:   RET_ReallyLR implicit $d0
+  ret <2 x i16> %v
+}
index e7df312..8a23b5d 100644 (file)
@@ -43,7 +43,7 @@ define i17 @test_funny_ints(i17 %a, i17 %b) {
 }
 
 define half @test_half(half %a, half %b) {
-; CHECK: remark: {{.*}} unable to lower arguments: half (half, half)* (in function: test_half)
+; CHECK: remark: {{.*}} unable to translate instruction: ret: '  ret half %res' (in function: test_half)
 ; CHECK-LABEL: warning: Instruction selection used fallback path for test_half
   %res = fadd half %a, %b
   ret half %res