From 2b0b9b1148c205dfd73c70d195f51ef9895e2307 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Thu, 14 May 2020 10:06:07 -0700 Subject: [PATCH] [X86] Fix a regression caused by moving combineLoopMAddPattern to IR When I moved combineLoopMAddPattern to an IR pass. I didn't match the behavior of canReduceVMulWidth that was used in the SelectionDAG version. canReduceVMulWidth just calls computeSignBits and assumes a truncate is always profitable. The version I put in IR just looks for constants and zext/sext. Though I neglected to check the number of bits in input of the zext/sext. This patch adds a check for the number of input bits to the sext/zext. And it adds a special case for add/sub with zext/sext inputs which can be handled by combineTruncatedArithmetic. Match the original SelectionDAG behavior appears to be a regression in some cases if the truncate isn't removed and becomes pack and permq. So enabling only this specific case is the conservative approach. Differential Revision: https://reviews.llvm.org/D79909 --- llvm/lib/Target/X86/X86PartialReduction.cpp | 32 ++++++++++++---- llvm/test/CodeGen/X86/madd.ll | 57 +++++++++++------------------ 2 files changed, 47 insertions(+), 42 deletions(-) diff --git a/llvm/lib/Target/X86/X86PartialReduction.cpp b/llvm/lib/Target/X86/X86PartialReduction.cpp index 5cc9281..4b3ba20 100644 --- a/llvm/lib/Target/X86/X86PartialReduction.cpp +++ b/llvm/lib/Target/X86/X86PartialReduction.cpp @@ -216,13 +216,31 @@ bool X86PartialReduction::tryMAddReplacement(Value *Op, BinaryOperator *Add) { } } - auto canShrinkOp = [&](Value *Op) { - if (isa(Op) && ComputeNumSignBits(Op, *DL, 0, nullptr, BO) > 16) + auto CanShrinkOp = [&](Value *Op) { + auto IsFreeTruncation = [&](Value *Op) { + if (auto *Cast = dyn_cast(Op)) { + if (Cast->getParent() == BB && + (Cast->getOpcode() == Instruction::SExt || + Cast->getOpcode() == Instruction::ZExt) && + Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 16) + return true; + } + + return isa(Op); + }; + + // If the operation can be freely truncated and has enough sign bits we + // can shrink. + if (IsFreeTruncation(Op) && + ComputeNumSignBits(Op, *DL, 0, nullptr, BO) > 16) return true; - if (auto *Cast = dyn_cast(Op)) { - if (Cast->getParent() == BB && - (Cast->getOpcode() == Instruction::SExt || - Cast->getOpcode() == Instruction::ZExt) && + + // SelectionDAG has limited support for truncating through an add or sub if + // the inputs are freely truncatable. + if (auto *BO = dyn_cast(Op)) { + if (BO->getParent() == BB && + IsFreeTruncation(BO->getOperand(0)) && + IsFreeTruncation(BO->getOperand(1)) && ComputeNumSignBits(Op, *DL, 0, nullptr, BO) > 16) return true; } @@ -231,7 +249,7 @@ bool X86PartialReduction::tryMAddReplacement(Value *Op, BinaryOperator *Add) { }; // Both Ops need to be shrinkable. - if (!canShrinkOp(LHS) && !canShrinkOp(RHS)) + if (!CanShrinkOp(LHS) && !CanShrinkOp(RHS)) return false; IRBuilder<> Builder(Add); diff --git a/llvm/test/CodeGen/X86/madd.ll b/llvm/test/CodeGen/X86/madd.ll index 3f221d4..d6d04d9 100644 --- a/llvm/test/CodeGen/X86/madd.ll +++ b/llvm/test/CodeGen/X86/madd.ll @@ -2935,36 +2935,29 @@ define i32 @sum_of_square_differences(i8* %a, i8* %b, i32 %n) { ; SSE2-LABEL: sum_of_square_differences: ; SSE2: # %bb.0: # %entry ; SSE2-NEXT: movl %edx, %eax -; SSE2-NEXT: pxor %xmm1, %xmm1 -; SSE2-NEXT: xorl %ecx, %ecx ; SSE2-NEXT: pxor %xmm0, %xmm0 -; SSE2-NEXT: pxor %xmm2, %xmm2 +; SSE2-NEXT: xorl %ecx, %ecx +; SSE2-NEXT: pxor %xmm1, %xmm1 ; SSE2-NEXT: .p2align 4, 0x90 ; SSE2-NEXT: .LBB34_1: # %vector.body ; SSE2-NEXT: # =>This Inner Loop Header: Depth=1 +; SSE2-NEXT: movq {{.*#+}} xmm2 = mem[0],zero ; SSE2-NEXT: movq {{.*#+}} xmm3 = mem[0],zero -; SSE2-NEXT: movq {{.*#+}} xmm4 = mem[0],zero -; SSE2-NEXT: punpcklbw {{.*#+}} xmm3 = xmm3[0],xmm1[0],xmm3[1],xmm1[1],xmm3[2],xmm1[2],xmm3[3],xmm1[3],xmm3[4],xmm1[4],xmm3[5],xmm1[5],xmm3[6],xmm1[6],xmm3[7],xmm1[7] -; SSE2-NEXT: punpcklbw {{.*#+}} xmm4 = xmm4[0],xmm1[0],xmm4[1],xmm1[1],xmm4[2],xmm1[2],xmm4[3],xmm1[3],xmm4[4],xmm1[4],xmm4[5],xmm1[5],xmm4[6],xmm1[6],xmm4[7],xmm1[7] -; SSE2-NEXT: psubw %xmm3, %xmm4 -; SSE2-NEXT: movdqa %xmm4, %xmm3 -; SSE2-NEXT: pmulhw %xmm4, %xmm3 -; SSE2-NEXT: pmullw %xmm4, %xmm4 -; SSE2-NEXT: movdqa %xmm4, %xmm5 -; SSE2-NEXT: punpcklwd {{.*#+}} xmm5 = xmm5[0],xmm3[0],xmm5[1],xmm3[1],xmm5[2],xmm3[2],xmm5[3],xmm3[3] -; SSE2-NEXT: paddd %xmm5, %xmm0 -; SSE2-NEXT: punpckhwd {{.*#+}} xmm4 = xmm4[4],xmm3[4],xmm4[5],xmm3[5],xmm4[6],xmm3[6],xmm4[7],xmm3[7] -; SSE2-NEXT: paddd %xmm4, %xmm2 +; SSE2-NEXT: punpcklbw {{.*#+}} xmm2 = xmm2[0],xmm0[0],xmm2[1],xmm0[1],xmm2[2],xmm0[2],xmm2[3],xmm0[3],xmm2[4],xmm0[4],xmm2[5],xmm0[5],xmm2[6],xmm0[6],xmm2[7],xmm0[7] +; SSE2-NEXT: punpcklbw {{.*#+}} xmm3 = xmm3[0],xmm0[0],xmm3[1],xmm0[1],xmm3[2],xmm0[2],xmm3[3],xmm0[3],xmm3[4],xmm0[4],xmm3[5],xmm0[5],xmm3[6],xmm0[6],xmm3[7],xmm0[7] +; SSE2-NEXT: psubw %xmm2, %xmm3 +; SSE2-NEXT: pmaddwd %xmm3, %xmm3 +; SSE2-NEXT: paddd %xmm3, %xmm1 ; SSE2-NEXT: addq $8, %rcx ; SSE2-NEXT: cmpq %rcx, %rax ; SSE2-NEXT: jne .LBB34_1 ; SSE2-NEXT: # %bb.2: # %middle.block -; SSE2-NEXT: paddd %xmm2, %xmm0 -; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm0[2,3,0,1] ; SSE2-NEXT: paddd %xmm0, %xmm1 -; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm1[1,1,2,3] +; SSE2-NEXT: pshufd {{.*#+}} xmm0 = xmm1[2,3,0,1] ; SSE2-NEXT: paddd %xmm1, %xmm0 -; SSE2-NEXT: movd %xmm0, %eax +; SSE2-NEXT: pshufd {{.*#+}} xmm1 = xmm0[1,1,2,3] +; SSE2-NEXT: paddd %xmm0, %xmm1 +; SSE2-NEXT: movd %xmm1, %eax ; SSE2-NEXT: retq ; ; AVX1-LABEL: sum_of_square_differences: @@ -2975,18 +2968,12 @@ define i32 @sum_of_square_differences(i8* %a, i8* %b, i32 %n) { ; AVX1-NEXT: .p2align 4, 0x90 ; AVX1-NEXT: .LBB34_1: # %vector.body ; AVX1-NEXT: # =>This Inner Loop Header: Depth=1 -; AVX1-NEXT: vpmovzxbd {{.*#+}} xmm1 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero -; AVX1-NEXT: vpmovzxbd {{.*#+}} xmm2 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero -; AVX1-NEXT: vpmovzxbd {{.*#+}} xmm3 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero -; AVX1-NEXT: vpsubd %xmm1, %xmm3, %xmm1 -; AVX1-NEXT: vpmovzxbd {{.*#+}} xmm3 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero -; AVX1-NEXT: vpsubd %xmm2, %xmm3, %xmm2 -; AVX1-NEXT: vpmulld %xmm1, %xmm1, %xmm1 -; AVX1-NEXT: vpmulld %xmm2, %xmm2, %xmm2 -; AVX1-NEXT: vextractf128 $1, %ymm0, %xmm3 -; AVX1-NEXT: vpaddd %xmm3, %xmm2, %xmm2 -; AVX1-NEXT: vpaddd %xmm0, %xmm1, %xmm0 -; AVX1-NEXT: vinsertf128 $1, %xmm2, %ymm0, %ymm0 +; AVX1-NEXT: vpmovzxbw {{.*#+}} xmm1 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero +; AVX1-NEXT: vpmovzxbw {{.*#+}} xmm2 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero +; AVX1-NEXT: vpsubw %xmm1, %xmm2, %xmm1 +; AVX1-NEXT: vpmaddwd %xmm1, %xmm1, %xmm1 +; AVX1-NEXT: vpaddd %xmm0, %xmm1, %xmm1 +; AVX1-NEXT: vblendps {{.*#+}} ymm0 = ymm1[0,1,2,3],ymm0[4,5,6,7] ; AVX1-NEXT: addq $8, %rcx ; AVX1-NEXT: cmpq %rcx, %rax ; AVX1-NEXT: jne .LBB34_1 @@ -3009,10 +2996,10 @@ define i32 @sum_of_square_differences(i8* %a, i8* %b, i32 %n) { ; AVX256-NEXT: .p2align 4, 0x90 ; AVX256-NEXT: .LBB34_1: # %vector.body ; AVX256-NEXT: # =>This Inner Loop Header: Depth=1 -; AVX256-NEXT: vpmovzxbd {{.*#+}} ymm1 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero -; AVX256-NEXT: vpmovzxbd {{.*#+}} ymm2 = mem[0],zero,zero,zero,mem[1],zero,zero,zero,mem[2],zero,zero,zero,mem[3],zero,zero,zero,mem[4],zero,zero,zero,mem[5],zero,zero,zero,mem[6],zero,zero,zero,mem[7],zero,zero,zero -; AVX256-NEXT: vpsubd %ymm1, %ymm2, %ymm1 -; AVX256-NEXT: vpmulld %ymm1, %ymm1, %ymm1 +; AVX256-NEXT: vpmovzxbw {{.*#+}} xmm1 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero +; AVX256-NEXT: vpmovzxbw {{.*#+}} xmm2 = mem[0],zero,mem[1],zero,mem[2],zero,mem[3],zero,mem[4],zero,mem[5],zero,mem[6],zero,mem[7],zero +; AVX256-NEXT: vpsubw %xmm1, %xmm2, %xmm1 +; AVX256-NEXT: vpmaddwd %xmm1, %xmm1, %xmm1 ; AVX256-NEXT: vpaddd %ymm0, %ymm1, %ymm0 ; AVX256-NEXT: addq $8, %rcx ; AVX256-NEXT: cmpq %rcx, %rax -- 2.7.4