[X86] LowerVectorAllEqual - generalize mask type generation. NFC.
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Tue, 4 Apr 2023 13:48:20 +0000 (14:48 +0100)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Tue, 4 Apr 2023 13:48:28 +0000 (14:48 +0100)
This will be necessary once we merge with combineVectorSizedSetCCEquality and we need to support KORTEST handling for 128/256-bit comparisons

llvm/lib/Target/X86/X86ISelLowering.cpp

index 0ffce00..8353cda 100644 (file)
@@ -24419,20 +24419,23 @@ static SDValue LowerVectorAllEqual(const SDLoc &DL, SDValue LHS, SDValue RHS,
   }
 
   if (UseKORTEST && VT.is512BitVector()) {
-    LHS = DAG.getBitcast(MVT::v16i32, MaskBits(LHS));
-    RHS = DAG.getBitcast(MVT::v16i32, MaskBits(RHS));
-    SDValue V = DAG.getSetCC(DL, MVT::v16i1, LHS, RHS, ISD::SETNE);
+    MVT TestVT = MVT::getVectorVT(MVT::i32, VT.getSizeInBits() / 32);
+    MVT BoolVT = TestVT.changeVectorElementType(MVT::i1);
+    LHS = DAG.getBitcast(TestVT, MaskBits(LHS));
+    RHS = DAG.getBitcast(TestVT, MaskBits(RHS));
+    SDValue V = DAG.getSetCC(DL, BoolVT, LHS, RHS, ISD::SETNE);
     return DAG.getNode(X86ISD::KORTEST, DL, MVT::i32, V, V);
   }
 
   if (UsePTEST) {
-    MVT TestVT = VT.is128BitVector() ? MVT::v2i64 : MVT::v4i64;
+    MVT TestVT = MVT::getVectorVT(MVT::i64, VT.getSizeInBits() / 64);
     LHS = DAG.getBitcast(TestVT, MaskBits(LHS));
     RHS = DAG.getBitcast(TestVT, MaskBits(RHS));
     SDValue V = DAG.getNode(ISD::XOR, DL, TestVT, LHS, RHS);
     return DAG.getNode(X86ISD::PTEST, DL, MVT::i32, V, V);
   }
 
+  assert(VT.getSizeInBits() == 128 && "Failure to split to 128-bits");
   MVT MaskVT = ScalarSize >= 32 ? MVT::v4i32 : MVT::v16i8;
   LHS = DAG.getBitcast(MaskVT, MaskBits(LHS));
   RHS = DAG.getBitcast(MaskVT, MaskBits(RHS));