From f89f62a68094355bd37c74456aeef6ecab3898fe Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Fri, 6 Jul 2018 22:46:52 +0000 Subject: [PATCH] [X86] When creating a select for scalar masked sqrt and div builtins make sure we optimize the all ones mask case. This case occurs in the intrinsic headers so we should avoid emitting the mask in those cases. Factor the code into a helper function to make this easy. llvm-svn: 336472 --- clang/lib/CodeGen/CGBuiltin.cpp | 31 +++--- clang/test/CodeGen/avx512f-builtins.c | 184 +++++++++++++++++----------------- 2 files changed, 109 insertions(+), 106 deletions(-) diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index f2efca8c..3ebf584 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -8517,6 +8517,21 @@ static Value *EmitX86Select(CodeGenFunction &CGF, return CGF.Builder.CreateSelect(Mask, Op0, Op1); } +static Value *EmitX86ScalarSelect(CodeGenFunction &CGF, + Value *Mask, Value *Op0, Value *Op1) { + // If the mask is all ones just return first argument. + if (const auto *C = dyn_cast(Mask)) + if (C->isAllOnesValue()) + return Op0; + + llvm::VectorType *MaskTy = + llvm::VectorType::get(CGF.Builder.getInt1Ty(), + Mask->getType()->getIntegerBitWidth()); + Mask = CGF.Builder.CreateBitCast(Mask, MaskTy); + Mask = CGF.Builder.CreateExtractElement(Mask, (uint64_t)0); + return CGF.Builder.CreateSelect(Mask, Op0, Op1); +} + static Value *EmitX86MaskedCompareResult(CodeGenFunction &CGF, Value *Cmp, unsigned NumElts, Value *MaskIn) { if (MaskIn) { @@ -9884,12 +9899,9 @@ Value *CodeGenFunction::EmitX86BuiltinExpr(unsigned BuiltinID, } Value *A = Builder.CreateExtractElement(Ops[1], (uint64_t)0); Function *F = CGM.getIntrinsic(Intrinsic::sqrt, A->getType()); + A = Builder.CreateCall(F, A); Value *Src = Builder.CreateExtractElement(Ops[2], (uint64_t)0); - int MaskSize = Ops[3]->getType()->getScalarSizeInBits(); - llvm::Type *MaskTy = llvm::VectorType::get(Builder.getInt1Ty(), MaskSize); - Value *Mask = Builder.CreateBitCast(Ops[3], MaskTy); - Mask = Builder.CreateExtractElement(Mask, (uint64_t)0); - A = Builder.CreateSelect(Mask, Builder.CreateCall(F, {A}), Src); + A = EmitX86ScalarSelect(*this, Ops[3], A, Src); return Builder.CreateInsertElement(Ops[0], A, (uint64_t)0); } case X86::BI__builtin_ia32_sqrtpd256: @@ -10024,14 +10036,9 @@ Value *CodeGenFunction::EmitX86BuiltinExpr(unsigned BuiltinID, Value *A = Builder.CreateExtractElement(Ops[0], (uint64_t)0); Value *B = Builder.CreateExtractElement(Ops[1], (uint64_t)0); Value *C = Builder.CreateExtractElement(Ops[2], (uint64_t)0); - Value *Mask = Ops[3]; Value *Div = Builder.CreateFDiv(A, B); - llvm::VectorType *MaskTy = llvm::VectorType::get(Builder.getInt1Ty(), - cast(Mask->getType())->getBitWidth()); - Mask = Builder.CreateBitCast(Mask, MaskTy); - Mask = Builder.CreateExtractElement(Mask, (uint64_t)0); - Value *Select = Builder.CreateSelect(Mask, Div, C); - return Builder.CreateInsertElement(Ops[0], Select, (uint64_t)0); + Div = EmitX86ScalarSelect(*this, Ops[3], Div, C); + return Builder.CreateInsertElement(Ops[0], Div, (uint64_t)0); } // 3DNow! diff --git a/clang/test/CodeGen/avx512f-builtins.c b/clang/test/CodeGen/avx512f-builtins.c index 2cfaf0c..c72167e 100644 --- a/clang/test/CodeGen/avx512f-builtins.c +++ b/clang/test/CodeGen/avx512f-builtins.c @@ -3581,27 +3581,26 @@ __m128 test_mm_maskz_div_round_ss(__mmask8 __U, __m128 __A, __m128 __B) { } __m128 test_mm_mask_div_ss(__m128 __W, __mmask8 __U, __m128 __A, __m128 __B) { // CHECK-LABEL: @test_mm_mask_div_ss - // CHECK-NOT: @llvm.x86.avx512.mask.div.ss.round // CHECK: extractelement <4 x float> %{{.*}}, i64 0 - // CHECK: extractelement <4 x float> %{{.*}}, i64 0 - // CHECK: extractelement <4 x float> %{{.*}}, i64 0 - // CHECK: fdiv float %{{.*}}, %{{.*}} - // CHECK: bitcast i8 %{{.*}} to <8 x i1> - // CHECK: extractelement <8 x i1> %{{.*}}, i64 0 - // CHECK: select i1 %{{.*}}, float %{{.*}}, float %{{.*}} - // CHECK: insertelement <4 x float> %{{.*}}, float %{{.*}}, i64 0 + // CHECK-NEXT: extractelement <4 x float> %{{.*}}, i64 0 + // CHECK-NEXT: extractelement <4 x float> %{{.*}}, i64 0 + // CHECK-NEXT: fdiv float %{{.*}}, %{{.*}} + // CHECK-NEXT: bitcast i8 %{{.*}} to <8 x i1> + // CHECK-NEXT: extractelement <8 x i1> %{{.*}}, i64 0 + // CHECK-NEXT: select i1 %{{.*}}, float %{{.*}}, float %{{.*}} + // CHECK-NEXT: insertelement <4 x float> %{{.*}}, float %{{.*}}, i64 0 return _mm_mask_div_ss(__W,__U,__A,__B); } __m128 test_mm_maskz_div_ss(__mmask8 __U, __m128 __A, __m128 __B) { // CHECK-LABEL: @test_mm_maskz_div_ss - // CHECK-NOT: @llvm.x86.avx512.mask.div.ss.round // CHECK: extractelement <4 x float> %{{.*}}, i64 0 - // CHECK: extractelement <4 x float> %{{.*}}, i64 0 - // CHECK: fdiv float %{{.*}}, %{{.*}} - // CHECK: bitcast i8 %{{.*}} to <8 x i1> - // CHECK: extractelement <8 x i1> %{{.*}}, i64 0 - // CHECK: select i1 %{{.*}}, float %{{.*}}, float %{{.*}} - // CHECK: insertelement <4 x float> %{{.*}}, float %{{.*}}, i64 0 + // CHECK-NEXT: extractelement <4 x float> %{{.*}}, i64 0 + // CHECK-NEXT: extractelement <4 x float> %{{.*}}, i64 0 + // CHECK-NEXT: fdiv float %{{.*}}, %{{.*}} + // CHECK-NEXT: bitcast i8 %{{.*}} to <8 x i1> + // CHECK-NEXT: extractelement <8 x i1> %{{.*}}, i64 0 + // CHECK-NEXT: select i1 %{{.*}}, float %{{.*}}, float %{{.*}} + // CHECK-NEXT: insertelement <4 x float> %{{.*}}, float %{{.*}}, i64 0 return _mm_maskz_div_ss(__U,__A,__B); } __m128d test_mm_div_round_sd(__m128d __A, __m128d __B) { @@ -3621,27 +3620,26 @@ __m128d test_mm_maskz_div_round_sd(__mmask8 __U, __m128d __A, __m128d __B) { } __m128d test_mm_mask_div_sd(__m128d __W, __mmask8 __U, __m128d __A, __m128d __B) { // CHECK-LABEL: @test_mm_mask_div_sd - // CHECK-NOT: @llvm.x86.avx512.mask.div.sd.round - // CHECK: extractelement <2 x double> %{{.*}}, i64 0 // CHECK: extractelement <2 x double> %{{.*}}, i64 0 - // CHECK: extractelement <2 x double> %{{.*}}, i64 0 - // CHECK: fdiv double %{{.*}}, %{{.*}} - // CHECK: bitcast i8 %{{.*}} to <8 x i1> - // CHECK: extractelement <8 x i1> %{{.*}}, i64 0 - // CHECK: select i1 %{{.*}}, double %{{.*}}, double %{{.*}} - // CHECK: insertelement <2 x double> %{{.*}}, double %{{.*}}, i64 0 + // CHECK-NEXT: extractelement <2 x double> %{{.*}}, i64 0 + // CHECK-NEXT: extractelement <2 x double> %{{.*}}, i64 0 + // CHECK-NEXT: fdiv double %{{.*}}, %{{.*}} + // CHECK-NEXT: bitcast i8 %{{.*}} to <8 x i1> + // CHECK-NEXT: extractelement <8 x i1> %{{.*}}, i64 0 + // CHECK-NEXT: select i1 %{{.*}}, double %{{.*}}, double %{{.*}} + // CHECK-NEXT: insertelement <2 x double> %{{.*}}, double %{{.*}}, i64 0 return _mm_mask_div_sd(__W,__U,__A,__B); } __m128d test_mm_maskz_div_sd(__mmask8 __U, __m128d __A, __m128d __B) { // CHECK-LABEL: @test_mm_maskz_div_sd - // CHECK-NOT: @llvm.x86.avx512.mask.div.sd.round - // CHECK: extractelement <2 x double> %{{.*}}, i64 0 // CHECK: extractelement <2 x double> %{{.*}}, i64 0 - // CHECK: fdiv double %{{.*}}, %{{.*}} - // CHECK: bitcast i8 %{{.*}} to <8 x i1> - // CHECK: extractelement <8 x i1> %{{.*}}, i64 0 - // CHECK: select i1 %{{.*}}, double %{{.*}}, double %{{.*}} - // CHECK: insertelement <2 x double> %{{.*}}, double %{{.*}}, i64 0 + // CHECK-NEXT: extractelement <2 x double> %{{.*}}, i64 0 + // CHECK-NEXT: extractelement <2 x double> %{{.*}}, i64 0 + // CHECK-NEXT: fdiv double %{{.*}}, %{{.*}} + // CHECK-NEXT: bitcast i8 %{{.*}} to <8 x i1> + // CHECK-NEXT: extractelement <8 x i1> %{{.*}}, i64 0 + // CHECK-NEXT: select i1 %{{.*}}, double %{{.*}}, double %{{.*}} + // CHECK-NEXT: insertelement <2 x double> %{{.*}}, double %{{.*}}, i64 0 return _mm_maskz_div_sd(__U,__A,__B); } __m128 test_mm_max_round_ss(__m128 __A, __m128 __B) { @@ -5948,117 +5946,115 @@ __m512 test_mm512_maskz_shuffle_ps(__mmask16 __U, __m512 __M, __m512 __V) { __m128d test_mm_sqrt_round_sd(__m128d __A, __m128d __B) { // CHECK-LABEL: @test_mm_sqrt_round_sd // CHECK: extractelement <2 x double> %{{.*}}, i64 0 - // CHECK: extractelement <2 x double> %{{.*}}, i64 0 - // CHECK: call double @llvm.sqrt.f64(double %{{.*}}) - // CHECK: select i1 {{.*}}, double {{.*}}, double {{.*}} - // CHECK: insertelement <2 x double> %{{.*}}, double {{.*}}, i64 0 + // CHECK-NEXT: call double @llvm.sqrt.f64(double %{{.*}}) + // CHECK-NEXT: extractelement <2 x double> %{{.*}}, i64 0 + // CHECK-NEXT: insertelement <2 x double> %{{.*}}, double {{.*}}, i64 0 return _mm_sqrt_round_sd(__A, __B, 4); } __m128d test_mm_mask_sqrt_sd(__m128d __W, __mmask8 __U, __m128d __A, __m128d __B){ // CHECK-LABEL: @test_mm_mask_sqrt_sd // CHECK: extractelement <2 x double> %{{.*}}, i64 0 - // CHECK: extractelement <2 x double> %{{.*}}, i64 0 - // CHECK: bitcast i8 %{{.*}} to <8 x i1> - // CHECK: extractelement <8 x i1> %{{.*}}, i64 0 - // CHECK: call double @llvm.sqrt.f64(double %{{.*}}) - // CHECK: select i1 {{.*}}, double {{.*}}, double {{.*}} - // CHECK: insertelement <2 x double> %{{.*}}, double {{.*}}, i64 0 - return _mm_mask_sqrt_sd(__W,__U,__A,__B); + // CHECK-NEXT: call double @llvm.sqrt.f64(double %{{.*}}) + // CHECK-NEXT: extractelement <2 x double> %{{.*}}, i64 0 + // CHECK-NEXT: bitcast i8 %{{.*}} to <8 x i1> + // CHECK-NEXT: extractelement <8 x i1> %{{.*}}, i64 0 + // CHECK-NEXT: select i1 {{.*}}, double {{.*}}, double {{.*}} + // CHECK-NEXT: insertelement <2 x double> %{{.*}}, double {{.*}}, i64 0 + return _mm_mask_sqrt_sd(__W,__U,__A,__B); } __m128d test_mm_mask_sqrt_round_sd(__m128d __W, __mmask8 __U, __m128d __A, __m128d __B){ // CHECK-LABEL: @test_mm_mask_sqrt_round_sd // CHECK: extractelement <2 x double> %{{.*}}, i64 0 - // CHECK: extractelement <2 x double> %{{.*}}, i64 0 - // CHECK: bitcast i8 %{{.*}} to <8 x i1> - // CHECK: extractelement <8 x i1> %{{.*}}, i64 0 - // CHECK: call double @llvm.sqrt.f64(double %{{.*}}) - // CHECK: select i1 {{.*}}, double {{.*}}, double {{.*}} - // CHECK: insertelement <2 x double> %{{.*}}, double {{.*}}, i64 0 - return _mm_mask_sqrt_round_sd(__W,__U,__A,__B,_MM_FROUND_CUR_DIRECTION); + // CHECK-NEXT: call double @llvm.sqrt.f64(double %{{.*}}) + // CHECK-NEXT: extractelement <2 x double> %{{.*}}, i64 0 + // CHECK-NEXT: bitcast i8 %{{.*}} to <8 x i1> + // CHECK-NEXT: extractelement <8 x i1> %{{.*}}, i64 0 + // CHECK-NEXT: select i1 {{.*}}, double {{.*}}, double {{.*}} + // CHECK-NEXT: insertelement <2 x double> %{{.*}}, double {{.*}}, i64 0 + return _mm_mask_sqrt_round_sd(__W,__U,__A,__B,_MM_FROUND_CUR_DIRECTION); } __m128d test_mm_maskz_sqrt_sd(__mmask8 __U, __m128d __A, __m128d __B){ // CHECK-LABEL: @test_mm_maskz_sqrt_sd // CHECK: extractelement <2 x double> %{{.*}}, i64 0 - // CHECK: extractelement <2 x double> %{{.*}}, i64 0 - // CHECK: bitcast i8 %{{.*}} to <8 x i1> - // CHECK: extractelement <8 x i1> %{{.*}}, i64 0 - // CHECK: call double @llvm.sqrt.f64(double %{{.*}}) - // CHECK: select i1 {{.*}}, double {{.*}}, double {{.*}} - // CHECK: insertelement <2 x double> %{{.*}}, double {{.*}}, i64 0 - return _mm_maskz_sqrt_sd(__U,__A,__B); + // CHECK-NEXT: call double @llvm.sqrt.f64(double %{{.*}}) + // CHECK-NEXT: extractelement <2 x double> %{{.*}}, i64 0 + // CHECK-NEXT: bitcast i8 %{{.*}} to <8 x i1> + // CHECK-NEXT: extractelement <8 x i1> %{{.*}}, i64 0 + // CHECK-NEXT: select i1 {{.*}}, double {{.*}}, double {{.*}} + // CHECK-NEXT: insertelement <2 x double> %{{.*}}, double {{.*}}, i64 0 + return _mm_maskz_sqrt_sd(__U,__A,__B); } __m128d test_mm_maskz_sqrt_round_sd(__mmask8 __U, __m128d __A, __m128d __B){ // CHECK-LABEL: @test_mm_maskz_sqrt_round_sd // CHECK: extractelement <2 x double> %{{.*}}, i64 0 - // CHECK: extractelement <2 x double> %{{.*}}, i64 0 - // CHECK: bitcast i8 %{{.*}} to <8 x i1> - // CHECK: extractelement <8 x i1> %{{.*}}, i64 0 - // CHECK: call double @llvm.sqrt.f64(double %{{.*}}) - // CHECK: select i1 {{.*}}, double {{.*}}, double {{.*}} - // CHECK: insertelement <2 x double> %{{.*}}, double {{.*}}, i64 0 - return _mm_maskz_sqrt_round_sd(__U,__A,__B,_MM_FROUND_CUR_DIRECTION); + // CHECK-NEXT: call double @llvm.sqrt.f64(double %{{.*}}) + // CHECK-NEXT: extractelement <2 x double> %{{.*}}, i64 0 + // CHECK-NEXT: bitcast i8 %{{.*}} to <8 x i1> + // CHECK-NEXT: extractelement <8 x i1> %{{.*}}, i64 0 + // CHECK-NEXT: select i1 {{.*}}, double {{.*}}, double {{.*}} + // CHECK-NEXT: insertelement <2 x double> %{{.*}}, double {{.*}}, i64 0 + return _mm_maskz_sqrt_round_sd(__U,__A,__B,_MM_FROUND_CUR_DIRECTION); } __m128 test_mm_sqrt_round_ss(__m128 __A, __m128 __B) { // CHECK-LABEL: @test_mm_sqrt_round_ss // CHECK: extractelement <4 x float> %{{.*}}, i64 0 - // CHECK: extractelement <4 x float> %{{.*}}, i64 0 - // CHECK: call float @llvm.sqrt.f32(float %{{.*}}) - // CHECK: select i1 {{.*}}, float {{.*}}, float {{.*}} - // CHECK: insertelement <4 x float> %{{.*}}, float {{.*}}, i64 0 + // CHECK-NEXT: call float @llvm.sqrt.f32(float %{{.*}}) + // CHECK-NEXT: extractelement <4 x float> %{{.*}}, i64 0 + // CHECK-NEXT: insertelement <4 x float> %{{.*}}, float {{.*}}, i64 0 return _mm_sqrt_round_ss(__A, __B, 4); } __m128 test_mm_mask_sqrt_ss(__m128 __W, __mmask8 __U, __m128 __A, __m128 __B){ // CHECK-LABEL: @test_mm_mask_sqrt_ss // CHECK: extractelement <4 x float> %{{.*}}, i64 0 - // CHECK: extractelement <4 x float> %{{.*}}, i64 0 - // CHECK: bitcast i8 %{{.*}} to <8 x i1> - // CHECK: extractelement <8 x i1> %{{.*}}, i64 0 - // CHECK: call float @llvm.sqrt.f32(float %{{.*}}) - // CHECK: select i1 {{.*}}, float {{.*}}, float {{.*}} - // CHECK: insertelement <4 x float> %{{.*}}, float {{.*}}, i64 0 - return _mm_mask_sqrt_ss(__W,__U,__A,__B); + // CHECK-NEXT: call float @llvm.sqrt.f32(float %{{.*}}) + // CHECK-NEXT: extractelement <4 x float> %{{.*}}, i64 0 + // CHECK-NEXT: bitcast i8 %{{.*}} to <8 x i1> + // CHECK-NEXT: extractelement <8 x i1> %{{.*}}, i64 0 + // CHECK-NEXT: select i1 {{.*}}, float {{.*}}, float {{.*}} + // CHECK-NEXT: insertelement <4 x float> %{{.*}}, float {{.*}}, i64 0 + return _mm_mask_sqrt_ss(__W,__U,__A,__B); } __m128 test_mm_mask_sqrt_round_ss(__m128 __W, __mmask8 __U, __m128 __A, __m128 __B){ // CHECK-LABEL: @test_mm_mask_sqrt_round_ss // CHECK: extractelement <4 x float> %{{.*}}, i64 0 - // CHECK: extractelement <4 x float> %{{.*}}, i64 0 - // CHECK: bitcast i8 %{{.*}} to <8 x i1> - // CHECK: extractelement <8 x i1> %{{.*}}, i64 0 - // CHECK: call float @llvm.sqrt.f32(float %{{.*}}) - // CHECK: select i1 {{.*}}, float {{.*}}, float {{.*}} - // CHECK: insertelement <4 x float> %{{.*}}, float {{.*}}, i64 0 - return _mm_mask_sqrt_round_ss(__W,__U,__A,__B,_MM_FROUND_CUR_DIRECTION); + // CHECK-NEXT: call float @llvm.sqrt.f32(float %{{.*}}) + // CHECK-NEXT: extractelement <4 x float> %{{.*}}, i64 0 + // CHECK-NEXT: bitcast i8 %{{.*}} to <8 x i1> + // CHECK-NEXT: extractelement <8 x i1> %{{.*}}, i64 0 + // CHECK-NEXT: select i1 {{.*}}, float {{.*}}, float {{.*}} + // CHECK-NEXT: insertelement <4 x float> %{{.*}}, float {{.*}}, i64 0 + return _mm_mask_sqrt_round_ss(__W,__U,__A,__B,_MM_FROUND_CUR_DIRECTION); } __m128 test_mm_maskz_sqrt_ss(__mmask8 __U, __m128 __A, __m128 __B){ // CHECK-LABEL: @test_mm_maskz_sqrt_ss // CHECK: extractelement <4 x float> %{{.*}}, i64 0 - // CHECK: extractelement <4 x float> %{{.*}}, i64 0 - // CHECK: bitcast i8 %{{.*}} to <8 x i1> - // CHECK: extractelement <8 x i1> %{{.*}}, i64 0 - // CHECK: call float @llvm.sqrt.f32(float %{{.*}}) - // CHECK: select i1 {{.*}}, float {{.*}}, float {{.*}} - // CHECK: insertelement <4 x float> %{{.*}}, float {{.*}}, i64 0 - return _mm_maskz_sqrt_ss(__U,__A,__B); + // CHECK-NEXT: call float @llvm.sqrt.f32(float %{{.*}}) + // CHECK-NEXT: extractelement <4 x float> %{{.*}}, i64 0 + // CHECK-NEXT: bitcast i8 %{{.*}} to <8 x i1> + // CHECK-NEXT: extractelement <8 x i1> %{{.*}}, i64 0 + // CHECK-NEXT: select i1 {{.*}}, float {{.*}}, float {{.*}} + // CHECK-NEXT: insertelement <4 x float> %{{.*}}, float {{.*}}, i64 0 + return _mm_maskz_sqrt_ss(__U,__A,__B); } __m128 test_mm_maskz_sqrt_round_ss(__mmask8 __U, __m128 __A, __m128 __B){ // CHECK-LABEL: @test_mm_maskz_sqrt_round_ss // CHECK: extractelement <4 x float> %{{.*}}, i64 0 - // CHECK: extractelement <4 x float> %{{.*}}, i64 0 - // CHECK: bitcast i8 %{{.*}} to <8 x i1> - // CHECK: extractelement <8 x i1> %{{.*}}, i64 0 - // CHECK: call float @llvm.sqrt.f32(float %{{.*}}) - // CHECK: select i1 {{.*}}, float {{.*}}, float {{.*}} - // CHECK: insertelement <4 x float> %{{.*}}, float {{.*}}, i64 0 - return _mm_maskz_sqrt_round_ss(__U,__A,__B,_MM_FROUND_CUR_DIRECTION); + // CHECK-NEXT: call float @llvm.sqrt.f32(float %{{.*}}) + // CHECK-NEXT: extractelement <4 x float> %{{.*}}, i64 0 + // CHECK-NEXT: bitcast i8 %{{.*}} to <8 x i1> + // CHECK-NEXT: extractelement <8 x i1> %{{.*}}, i64 0 + // CHECK-NEXT: select i1 {{.*}}, float {{.*}}, float {{.*}} + // CHECK-NEXT: insertelement <4 x float> %{{.*}}, float {{.*}}, i64 0 + return _mm_maskz_sqrt_round_ss(__U,__A,__B,_MM_FROUND_CUR_DIRECTION); } __m512 test_mm512_broadcast_f32x4(float const* __A) { -- 2.7.4