[X86][AVX512] Fold extract_element(bitcast(<X x i1>) -> bitcast(extract_subvector())
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Sun, 23 Oct 2022 13:47:19 +0000 (14:47 +0100)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Sun, 23 Oct 2022 13:47:24 +0000 (14:47 +0100)
On AVX512, extract legal bool vectors as bool subvectors before bitcasting to scalars to avoid spilling to stack.

This helps rust which internally represents bool vectors as bool arrays

It also exposes more missed opportunities to use the KADD instruction to add masks together before moving to gpr

Fixes #58546

llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/bitcast-vector-bool.ll

index 71e643e..aafbe7b 100644 (file)
@@ -44324,6 +44324,8 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
   SDLoc dl(InputVector);
   bool IsPextr = N->getOpcode() != ISD::EXTRACT_VECTOR_ELT;
   unsigned NumSrcElts = SrcVT.getVectorNumElements();
+  unsigned NumEltBits = VT.getScalarSizeInBits();
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
 
   if (CIdx && CIdx->getAPIntValue().uge(NumSrcElts))
     return IsPextr ? DAG.getConstant(0, dl, VT) : DAG.getUNDEF(VT);
@@ -44338,15 +44340,26 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
       uint64_t Idx = CIdx->getZExtValue();
       if (UndefVecElts[Idx])
         return IsPextr ? DAG.getConstant(0, dl, VT) : DAG.getUNDEF(VT);
-      return DAG.getConstant(EltBits[Idx].zext(VT.getScalarSizeInBits()), dl,
-                             VT);
+      return DAG.getConstant(EltBits[Idx].zext(NumEltBits), dl, VT);
+    }
+
+    // Convert extract_element(bitcast(<X x i1>) -> bitcast(extract_subvector()).
+    // Improves lowering of bool masks on rust which splits them into byte array.
+    if (InputVector.getOpcode() == ISD::BITCAST && (NumEltBits % 8) == 0) {
+      SDValue Src = peekThroughBitcasts(InputVector);
+      if (Src.getValueType().getScalarType() == MVT::i1 &&
+          TLI.isTypeLegal(Src.getValueType())) {
+        MVT SubVT = MVT::getVectorVT(MVT::i1, NumEltBits);
+        SDValue Sub = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, SubVT, Src,
+            DAG.getIntPtrConstant(CIdx->getZExtValue() * NumEltBits, dl));
+        return DAG.getBitcast(VT, Sub);
+      }
     }
   }
 
   if (IsPextr) {
-    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-    if (TLI.SimplifyDemandedBits(SDValue(N, 0),
-                                 APInt::getAllOnes(VT.getSizeInBits()), DCI))
+    if (TLI.SimplifyDemandedBits(SDValue(N, 0), APInt::getAllOnes(NumEltBits),
+                                 DCI))
       return SDValue(N, 0);
 
     // PEXTR*(PINSR*(v, s, c), c) -> s (with implicit zext handling).
index 94901e6..de132c1 100644 (file)
@@ -123,14 +123,24 @@ define i8 @bitcast_v16i8_to_v2i8(<16 x i8> %a0) nounwind {
 ; SSE2-SSSE3-NEXT:    addb -{{[0-9]+}}(%rsp), %al
 ; SSE2-SSSE3-NEXT:    retq
 ;
-; AVX-LABEL: bitcast_v16i8_to_v2i8:
-; AVX:       # %bb.0:
-; AVX-NEXT:    vpmovmskb %xmm0, %ecx
-; AVX-NEXT:    movl %ecx, %eax
-; AVX-NEXT:    shrl $8, %eax
-; AVX-NEXT:    addb %cl, %al
-; AVX-NEXT:    # kill: def $al killed $al killed $eax
-; AVX-NEXT:    retq
+; AVX12-LABEL: bitcast_v16i8_to_v2i8:
+; AVX12:       # %bb.0:
+; AVX12-NEXT:    vpmovmskb %xmm0, %ecx
+; AVX12-NEXT:    movl %ecx, %eax
+; AVX12-NEXT:    shrl $8, %eax
+; AVX12-NEXT:    addb %cl, %al
+; AVX12-NEXT:    # kill: def $al killed $al killed $eax
+; AVX12-NEXT:    retq
+;
+; AVX512-LABEL: bitcast_v16i8_to_v2i8:
+; AVX512:       # %bb.0:
+; AVX512-NEXT:    vpmovb2m %xmm0, %k0
+; AVX512-NEXT:    kshiftrw $8, %k0, %k1
+; AVX512-NEXT:    kmovd %k0, %ecx
+; AVX512-NEXT:    kmovd %k1, %eax
+; AVX512-NEXT:    addb %cl, %al
+; AVX512-NEXT:    # kill: def $al killed $al killed $eax
+; AVX512-NEXT:    retq
   %1 = icmp slt <16 x i8> %a0, zeroinitializer
   %2 = bitcast <16 x i1> %1 to <2 x i8>
   %3 = extractelement <2 x i8> %2, i32 0
@@ -242,10 +252,9 @@ define i8 @bitcast_v16i16_to_v2i8(<16 x i16> %a0) nounwind {
 ; AVX512-LABEL: bitcast_v16i16_to_v2i8:
 ; AVX512:       # %bb.0:
 ; AVX512-NEXT:    vpmovw2m %ymm0, %k0
-; AVX512-NEXT:    kmovw %k0, -{{[0-9]+}}(%rsp)
-; AVX512-NEXT:    vmovdqa -{{[0-9]+}}(%rsp), %xmm0
-; AVX512-NEXT:    vmovd %xmm0, %ecx
-; AVX512-NEXT:    vpextrb $1, %xmm0, %eax
+; AVX512-NEXT:    kshiftrw $8, %k0, %k1
+; AVX512-NEXT:    kmovd %k0, %ecx
+; AVX512-NEXT:    kmovd %k1, %eax
 ; AVX512-NEXT:    addb %cl, %al
 ; AVX512-NEXT:    # kill: def $al killed $al killed $eax
 ; AVX512-NEXT:    vzeroupper
@@ -289,9 +298,10 @@ define i16 @bitcast_v32i8_to_v2i16(<32 x i8> %a0) nounwind {
 ;
 ; AVX512-LABEL: bitcast_v32i8_to_v2i16:
 ; AVX512:       # %bb.0:
-; AVX512-NEXT:    vpmovmskb %ymm0, %ecx
-; AVX512-NEXT:    movl %ecx, %eax
-; AVX512-NEXT:    shrl $16, %eax
+; AVX512-NEXT:    vpmovb2m %ymm0, %k0
+; AVX512-NEXT:    kshiftrd $16, %k0, %k1
+; AVX512-NEXT:    kmovd %k0, %ecx
+; AVX512-NEXT:    kmovd %k1, %eax
 ; AVX512-NEXT:    addl %ecx, %eax
 ; AVX512-NEXT:    # kill: def $ax killed $ax killed $eax
 ; AVX512-NEXT:    vzeroupper
@@ -424,10 +434,9 @@ define i8 @bitcast_v16i32_to_v2i8(<16 x i32> %a0) nounwind {
 ; AVX512:       # %bb.0:
 ; AVX512-NEXT:    vpxor %xmm1, %xmm1, %xmm1
 ; AVX512-NEXT:    vpcmpgtd %zmm0, %zmm1, %k0
-; AVX512-NEXT:    kmovw %k0, -{{[0-9]+}}(%rsp)
-; AVX512-NEXT:    vmovdqa -{{[0-9]+}}(%rsp), %xmm0
-; AVX512-NEXT:    vmovd %xmm0, %ecx
-; AVX512-NEXT:    vpextrb $1, %xmm0, %eax
+; AVX512-NEXT:    kshiftrw $8, %k0, %k1
+; AVX512-NEXT:    kmovd %k0, %ecx
+; AVX512-NEXT:    kmovd %k1, %eax
 ; AVX512-NEXT:    addb %cl, %al
 ; AVX512-NEXT:    # kill: def $al killed $al killed $eax
 ; AVX512-NEXT:    vzeroupper
@@ -479,10 +488,9 @@ define i16 @bitcast_v32i16_to_v2i16(<32 x i16> %a0) nounwind {
 ; AVX512-LABEL: bitcast_v32i16_to_v2i16:
 ; AVX512:       # %bb.0:
 ; AVX512-NEXT:    vpmovw2m %zmm0, %k0
-; AVX512-NEXT:    kmovd %k0, -{{[0-9]+}}(%rsp)
-; AVX512-NEXT:    vmovdqa -{{[0-9]+}}(%rsp), %xmm0
-; AVX512-NEXT:    vmovd %xmm0, %ecx
-; AVX512-NEXT:    vpextrw $1, %xmm0, %eax
+; AVX512-NEXT:    kshiftrd $16, %k0, %k1
+; AVX512-NEXT:    kmovd %k0, %ecx
+; AVX512-NEXT:    kmovd %k1, %eax
 ; AVX512-NEXT:    addl %ecx, %eax
 ; AVX512-NEXT:    # kill: def $ax killed $ax killed $eax
 ; AVX512-NEXT:    vzeroupper
@@ -541,9 +549,10 @@ define i32 @bitcast_v64i8_to_v2i32(<64 x i8> %a0) nounwind {
 ; AVX512-LABEL: bitcast_v64i8_to_v2i32:
 ; AVX512:       # %bb.0:
 ; AVX512-NEXT:    vpmovb2m %zmm0, %k0
-; AVX512-NEXT:    kmovq %k0, -{{[0-9]+}}(%rsp)
-; AVX512-NEXT:    movl -{{[0-9]+}}(%rsp), %eax
-; AVX512-NEXT:    addl -{{[0-9]+}}(%rsp), %eax
+; AVX512-NEXT:    kshiftrq $32, %k0, %k1
+; AVX512-NEXT:    kmovd %k0, %ecx
+; AVX512-NEXT:    kmovd %k1, %eax
+; AVX512-NEXT:    addl %ecx, %eax
 ; AVX512-NEXT:    vzeroupper
 ; AVX512-NEXT:    retq
   %1 = icmp slt <64 x i8> %a0, zeroinitializer
@@ -698,10 +707,9 @@ define [2 x i8] @PR58546(<16 x float> %a0) {
 ; AVX512:       # %bb.0:
 ; AVX512-NEXT:    vxorps %xmm1, %xmm1, %xmm1
 ; AVX512-NEXT:    vcmpunordps %zmm1, %zmm0, %k0
-; AVX512-NEXT:    kmovw %k0, -{{[0-9]+}}(%rsp)
-; AVX512-NEXT:    vmovdqa -{{[0-9]+}}(%rsp), %xmm0
-; AVX512-NEXT:    vmovd %xmm0, %eax
-; AVX512-NEXT:    vpextrb $1, %xmm0, %edx
+; AVX512-NEXT:    kshiftrw $8, %k0, %k1
+; AVX512-NEXT:    kmovd %k0, %eax
+; AVX512-NEXT:    kmovd %k1, %edx
 ; AVX512-NEXT:    # kill: def $al killed $al killed $eax
 ; AVX512-NEXT:    # kill: def $dl killed $dl killed $edx
 ; AVX512-NEXT:    vzeroupper