From 690db164226fb1d454c5e592726a8bc0de16c6b5 Mon Sep 17 00:00:00 2001 From: Sander de Smalen Date: Fri, 1 Jul 2022 14:29:07 +0000 Subject: [PATCH] [AArch64] Make nxv1i1 types a legal type for SVE. One motivation to add support for these types are the LD1Q/ST1Q instructions in SME, for which we have defined a number of load/store intrinsics which at the moment still take a `` predicate regardless of their element type. This patch adds basic support for the nxv1i1 type such that it can be passed/returned from functions, as well as some basic support to support some existing tests that result in a nxv1i1 type. It also adds support for splats. Other operations (e.g. insert/extract subvector, logical ops, etc) will be supported in follow-up patches. Reviewed By: paulwalker-arm, efriedma Differential Revision: https://reviews.llvm.org/D128665 --- .../CodeGen/SelectionDAG/LegalizeVectorTypes.cpp | 18 ++++++++++------ .../lib/Target/AArch64/AArch64CallingConvention.td | 6 +++--- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 14 +++++++++---- llvm/lib/Target/AArch64/AArch64RegisterInfo.td | 2 +- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td | 19 +++++++++++++++++ llvm/lib/Target/AArch64/SVEInstrFormats.td | 2 ++ .../CodeGen/AArch64/sve-extract-scalable-vector.ll | 24 ++++++++++++++++++++++ llvm/test/CodeGen/AArch64/sve-select.ll | 3 +++ llvm/test/CodeGen/AArch64/sve-zeroinit.ll | 7 +++++++ 9 files changed, 81 insertions(+), 14 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp index 44521bb..fa555be 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -6653,7 +6653,7 @@ SDValue DAGTypeLegalizer::ModifyToType(SDValue InOp, EVT NVT, EVT InVT = InOp.getValueType(); assert(InVT.getVectorElementType() == NVT.getVectorElementType() && "input and widen element type must match"); - assert(!InVT.isScalableVector() && !NVT.isScalableVector() && + assert(InVT.isScalableVector() == NVT.isScalableVector() && "cannot modify scalable vectors in this way"); SDLoc dl(InOp); @@ -6661,10 +6661,10 @@ SDValue DAGTypeLegalizer::ModifyToType(SDValue InOp, EVT NVT, if (InVT == NVT) return InOp; - unsigned InNumElts = InVT.getVectorNumElements(); - unsigned WidenNumElts = NVT.getVectorNumElements(); - if (WidenNumElts > InNumElts && WidenNumElts % InNumElts == 0) { - unsigned NumConcat = WidenNumElts / InNumElts; + ElementCount InEC = InVT.getVectorElementCount(); + ElementCount WidenEC = NVT.getVectorElementCount(); + if (WidenEC.hasKnownScalarFactor(InEC)) { + unsigned NumConcat = WidenEC.getKnownScalarFactor(InEC); SmallVector Ops(NumConcat); SDValue FillVal = FillWithZeroes ? DAG.getConstant(0, dl, InVT) : DAG.getUNDEF(InVT); @@ -6675,10 +6675,16 @@ SDValue DAGTypeLegalizer::ModifyToType(SDValue InOp, EVT NVT, return DAG.getNode(ISD::CONCAT_VECTORS, dl, NVT, Ops); } - if (WidenNumElts < InNumElts && InNumElts % WidenNumElts) + if (InEC.hasKnownScalarFactor(WidenEC)) return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NVT, InOp, DAG.getVectorIdxConstant(0, dl)); + assert(!InVT.isScalableVector() && !NVT.isScalableVector() && + "Scalable vectors should have been handled already."); + + unsigned InNumElts = InEC.getFixedValue(); + unsigned WidenNumElts = WidenEC.getFixedValue(); + // Fall back to extract and build. SmallVector Ops(WidenNumElts); EVT EltVT = NVT.getVectorElementType(); diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td index f261515..c0da242 100644 --- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td +++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td @@ -82,9 +82,9 @@ def CC_AArch64_AAPCS : CallingConv<[ nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64], CCPassIndirect>, - CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1], + CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1], CCAssignToReg<[P0, P1, P2, P3]>>, - CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1], + CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1], CCPassIndirect>, // Handle i1, i8, i16, i32, i64, f32, f64 and v2f64 by passing in registers, @@ -149,7 +149,7 @@ def RetCC_AArch64_AAPCS : CallingConv<[ nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64], CCAssignToReg<[Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7]>>, - CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1], + CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1], CCAssignToReg<[P0, P1, P2, P3]>> ]>; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 63cd8f9..abfe2d5 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -292,6 +292,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, if (Subtarget->hasSVE() || Subtarget->hasSME()) { // Add legal sve predicate types + addRegisterClass(MVT::nxv1i1, &AArch64::PPRRegClass); addRegisterClass(MVT::nxv2i1, &AArch64::PPRRegClass); addRegisterClass(MVT::nxv4i1, &AArch64::PPRRegClass); addRegisterClass(MVT::nxv8i1, &AArch64::PPRRegClass); @@ -1156,7 +1157,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, MVT::nxv4i16, MVT::nxv4i32, MVT::nxv8i8, MVT::nxv8i16 }) setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Legal); - for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) { + for (auto VT : + {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1}) { setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); setOperationAction(ISD::SELECT, VT, Custom); setOperationAction(ISD::SETCC, VT, Custom); @@ -4676,7 +4678,6 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op, Op.getOperand(2), Op.getOperand(3), DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i32)), Op.getOperand(1)); - case Intrinsic::localaddress: { const auto &MF = DAG.getMachineFunction(); const auto *RegInfo = Subtarget->getRegisterInfo(); @@ -10551,8 +10552,13 @@ SDValue AArch64TargetLowering::LowerSPLAT_VECTOR(SDValue Op, DAG.getValueType(MVT::i1)); SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64); - return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, ID, - DAG.getConstant(0, DL, MVT::i64), SplatVal); + SDValue Zero = DAG.getConstant(0, DL, MVT::i64); + if (VT == MVT::nxv1i1) + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::nxv1i1, + DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv2i1, ID, + Zero, SplatVal), + Zero); + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, ID, Zero, SplatVal); } SDValue AArch64TargetLowering::LowerDUPQLane(SDValue Op, diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td index e42feea..7a2b165 100644 --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td @@ -871,7 +871,7 @@ class ZPRRegOp : RegisterClass< "AArch64", - [ nxv16i1, nxv8i1, nxv4i1, nxv2i1 ], 16, + [ nxv16i1, nxv8i1, nxv4i1, nxv2i1, nxv1i1 ], 16, (sequence "P%u", 0, lastreg)> { let Size = 16; } diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 6bd3a96..68ff1b7 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -748,6 +748,11 @@ let Predicates = [HasSVEorSME] in { defm PUNPKLO_PP : sve_int_perm_punpk<0b0, "punpklo", int_aarch64_sve_punpklo>; defm PUNPKHI_PP : sve_int_perm_punpk<0b1, "punpkhi", int_aarch64_sve_punpkhi>; + // Define pattern for `nxv1i1 splat_vector(1)`. + // We do this here instead of in ISelLowering such that PatFrag's can still + // recognize a splat. + def : Pat<(nxv1i1 immAllOnesV), (PUNPKLO_PP (PTRUE_D 31))>; + defm MOVPRFX_ZPzZ : sve_int_movprfx_pred_zero<0b000, "movprfx">; defm MOVPRFX_ZPmZ : sve_int_movprfx_pred_merge<0b001, "movprfx">; def MOVPRFX_ZZ : sve_int_bin_cons_misc_0_c<0b00000001, "movprfx", ZPRAny>; @@ -1509,6 +1514,10 @@ let Predicates = [HasSVEorSME] in { defm TRN2_PPP : sve_int_perm_bin_perm_pp<0b101, "trn2", AArch64trn2>; // Extract lo/hi halves of legal predicate types. + def : Pat<(nxv1i1 (extract_subvector (nxv2i1 PPR:$Ps), (i64 0))), + (PUNPKLO_PP PPR:$Ps)>; + def : Pat<(nxv1i1 (extract_subvector (nxv2i1 PPR:$Ps), (i64 1))), + (PUNPKHI_PP PPR:$Ps)>; def : Pat<(nxv2i1 (extract_subvector (nxv4i1 PPR:$Ps), (i64 0))), (PUNPKLO_PP PPR:$Ps)>; def : Pat<(nxv2i1 (extract_subvector (nxv4i1 PPR:$Ps), (i64 2))), @@ -1599,6 +1608,8 @@ let Predicates = [HasSVEorSME] in { (UUNPKHI_ZZ_D (UUNPKHI_ZZ_S ZPR:$Zs))>; // Concatenate two predicates. + def : Pat<(nxv2i1 (concat_vectors nxv1i1:$p1, nxv1i1:$p2)), + (UZP1_PPP_D $p1, $p2)>; def : Pat<(nxv4i1 (concat_vectors nxv2i1:$p1, nxv2i1:$p2)), (UZP1_PPP_S $p1, $p2)>; def : Pat<(nxv8i1 (concat_vectors nxv4i1:$p1, nxv4i1:$p2)), @@ -2298,15 +2309,23 @@ let Predicates = [HasSVEorSME] in { def : Pat<(nxv16i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv16i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv16i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv16i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv8i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv8i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv8i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv8i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv4i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv4i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv4i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv4i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv2i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv2i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv2i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv2i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv1i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv1i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv1i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv1i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; // These allow casting from/to unpacked floating-point types. def : Pat<(nxv2f16 (reinterpret_cast (nxv8f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td index 13c04ee..3631536 100644 --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -647,6 +647,7 @@ multiclass sve_int_pfalse opc, string asm> { def : Pat<(nxv8i1 immAllZerosV), (!cast(NAME))>; def : Pat<(nxv4i1 immAllZerosV), (!cast(NAME))>; def : Pat<(nxv2i1 immAllZerosV), (!cast(NAME))>; + def : Pat<(nxv1i1 immAllZerosV), (!cast(NAME))>; } class sve_int_ptest opc, string asm> @@ -1681,6 +1682,7 @@ multiclass sve_int_pred_log opc, string asm, SDPatternOperator op, def : SVE_3_Op_Pat(NAME)>; def : SVE_3_Op_Pat(NAME)>; def : SVE_3_Op_Pat(NAME)>; + def : SVE_3_Op_Pat(NAME)>; def : SVE_2_Op_AllActive_Pat(NAME), PTRUE_B>; def : SVE_2_Op_AllActive_Pat @extract_nxv2i1_nxv16i1_all_zero() { declare @llvm.vector.extract.nxv2f32.nxv4f32(, i64) declare @llvm.vector.extract.nxv4i32.nxv8i32(, i64) + +; +; Extract nxv1i1 type from: nxv2i1 +; + +define @extract_nxv1i1_nxv2i1_0( %in) { +; CHECK-LABEL: extract_nxv1i1_nxv2i1_0: +; CHECK: // %bb.0: +; CHECK-NEXT: punpklo p0.h, p0.b +; CHECK-NEXT: ret + %res = call @llvm.vector.extract.nxv1i1.nxv2i1( %in, i64 0) + ret %res +} + +define @extract_nxv1i1_nxv2i1_1( %in) { +; CHECK-LABEL: extract_nxv1i1_nxv2i1_1: +; CHECK: // %bb.0: +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: ret + %res = call @llvm.vector.extract.nxv1i1.nxv2i1( %in, i64 1) + ret %res +} + +declare @llvm.vector.extract.nxv1i1.nxv2i1(, i64) diff --git a/llvm/test/CodeGen/AArch64/sve-select.ll b/llvm/test/CodeGen/AArch64/sve-select.ll index 857e057..5c1cfe6 100644 --- a/llvm/test/CodeGen/AArch64/sve-select.ll +++ b/llvm/test/CodeGen/AArch64/sve-select.ll @@ -187,6 +187,7 @@ define @select_nxv1i1(i1 %cond, %a, %a, %b @@ -225,6 +226,7 @@ define @sel_nxv4i32( %p, define @sel_nxv1i64( %p, %dst, %a) { ; CHECK-LABEL: sel_nxv1i64: ; CHECK: // %bb.0: +; CHECK-NEXT: uzp1 p0.d, p0.d, p0.d ; CHECK-NEXT: mov z0.d, p0/m, z1.d ; CHECK-NEXT: ret %sel = select %p, %a, %dst @@ -483,6 +485,7 @@ define @icmp_select_nxv1i1( %a, @test_zeroinit_8xf16() { ret zeroinitializer } +define @test_zeroinit_1xi1() { +; CHECK-LABEL: test_zeroinit_1xi1 +; CHECK: pfalse p0.b +; CHECK-NEXT: ret + ret zeroinitializer +} + define @test_zeroinit_2xi1() { ; CHECK-LABEL: test_zeroinit_2xi1 ; CHECK: pfalse p0.b -- 2.7.4