[InstCombine] try to convert x86 movmsk intrinsic to generic IR (PR39927)
authorSanjay Patel <spatel@rotateright.com>
Tue, 11 Dec 2018 16:38:03 +0000 (16:38 +0000)
committerSanjay Patel <spatel@rotateright.com>
Tue, 11 Dec 2018 16:38:03 +0000 (16:38 +0000)
call iM movmsk(sext <N x i1> X) --> zext (bitcast <N x i1> X to iN) to iM

This has the potential to create less-than-8-bit scalar types as shown in
some of the test diffs, but it looks like the backend knows how to deal
with that in these patterns. This is the simple part of the fix suggested in:
https://bugs.llvm.org/show_bug.cgi?id=39927

Differential Revision: https://reviews.llvm.org/D55529

llvm-svn: 348862

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
llvm/test/Transforms/InstCombine/X86/x86-movmsk.ll

index a023b6d..ae158ae 100644 (file)
@@ -736,7 +736,8 @@ static Value *simplifyX86round(IntrinsicInst &II,
   return Builder.CreateInsertElement(Dst, Res, (uint64_t)0);
 }
 
-static Value *simplifyX86movmsk(const IntrinsicInst &II) {
+static Value *simplifyX86movmsk(const IntrinsicInst &II,
+                                InstCombiner::BuilderTy &Builder) {
   Value *Arg = II.getArgOperand(0);
   Type *ResTy = II.getType();
   Type *ArgTy = Arg->getType();
@@ -749,29 +750,46 @@ static Value *simplifyX86movmsk(const IntrinsicInst &II) {
   if (!ArgTy->isVectorTy())
     return nullptr;
 
-  auto *C = dyn_cast<Constant>(Arg);
-  if (!C)
-    return nullptr;
+  if (auto *C = dyn_cast<Constant>(Arg)) {
+    // Extract signbits of the vector input and pack into integer result.
+    APInt Result(ResTy->getPrimitiveSizeInBits(), 0);
+    for (unsigned I = 0, E = ArgTy->getVectorNumElements(); I != E; ++I) {
+      auto *COp = C->getAggregateElement(I);
+      if (!COp)
+        return nullptr;
+      if (isa<UndefValue>(COp))
+        continue;
 
-  // Extract signbits of the vector input and pack into integer result.
-  APInt Result(ResTy->getPrimitiveSizeInBits(), 0);
-  for (unsigned I = 0, E = ArgTy->getVectorNumElements(); I != E; ++I) {
-    auto *COp = C->getAggregateElement(I);
-    if (!COp)
-      return nullptr;
-    if (isa<UndefValue>(COp))
-      continue;
+      auto *CInt = dyn_cast<ConstantInt>(COp);
+      auto *CFp = dyn_cast<ConstantFP>(COp);
+      if (!CInt && !CFp)
+        return nullptr;
 
-    auto *CInt = dyn_cast<ConstantInt>(COp);
-    auto *CFp = dyn_cast<ConstantFP>(COp);
-    if (!CInt && !CFp)
-      return nullptr;
+      if ((CInt && CInt->isNegative()) || (CFp && CFp->isNegative()))
+        Result.setBit(I);
+    }
+    return Constant::getIntegerValue(ResTy, Result);
+  }
 
-    if ((CInt && CInt->isNegative()) || (CFp && CFp->isNegative()))
-      Result.setBit(I);
+  // Look for a sign-extended boolean source vector as the argument to this
+  // movmsk. If the argument is bitcast, look through that, but make sure the
+  // source of that bitcast is still a vector with the same number of elements.
+  // TODO: We can also convert a bitcast with wider elements, but that requires
+  // duplicating the bool source sign bits to match the number of elements
+  // expected by the movmsk call.
+  Arg = peekThroughBitcast(Arg);
+  Value *X;
+  if (Arg->getType()->isVectorTy() &&
+      Arg->getType()->getVectorNumElements() == ArgTy->getVectorNumElements() &&
+      match(Arg, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
+    // call iM movmsk(sext <N x i1> X) --> zext (bitcast <N x i1> X to iN) to iM
+    unsigned NumElts = X->getType()->getVectorNumElements();
+    Type *ScalarTy = Type::getIntNTy(Arg->getContext(), NumElts);
+    Value *BC = Builder.CreateBitCast(X, ScalarTy);
+    return Builder.CreateZExtOrTrunc(BC, ResTy);
   }
 
-  return Constant::getIntegerValue(ResTy, Result);
+  return nullptr;
 }
 
 static Value *simplifyX86insertps(const IntrinsicInst &II,
@@ -2543,7 +2561,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
   case Intrinsic::x86_avx_movmsk_pd_256:
   case Intrinsic::x86_avx_movmsk_ps_256:
   case Intrinsic::x86_avx2_pmovmskb:
-    if (Value *V = simplifyX86movmsk(*II))
+    if (Value *V = simplifyX86movmsk(*II, Builder))
       return replaceInstUsesWith(*II, V);
     break;
 
index 15ff405..ff323b0 100644 (file)
@@ -315,10 +315,9 @@ define i32 @fold_x86_avx2_pmovmskb() {
 
 define i32 @sext_sse_movmsk_ps(<4 x i1> %x) {
 ; CHECK-LABEL: @sext_sse_movmsk_ps(
-; CHECK-NEXT:    [[SEXT:%.*]] = sext <4 x i1> [[X:%.*]] to <4 x i32>
-; CHECK-NEXT:    [[BC:%.*]] = bitcast <4 x i32> [[SEXT]] to <4 x float>
-; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> [[BC]])
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i1> [[X:%.*]] to i4
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i4 [[TMP1]] to i32
+; CHECK-NEXT:    ret i32 [[TMP2]]
 ;
   %sext = sext <4 x i1> %x to <4 x i32>
   %bc = bitcast <4 x i32> %sext to <4 x float>
@@ -328,10 +327,9 @@ define i32 @sext_sse_movmsk_ps(<4 x i1> %x) {
 
 define i32 @sext_sse2_movmsk_pd(<2 x i1> %x) {
 ; CHECK-LABEL: @sext_sse2_movmsk_pd(
-; CHECK-NEXT:    [[SEXT:%.*]] = sext <2 x i1> [[X:%.*]] to <2 x i64>
-; CHECK-NEXT:    [[BC:%.*]] = bitcast <2 x i64> [[SEXT]] to <2 x double>
-; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.sse2.movmsk.pd(<2 x double> [[BC]])
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <2 x i1> [[X:%.*]] to i2
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i2 [[TMP1]] to i32
+; CHECK-NEXT:    ret i32 [[TMP2]]
 ;
   %sext = sext <2 x i1> %x to <2 x i64>
   %bc = bitcast <2 x i64> %sext to <2 x double>
@@ -341,9 +339,9 @@ define i32 @sext_sse2_movmsk_pd(<2 x i1> %x) {
 
 define i32 @sext_sse2_pmovmskb_128(<16 x i1> %x) {
 ; CHECK-LABEL: @sext_sse2_pmovmskb_128(
-; CHECK-NEXT:    [[SEXT:%.*]] = sext <16 x i1> [[X:%.*]] to <16 x i8>
-; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.sse2.pmovmskb.128(<16 x i8> [[SEXT]])
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <16 x i1> [[X:%.*]] to i16
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i16 [[TMP1]] to i32
+; CHECK-NEXT:    ret i32 [[TMP2]]
 ;
   %sext = sext <16 x i1> %x to <16 x i8>
   %r = call i32 @llvm.x86.sse2.pmovmskb.128(<16 x i8> %sext)
@@ -352,10 +350,9 @@ define i32 @sext_sse2_pmovmskb_128(<16 x i1> %x) {
 
 define i32 @sext_avx_movmsk_ps_256(<8 x i1> %x) {
 ; CHECK-LABEL: @sext_avx_movmsk_ps_256(
-; CHECK-NEXT:    [[SEXT:%.*]] = sext <8 x i1> [[X:%.*]] to <8 x i32>
-; CHECK-NEXT:    [[BC:%.*]] = bitcast <8 x i32> [[SEXT]] to <8 x float>
-; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.avx.movmsk.ps.256(<8 x float> [[BC]])
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <8 x i1> [[X:%.*]] to i8
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i8 [[TMP1]] to i32
+; CHECK-NEXT:    ret i32 [[TMP2]]
 ;
   %sext = sext <8 x i1> %x to <8 x i32>
   %bc = bitcast <8 x i32> %sext to <8 x float>
@@ -365,10 +362,9 @@ define i32 @sext_avx_movmsk_ps_256(<8 x i1> %x) {
 
 define i32 @sext_avx_movmsk_pd_256(<4 x i1> %x) {
 ; CHECK-LABEL: @sext_avx_movmsk_pd_256(
-; CHECK-NEXT:    [[SEXT:%.*]] = sext <4 x i1> [[X:%.*]] to <4 x i64>
-; CHECK-NEXT:    [[BC:%.*]] = bitcast <4 x i64> [[SEXT]] to <4 x double>
-; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.avx.movmsk.pd.256(<4 x double> [[BC]])
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <4 x i1> [[X:%.*]] to i4
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i4 [[TMP1]] to i32
+; CHECK-NEXT:    ret i32 [[TMP2]]
 ;
   %sext = sext <4 x i1> %x to <4 x i64>
   %bc = bitcast <4 x i64> %sext to <4 x double>
@@ -378,15 +374,60 @@ define i32 @sext_avx_movmsk_pd_256(<4 x i1> %x) {
 
 define i32 @sext_avx2_pmovmskb(<32 x i1> %x) {
 ; CHECK-LABEL: @sext_avx2_pmovmskb(
-; CHECK-NEXT:    [[SEXT:%.*]] = sext <32 x i1> [[X:%.*]] to <32 x i8>
-; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.avx2.pmovmskb(<32 x i8> [[SEXT]])
-; CHECK-NEXT:    ret i32 [[R]]
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast <32 x i1> [[X:%.*]] to i32
+; CHECK-NEXT:    ret i32 [[TMP1]]
 ;
   %sext = sext <32 x i1> %x to <32 x i8>
   %r = call i32 @llvm.x86.avx2.pmovmskb(<32 x i8> %sext)
   ret i32 %r
 }
 
+; Negative test - bitcast from scalar.
+
+define i32 @sext_sse_movmsk_ps_scalar_source(i1 %x) {
+; CHECK-LABEL: @sext_sse_movmsk_ps_scalar_source(
+; CHECK-NEXT:    [[SEXT:%.*]] = sext i1 [[X:%.*]] to i128
+; CHECK-NEXT:    [[BC:%.*]] = bitcast i128 [[SEXT]] to <4 x float>
+; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> [[BC]])
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %sext = sext i1 %x to i128
+  %bc = bitcast i128 %sext to <4 x float>
+  %r = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> %bc)
+  ret i32 %r
+}
+
+; Negative test - bitcast from vector type with more elements.
+
+define i32 @sext_sse_movmsk_ps_too_many_elts(<8 x i1> %x) {
+; CHECK-LABEL: @sext_sse_movmsk_ps_too_many_elts(
+; CHECK-NEXT:    [[SEXT:%.*]] = sext <8 x i1> [[X:%.*]] to <8 x i16>
+; CHECK-NEXT:    [[BC:%.*]] = bitcast <8 x i16> [[SEXT]] to <4 x float>
+; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> [[BC]])
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %sext = sext <8 x i1> %x to <8 x i16>
+  %bc = bitcast <8 x i16> %sext to <4 x float>
+  %r = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> %bc)
+  ret i32 %r
+}
+
+; TODO: We could handle this by doing a bitcasted sign-bit test after the sext?
+; But need to make sure the backend handles that correctly.
+
+define i32 @sext_sse_movmsk_ps_must_replicate_bits(<2 x i1> %x) {
+; CHECK-LABEL: @sext_sse_movmsk_ps_must_replicate_bits(
+; CHECK-NEXT:    [[SEXT:%.*]] = sext <2 x i1> [[X:%.*]] to <2 x i64>
+; CHECK-NEXT:    [[BC:%.*]] = bitcast <2 x i64> [[SEXT]] to <4 x float>
+; CHECK-NEXT:    [[R:%.*]] = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> [[BC]])
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %sext = sext <2 x i1> %x to <2 x i64>
+  %bc = bitcast <2 x i64> %sext to <4 x float>
+  %r = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> %bc)
+  ret i32 %r
+}
+
 declare i32 @llvm.x86.mmx.pmovmskb(x86_mmx)
 
 declare i32 @llvm.x86.sse.movmsk.ps(<4 x float>)