From: Sanjay Patel Date: Tue, 9 Nov 2021 13:27:09 +0000 (-0500) Subject: [InstCombine] enhance vector bitwise select matching X-Git-Tag: upstream/15.0.7~26349 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=c36b7e21bd8f04a44d6935c3469b1bcbbafeeb2d;p=platform%2Fupstream%2Fllvm.git [InstCombine] enhance vector bitwise select matching (Cond & C) | (~bitcast(Cond) & D) --> bitcast (select Cond, (bc C), (bc D)) This is part of fixing: https://llvm.org/PR34047 That report shows a case where a bitcast is sitting between the select condition candidate and its 'not' value due to current cast canonicalization rules. There's a bitcast type restriction that might be violated in existing matching, but I still need to investigate if that is possible - Alive2 shows we can only do this transform safely when the bitcast is from narrow to wide vector elements (otherwise poison could leak into elements that were safe in the original code): https://alive2.llvm.org/ce/z/Hf66qh Differential Revision: https://reviews.llvm.org/D113035 --- diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 7e40e35..ebd2e3e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -2298,22 +2298,30 @@ Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) { if (!Ty->isIntOrIntVectorTy() || !B->getType()->isIntOrIntVectorTy()) return nullptr; - // We need 0 or all-1's bitmasks. - if (ComputeNumSignBits(A) != Ty->getScalarSizeInBits()) - return nullptr; - - // If B is the 'not' value of A, we have our answer. + // If A is the 'not' operand of B and has enough signbits, we have our answer. if (match(B, m_Not(m_Specific(A)))) { // If these are scalars or vectors of i1, A can be used directly. if (Ty->isIntOrIntVectorTy(1)) return A; - return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(Ty)); + + // If we look through a vector bitcast, the caller will bitcast the operands + // to match the condition's number of bits (N x i1). + // To make this poison-safe, disallow bitcast from wide element to narrow + // element. That could allow poison in lanes where it was not present in the + // original code. + A = peekThroughBitcast(A); + unsigned NumSignBits = ComputeNumSignBits(A); + if (NumSignBits == A->getType()->getScalarSizeInBits() && + NumSignBits <= Ty->getScalarSizeInBits()) + return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(A->getType())); + return nullptr; } // If both operands are constants, see if the constants are inverse bitmasks. Constant *AConst, *BConst; if (match(A, m_Constant(AConst)) && match(B, m_Constant(BConst))) - if (AConst == ConstantExpr::getNot(BConst)) + if (AConst == ConstantExpr::getNot(BConst) && + ComputeNumSignBits(A) == Ty->getScalarSizeInBits()) return Builder.CreateZExtOrTrunc(A, CmpInst::makeCmpResultType(Ty)); // Look for more complex patterns. The 'not' op may be hidden behind various @@ -2357,10 +2365,17 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B, B = peekThroughBitcast(B, true); if (Value *Cond = getSelectCondition(A, B)) { // ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D)) + // If this is a vector, we may need to cast to match the condition's length. // The bitcasts will either all exist or all not exist. The builder will // not create unnecessary casts if the types already match. - Value *BitcastC = Builder.CreateBitCast(C, A->getType()); - Value *BitcastD = Builder.CreateBitCast(D, A->getType()); + Type *SelTy = A->getType(); + if (auto *VecTy = dyn_cast(Cond->getType())) { + unsigned Elts = VecTy->getElementCount().getKnownMinValue(); + Type *EltTy = Builder.getIntNTy(SelTy->getPrimitiveSizeInBits() / Elts); + SelTy = VectorType::get(EltTy, VecTy->getElementCount()); + } + Value *BitcastC = Builder.CreateBitCast(C, SelTy); + Value *BitcastD = Builder.CreateBitCast(D, SelTy); Value *Select = Builder.CreateSelect(Cond, BitcastC, BitcastD); return Builder.CreateBitCast(Select, OrigType); } diff --git a/llvm/test/Transforms/InstCombine/logical-select.ll b/llvm/test/Transforms/InstCombine/logical-select.ll index 610eb20..3e3cc11 100644 --- a/llvm/test/Transforms/InstCombine/logical-select.ll +++ b/llvm/test/Transforms/InstCombine/logical-select.ll @@ -682,15 +682,15 @@ define <4 x i32> @computesignbits_through_two_input_shuffle(<4 x i32> %x, <4 x i ret <4 x i32> %sel } +; Bitcast of condition from narrow source element type can be converted to select. + define <2 x i64> @bitcast_vec_cond(<16 x i1> %cond, <2 x i64> %c, <2 x i64> %d) { ; CHECK-LABEL: @bitcast_vec_cond( -; CHECK-NEXT: [[S:%.*]] = sext <16 x i1> [[COND:%.*]] to <16 x i8> -; CHECK-NEXT: [[T9:%.*]] = bitcast <16 x i8> [[S]] to <2 x i64> -; CHECK-NEXT: [[NOTT9:%.*]] = xor <2 x i64> [[T9]], -; CHECK-NEXT: [[T11:%.*]] = and <2 x i64> [[NOTT9]], [[C:%.*]] -; CHECK-NEXT: [[T12:%.*]] = and <2 x i64> [[T9]], [[D:%.*]] -; CHECK-NEXT: [[R:%.*]] = or <2 x i64> [[T11]], [[T12]] -; CHECK-NEXT: ret <2 x i64> [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i64> [[D:%.*]] to <16 x i8> +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <2 x i64> [[C:%.*]] to <16 x i8> +; CHECK-NEXT: [[TMP3:%.*]] = select <16 x i1> [[COND:%.*]], <16 x i8> [[TMP1]], <16 x i8> [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <16 x i8> [[TMP3]] to <2 x i64> +; CHECK-NEXT: ret <2 x i64> [[TMP4]] ; %s = sext <16 x i1> %cond to <16 x i8> %t9 = bitcast <16 x i8> %s to <2 x i64> @@ -701,6 +701,8 @@ define <2 x i64> @bitcast_vec_cond(<16 x i1> %cond, <2 x i64> %c, <2 x i64> %d) ret <2 x i64> %r } +; Negative test - bitcast of condition from wide source element type cannot be converted to select. + define <8 x i3> @bitcast_vec_cond_commute1(<3 x i1> %cond, <8 x i3> %pc, <8 x i3> %d) { ; CHECK-LABEL: @bitcast_vec_cond_commute1( ; CHECK-NEXT: [[C:%.*]] = mul <8 x i3> [[PC:%.*]], [[PC]] @@ -726,13 +728,11 @@ define <2 x i16> @bitcast_vec_cond_commute2(<4 x i1> %cond, <2 x i16> %pc, <2 x ; CHECK-LABEL: @bitcast_vec_cond_commute2( ; CHECK-NEXT: [[C:%.*]] = mul <2 x i16> [[PC:%.*]], [[PC]] ; CHECK-NEXT: [[D:%.*]] = mul <2 x i16> [[PD:%.*]], [[PD]] -; CHECK-NEXT: [[S:%.*]] = sext <4 x i1> [[COND:%.*]] to <4 x i8> -; CHECK-NEXT: [[T9:%.*]] = bitcast <4 x i8> [[S]] to <2 x i16> -; CHECK-NEXT: [[NOTT9:%.*]] = xor <2 x i16> [[T9]], -; CHECK-NEXT: [[T11:%.*]] = and <2 x i16> [[C]], [[NOTT9]] -; CHECK-NEXT: [[T12:%.*]] = and <2 x i16> [[D]], [[T9]] -; CHECK-NEXT: [[R:%.*]] = or <2 x i16> [[T11]], [[T12]] -; CHECK-NEXT: ret <2 x i16> [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i16> [[D]] to <4 x i8> +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <2 x i16> [[C]] to <4 x i8> +; CHECK-NEXT: [[TMP3:%.*]] = select <4 x i1> [[COND:%.*]], <4 x i8> [[TMP1]], <4 x i8> [[TMP2]] +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <4 x i8> [[TMP3]] to <2 x i16> +; CHECK-NEXT: ret <2 x i16> [[TMP4]] ; %c = mul <2 x i16> %pc, %pc ; thwart complexity-based canonicalization %d = mul <2 x i16> %pd, %pd ; thwart complexity-based canonicalization @@ -745,17 +745,18 @@ define <2 x i16> @bitcast_vec_cond_commute2(<4 x i1> %cond, <2 x i16> %pc, <2 x ret <2 x i16> %r } +; Condition doesn't have to be a bool vec - just all signbits. + define <2 x i16> @bitcast_vec_cond_commute3(<4 x i8> %cond, <2 x i16> %pc, <2 x i16> %pd) { ; CHECK-LABEL: @bitcast_vec_cond_commute3( ; CHECK-NEXT: [[C:%.*]] = mul <2 x i16> [[PC:%.*]], [[PC]] ; CHECK-NEXT: [[D:%.*]] = mul <2 x i16> [[PD:%.*]], [[PD]] -; CHECK-NEXT: [[S:%.*]] = ashr <4 x i8> [[COND:%.*]], -; CHECK-NEXT: [[T9:%.*]] = bitcast <4 x i8> [[S]] to <2 x i16> -; CHECK-NEXT: [[NOTT9:%.*]] = xor <2 x i16> [[T9]], -; CHECK-NEXT: [[T11:%.*]] = and <2 x i16> [[C]], [[NOTT9]] -; CHECK-NEXT: [[T12:%.*]] = and <2 x i16> [[D]], [[T9]] -; CHECK-NEXT: [[R:%.*]] = or <2 x i16> [[T11]], [[T12]] -; CHECK-NEXT: ret <2 x i16> [[R]] +; CHECK-NEXT: [[DOTNOT:%.*]] = icmp sgt <4 x i8> [[COND:%.*]], +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i16> [[D]] to <4 x i8> +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <2 x i16> [[C]] to <4 x i8> +; CHECK-NEXT: [[TMP3:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i8> [[TMP2]], <4 x i8> [[TMP1]] +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <4 x i8> [[TMP3]] to <2 x i16> +; CHECK-NEXT: ret <2 x i16> [[TMP4]] ; %c = mul <2 x i16> %pc, %pc ; thwart complexity-based canonicalization %d = mul <2 x i16> %pd, %pd ; thwart complexity-based canonicalization diff --git a/llvm/test/Transforms/PhaseOrdering/X86/vector-math.ll b/llvm/test/Transforms/PhaseOrdering/X86/vector-math.ll index a658b19..964307b 100644 --- a/llvm/test/Transforms/PhaseOrdering/X86/vector-math.ll +++ b/llvm/test/Transforms/PhaseOrdering/X86/vector-math.ll @@ -68,15 +68,9 @@ define internal <2 x i64> @_mm_set_epi32(i32 %__i3, i32 %__i2, i32 %__i1, i32 %_ define <2 x i64> @abs_v4i32(<2 x i64> %x) { ; CHECK-LABEL: @abs_v4i32( ; CHECK-NEXT: [[T1_I:%.*]] = bitcast <2 x i64> [[X:%.*]] to <4 x i32> -; CHECK-NEXT: [[SUB_I:%.*]] = sub <4 x i32> zeroinitializer, [[T1_I]] -; CHECK-NEXT: [[T1_I_LOBIT:%.*]] = ashr <4 x i32> [[T1_I]], -; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i32> [[T1_I_LOBIT]] to <2 x i64> -; CHECK-NEXT: [[T2_I_I:%.*]] = xor <2 x i64> [[TMP1]], -; CHECK-NEXT: [[AND_I_I1:%.*]] = and <4 x i32> [[T1_I_LOBIT]], [[SUB_I]] -; CHECK-NEXT: [[AND_I_I:%.*]] = bitcast <4 x i32> [[AND_I_I1]] to <2 x i64> -; CHECK-NEXT: [[AND_I1_I:%.*]] = and <2 x i64> [[T2_I_I]], [[X]] -; CHECK-NEXT: [[OR_I_I:%.*]] = or <2 x i64> [[AND_I1_I]], [[AND_I_I]] -; CHECK-NEXT: ret <2 x i64> [[OR_I_I]] +; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[T1_I]], i1 false) +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <4 x i32> [[TMP1]] to <2 x i64> +; CHECK-NEXT: ret <2 x i64> [[TMP2]] ; %call = call <2 x i64> @_mm_set1_epi32(i32 -1) %call1 = call <2 x i64> @_mm_setzero_si128() @@ -90,13 +84,9 @@ define <2 x i64> @max_v4i32(<2 x i64> %x, <2 x i64> %y) { ; CHECK-NEXT: [[T0_I_I:%.*]] = bitcast <2 x i64> [[X:%.*]] to <4 x i32> ; CHECK-NEXT: [[T1_I_I:%.*]] = bitcast <2 x i64> [[Y:%.*]] to <4 x i32> ; CHECK-NEXT: [[CMP_I_I:%.*]] = icmp sgt <4 x i32> [[T0_I_I]], [[T1_I_I]] -; CHECK-NEXT: [[SEXT_I_I:%.*]] = sext <4 x i1> [[CMP_I_I]] to <4 x i32> -; CHECK-NEXT: [[T2_I_I:%.*]] = bitcast <4 x i32> [[SEXT_I_I]] to <2 x i64> -; CHECK-NEXT: [[NEG_I_I:%.*]] = xor <2 x i64> [[T2_I_I]], -; CHECK-NEXT: [[AND_I_I:%.*]] = and <2 x i64> [[NEG_I_I]], [[Y]] -; CHECK-NEXT: [[AND_I1_I:%.*]] = and <2 x i64> [[T2_I_I]], [[X]] -; CHECK-NEXT: [[OR_I_I:%.*]] = or <2 x i64> [[AND_I1_I]], [[AND_I_I]] -; CHECK-NEXT: ret <2 x i64> [[OR_I_I]] +; CHECK-NEXT: [[TMP1:%.*]] = select <4 x i1> [[CMP_I_I]], <4 x i32> [[T0_I_I]], <4 x i32> [[T1_I_I]] +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <4 x i32> [[TMP1]] to <2 x i64> +; CHECK-NEXT: ret <2 x i64> [[TMP2]] ; %call = call <2 x i64> @cmpgt_i32_sel_m128i(<2 x i64> %x, <2 x i64> %y, <2 x i64> %y, <2 x i64> %x) ret <2 x i64> %call