AVX-512: Generate KTEST instead of TEST fir i1 vectors
authorElena Demikhovsky <elena.demikhovsky@intel.com>
Thu, 24 Mar 2016 15:53:45 +0000 (15:53 +0000)
committerElena Demikhovsky <elena.demikhovsky@intel.com>
Thu, 24 Mar 2016 15:53:45 +0000 (15:53 +0000)
KTEST instruction may be used instead of TEST in this case:

%int_sel3 = bitcast <8 x i1> %sel3 to i8
%res = icmp eq i8 %int_sel3, zeroinitializer
br i1 %res, label %L2, label %L1

Differential Revision: http://reviews.llvm.org/D18444

llvm-svn: 264298

llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/test/CodeGen/X86/avx512-mask-op.ll

index 9ed1e07..14c3fd7 100644 (file)
@@ -14519,6 +14519,24 @@ static bool hasNonFlagsUse(SDValue Op) {
   return false;
 }
 
+// Emit KTEST instruction for bit vectors on AVX-512
+static SDValue EmitKTEST(SDValue Op, SelectionDAG &DAG,
+                         const X86Subtarget &Subtarget) {
+  if (Op.getOpcode() == ISD::BITCAST) {
+    auto hasKTEST = [&](MVT VT) {
+      unsigned SizeInBits = VT.getSizeInBits();
+      return (Subtarget.hasDQI() && (SizeInBits == 8 || SizeInBits == 8)) ||
+        (Subtarget.hasBWI() && (SizeInBits == 32 || SizeInBits == 64));
+    };
+    SDValue Op0 = Op.getOperand(0);
+    MVT Op0VT = Op0.getValueType().getSimpleVT();
+    if (Op0VT.isVector() && Op0VT.getVectorElementType() == MVT::i1 &&
+        hasKTEST(Op0VT))
+      return DAG.getNode(X86ISD::KTEST, SDLoc(Op), Op0VT, Op0, Op0);
+  }
+  return SDValue();
+}
+
 /// Emit nodes that will be selected as "test Op0,Op0", or something
 /// equivalent.
 SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, SDLoc dl,
@@ -14564,10 +14582,10 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, SDLoc dl,
   // doing a separate TEST. TEST always sets OF and CF to 0, so unless
   // we prove that the arithmetic won't overflow, we can't use OF or CF.
   if (Op.getResNo() != 0 || NeedOF || NeedCF) {
+    // Emit KTEST for bit vectors
+    if (auto Node = EmitKTEST(Op, DAG, Subtarget))
+      return Node;
     // Emit a CMP with 0, which is the TEST pattern.
-    //if (Op.getValueType() == MVT::i1)
-    //  return DAG.getNode(X86ISD::CMP, dl, MVT::i1, Op,
-    //                     DAG.getConstant(0, MVT::i1));
     return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op,
                        DAG.getConstant(0, dl, Op.getValueType()));
   }
@@ -14739,11 +14757,15 @@ SDValue X86TargetLowering::EmitTest(SDValue Op, unsigned X86CC, SDLoc dl,
     }
   }
 
-  if (Opcode == 0)
+  if (Opcode == 0) {
+    // Emit KTEST for bit vectors
+    if (auto Node = EmitKTEST(Op, DAG, Subtarget))
+      return Node;
+
     // Emit a CMP with 0, which is the TEST pattern.
     return DAG.getNode(X86ISD::CMP, dl, MVT::i32, Op,
                        DAG.getConstant(0, dl, Op.getValueType()));
-
+  }
   SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32);
   SmallVector<SDValue, 4> Ops(Op->op_begin(), Op->op_begin() + NumOperands);
 
index c2475be..cc21a17 100644 (file)
@@ -244,8 +244,7 @@ define void @test7(<8 x i1> %mask)  {
 ; SKX-NEXT:    movb $85, %al
 ; SKX-NEXT:    kmovb %eax, %k1
 ; SKX-NEXT:    korb %k1, %k0, %k0
-; SKX-NEXT:    kmovb %k0, %eax
-; SKX-NEXT:    testb %al, %al
+; SKX-NEXT:    ktestb %k0, %k0
 ; SKX-NEXT:    retq
 allocas:
   %a= or <8 x i1> %mask, <i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false>
@@ -1681,3 +1680,113 @@ define <64 x i8> @test_build_vec_v64i1(<64 x i8> %x) {
   %ret = select <64 x i1> <i1 false, i1 false, i1 true, i1 false, i1 false, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 false, i1 true, i1 false, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 false, i1 true, i1 false, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 true, i1 false, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 false, i1 false, i1 true, i1 false, i1 true, i1 false, i1 false, i1 true, i1 true, i1 false, i1 true, i1 false>, <64 x i8> %x, <64 x i8> zeroinitializer
   ret <64 x i8> %ret
 }
+
+define void @ktest_1(<8 x double> %in, double * %base) {
+; KNL-LABEL: ktest_1:
+; KNL:       ## BB#0:
+; KNL-NEXT:    vmovupd (%rdi), %zmm1
+; KNL-NEXT:    vcmpltpd %zmm0, %zmm1, %k1
+; KNL-NEXT:    vmovupd 8(%rdi), %zmm1 {%k1} {z}
+; KNL-NEXT:    vcmpltpd %zmm1, %zmm0, %k0 {%k1}
+; KNL-NEXT:    kmovw %k0, %eax
+; KNL-NEXT:    testb %al, %al
+; KNL-NEXT:    je LBB38_2
+; KNL-NEXT:  ## BB#1: ## %L1
+; KNL-NEXT:    vmovapd %zmm0, (%rdi)
+; KNL-NEXT:    retq
+; KNL-NEXT:  LBB38_2: ## %L2
+; KNL-NEXT:    vmovapd %zmm0, 8(%rdi)
+; KNL-NEXT:    retq
+;
+; SKX-LABEL: ktest_1:
+; SKX:       ## BB#0:
+; SKX-NEXT:    vmovupd (%rdi), %zmm1
+; SKX-NEXT:    vcmpltpd %zmm0, %zmm1, %k1
+; SKX-NEXT:    vmovupd 8(%rdi), %zmm1 {%k1} {z}
+; SKX-NEXT:    vcmpltpd %zmm1, %zmm0, %k0 {%k1}
+; SKX-NEXT:    ktestb %k0, %k0
+; SKX-NEXT:    je LBB38_2
+; SKX-NEXT:  ## BB#1: ## %L1
+; SKX-NEXT:    vmovapd %zmm0, (%rdi)
+; SKX-NEXT:    retq
+; SKX-NEXT:  LBB38_2: ## %L2
+; SKX-NEXT:    vmovapd %zmm0, 8(%rdi)
+; SKX-NEXT:    retq
+  %addr1 = getelementptr double, double * %base, i64 0
+  %addr2 = getelementptr double, double * %base, i64 1
+
+  %vaddr1 = bitcast double* %addr1 to <8 x double>*
+  %vaddr2 = bitcast double* %addr2 to <8 x double>*
+
+  %val1 = load <8 x double>, <8 x double> *%vaddr1, align 1
+  %val2 = load <8 x double>, <8 x double> *%vaddr2, align 1
+
+  %sel1 = fcmp ogt <8 x double>%in, %val1
+  %val3 = select <8 x i1> %sel1, <8 x double> %val2, <8 x double> zeroinitializer
+  %sel2 = fcmp olt <8 x double> %in, %val3
+  %sel3 = and <8 x i1> %sel1, %sel2
+
+  %int_sel3 = bitcast <8 x i1> %sel3 to i8
+  %res = icmp eq i8 %int_sel3, zeroinitializer
+  br i1 %res, label %L2, label %L1
+L1:
+  store <8 x double> %in, <8 x double>* %vaddr1
+  br label %End
+L2:
+  store <8 x double> %in, <8 x double>* %vaddr2
+  br label %End
+End:
+  ret void
+}
+
+define void @ktest_2(<32 x float> %in, float * %base) {
+;
+; SKX-LABEL: ktest_2:
+; SKX:       ## BB#0:
+; SKX-NEXT:    vmovups 64(%rdi), %zmm2
+; SKX-NEXT:    vmovups (%rdi), %zmm3
+; SKX-NEXT:    vcmpltps %zmm0, %zmm3, %k1
+; SKX-NEXT:    vcmpltps %zmm1, %zmm2, %k2
+; SKX-NEXT:    kunpckwd %k1, %k2, %k0
+; SKX-NEXT:    vmovups 68(%rdi), %zmm2 {%k2} {z}
+; SKX-NEXT:    vmovups 4(%rdi), %zmm3 {%k1} {z}
+; SKX-NEXT:    vcmpltps %zmm3, %zmm0, %k1
+; SKX-NEXT:    vcmpltps %zmm2, %zmm1, %k2
+; SKX-NEXT:    kunpckwd %k1, %k2, %k1
+; SKX-NEXT:    kord %k1, %k0, %k0
+; SKX-NEXT:    ktestd %k0, %k0
+; SKX-NEXT:    je LBB39_2
+; SKX-NEXT:  ## BB#1: ## %L1
+; SKX-NEXT:    vmovaps %zmm0, (%rdi)
+; SKX-NEXT:    vmovaps %zmm1, 64(%rdi)
+; SKX-NEXT:    retq
+; SKX-NEXT:  LBB39_2: ## %L2
+; SKX-NEXT:    vmovaps %zmm0, 4(%rdi)
+; SKX-NEXT:    vmovaps %zmm1, 68(%rdi)
+; SKX-NEXT:    retq
+  %addr1 = getelementptr float, float * %base, i64 0
+  %addr2 = getelementptr float, float * %base, i64 1
+
+  %vaddr1 = bitcast float* %addr1 to <32 x float>*
+  %vaddr2 = bitcast float* %addr2 to <32 x float>*
+
+  %val1 = load <32 x float>, <32 x float> *%vaddr1, align 1
+  %val2 = load <32 x float>, <32 x float> *%vaddr2, align 1
+
+  %sel1 = fcmp ogt <32 x float>%in, %val1
+  %val3 = select <32 x i1> %sel1, <32 x float> %val2, <32 x float> zeroinitializer
+  %sel2 = fcmp olt <32 x float> %in, %val3
+  %sel3 = or <32 x i1> %sel1, %sel2
+
+  %int_sel3 = bitcast <32 x i1> %sel3 to i32
+  %res = icmp eq i32 %int_sel3, zeroinitializer
+  br i1 %res, label %L2, label %L1
+L1:
+  store <32 x float> %in, <32 x float>* %vaddr1
+  br label %End
+L2:
+  store <32 x float> %in, <32 x float>* %vaddr2
+  br label %End
+End:
+  ret void
+}