From b4f1bfa65982804d0e34beffea2753783c9878c2 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Mon, 8 Apr 2019 13:17:51 +0000 Subject: [PATCH] [InstCombine][X86] Expand MOVMSK to generic IR (PR39927) First step towards removing the MOVMSK intrinsics completely - this patch expands MOVMSK to the pattern: e.g. PMOVMSKB(v16i8 x): %cmp = icmp slt <16 x i8> %x, zeroinitializer %int = bitcast <16 x i8> %cmp to i16 %res = zext i16 %int to i32 Which is correctly handled by ISel and FastIsel (give or take an annoying movzx move....): https://godbolt.org/z/rkrSFW Differential Revision: https://reviews.llvm.org/D60256 llvm-svn: 357909 --- .../Transforms/InstCombine/InstCombineCalls.cpp | 54 +++++------------- llvm/test/Transforms/InstCombine/X86/x86-movmsk.ll | 65 ++++++++++++++-------- 2 files changed, 56 insertions(+), 63 deletions(-) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index afcd878..08f4be9 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -710,46 +710,20 @@ static Value *simplifyX86movmsk(const IntrinsicInst &II, if (!ArgTy->isVectorTy()) return nullptr; - if (auto *C = dyn_cast(Arg)) { - // Extract signbits of the vector input and pack into integer result. - APInt Result(ResTy->getPrimitiveSizeInBits(), 0); - for (unsigned I = 0, E = ArgTy->getVectorNumElements(); I != E; ++I) { - auto *COp = C->getAggregateElement(I); - if (!COp) - return nullptr; - if (isa(COp)) - continue; - - auto *CInt = dyn_cast(COp); - auto *CFp = dyn_cast(COp); - if (!CInt && !CFp) - return nullptr; - - if ((CInt && CInt->isNegative()) || (CFp && CFp->isNegative())) - Result.setBit(I); - } - return Constant::getIntegerValue(ResTy, Result); - } - - // Look for a sign-extended boolean source vector as the argument to this - // movmsk. If the argument is bitcast, look through that, but make sure the - // source of that bitcast is still a vector with the same number of elements. - // TODO: We can also convert a bitcast with wider elements, but that requires - // duplicating the bool source sign bits to match the number of elements - // expected by the movmsk call. - Arg = peekThroughBitcast(Arg); - Value *X; - if (Arg->getType()->isVectorTy() && - Arg->getType()->getVectorNumElements() == ArgTy->getVectorNumElements() && - match(Arg, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { - // call iM movmsk(sext X) --> zext (bitcast X to iN) to iM - unsigned NumElts = X->getType()->getVectorNumElements(); - Type *ScalarTy = Type::getIntNTy(Arg->getContext(), NumElts); - Value *BC = Builder.CreateBitCast(X, ScalarTy); - return Builder.CreateZExtOrTrunc(BC, ResTy); - } - - return nullptr; + // Expand MOVMSK to compare/bitcast/zext: + // e.g. PMOVMSKB(v16i8 x): + // %cmp = icmp slt <16 x i8> %x, zeroinitializer + // %int = bitcast <16 x i1> %cmp to i16 + // %res = zext i16 %int to i32 + unsigned NumElts = ArgTy->getVectorNumElements(); + Type *IntegerVecTy = VectorType::getInteger(cast(ArgTy)); + Type *IntegerTy = Builder.getIntNTy(NumElts); + + Value *Res = Builder.CreateBitCast(Arg, IntegerVecTy); + Res = Builder.CreateICmpSLT(Res, Constant::getNullValue(IntegerVecTy)); + Res = Builder.CreateBitCast(Res, IntegerTy); + Res = Builder.CreateZExtOrTrunc(Res, ResTy); + return Res; } static Value *simplifyX86addcarry(const IntrinsicInst &II, diff --git a/llvm/test/Transforms/InstCombine/X86/x86-movmsk.ll b/llvm/test/Transforms/InstCombine/X86/x86-movmsk.ll index ff323b0..7be8f08 100644 --- a/llvm/test/Transforms/InstCombine/X86/x86-movmsk.ll +++ b/llvm/test/Transforms/InstCombine/X86/x86-movmsk.ll @@ -19,8 +19,11 @@ define i32 @test_upper_x86_mmx_pmovmskb(x86_mmx %a0) { define i32 @test_upper_x86_sse_movmsk_ps(<4 x float> %a0) { ; CHECK-LABEL: @test_upper_x86_sse_movmsk_ps( -; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> [[A0:%.*]]) -; CHECK-NEXT: ret i32 [[TMP1]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x float> [[A0:%.*]] to <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = icmp slt <4 x i32> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i1> [[TMP2]] to i4 +; CHECK-NEXT: [[TMP4:%.*]] = zext i4 [[TMP3]] to i32 +; CHECK-NEXT: ret i32 [[TMP4]] ; %1 = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> %a0) %2 = and i32 %1, 15 @@ -29,8 +32,11 @@ define i32 @test_upper_x86_sse_movmsk_ps(<4 x float> %a0) { define i32 @test_upper_x86_sse2_movmsk_pd(<2 x double> %a0) { ; CHECK-LABEL: @test_upper_x86_sse2_movmsk_pd( -; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.x86.sse2.movmsk.pd(<2 x double> [[A0:%.*]]) -; CHECK-NEXT: ret i32 [[TMP1]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x double> [[A0:%.*]] to <2 x i64> +; CHECK-NEXT: [[TMP2:%.*]] = icmp slt <2 x i64> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = bitcast <2 x i1> [[TMP2]] to i2 +; CHECK-NEXT: [[TMP4:%.*]] = zext i2 [[TMP3]] to i32 +; CHECK-NEXT: ret i32 [[TMP4]] ; %1 = call i32 @llvm.x86.sse2.movmsk.pd(<2 x double> %a0) %2 = and i32 %1, 3 @@ -39,8 +45,10 @@ define i32 @test_upper_x86_sse2_movmsk_pd(<2 x double> %a0) { define i32 @test_upper_x86_sse2_pmovmskb_128(<16 x i8> %a0) { ; CHECK-LABEL: @test_upper_x86_sse2_pmovmskb_128( -; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.x86.sse2.pmovmskb.128(<16 x i8> [[A0:%.*]]) -; CHECK-NEXT: ret i32 [[TMP1]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp slt <16 x i8> [[A0:%.*]], zeroinitializer +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <16 x i1> [[TMP1]] to i16 +; CHECK-NEXT: [[TMP3:%.*]] = zext i16 [[TMP2]] to i32 +; CHECK-NEXT: ret i32 [[TMP3]] ; %1 = call i32 @llvm.x86.sse2.pmovmskb.128(<16 x i8> %a0) %2 = and i32 %1, 65535 @@ -49,8 +57,11 @@ define i32 @test_upper_x86_sse2_pmovmskb_128(<16 x i8> %a0) { define i32 @test_upper_x86_avx_movmsk_ps_256(<8 x float> %a0) { ; CHECK-LABEL: @test_upper_x86_avx_movmsk_ps_256( -; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.x86.avx.movmsk.ps.256(<8 x float> [[A0:%.*]]) -; CHECK-NEXT: ret i32 [[TMP1]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x float> [[A0:%.*]] to <8 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = icmp slt <8 x i32> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = bitcast <8 x i1> [[TMP2]] to i8 +; CHECK-NEXT: [[TMP4:%.*]] = zext i8 [[TMP3]] to i32 +; CHECK-NEXT: ret i32 [[TMP4]] ; %1 = call i32 @llvm.x86.avx.movmsk.ps.256(<8 x float> %a0) %2 = and i32 %1, 255 @@ -59,8 +70,11 @@ define i32 @test_upper_x86_avx_movmsk_ps_256(<8 x float> %a0) { define i32 @test_upper_x86_avx_movmsk_pd_256(<4 x double> %a0) { ; CHECK-LABEL: @test_upper_x86_avx_movmsk_pd_256( -; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.x86.avx.movmsk.pd.256(<4 x double> [[A0:%.*]]) -; CHECK-NEXT: ret i32 [[TMP1]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x double> [[A0:%.*]] to <4 x i64> +; CHECK-NEXT: [[TMP2:%.*]] = icmp slt <4 x i64> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i1> [[TMP2]] to i4 +; CHECK-NEXT: [[TMP4:%.*]] = zext i4 [[TMP3]] to i32 +; CHECK-NEXT: ret i32 [[TMP4]] ; %1 = call i32 @llvm.x86.avx.movmsk.pd.256(<4 x double> %a0) %2 = and i32 %1, 15 @@ -382,14 +396,16 @@ define i32 @sext_avx2_pmovmskb(<32 x i1> %x) { ret i32 %r } -; Negative test - bitcast from scalar. +; Bitcast from sign-extended scalar. define i32 @sext_sse_movmsk_ps_scalar_source(i1 %x) { ; CHECK-LABEL: @sext_sse_movmsk_ps_scalar_source( ; CHECK-NEXT: [[SEXT:%.*]] = sext i1 [[X:%.*]] to i128 -; CHECK-NEXT: [[BC:%.*]] = bitcast i128 [[SEXT]] to <4 x float> -; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> [[BC]]) -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i128 [[SEXT]] to <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = icmp slt <4 x i32> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i1> [[TMP2]] to i4 +; CHECK-NEXT: [[TMP4:%.*]] = zext i4 [[TMP3]] to i32 +; CHECK-NEXT: ret i32 [[TMP4]] ; %sext = sext i1 %x to i128 %bc = bitcast i128 %sext to <4 x float> @@ -397,14 +413,16 @@ define i32 @sext_sse_movmsk_ps_scalar_source(i1 %x) { ret i32 %r } -; Negative test - bitcast from vector type with more elements. +; Bitcast from vector type with more elements. define i32 @sext_sse_movmsk_ps_too_many_elts(<8 x i1> %x) { ; CHECK-LABEL: @sext_sse_movmsk_ps_too_many_elts( ; CHECK-NEXT: [[SEXT:%.*]] = sext <8 x i1> [[X:%.*]] to <8 x i16> -; CHECK-NEXT: [[BC:%.*]] = bitcast <8 x i16> [[SEXT]] to <4 x float> -; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> [[BC]]) -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i16> [[SEXT]] to <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = icmp slt <4 x i32> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i1> [[TMP2]] to i4 +; CHECK-NEXT: [[TMP4:%.*]] = zext i4 [[TMP3]] to i32 +; CHECK-NEXT: ret i32 [[TMP4]] ; %sext = sext <8 x i1> %x to <8 x i16> %bc = bitcast <8 x i16> %sext to <4 x float> @@ -412,15 +430,16 @@ define i32 @sext_sse_movmsk_ps_too_many_elts(<8 x i1> %x) { ret i32 %r } -; TODO: We could handle this by doing a bitcasted sign-bit test after the sext? -; But need to make sure the backend handles that correctly. +; Handle this by doing a bitcasted sign-bit test after the sext. define i32 @sext_sse_movmsk_ps_must_replicate_bits(<2 x i1> %x) { ; CHECK-LABEL: @sext_sse_movmsk_ps_must_replicate_bits( ; CHECK-NEXT: [[SEXT:%.*]] = sext <2 x i1> [[X:%.*]] to <2 x i64> -; CHECK-NEXT: [[BC:%.*]] = bitcast <2 x i64> [[SEXT]] to <4 x float> -; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.x86.sse.movmsk.ps(<4 x float> [[BC]]) -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i64> [[SEXT]] to <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = icmp slt <4 x i32> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = bitcast <4 x i1> [[TMP2]] to i4 +; CHECK-NEXT: [[TMP4:%.*]] = zext i4 [[TMP3]] to i32 +; CHECK-NEXT: ret i32 [[TMP4]] ; %sext = sext <2 x i1> %x to <2 x i64> %bc = bitcast <2 x i64> %sext to <4 x float> -- 2.7.4