From 18e6a03b1a15b2661259af15ae604b4c4850cd61 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Fri, 6 Aug 2021 10:46:22 +0100 Subject: [PATCH] [X86][AVX] Extract SUBV_BROADCAST constant bits from just the lower subvector range (PR51281) As reported on PR51281, an internal fuzz test encountered an issue when extracting constant bits from a SUBV_BROADCAST node from a constant pool source larger than the broadcasted subvector width. The getTargetConstantBitsFromNode was assuming that the Constant would the same size as the subvector, resulting in the incorrect packing of the per-element bits data. This patch attempts to solve this by using the SUBV_BROADCAST node to determine the subvector width, and then ensuring we extract only the lowest bits from Constant of that subvector bitsize. Differential Revision: https://reviews.llvm.org/D107158 --- llvm/lib/Target/X86/X86ISelLowering.cpp | 14 +++++++++----- llvm/test/CodeGen/X86/pr51281.ll | 4 ++-- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 2efd06c..72e0c38 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -6704,17 +6704,21 @@ static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits, if (Op.getOpcode() == X86ISD::SUBV_BROADCAST_LOAD) { auto *MemIntr = cast(Op); SDValue Ptr = MemIntr->getBasePtr(); + // The source constant may be larger than the subvector broadcast, + // ensure we extract the correct subvector constants. if (const Constant *Cst = getTargetConstantFromBasePtr(Ptr)) { Type *CstTy = Cst->getType(); unsigned CstSizeInBits = CstTy->getPrimitiveSizeInBits(); - if (!CstTy->isVectorTy() || (SizeInBits % CstSizeInBits) != 0) + unsigned SubVecSizeInBits = MemIntr->getMemoryVT().getStoreSizeInBits(); + if (!CstTy->isVectorTy() || (CstSizeInBits % SubVecSizeInBits) != 0 || + (SizeInBits % SubVecSizeInBits) != 0) return false; - unsigned SubEltSizeInBits = CstTy->getScalarSizeInBits(); - unsigned NumSubElts = CstSizeInBits / SubEltSizeInBits; - unsigned NumSubVecs = SizeInBits / CstSizeInBits; + unsigned CstEltSizeInBits = CstTy->getScalarSizeInBits(); + unsigned NumSubElts = SubVecSizeInBits / CstEltSizeInBits; + unsigned NumSubVecs = SizeInBits / SubVecSizeInBits; APInt UndefSubElts(NumSubElts, 0); SmallVector SubEltBits(NumSubElts * NumSubVecs, - APInt(SubEltSizeInBits, 0)); + APInt(CstEltSizeInBits, 0)); for (unsigned i = 0; i != NumSubElts; ++i) { if (!CollectConstantBits(Cst->getAggregateElement(i), SubEltBits[i], UndefSubElts, i)) diff --git a/llvm/test/CodeGen/X86/pr51281.ll b/llvm/test/CodeGen/X86/pr51281.ll index 116e6d1..3812f3f 100644 --- a/llvm/test/CodeGen/X86/pr51281.ll +++ b/llvm/test/CodeGen/X86/pr51281.ll @@ -8,8 +8,8 @@ ; CHECK-NEXT: .long 0x3eb5dbc6 ; CHECK-NEXT: .zero 4 ; CHECK-NEXT: .long 0x3eb5dbc6 -; CHECK-NEXT: .zero 4 -; CHECK-NEXT: .zero 4 +; CHECK-NEXT: .long 0x3eb5dbc6 +; CHECK-NEXT: .long 0x3eb5dbc6 ; CHECK: .LCPI0_1: ; CHECK-NEXT: .long 3 -- 2.7.4