[X86][AVX] Add support for and/or scalar bool reduction with AVX512 mask registers
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Fri, 1 Nov 2019 17:55:18 +0000 (17:55 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Fri, 1 Nov 2019 17:55:31 +0000 (17:55 +0000)
combineBitcastvxi1 only handles bitcast->MOVMSK combines, with mask registers we use BITCAST directly.

llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/movmsk-cmp.ll

index 5efaa23..c7a45f6 100644 (file)
@@ -39220,9 +39220,12 @@ static SDValue combineAnd(SDNode *N, SelectionDAG &DAG,
     if (matchScalarReduction(SDValue(N, 0), ISD::AND, SrcOps) &&
         SrcOps.size() == 1) {
       SDLoc dl(N);
+      const TargetLowering &TLI = DAG.getTargetLoweringInfo();
       unsigned NumElts = SrcOps[0].getValueType().getVectorNumElements();
       EVT MaskVT = EVT::getIntegerVT(*DAG.getContext(), NumElts);
       SDValue Mask = combineBitcastvxi1(DAG, MaskVT, SrcOps[0], dl, Subtarget);
+      if (!Mask && TLI.isTypeLegal(SrcOps[0].getValueType()))
+        Mask = DAG.getBitcast(MaskVT, SrcOps[0]);
       if (Mask) {
         APInt AllBits = APInt::getAllOnesValue(NumElts);
         return DAG.getSetCC(dl, MVT::i1, Mask,
@@ -39758,9 +39761,12 @@ static SDValue combineOr(SDNode *N, SelectionDAG &DAG,
     if (matchScalarReduction(SDValue(N, 0), ISD::OR, SrcOps) &&
         SrcOps.size() == 1) {
       SDLoc dl(N);
+      const TargetLowering &TLI = DAG.getTargetLoweringInfo();
       unsigned NumElts = SrcOps[0].getValueType().getVectorNumElements();
       EVT MaskVT = EVT::getIntegerVT(*DAG.getContext(), NumElts);
       SDValue Mask = combineBitcastvxi1(DAG, MaskVT, SrcOps[0], dl, Subtarget);
+      if (!Mask && TLI.isTypeLegal(SrcOps[0].getValueType()))
+        Mask = DAG.getBitcast(MaskVT, SrcOps[0]);
       if (Mask) {
         APInt AllBits = APInt::getNullValue(NumElts);
         return DAG.getSetCC(dl, MVT::i1, Mask,
index c850d9d..97b6929 100644 (file)
@@ -4467,22 +4467,19 @@ define i1 @movmsk_and_v2i64(<2 x i64> %x, <2 x i64> %y) {
 ; KNL-NEXT:    # kill: def $xmm1 killed $xmm1 def $zmm1
 ; KNL-NEXT:    # kill: def $xmm0 killed $xmm0 def $zmm0
 ; KNL-NEXT:    vpcmpneqq %zmm1, %zmm0, %k0
-; KNL-NEXT:    kshiftrw $1, %k0, %k1
-; KNL-NEXT:    kmovw %k1, %ecx
 ; KNL-NEXT:    kmovw %k0, %eax
-; KNL-NEXT:    andb %cl, %al
-; KNL-NEXT:    # kill: def $al killed $al killed $eax
+; KNL-NEXT:    andb $3, %al
+; KNL-NEXT:    cmpb $3, %al
+; KNL-NEXT:    sete %al
 ; KNL-NEXT:    vzeroupper
 ; KNL-NEXT:    retq
 ;
 ; SKX-LABEL: movmsk_and_v2i64:
 ; SKX:       # %bb.0:
 ; SKX-NEXT:    vpcmpneqq %xmm1, %xmm0, %k0
-; SKX-NEXT:    kshiftrb $1, %k0, %k1
-; SKX-NEXT:    kmovd %k1, %ecx
 ; SKX-NEXT:    kmovd %k0, %eax
-; SKX-NEXT:    andb %cl, %al
-; SKX-NEXT:    # kill: def $al killed $al killed $eax
+; SKX-NEXT:    cmpb $3, %al
+; SKX-NEXT:    sete %al
 ; SKX-NEXT:    retq
   %cmp = icmp ne <2 x i64> %x, %y
   %e1 = extractelement <2 x i1> %cmp, i32 0
@@ -4515,22 +4512,17 @@ define i1 @movmsk_or_v2i64(<2 x i64> %x, <2 x i64> %y) {
 ; KNL-NEXT:    # kill: def $xmm1 killed $xmm1 def $zmm1
 ; KNL-NEXT:    # kill: def $xmm0 killed $xmm0 def $zmm0
 ; KNL-NEXT:    vpcmpneqq %zmm1, %zmm0, %k0
-; KNL-NEXT:    kshiftrw $1, %k0, %k1
-; KNL-NEXT:    kmovw %k1, %ecx
 ; KNL-NEXT:    kmovw %k0, %eax
-; KNL-NEXT:    orb %cl, %al
-; KNL-NEXT:    # kill: def $al killed $al killed $eax
+; KNL-NEXT:    testb $3, %al
+; KNL-NEXT:    setne %al
 ; KNL-NEXT:    vzeroupper
 ; KNL-NEXT:    retq
 ;
 ; SKX-LABEL: movmsk_or_v2i64:
 ; SKX:       # %bb.0:
 ; SKX-NEXT:    vpcmpneqq %xmm1, %xmm0, %k0
-; SKX-NEXT:    kshiftrb $1, %k0, %k1
-; SKX-NEXT:    kmovd %k1, %ecx
-; SKX-NEXT:    kmovd %k0, %eax
-; SKX-NEXT:    orb %cl, %al
-; SKX-NEXT:    # kill: def $al killed $al killed $eax
+; SKX-NEXT:    kortestb %k0, %k0
+; SKX-NEXT:    setne %al
 ; SKX-NEXT:    retq
   %cmp = icmp ne <2 x i64> %x, %y
   %e1 = extractelement <2 x i1> %cmp, i32 0
@@ -4620,22 +4612,19 @@ define i1 @movmsk_and_v2f64(<2 x double> %x, <2 x double> %y) {
 ; KNL-NEXT:    # kill: def $xmm1 killed $xmm1 def $zmm1
 ; KNL-NEXT:    # kill: def $xmm0 killed $xmm0 def $zmm0
 ; KNL-NEXT:    vcmplepd %zmm0, %zmm1, %k0
-; KNL-NEXT:    kshiftrw $1, %k0, %k1
-; KNL-NEXT:    kmovw %k1, %ecx
 ; KNL-NEXT:    kmovw %k0, %eax
-; KNL-NEXT:    andb %cl, %al
-; KNL-NEXT:    # kill: def $al killed $al killed $eax
+; KNL-NEXT:    andb $3, %al
+; KNL-NEXT:    cmpb $3, %al
+; KNL-NEXT:    sete %al
 ; KNL-NEXT:    vzeroupper
 ; KNL-NEXT:    retq
 ;
 ; SKX-LABEL: movmsk_and_v2f64:
 ; SKX:       # %bb.0:
 ; SKX-NEXT:    vcmplepd %xmm0, %xmm1, %k0
-; SKX-NEXT:    kshiftrb $1, %k0, %k1
-; SKX-NEXT:    kmovd %k1, %ecx
 ; SKX-NEXT:    kmovd %k0, %eax
-; SKX-NEXT:    andb %cl, %al
-; SKX-NEXT:    # kill: def $al killed $al killed $eax
+; SKX-NEXT:    cmpb $3, %al
+; SKX-NEXT:    sete %al
 ; SKX-NEXT:    retq
   %cmp = fcmp oge <2 x double> %x, %y
   %e1 = extractelement <2 x i1> %cmp, i32 0
@@ -4666,22 +4655,17 @@ define i1 @movmsk_or_v2f64(<2 x double> %x, <2 x double> %y) {
 ; KNL-NEXT:    # kill: def $xmm1 killed $xmm1 def $zmm1
 ; KNL-NEXT:    # kill: def $xmm0 killed $xmm0 def $zmm0
 ; KNL-NEXT:    vcmplepd %zmm0, %zmm1, %k0
-; KNL-NEXT:    kshiftrw $1, %k0, %k1
-; KNL-NEXT:    kmovw %k1, %ecx
 ; KNL-NEXT:    kmovw %k0, %eax
-; KNL-NEXT:    orb %cl, %al
-; KNL-NEXT:    # kill: def $al killed $al killed $eax
+; KNL-NEXT:    testb $3, %al
+; KNL-NEXT:    setne %al
 ; KNL-NEXT:    vzeroupper
 ; KNL-NEXT:    retq
 ;
 ; SKX-LABEL: movmsk_or_v2f64:
 ; SKX:       # %bb.0:
 ; SKX-NEXT:    vcmplepd %xmm0, %xmm1, %k0
-; SKX-NEXT:    kshiftrb $1, %k0, %k1
-; SKX-NEXT:    kmovd %k1, %ecx
-; SKX-NEXT:    kmovd %k0, %eax
-; SKX-NEXT:    orb %cl, %al
-; SKX-NEXT:    # kill: def $al killed $al killed $eax
+; SKX-NEXT:    kortestb %k0, %k0
+; SKX-NEXT:    setne %al
 ; SKX-NEXT:    retq
   %cmp = fcmp oge <2 x double> %x, %y
   %e1 = extractelement <2 x i1> %cmp, i32 0