From 3010f60381bcd828d1b409cfaa576328bcd05bbc Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Mon, 12 Dec 2022 16:26:20 +0000 Subject: [PATCH] Reland "[TargetLowering] Teach DemandedBits about VSCALE" Reland with a fixup to avoid converting APInts to int64_t which allowed for overflows (UB) with sufficiently high/low multiplier values. This allows DemandedBits to see the result of VSCALE will be at most VScaleMax * some compile-time constant. This relies on the vscale_range() attribute being present on the function, with a max set. (This is done by default when clang is targeting AArch64+SVE). Using this various redundant operations (zexts, sexts, ands, ors, etc) can be eliminated. Differential Revision: https://reviews.llvm.org/D138508 --- llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp | 17 ++++++++++ .../AArch64/vscale-and-sve-cnt-demandedbits.ll | 37 ++++++++++++++-------- 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp index 202178e..a0e7705 100644 --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -1125,6 +1125,23 @@ bool TargetLowering::SimplifyDemandedBits( KnownBits Known2; switch (Op.getOpcode()) { + case ISD::VSCALE: { + Function const &F = TLO.DAG.getMachineFunction().getFunction(); + Attribute const &Attr = F.getFnAttribute(Attribute::VScaleRange); + if (!Attr.isValid()) + return false; + std::optional MaxVScale = Attr.getVScaleRangeMax(); + if (!MaxVScale.has_value()) + return false; + APInt VScaleResultUpperbound = *MaxVScale * Op.getConstantOperandAPInt(0); + bool Negative = VScaleResultUpperbound.isNegative(); + if (Negative) + VScaleResultUpperbound = ~VScaleResultUpperbound; + unsigned RequiredBits = VScaleResultUpperbound.getActiveBits(); + if (RequiredBits < BitWidth) + (Negative ? Known.One : Known.Zero).setHighBits(BitWidth - RequiredBits); + return false; + } case ISD::SCALAR_TO_VECTOR: { if (VT.isScalableVector()) return false; diff --git a/llvm/test/CodeGen/AArch64/vscale-and-sve-cnt-demandedbits.ll b/llvm/test/CodeGen/AArch64/vscale-and-sve-cnt-demandedbits.ll index 895f5da..dbdab799 100644 --- a/llvm/test/CodeGen/AArch64/vscale-and-sve-cnt-demandedbits.ll +++ b/llvm/test/CodeGen/AArch64/vscale-and-sve-cnt-demandedbits.ll @@ -14,9 +14,8 @@ define i32 @vscale_and_elimination() vscale_range(1,16) { ; CHECK: // %bb.0: ; CHECK-NEXT: rdvl x8, #1 ; CHECK-NEXT: lsr x8, x8, #4 -; CHECK-NEXT: and w9, w8, #0x1f -; CHECK-NEXT: and w8, w8, #0xfffffffc -; CHECK-NEXT: add w0, w9, w8 +; CHECK-NEXT: and w9, w8, #0x1c +; CHECK-NEXT: add w0, w8, w9 ; CHECK-NEXT: ret %vscale = call i32 @llvm.vscale.i32() %and_redundant = and i32 %vscale, 31 @@ -85,8 +84,7 @@ define i64 @vscale_trunc_zext() vscale_range(1,16) { ; CHECK-LABEL: vscale_trunc_zext: ; CHECK: // %bb.0: ; CHECK-NEXT: rdvl x8, #1 -; CHECK-NEXT: lsr x8, x8, #4 -; CHECK-NEXT: and x0, x8, #0xffffffff +; CHECK-NEXT: lsr x0, x8, #4 ; CHECK-NEXT: ret %vscale = call i32 @llvm.vscale.i32() %zext = zext i32 %vscale to i64 @@ -97,8 +95,7 @@ define i64 @vscale_trunc_sext() vscale_range(1,16) { ; CHECK-LABEL: vscale_trunc_sext: ; CHECK: // %bb.0: ; CHECK-NEXT: rdvl x8, #1 -; CHECK-NEXT: lsr x8, x8, #4 -; CHECK-NEXT: sxtw x0, w8 +; CHECK-NEXT: lsr x0, x8, #4 ; CHECK-NEXT: ret %vscale = call i32 @llvm.vscale.i32() %sext = sext i32 %vscale to i64 @@ -200,9 +197,8 @@ define i32 @vscale_with_multiplier() vscale_range(1,16) { ; CHECK-NEXT: mov w9, #5 ; CHECK-NEXT: lsr x8, x8, #4 ; CHECK-NEXT: mul x8, x8, x9 -; CHECK-NEXT: and w9, w8, #0x7f -; CHECK-NEXT: and w8, w8, #0x3f -; CHECK-NEXT: add w0, w9, w8 +; CHECK-NEXT: and w9, w8, #0x3f +; CHECK-NEXT: add w0, w8, w9 ; CHECK-NEXT: ret %vscale = call i32 @llvm.vscale.i32() %mul = mul i32 %vscale, 5 @@ -219,9 +215,8 @@ define i32 @vscale_with_negative_multiplier() vscale_range(1,16) { ; CHECK-NEXT: mov x9, #-5 ; CHECK-NEXT: lsr x8, x8, #4 ; CHECK-NEXT: mul x8, x8, x9 -; CHECK-NEXT: orr w9, w8, #0xffffff80 -; CHECK-NEXT: and w8, w8, #0xffffffc0 -; CHECK-NEXT: add w0, w9, w8 +; CHECK-NEXT: and w9, w8, #0xffffffc0 +; CHECK-NEXT: add w0, w8, w9 ; CHECK-NEXT: ret %vscale = call i32 @llvm.vscale.i32() %mul = mul i32 %vscale, -5 @@ -231,6 +226,22 @@ define i32 @vscale_with_negative_multiplier() vscale_range(1,16) { ret i32 %result } +define i32 @pow2_vscale_with_negative_multiplier() vscale_range(1,16) { +; CHECK-LABEL: pow2_vscale_with_negative_multiplier: +; CHECK: // %bb.0: +; CHECK-NEXT: cntd x8 +; CHECK-NEXT: neg x8, x8 +; CHECK-NEXT: orr w9, w8, #0xfffffff0 +; CHECK-NEXT: add w0, w8, w9 +; CHECK-NEXT: ret + %vscale = call i32 @llvm.vscale.i32() + %mul = mul i32 %vscale, -2 + %or_redundant = or i32 %mul, 4294967264 + %or_required = or i32 %mul, 4294967280 + %result = add i32 %or_redundant, %or_required + ret i32 %result +} + declare i32 @llvm.vscale.i32() declare i64 @llvm.aarch64.sve.cntb(i32 %pattern) declare i64 @llvm.aarch64.sve.cnth(i32 %pattern) -- 2.7.4