[InstCombine] enhance vector bitwise select matching
authorSanjay Patel <spatel@rotateright.com>
Tue, 9 Nov 2021 13:27:09 +0000 (08:27 -0500)
committerSanjay Patel <spatel@rotateright.com>
Tue, 9 Nov 2021 13:54:59 +0000 (08:54 -0500)
(Cond & C) | (~bitcast(Cond) & D) --> bitcast (select Cond, (bc C), (bc D))

This is part of fixing:
https://llvm.org/PR34047

That report shows a case where a bitcast is sitting between the select condition
candidate and its 'not' value due to current cast canonicalization rules.

There's a bitcast type restriction that might be violated in existing matching,
but I still need to investigate if that is possible -
Alive2 shows we can only do this transform safely when the bitcast is from
narrow to wide vector elements (otherwise poison could leak into elements
that were safe in the original code):
https://alive2.llvm.org/ce/z/Hf66qh

Differential Revision: https://reviews.llvm.org/D113035

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
llvm/test/Transforms/InstCombine/logical-select.ll
llvm/test/Transforms/PhaseOrdering/X86/vector-math.ll

index 7e40e35..ebd2e3e 100644 (file)
@@ -2298,22 +2298,30 @@ Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) {
   if (!Ty->isIntOrIntVectorTy() || !B->getType()->isIntOrIntVectorTy())
     return nullptr;
 
-  // We need 0 or all-1's bitmasks.
-  if (ComputeNumSignBits(A) != Ty->getScalarSizeInBits())
-    return nullptr;
-
-  // If B is the 'not' value of A, we have our answer.
+  // If A is the 'not' operand of B and has enough signbits, we have our answer.
   if (match(B, m_Not(m_Specific(A)))) {
     // If these are scalars or vectors of i1, A can be used directly.
     if (Ty->isIntOrIntVectorTy(1))
       return A;
-    return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(Ty));
+
+    // If we look through a vector bitcast, the caller will bitcast the operands
+    // to match the condition's number of bits (N x i1).
+    // To make this poison-safe, disallow bitcast from wide element to narrow
+    // element. That could allow poison in lanes where it was not present in the
+    // original code.
+    A = peekThroughBitcast(A);
+    unsigned NumSignBits = ComputeNumSignBits(A);
+    if (NumSignBits == A->getType()->getScalarSizeInBits() &&
+        NumSignBits <= Ty->getScalarSizeInBits())
+      return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(A->getType()));
+    return nullptr;
   }
 
   // If both operands are constants, see if the constants are inverse bitmasks.
   Constant *AConst, *BConst;
   if (match(A, m_Constant(AConst)) && match(B, m_Constant(BConst)))
-    if (AConst == ConstantExpr::getNot(BConst))
+    if (AConst == ConstantExpr::getNot(BConst) &&
+        ComputeNumSignBits(A) == Ty->getScalarSizeInBits())
       return Builder.CreateZExtOrTrunc(A, CmpInst::makeCmpResultType(Ty));
 
   // Look for more complex patterns. The 'not' op may be hidden behind various
@@ -2357,10 +2365,17 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B,
   B = peekThroughBitcast(B, true);
   if (Value *Cond = getSelectCondition(A, B)) {
     // ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D))
+    // If this is a vector, we may need to cast to match the condition's length.
     // The bitcasts will either all exist or all not exist. The builder will
     // not create unnecessary casts if the types already match.
-    Value *BitcastC = Builder.CreateBitCast(C, A->getType());
-    Value *BitcastD = Builder.CreateBitCast(D, A->getType());
+    Type *SelTy = A->getType();
+    if (auto *VecTy = dyn_cast<VectorType>(Cond->getType())) {
+      unsigned Elts = VecTy->getElementCount().getKnownMinValue();
+      Type *EltTy = Builder.getIntNTy(SelTy->getPrimitiveSizeInBits() / Elts);
+      SelTy = VectorType::get(EltTy, VecTy->getElementCount());
+    }
+    Value *BitcastC = Builder.CreateBitCast(C, SelTy);
+    Value *BitcastD = Builder.CreateBitCast(D, SelTy);
     Value *Select = Builder.CreateSelect(Cond, BitcastC, BitcastD);
     return Builder.CreateBitCast(Select, OrigType);
   }
index 610eb20..3e3cc11 100644 (file)
@@ -682,15 +682,15 @@ define <4 x i32> @computesignbits_through_two_input_shuffle(<4 x i32> %x, <4 x i
   ret <4 x i32> %sel
 }
 
+; Bitcast of condition from narrow source element type can be converted to select.
+
 define <2 x i64> @bitcast_vec_cond(<16 x i1> %cond, <2 x i64> %c, <2 x i64> %d) {
 ; CHECK-LABEL: @bitcast_vec_cond(
-; CHECK-NEXT:    [[S:%.*]] = sext <16 x i1> [[COND:%.*]] to <16 x i8>
-; CHECK-NEXT:    [[T9:%.*]] = bitcast <16 x i8> [[S]] to <2 x i64>
-; CHECK-NEXT:    [[NOTT9:%.*]] = xor <2 x i64> [[T9]], <i64 -1, i64 -1>
-; CHECK-NEXT:    [[T11:%.*]] = and <2 x i64> [[NOTT9]], [[C:%.*]]
-; CHECK-NEXT:    [[T12:%.*]] = and <2 x i64> [[T9]], [[D:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = or <2 x i64> [[T11]], [[T12]]
-; CHECK-NEXT:    ret <2 x i64> [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i64> [[D:%.*]] to <16 x i8>
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <2 x i64> [[C:%.*]] to <16 x i8>
+; CHECK-NEXT:    [[TMP3:%.*]] = select <16 x i1> [[COND:%.*]], <16 x i8> [[TMP1]], <16 x i8> [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <16 x i8> [[TMP3]] to <2 x i64>
+; CHECK-NEXT:    ret <2 x i64> [[TMP4]]
 ;
   %s = sext <16 x i1> %cond to <16 x i8>
   %t9 = bitcast <16 x i8> %s to <2 x i64>
@@ -701,6 +701,8 @@ define <2 x i64> @bitcast_vec_cond(<16 x i1> %cond, <2 x i64> %c, <2 x i64> %d)
   ret <2 x i64> %r
 }
 
+; Negative test - bitcast of condition from wide source element type cannot be converted to select.
+
 define <8 x i3> @bitcast_vec_cond_commute1(<3 x i1> %cond, <8 x i3> %pc, <8 x i3> %d) {
 ; CHECK-LABEL: @bitcast_vec_cond_commute1(
 ; CHECK-NEXT:    [[C:%.*]] = mul <8 x i3> [[PC:%.*]], [[PC]]
@@ -726,13 +728,11 @@ define <2 x i16> @bitcast_vec_cond_commute2(<4 x i1> %cond, <2 x i16> %pc, <2 x
 ; CHECK-LABEL: @bitcast_vec_cond_commute2(
 ; CHECK-NEXT:    [[C:%.*]] = mul <2 x i16> [[PC:%.*]], [[PC]]
 ; CHECK-NEXT:    [[D:%.*]] = mul <2 x i16> [[PD:%.*]], [[PD]]
-; CHECK-NEXT:    [[S:%.*]] = sext <4 x i1> [[COND:%.*]] to <4 x i8>
-; CHECK-NEXT:    [[T9:%.*]] = bitcast <4 x i8> [[S]] to <2 x i16>
-; CHECK-NEXT:    [[NOTT9:%.*]] = xor <2 x i16> [[T9]], <i16 -1, i16 -1>
-; CHECK-NEXT:    [[T11:%.*]] = and <2 x i16> [[C]], [[NOTT9]]
-; CHECK-NEXT:    [[T12:%.*]] = and <2 x i16> [[D]], [[T9]]
-; CHECK-NEXT:    [[R:%.*]] = or <2 x i16> [[T11]], [[T12]]
-; CHECK-NEXT:    ret <2 x i16> [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i16> [[D]] to <4 x i8>
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <2 x i16> [[C]] to <4 x i8>
+; CHECK-NEXT:    [[TMP3:%.*]] = select <4 x i1> [[COND:%.*]], <4 x i8> [[TMP1]], <4 x i8> [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <4 x i8> [[TMP3]] to <2 x i16>
+; CHECK-NEXT:    ret <2 x i16> [[TMP4]]
 ;
   %c = mul <2 x i16> %pc, %pc ; thwart complexity-based canonicalization
   %d = mul <2 x i16> %pd, %pd ; thwart complexity-based canonicalization
@@ -745,17 +745,18 @@ define <2 x i16> @bitcast_vec_cond_commute2(<4 x i1> %cond, <2 x i16> %pc, <2 x
   ret <2 x i16> %r
 }
 
+; Condition doesn't have to be a bool vec - just all signbits.
+
 define <2 x i16> @bitcast_vec_cond_commute3(<4 x i8> %cond, <2 x i16> %pc, <2 x i16> %pd) {
 ; CHECK-LABEL: @bitcast_vec_cond_commute3(
 ; CHECK-NEXT:    [[C:%.*]] = mul <2 x i16> [[PC:%.*]], [[PC]]
 ; CHECK-NEXT:    [[D:%.*]] = mul <2 x i16> [[PD:%.*]], [[PD]]
-; CHECK-NEXT:    [[S:%.*]] = ashr <4 x i8> [[COND:%.*]], <i8 7, i8 7, i8 7, i8 7>
-; CHECK-NEXT:    [[T9:%.*]] = bitcast <4 x i8> [[S]] to <2 x i16>
-; CHECK-NEXT:    [[NOTT9:%.*]] = xor <2 x i16> [[T9]], <i16 -1, i16 -1>
-; CHECK-NEXT:    [[T11:%.*]] = and <2 x i16> [[C]], [[NOTT9]]
-; CHECK-NEXT:    [[T12:%.*]] = and <2 x i16> [[D]], [[T9]]
-; CHECK-NEXT:    [[R:%.*]] = or <2 x i16> [[T11]], [[T12]]
-; CHECK-NEXT:    ret <2 x i16> [[R]]
+; CHECK-NEXT:    [[DOTNOT:%.*]] = icmp sgt <4 x i8> [[COND:%.*]], <i8 -1, i8 -1, i8 -1, i8 -1>
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i16> [[D]] to <4 x i8>
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <2 x i16> [[C]] to <4 x i8>
+; CHECK-NEXT:    [[TMP3:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i8> [[TMP2]], <4 x i8> [[TMP1]]
+; CHECK-NEXT:    [[TMP4:%.*]] = bitcast <4 x i8> [[TMP3]] to <2 x i16>
+; CHECK-NEXT:    ret <2 x i16> [[TMP4]]
 ;
   %c = mul <2 x i16> %pc, %pc ; thwart complexity-based canonicalization
   %d = mul <2 x i16> %pd, %pd ; thwart complexity-based canonicalization
index a658b19..964307b 100644 (file)
@@ -68,15 +68,9 @@ define internal <2 x i64> @_mm_set_epi32(i32 %__i3, i32 %__i2, i32 %__i1, i32 %_
 define <2 x i64> @abs_v4i32(<2 x i64> %x) {
 ; CHECK-LABEL: @abs_v4i32(
 ; CHECK-NEXT:    [[T1_I:%.*]] = bitcast <2 x i64> [[X:%.*]] to <4 x i32>
-; CHECK-NEXT:    [[SUB_I:%.*]] = sub <4 x i32> zeroinitializer, [[T1_I]]
-; CHECK-NEXT:    [[T1_I_LOBIT:%.*]] = ashr <4 x i32> [[T1_I]], <i32 31, i32 31, i32 31, i32 31>
-; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i32> [[T1_I_LOBIT]] to <2 x i64>
-; CHECK-NEXT:    [[T2_I_I:%.*]] = xor <2 x i64> [[TMP1]], <i64 -1, i64 -1>
-; CHECK-NEXT:    [[AND_I_I1:%.*]] = and <4 x i32> [[T1_I_LOBIT]], [[SUB_I]]
-; CHECK-NEXT:    [[AND_I_I:%.*]] = bitcast <4 x i32> [[AND_I_I1]] to <2 x i64>
-; CHECK-NEXT:    [[AND_I1_I:%.*]] = and <2 x i64> [[T2_I_I]], [[X]]
-; CHECK-NEXT:    [[OR_I_I:%.*]] = or <2 x i64> [[AND_I1_I]], [[AND_I_I]]
-; CHECK-NEXT:    ret <2 x i64> [[OR_I_I]]
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[T1_I]], i1 false)
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <4 x i32> [[TMP1]] to <2 x i64>
+; CHECK-NEXT:    ret <2 x i64> [[TMP2]]
 ;
   %call = call <2 x i64> @_mm_set1_epi32(i32 -1)
   %call1 = call <2 x i64> @_mm_setzero_si128()
@@ -90,13 +84,9 @@ define <2 x i64> @max_v4i32(<2 x i64> %x, <2 x i64> %y) {
 ; CHECK-NEXT:    [[T0_I_I:%.*]] = bitcast <2 x i64> [[X:%.*]] to <4 x i32>
 ; CHECK-NEXT:    [[T1_I_I:%.*]] = bitcast <2 x i64> [[Y:%.*]] to <4 x i32>
 ; CHECK-NEXT:    [[CMP_I_I:%.*]] = icmp sgt <4 x i32> [[T0_I_I]], [[T1_I_I]]
-; CHECK-NEXT:    [[SEXT_I_I:%.*]] = sext <4 x i1> [[CMP_I_I]] to <4 x i32>
-; CHECK-NEXT:    [[T2_I_I:%.*]] = bitcast <4 x i32> [[SEXT_I_I]] to <2 x i64>
-; CHECK-NEXT:    [[NEG_I_I:%.*]] = xor <2 x i64> [[T2_I_I]], <i64 -1, i64 -1>
-; CHECK-NEXT:    [[AND_I_I:%.*]] = and <2 x i64> [[NEG_I_I]], [[Y]]
-; CHECK-NEXT:    [[AND_I1_I:%.*]] = and <2 x i64> [[T2_I_I]], [[X]]
-; CHECK-NEXT:    [[OR_I_I:%.*]] = or <2 x i64> [[AND_I1_I]], [[AND_I_I]]
-; CHECK-NEXT:    ret <2 x i64> [[OR_I_I]]
+; CHECK-NEXT:    [[TMP1:%.*]] = select <4 x i1> [[CMP_I_I]], <4 x i32> [[T0_I_I]], <4 x i32> [[T1_I_I]]
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast <4 x i32> [[TMP1]] to <2 x i64>
+; CHECK-NEXT:    ret <2 x i64> [[TMP2]]
 ;
   %call = call <2 x i64> @cmpgt_i32_sel_m128i(<2 x i64> %x, <2 x i64> %y, <2 x i64> %y, <2 x i64> %x)
   ret <2 x i64> %call