From eac2638ec169a5d6987ac4fbbcd430bee4489348 Mon Sep 17 00:00:00 2001 From: Sander de Smalen Date: Mon, 28 Feb 2022 12:17:25 +0000 Subject: [PATCH] [AArch64][SVE] Fold away SETCC if original input was predicate vector. This adds the following two folds: Fold 1: setcc_merge_zero( all_active, extend(nxvNi1 ...), != splat(0)) -> nxvNi1 ... Fold 2: setcc_merge_zero( pred, extend(nxvNi1 ...), != splat(0)) -> nxvNi1 and(pred, ...) Reviewed By: david-arm Differential Revision: https://reviews.llvm.org/D119334 --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 41 ++++++++++++++------ llvm/test/CodeGen/AArch64/sve-cmp-select.ll | 25 ++++++++++++ .../AArch64/sve-fixed-length-masked-gather.ll | 5 +-- .../AArch64/sve-fixed-length-masked-scatter.ll | 23 ++++++----- llvm/test/CodeGen/AArch64/sve-punpklo-combine.ll | 37 +++++++----------- llvm/test/CodeGen/AArch64/sve-setcc.ll | 45 ++++++++++++++++++++++ 6 files changed, 126 insertions(+), 50 deletions(-) diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index daadcc0..fecdf49 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17158,27 +17158,46 @@ static SDValue performSetCCPunpkCombine(SDNode *N, SelectionDAG &DAG) { return SDValue(); } -static SDValue performSetccMergeZeroCombine(SDNode *N, SelectionDAG &DAG) { +static SDValue +performSetccMergeZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { assert(N->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO && "Unexpected opcode!"); + SelectionDAG &DAG = DCI.DAG; SDValue Pred = N->getOperand(0); SDValue LHS = N->getOperand(1); SDValue RHS = N->getOperand(2); ISD::CondCode Cond = cast(N->getOperand(3))->get(); - // setcc_merge_zero pred (sign_extend (setcc_merge_zero ... pred ...)), 0, ne - // => inner setcc_merge_zero - if (Cond == ISD::SETNE && isZerosVector(RHS.getNode()) && - LHS->getOpcode() == ISD::SIGN_EXTEND && - LHS->getOperand(0)->getValueType(0) == N->getValueType(0) && - LHS->getOperand(0)->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO && - LHS->getOperand(0)->getOperand(0) == Pred) - return LHS->getOperand(0); - if (SDValue V = performSetCCPunpkCombine(N, DAG)) return V; + if (Cond == ISD::SETNE && isZerosVector(RHS.getNode()) && + LHS->getOpcode() == ISD::SIGN_EXTEND && + LHS->getOperand(0)->getValueType(0) == N->getValueType(0)) { + // setcc_merge_zero( + // pred, extend(setcc_merge_zero(pred, ...)), != splat(0)) + // => setcc_merge_zero(pred, ...) + if (LHS->getOperand(0)->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO && + LHS->getOperand(0)->getOperand(0) == Pred) + return LHS->getOperand(0); + + // setcc_merge_zero( + // all_active, extend(nxvNi1 ...), != splat(0)) + // -> nxvNi1 ... + if (isAllActivePredicate(DAG, Pred)) + return LHS->getOperand(0); + + // setcc_merge_zero( + // pred, extend(nxvNi1 ...), != splat(0)) + // -> nxvNi1 and(pred, ...) + if (DCI.isAfterLegalizeDAG()) + // Do this after legalization to allow more folds on setcc_merge_zero + // to be recognized. + return DAG.getNode(ISD::AND, SDLoc(N), N->getValueType(0), + LHS->getOperand(0), Pred); + } + return SDValue(); } @@ -18175,7 +18194,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, case AArch64ISD::UZP1: return performUzpCombine(N, DAG); case AArch64ISD::SETCC_MERGE_ZERO: - return performSetccMergeZeroCombine(N, DAG); + return performSetccMergeZeroCombine(N, DCI); case AArch64ISD::GLD1_MERGE_ZERO: case AArch64ISD::GLD1_SCALED_MERGE_ZERO: case AArch64ISD::GLD1_UXTW_MERGE_ZERO: diff --git a/llvm/test/CodeGen/AArch64/sve-cmp-select.ll b/llvm/test/CodeGen/AArch64/sve-cmp-select.ll index 1a30005..9456342 100644 --- a/llvm/test/CodeGen/AArch64/sve-cmp-select.ll +++ b/llvm/test/CodeGen/AArch64/sve-cmp-select.ll @@ -36,3 +36,28 @@ define @vselect_cmp_ugt( %a, %cmp, %b, %c ret %d } + +; Some folds to remove a redundant icmp if the original input was a predicate vector. + +define @fold_away_icmp_ptrue_all( %p) { +; CHECK-LABEL: fold_away_icmp_ptrue_all: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %t0 = sext %p to + %t1 = icmp ne %t0, zeroinitializer + ret %t1 +} + +define @fold_away_icmp_ptrue_vl16( %p) vscale_range(4, 4) { +; CHECK-LABEL: fold_away_icmp_ptrue_vl16: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %t0 = call @llvm.aarch64.sve.ptrue.nxv4i1(i32 9) ; VL16 is encoded as 9. + %t1 = sext %p to + %t2 = call @llvm.aarch64.sve.cmpne.nxv4i32( %t0, %t1, zeroinitializer) + ret %t2 +} + + +declare @llvm.aarch64.sve.ptrue.nxv4i1(i32) +declare @llvm.aarch64.sve.cmpne.nxv4i32(, , ) diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll index 0b76f0b..085e022 100644 --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll @@ -387,11 +387,10 @@ define void @masked_gather_v8i32(<8 x i32>* %a, <8 x i32*>* %b) #0 { ; VBITS_EQ_256-NEXT: mov z0.s, p2/z, #-1 // =0xffffffffffffffff ; VBITS_EQ_256-NEXT: punpklo p2.h, p2.b ; VBITS_EQ_256-NEXT: ext z0.b, z0.b, z0.b, #16 -; VBITS_EQ_256-NEXT: mov z3.d, p2/z, #-1 // =0xffffffffffffffff +; VBITS_EQ_256-NEXT: and p2.b, p2/z, p2.b, p1.b ; VBITS_EQ_256-NEXT: sunpklo z0.d, z0.s -; VBITS_EQ_256-NEXT: cmpne p2.d, p1/z, z3.d, #0 -; VBITS_EQ_256-NEXT: cmpne p1.d, p1/z, z0.d, #0 ; VBITS_EQ_256-NEXT: ld1w { z2.d }, p2/z, [z2.d] +; VBITS_EQ_256-NEXT: cmpne p1.d, p1/z, z0.d, #0 ; VBITS_EQ_256-NEXT: ld1w { z0.d }, p1/z, [z1.d] ; VBITS_EQ_256-NEXT: ptrue p1.s, vl4 ; VBITS_EQ_256-NEXT: uzp1 z1.s, z2.s, z2.s diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll index 44bb367..4e6175a 100644 --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll @@ -351,22 +351,21 @@ define void @masked_scatter_v8i32(<8 x i32>* %a, <8 x i32*>* %b) #0 { ; VBITS_EQ_256-NEXT: ptrue p0.s, vl8 ; VBITS_EQ_256-NEXT: mov x8, #4 ; VBITS_EQ_256-NEXT: ld1w { z0.s }, p0/z, [x0] -; VBITS_EQ_256-NEXT: cmpeq p0.s, p0/z, z0.s, #0 -; VBITS_EQ_256-NEXT: punpklo p1.h, p0.b -; VBITS_EQ_256-NEXT: mov z4.s, p0/z, #-1 // =0xffffffffffffffff -; VBITS_EQ_256-NEXT: mov z1.d, p1/z, #-1 // =0xffffffffffffffff ; VBITS_EQ_256-NEXT: ptrue p1.d, vl4 -; VBITS_EQ_256-NEXT: ld1d { z2.d }, p1/z, [x1, x8, lsl #3] +; VBITS_EQ_256-NEXT: ld1d { z1.d }, p1/z, [x1, x8, lsl #3] ; VBITS_EQ_256-NEXT: ld1d { z3.d }, p1/z, [x1] -; VBITS_EQ_256-NEXT: ext z4.b, z4.b, z4.b, #16 -; VBITS_EQ_256-NEXT: cmpne p0.d, p1/z, z1.d, #0 -; VBITS_EQ_256-NEXT: uunpklo z1.d, z0.s -; VBITS_EQ_256-NEXT: sunpklo z4.d, z4.s +; VBITS_EQ_256-NEXT: cmpeq p0.s, p0/z, z0.s, #0 +; VBITS_EQ_256-NEXT: uunpklo z4.d, z0.s +; VBITS_EQ_256-NEXT: mov z2.s, p0/z, #-1 // =0xffffffffffffffff +; VBITS_EQ_256-NEXT: punpklo p0.h, p0.b +; VBITS_EQ_256-NEXT: ext z2.b, z2.b, z2.b, #16 ; VBITS_EQ_256-NEXT: ext z0.b, z0.b, z0.b, #16 -; VBITS_EQ_256-NEXT: cmpne p1.d, p1/z, z4.d, #0 +; VBITS_EQ_256-NEXT: sunpklo z2.d, z2.s +; VBITS_EQ_256-NEXT: and p0.b, p0/z, p0.b, p1.b +; VBITS_EQ_256-NEXT: cmpne p1.d, p1/z, z2.d, #0 ; VBITS_EQ_256-NEXT: uunpklo z0.d, z0.s -; VBITS_EQ_256-NEXT: st1w { z1.d }, p0, [z3.d] -; VBITS_EQ_256-NEXT: st1w { z0.d }, p1, [z2.d] +; VBITS_EQ_256-NEXT: st1w { z4.d }, p0, [z3.d] +; VBITS_EQ_256-NEXT: st1w { z0.d }, p1, [z1.d] ; VBITS_EQ_256-NEXT: ret ; VBITS_GE_512-LABEL: masked_scatter_v8i32: ; VBITS_GE_512: // %bb.0: diff --git a/llvm/test/CodeGen/AArch64/sve-punpklo-combine.ll b/llvm/test/CodeGen/AArch64/sve-punpklo-combine.ll index ddc2f5b..8f76b5f 100644 --- a/llvm/test/CodeGen/AArch64/sve-punpklo-combine.ll +++ b/llvm/test/CodeGen/AArch64/sve-punpklo-combine.ll @@ -23,11 +23,10 @@ define @masked_load_sext_i8i16_ptrue_vl(i8* %ap, @llvm.aarch64.sve.ptrue.nxv16i1(i32 11) %cmp = call @llvm.aarch64.sve.cmpeq.nxv16i8( %p0, %b, zeroinitializer) @@ -45,8 +44,7 @@ define @masked_load_sext_i8i16_parg(i8* %ap, @llvm.aarch64.sve.cmpeq.nxv16i8( %p0, %b, zeroinitializer) %extract = call @llvm.experimental.vector.extract.nxv8i1.nxv16i1( %cmp, i64 0) @@ -78,12 +76,11 @@ define @masked_load_sext_i8i32_ptrue_vl(i8* %ap, @llvm.aarch64.sve.ptrue.nxv16i1(i32 11) %cmp = call @llvm.aarch64.sve.cmpeq.nxv16i8( %p0, %b, zeroinitializer) @@ -102,8 +99,7 @@ define @masked_load_sext_i8i32_parg(i8* %ap, @llvm.aarch64.sve.cmpeq.nxv16i8( %p0, %b, zeroinitializer) %extract = call @llvm.experimental.vector.extract.nxv4i1.nxv16i1( %cmp, i64 0) @@ -136,13 +132,12 @@ define @masked_load_sext_i8i64_ptrue_vl(i8* %ap, @llvm.aarch64.sve.ptrue.nxv16i1(i32 11) %cmp = call @llvm.aarch64.sve.cmpeq.nxv16i8( %p0, %b, zeroinitializer) @@ -162,8 +157,7 @@ define @masked_load_sext_i8i64_parg(i8* %ap, @llvm.aarch64.sve.cmpeq.nxv16i8( %p0, %b, zeroinitializer) %extract = call @llvm.experimental.vector.extract.nxv2i1.nxv16i1( %cmp, i64 0) @@ -178,11 +172,10 @@ define @masked_load_sext_i8i16_ptrue_all(i8* %ap, @llvm.aarch64.sve.ptrue.nxv16i1(i32 11) %cmp = call @llvm.aarch64.sve.cmpeq.nxv16i8( %p0, %b, zeroinitializer) @@ -198,12 +191,11 @@ define @masked_load_sext_i8i32_ptrue_all(i8* %ap, @llvm.aarch64.sve.ptrue.nxv16i1(i32 11) %cmp = call @llvm.aarch64.sve.cmpeq.nxv16i8( %p0, %b, zeroinitializer) @@ -223,9 +215,6 @@ define @masked_load_sext_i8i64_ptrue_all(i8* %ap, @llvm.aarch64.sve.ptrue.nxv16i1(i32 31) %cmp = call @llvm.aarch64.sve.cmpeq.nxv16i8( %p0, %b, zeroinitializer) diff --git a/llvm/test/CodeGen/AArch64/sve-setcc.ll b/llvm/test/CodeGen/AArch64/sve-setcc.ll index 026c0dc..8d7aae8 100644 --- a/llvm/test/CodeGen/AArch64/sve-setcc.ll +++ b/llvm/test/CodeGen/AArch64/sve-setcc.ll @@ -70,6 +70,51 @@ if.end: ret void } +; Fold away the redundant setcc:: +; setcc(ne, , sext(nxvNi1 ...), splat(0)) +; -> nxvNi1 ... +define @sve_cmpne_setcc_all_true_sext( %vec, %pg) { +; CHECK-LABEL: sve_cmpne_setcc_all_true_sext: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %alltrue.ins = insertelement poison, i1 true, i32 0 + %alltrue = shufflevector %alltrue.ins, poison, zeroinitializer + %pg.sext = sext %pg to + %cmp2 = call @llvm.aarch64.sve.cmpne.nxv16i8( %alltrue, %pg.sext, zeroinitializer) + ret %cmp2 +} + +; Fold away the redundant setcc:: +; setcc(ne, pred, sext(setcc(ne, pred, ..., splat(0))), splat(0)) +; -> setcc(ne, pred, ..., splat(0)) +define @sve_cmpne_setcc_equal_pred( %vec, %pg) { +; CHECK-LABEL: sve_cmpne_setcc_equal_pred: +; CHECK: // %bb.0: +; CHECK-NEXT: cmpne p0.b, p0/z, z0.b, #0 +; CHECK-NEXT: ret + %cmp1 = call @llvm.aarch64.sve.cmpne.nxv16i8( %pg, %vec, zeroinitializer) + %cmp1.sext = sext %cmp1 to + %cmp2 = call @llvm.aarch64.sve.cmpne.nxv16i8( %pg, %cmp1.sext, zeroinitializer) + ret %cmp2 +} + +; Combine: +; setcc(ne, pred1, sext(setcc(ne, pred2, ..., splat(0))), splat(0)) +; -> setcc(ne, and(pred1, pred2), ..., splat(0)) +define @sve_cmpne_setcc_different_pred( %vec, %pg1, %pg2) { +; CHECK-LABEL: sve_cmpne_setcc_different_pred: +; CHECK: // %bb.0: +; CHECK-NEXT: cmpne p0.b, p0/z, z0.b, #0 +; CHECK-NEXT: and p0.b, p0/z, p0.b, p1.b +; CHECK-NEXT: ret + %cmp1 = call @llvm.aarch64.sve.cmpne.nxv16i8( %pg1, %vec, zeroinitializer) + %cmp1.sext = sext %cmp1 to + %cmp2 = call @llvm.aarch64.sve.cmpne.nxv16i8( %pg2, %cmp1.sext, zeroinitializer) + ret %cmp2 +} + +declare @llvm.aarch64.sve.cmpne.nxv16i8(, , ) + declare i1 @llvm.aarch64.sve.ptest.any.nxv8i1(, ) declare i1 @llvm.aarch64.sve.ptest.last.nxv8i1(, ) -- 2.7.4