From da4ef9b4c86d0002f57f6cb9c4bfb6c435f2bef6 Mon Sep 17 00:00:00 2001 From: Kerry McLaughlin Date: Mon, 2 Sep 2019 16:12:31 +0000 Subject: [PATCH] [SVE][Inline-Asm] Support for SVE asm operands Summary: Adds the following inline asm constraints for SVE: - w: SVE vector register with full range, Z0 to Z31 - x: Restricted to registers Z0 to Z15 inclusive. - y: Restricted to registers Z0 to Z7 inclusive. This change also adds the "z" modifier to interpret a register as an SVE register. Not all of the bitconvert patterns added by this patch are used, but they have been included here for completeness. Reviewers: t.p.northover, sdesmalen, rovka, momchil.velikov, rengolin, cameron.mcinally, greened Reviewed By: sdesmalen Subscribers: javed.absar, tschuett, rkruppe, psnobl, cfe-commits, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D66302 llvm-svn: 370673 --- llvm/docs/LangRef.rst | 5 ++- llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp | 25 ++++++++--- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 11 +++++ llvm/lib/Target/AArch64/AArch64InstrInfo.cpp | 10 +++++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td | 50 ++++++++++++++++++++++ .../CodeGen/AArch64/aarch64-sve-asm-negative.ll | 12 ++++++ llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll | 44 +++++++++++++++++++ llvm/test/CodeGen/AArch64/arm64-inline-asm.ll | 2 + 8 files changed, 150 insertions(+), 9 deletions(-) create mode 100644 llvm/test/CodeGen/AArch64/aarch64-sve-asm-negative.ll create mode 100644 llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst index f35ceb8..ff45610 100644 --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -3811,8 +3811,9 @@ AArch64: offsets). (However, LLVM currently does this for the ``m`` constraint as well.) - ``r``: A 32 or 64-bit integer register (W* or X*). -- ``w``: A 32, 64, or 128-bit floating-point/SIMD register. -- ``x``: A lower 128-bit floating-point/SIMD register (``V0`` to ``V15``). +- ``w``: A 32, 64, or 128-bit floating-point, SIMD or SVE vector register. +- ``x``: Like w, but restricted to registers 0 to 15 inclusive. +- ``y``: Like w, but restricted to SVE vector registers Z0 to Z7 inclusive. AMDGPU: diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp index 5d1662a..4a06de9 100644 --- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp +++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -150,7 +150,7 @@ private: void printOperand(const MachineInstr *MI, unsigned OpNum, raw_ostream &O); bool printAsmMRegister(const MachineOperand &MO, char Mode, raw_ostream &O); bool printAsmRegInClass(const MachineOperand &MO, - const TargetRegisterClass *RC, bool isVector, + const TargetRegisterClass *RC, unsigned AltName, raw_ostream &O); bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNum, @@ -530,14 +530,13 @@ bool AArch64AsmPrinter::printAsmMRegister(const MachineOperand &MO, char Mode, // printing. bool AArch64AsmPrinter::printAsmRegInClass(const MachineOperand &MO, const TargetRegisterClass *RC, - bool isVector, raw_ostream &O) { + unsigned AltName, raw_ostream &O) { assert(MO.isReg() && "Should only get here with a register!"); const TargetRegisterInfo *RI = STI->getRegisterInfo(); Register Reg = MO.getReg(); unsigned RegToPrint = RC->getRegister(RI->getEncodingValue(Reg)); assert(RI->regsOverlap(RegToPrint, Reg)); - O << AArch64InstPrinter::getRegisterName( - RegToPrint, isVector ? AArch64::vreg : AArch64::NoRegAltName); + O << AArch64InstPrinter::getRegisterName(RegToPrint, AltName); return false; } @@ -573,6 +572,7 @@ bool AArch64AsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNum, case 's': // Print S register. case 'd': // Print D register. case 'q': // Print Q register. + case 'z': // Print Z register. if (MO.isReg()) { const TargetRegisterClass *RC; switch (ExtraCode[0]) { @@ -591,10 +591,13 @@ bool AArch64AsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNum, case 'q': RC = &AArch64::FPR128RegClass; break; + case 'z': + RC = &AArch64::ZPRRegClass; + break; default: return true; } - return printAsmRegInClass(MO, RC, false /* vector */, O); + return printAsmRegInClass(MO, RC, AArch64::NoRegAltName, O); } printOperand(MI, OpNum, O); return false; @@ -611,9 +614,17 @@ bool AArch64AsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNum, AArch64::GPR64allRegClass.contains(Reg)) return printAsmMRegister(MO, 'x', O); + unsigned AltName = AArch64::NoRegAltName; + const TargetRegisterClass *RegClass; + if (AArch64::ZPRRegClass.contains(Reg)) { + RegClass = &AArch64::ZPRRegClass; + } else { + RegClass = &AArch64::FPR128RegClass; + AltName = AArch64::vreg; + } + // If this is a b, h, s, d, or q register, print it as a v register. - return printAsmRegInClass(MO, &AArch64::FPR128RegClass, true /* vector */, - O); + return printAsmRegInClass(MO, RegClass, AltName, O); } printOperand(MI, OpNum, O); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index febc1ff..2757fd1 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -5748,6 +5748,7 @@ AArch64TargetLowering::getConstraintType(StringRef Constraint) const { break; case 'x': case 'w': + case 'y': return C_RegisterClass; // An address with a single base register. Due to the way we // currently handle addresses it is the same as 'r'. @@ -5790,6 +5791,7 @@ AArch64TargetLowering::getSingleConstraintMatchWeight( break; case 'x': case 'w': + case 'y': if (type->isFloatingPointTy() || type->isVectorTy()) weight = CW_Register; break; @@ -5812,6 +5814,8 @@ AArch64TargetLowering::getRegForInlineAsmConstraint( case 'w': if (!Subtarget->hasFPARMv8()) break; + if (VT.isScalableVector()) + return std::make_pair(0U, &AArch64::ZPRRegClass); if (VT.getSizeInBits() == 16) return std::make_pair(0U, &AArch64::FPR16RegClass); if (VT.getSizeInBits() == 32) @@ -5826,8 +5830,15 @@ AArch64TargetLowering::getRegForInlineAsmConstraint( case 'x': if (!Subtarget->hasFPARMv8()) break; + if (VT.isScalableVector()) + return std::make_pair(0U, &AArch64::ZPR_4bRegClass); if (VT.getSizeInBits() == 128) return std::make_pair(0U, &AArch64::FPR128_loRegClass); + case 'y': + if (!Subtarget->hasFPARMv8()) + break; + if (VT.isScalableVector()) + return std::make_pair(0U, &AArch64::ZPR_3bRegClass); break; } } diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp index 1e089ff..a9f54a1 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp @@ -2484,6 +2484,16 @@ void AArch64InstrInfo::copyPhysReg(MachineBasicBlock &MBB, return; } + // Copy a Z register by ORRing with itself. + if (AArch64::ZPRRegClass.contains(DestReg) && + AArch64::ZPRRegClass.contains(SrcReg)) { + assert(Subtarget.hasSVE() && "Unexpected SVE register."); + BuildMI(MBB, I, DL, get(AArch64::ORR_ZZZ), DestReg) + .addReg(SrcReg) + .addReg(SrcReg, getKillRegState(KillSrc)); + return; + } + if (AArch64::GPR64spRegClass.contains(DestReg) && (AArch64::GPR64spRegClass.contains(SrcReg) || SrcReg == AArch64::XZR)) { if (DestReg == AArch64::SP || SrcReg == AArch64::SP) { diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 8e1ff99..7359ea3 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -1020,6 +1020,56 @@ let Predicates = [HasSVE] in { (FCMGT_PPzZZ_S PPR32:$Zd, PPR3bAny:$Pg, ZPR32:$Zn, ZPR32:$Zm), 0>; def : InstAlias<"fcmlt $Zd, $Pg/z, $Zm, $Zn", (FCMGT_PPzZZ_D PPR64:$Zd, PPR3bAny:$Pg, ZPR64:$Zn, ZPR64:$Zm), 0>; + + def : Pat<(nxv16i8 (bitconvert (nxv8i16 ZPR:$src))), (nxv16i8 ZPR:$src)>; + def : Pat<(nxv16i8 (bitconvert (nxv4i32 ZPR:$src))), (nxv16i8 ZPR:$src)>; + def : Pat<(nxv16i8 (bitconvert (nxv2i64 ZPR:$src))), (nxv16i8 ZPR:$src)>; + def : Pat<(nxv16i8 (bitconvert (nxv8f16 ZPR:$src))), (nxv16i8 ZPR:$src)>; + def : Pat<(nxv16i8 (bitconvert (nxv4f32 ZPR:$src))), (nxv16i8 ZPR:$src)>; + def : Pat<(nxv16i8 (bitconvert (nxv2f64 ZPR:$src))), (nxv16i8 ZPR:$src)>; + + def : Pat<(nxv8i16 (bitconvert (nxv16i8 ZPR:$src))), (nxv8i16 ZPR:$src)>; + def : Pat<(nxv8i16 (bitconvert (nxv4i32 ZPR:$src))), (nxv8i16 ZPR:$src)>; + def : Pat<(nxv8i16 (bitconvert (nxv2i64 ZPR:$src))), (nxv8i16 ZPR:$src)>; + def : Pat<(nxv8i16 (bitconvert (nxv8f16 ZPR:$src))), (nxv8i16 ZPR:$src)>; + def : Pat<(nxv8i16 (bitconvert (nxv4f32 ZPR:$src))), (nxv8i16 ZPR:$src)>; + def : Pat<(nxv8i16 (bitconvert (nxv2f64 ZPR:$src))), (nxv8i16 ZPR:$src)>; + + def : Pat<(nxv4i32 (bitconvert (nxv16i8 ZPR:$src))), (nxv4i32 ZPR:$src)>; + def : Pat<(nxv4i32 (bitconvert (nxv8i16 ZPR:$src))), (nxv4i32 ZPR:$src)>; + def : Pat<(nxv4i32 (bitconvert (nxv2i64 ZPR:$src))), (nxv4i32 ZPR:$src)>; + def : Pat<(nxv4i32 (bitconvert (nxv8f16 ZPR:$src))), (nxv4i32 ZPR:$src)>; + def : Pat<(nxv4i32 (bitconvert (nxv4f32 ZPR:$src))), (nxv4i32 ZPR:$src)>; + def : Pat<(nxv4i32 (bitconvert (nxv2f64 ZPR:$src))), (nxv4i32 ZPR:$src)>; + + def : Pat<(nxv2i64 (bitconvert (nxv16i8 ZPR:$src))), (nxv2i64 ZPR:$src)>; + def : Pat<(nxv2i64 (bitconvert (nxv8i16 ZPR:$src))), (nxv2i64 ZPR:$src)>; + def : Pat<(nxv2i64 (bitconvert (nxv4i32 ZPR:$src))), (nxv2i64 ZPR:$src)>; + def : Pat<(nxv2i64 (bitconvert (nxv8f16 ZPR:$src))), (nxv2i64 ZPR:$src)>; + def : Pat<(nxv2i64 (bitconvert (nxv4f32 ZPR:$src))), (nxv2i64 ZPR:$src)>; + def : Pat<(nxv2i64 (bitconvert (nxv2f64 ZPR:$src))), (nxv2i64 ZPR:$src)>; + + def : Pat<(nxv8f16 (bitconvert (nxv16i8 ZPR:$src))), (nxv8f16 ZPR:$src)>; + def : Pat<(nxv8f16 (bitconvert (nxv8i16 ZPR:$src))), (nxv8f16 ZPR:$src)>; + def : Pat<(nxv8f16 (bitconvert (nxv4i32 ZPR:$src))), (nxv8f16 ZPR:$src)>; + def : Pat<(nxv8f16 (bitconvert (nxv2i64 ZPR:$src))), (nxv8f16 ZPR:$src)>; + def : Pat<(nxv8f16 (bitconvert (nxv4f32 ZPR:$src))), (nxv8f16 ZPR:$src)>; + def : Pat<(nxv8f16 (bitconvert (nxv2f64 ZPR:$src))), (nxv8f16 ZPR:$src)>; + + def : Pat<(nxv4f32 (bitconvert (nxv16i8 ZPR:$src))), (nxv4f32 ZPR:$src)>; + def : Pat<(nxv4f32 (bitconvert (nxv8i16 ZPR:$src))), (nxv4f32 ZPR:$src)>; + def : Pat<(nxv4f32 (bitconvert (nxv4i32 ZPR:$src))), (nxv4f32 ZPR:$src)>; + def : Pat<(nxv4f32 (bitconvert (nxv2i64 ZPR:$src))), (nxv4f32 ZPR:$src)>; + def : Pat<(nxv4f32 (bitconvert (nxv8f16 ZPR:$src))), (nxv4f32 ZPR:$src)>; + def : Pat<(nxv4f32 (bitconvert (nxv2f64 ZPR:$src))), (nxv4f32 ZPR:$src)>; + + def : Pat<(nxv2f64 (bitconvert (nxv16i8 ZPR:$src))), (nxv2f64 ZPR:$src)>; + def : Pat<(nxv2f64 (bitconvert (nxv8i16 ZPR:$src))), (nxv2f64 ZPR:$src)>; + def : Pat<(nxv2f64 (bitconvert (nxv4i32 ZPR:$src))), (nxv2f64 ZPR:$src)>; + def : Pat<(nxv2f64 (bitconvert (nxv2i64 ZPR:$src))), (nxv2f64 ZPR:$src)>; + def : Pat<(nxv2f64 (bitconvert (nxv8f16 ZPR:$src))), (nxv2f64 ZPR:$src)>; + def : Pat<(nxv2f64 (bitconvert (nxv4f32 ZPR:$src))), (nxv2f64 ZPR:$src)>; + } let Predicates = [HasSVE2] in { diff --git a/llvm/test/CodeGen/AArch64/aarch64-sve-asm-negative.ll b/llvm/test/CodeGen/AArch64/aarch64-sve-asm-negative.ll new file mode 100644 index 0000000..ad483f4 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/aarch64-sve-asm-negative.ll @@ -0,0 +1,12 @@ +; RUN: not llc -mtriple aarch64-none-linux-gnu -mattr=+neon -o %t.s -filetype=asm %s 2>&1 | FileCheck %s + +; The 'y' constraint only applies to SVE vector registers (Z0-Z7) +; The test below ensures that we get an appropriate error should the +; constraint be used with a Neon register. + +; Function Attrs: nounwind readnone +; CHECK: error: couldn't allocate input reg for constraint 'y' +define <4 x i32> @test_neon(<4 x i32> %in1, <4 x i32> %in2) { + %1 = tail call <4 x i32> asm "add $0.4s, $1.4s, $2.4s", "=w,w,y"(<4 x i32> %in1, <4 x i32> %in2) + ret <4 x i32> %1 +} diff --git a/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll b/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll new file mode 100644 index 0000000..2ebb083 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/aarch64-sve-asm.ll @@ -0,0 +1,44 @@ +; RUN: llc < %s -mtriple aarch64-none-linux-gnu -mattr=+sve -stop-after=finalize-isel | FileCheck %s --check-prefix=CHECK + +target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128" +target triple = "aarch64-none-linux-gnu" + +; Function Attrs: nounwind readnone +; CHECK: [[ARG1:%[0-9]+]]:zpr = COPY $z1 +; CHECK: [[ARG2:%[0-9]+]]:zpr = COPY $z0 +; CHECK: [[ARG3:%[0-9]+]]:zpr = COPY [[ARG2]] +; CHECK: [[ARG4:%[0-9]+]]:zpr_3b = COPY [[ARG1]] +define @test_svadd_i8( %Zn, %Zm) { + %1 = tail call asm "add $0.b, $1.b, $2.b", "=w,w,y"( %Zn, %Zm) + ret %1 +} + +; Function Attrs: nounwind readnone +; CHECK: [[ARG1:%[0-9]+]]:zpr = COPY $z1 +; CHECK: [[ARG2:%[0-9]+]]:zpr = COPY $z0 +; CHECK: [[ARG3:%[0-9]+]]:zpr = COPY [[ARG2]] +; CHECK: [[ARG4:%[0-9]+]]:zpr_4b = COPY [[ARG1]] +define @test_svsub_i64( %Zn, %Zm) { + %1 = tail call asm "sub $0.d, $1.d, $2.d", "=w,w,x"( %Zn, %Zm) + ret %1 +} + +; Function Attrs: nounwind readnone +; CHECK: [[ARG1:%[0-9]+]]:zpr = COPY $z1 +; CHECK: [[ARG2:%[0-9]+]]:zpr = COPY $z0 +; CHECK: [[ARG3:%[0-9]+]]:zpr = COPY [[ARG2]] +; CHECK: [[ARG4:%[0-9]+]]:zpr_3b = COPY [[ARG1]] +define @test_svfmul_f16( %Zn, %Zm) { + %1 = tail call asm "fmul $0.h, $1.h, $2.h", "=w,w,y"( %Zn, %Zm) + ret %1 +} + +; Function Attrs: nounwind readnone +; CHECK: [[ARG1:%[0-9]+]]:zpr = COPY $z1 +; CHECK: [[ARG2:%[0-9]+]]:zpr = COPY $z0 +; CHECK: [[ARG3:%[0-9]+]]:zpr = COPY [[ARG2]] +; CHECK: [[ARG4:%[0-9]+]]:zpr_4b = COPY [[ARG1]] +define @test_svfmul_f( %Zn, %Zm) { + %1 = tail call asm "fmul $0.s, $1.s, $2.s", "=w,w,x"( %Zn, %Zm) + ret %1 +} diff --git a/llvm/test/CodeGen/AArch64/arm64-inline-asm.ll b/llvm/test/CodeGen/AArch64/arm64-inline-asm.ll index 82e0a1c..3b8b4d8 100644 --- a/llvm/test/CodeGen/AArch64/arm64-inline-asm.ll +++ b/llvm/test/CodeGen/AArch64/arm64-inline-asm.ll @@ -138,6 +138,8 @@ entry: %a = alloca [2 x float], align 4 %arraydecay = getelementptr inbounds [2 x float], [2 x float]* %a, i32 0, i32 0 %0 = load <2 x float>, <2 x float>* %data, align 8 + call void asm sideeffect "ldr ${1:z}, [$0]\0A", "r,w"(float* %arraydecay, <2 x float> %0) nounwind + ; CHECK: ldr {{z[0-9]+}}, [{{x[0-9]+}}] call void asm sideeffect "ldr ${1:q}, [$0]\0A", "r,w"(float* %arraydecay, <2 x float> %0) nounwind ; CHECK: ldr {{q[0-9]+}}, [{{x[0-9]+}}] call void asm sideeffect "ldr ${1:d}, [$0]\0A", "r,w"(float* %arraydecay, <2 x float> %0) nounwind -- 2.7.4