From 58fa724494b6d3d85fa1e5f6b6c309f3f3ee69d4 Mon Sep 17 00:00:00 2001 From: Dehao Chen Date: Fri, 7 Apr 2017 15:41:52 +0000 Subject: [PATCH] Use PMADDWD to expand reduction in a loop Summary: PMADDWD can help improve 8/16 bit integer mutliply-add operation performance for cases like: for (int i = 0; i < count; i++) a += x[i] * y[i]; Reviewers: wmi, davidxl, hfinkel, RKSimon, zvi, mkuper Reviewed By: mkuper Subscribers: llvm-commits Differential Revision: https://reviews.llvm.org/D31679 llvm-svn: 299776 --- llvm/lib/Target/X86/X86ISelLowering.cpp | 47 +++++++++++++++ llvm/test/CodeGen/X86/madd.ll | 103 ++++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+) create mode 100644 llvm/test/CodeGen/X86/madd.ll diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index c7b01fa..fb31ce6 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -34618,6 +34618,51 @@ static SDValue combineAddOrSubToADCOrSBB(SDNode *N, SelectionDAG &DAG) { DAG.getConstant(0, DL, VT), NewCmp); } +static SDValue combineLoopMAddPattern(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + SDValue MulOp = N->getOperand(0); + SDValue Phi = N->getOperand(1); + + if (MulOp.getOpcode() != ISD::MUL) + std::swap(MulOp, Phi); + if (MulOp.getOpcode() != ISD::MUL) + return SDValue(); + + ShrinkMode Mode; + if (!canReduceVMulWidth(MulOp.getNode(), DAG, Mode)) + return SDValue(); + + EVT VT = N->getValueType(0); + + unsigned RegSize = 128; + if (Subtarget.hasBWI()) + RegSize = 512; + else if (Subtarget.hasAVX2()) + RegSize = 256; + unsigned VectorSize = VT.getVectorNumElements() * 16; + // If the vector size is less than 128, or greater than the supported RegSize, + // do not use PMADD. + if (VectorSize < 128 || VectorSize > RegSize) + return SDValue(); + + SDLoc DL(N); + EVT ReducedVT = EVT::getVectorVT(*DAG.getContext(), MVT::i16, + VT.getVectorNumElements()); + EVT MAddVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, + VT.getVectorNumElements() / 2); + + // Shrink the operands of mul. + SDValue N0 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(0)); + SDValue N1 = DAG.getNode(ISD::TRUNCATE, DL, ReducedVT, MulOp->getOperand(1)); + + // Madd vector size is half of the original vector size + SDValue Madd = DAG.getNode(X86ISD::VPMADDWD, DL, MAddVT, N0, N1); + // Fill the rest of the output with 0 + SDValue Zero = getZeroVector(Madd.getSimpleValueType(), Subtarget, DAG, DL); + SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Madd, Zero); + return DAG.getNode(ISD::ADD, DL, VT, Concat, Phi); +} + static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG, const X86Subtarget &Subtarget) { SDLoc DL(N); @@ -34695,6 +34740,8 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG, if (Flags->hasVectorReduction()) { if (SDValue Sad = combineLoopSADPattern(N, DAG, Subtarget)) return Sad; + if (SDValue MAdd = combineLoopMAddPattern(N, DAG, Subtarget)) + return MAdd; } EVT VT = N->getValueType(0); SDValue Op0 = N->getOperand(0); diff --git a/llvm/test/CodeGen/X86/madd.ll b/llvm/test/CodeGen/X86/madd.ll new file mode 100644 index 0000000..fdc5ace --- /dev/null +++ b/llvm/test/CodeGen/X86/madd.ll @@ -0,0 +1,103 @@ +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+sse2 | FileCheck %s --check-prefix=SSE2 +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx2 | FileCheck %s --check-prefix=AVX2 +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512f | FileCheck %s --check-prefix=AVX512 +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+avx512bw | FileCheck %s --check-prefix=AVX512 + +;SSE2-label: @_Z10test_shortPsS_i +;SSE2: movdqu +;SSE2-NEXT: movdqu +;SSE2-NEXT: pmaddwd +;SSE2-NEXT: paddd + +;AVX2-label: @_Z10test_shortPsS_i +;AVX2: vmovdqu +;AVX2-NEXT: vpmaddwd +;AVX2-NEXT: vinserti128 +;AVX2-NEXT: vpaddd + +;AVX512-label: @_Z10test_shortPsS_i +;AVX512: vmovdqu +;AVX512-NEXT: vpmaddwd +;AVX512-NEXT: vinserti128 +;AVX512-NEXT: vpaddd + +define i32 @_Z10test_shortPsS_i(i16* nocapture readonly, i16* nocapture readonly, i32) local_unnamed_addr #0 { +entry: + %3 = zext i32 %2 to i64 + br label %vector.body + +vector.body: + %index = phi i64 [ %index.next, %vector.body ], [ 0, %entry ] + %vec.phi = phi <8 x i32> [ %11, %vector.body ], [ zeroinitializer, %entry ] + %4 = getelementptr inbounds i16, i16* %0, i64 %index + %5 = bitcast i16* %4 to <8 x i16>* + %wide.load = load <8 x i16>, <8 x i16>* %5, align 2 + %6 = sext <8 x i16> %wide.load to <8 x i32> + %7 = getelementptr inbounds i16, i16* %1, i64 %index + %8 = bitcast i16* %7 to <8 x i16>* + %wide.load14 = load <8 x i16>, <8 x i16>* %8, align 2 + %9 = sext <8 x i16> %wide.load14 to <8 x i32> + %10 = mul nsw <8 x i32> %9, %6 + %11 = add nsw <8 x i32> %10, %vec.phi + %index.next = add i64 %index, 8 + %12 = icmp eq i64 %index.next, %3 + br i1 %12, label %middle.block, label %vector.body + +middle.block: + %rdx.shuf = shufflevector <8 x i32> %11, <8 x i32> undef, <8 x i32> + %bin.rdx = add <8 x i32> %11, %rdx.shuf + %rdx.shuf15 = shufflevector <8 x i32> %bin.rdx, <8 x i32> undef, <8 x i32> + %bin.rdx16 = add <8 x i32> %bin.rdx, %rdx.shuf15 + %rdx.shuf17 = shufflevector <8 x i32> %bin.rdx16, <8 x i32> undef, <8 x i32> + %bin.rdx18 = add <8 x i32> %bin.rdx16, %rdx.shuf17 + %13 = extractelement <8 x i32> %bin.rdx18, i32 0 + ret i32 %13 +} + +;AVX2-label: @_Z9test_charPcS_i +;AVX2: vpmovsxbw +;AVX2-NEXT: vpmovsxbw +;AVX2-NEXT: vpmaddwd +;AVX2-NEXT: vpaddd + +;AVX512-label: @_Z9test_charPcS_i +;AVX512: vpmovsxbw +;AVX512-NEXT: vpmovsxbw +;AVX512-NEXT: vpmaddwd +;AVX512-NEXT: vinserti64x4 +;AVX512-NEXT: vpaddd + +define i32 @_Z9test_charPcS_i(i8* nocapture readonly, i8* nocapture readonly, i32) local_unnamed_addr #0 { +entry: + %3 = zext i32 %2 to i64 + br label %vector.body + +vector.body: + %index = phi i64 [ %index.next, %vector.body ], [ 0, %entry ] + %vec.phi = phi <16 x i32> [ %11, %vector.body ], [ zeroinitializer, %entry ] + %4 = getelementptr inbounds i8, i8* %0, i64 %index + %5 = bitcast i8* %4 to <16 x i8>* + %wide.load = load <16 x i8>, <16 x i8>* %5, align 1 + %6 = sext <16 x i8> %wide.load to <16 x i32> + %7 = getelementptr inbounds i8, i8* %1, i64 %index + %8 = bitcast i8* %7 to <16 x i8>* + %wide.load14 = load <16 x i8>, <16 x i8>* %8, align 1 + %9 = sext <16 x i8> %wide.load14 to <16 x i32> + %10 = mul nsw <16 x i32> %9, %6 + %11 = add nsw <16 x i32> %10, %vec.phi + %index.next = add i64 %index, 16 + %12 = icmp eq i64 %index.next, %3 + br i1 %12, label %middle.block, label %vector.body + +middle.block: + %rdx.shuf = shufflevector <16 x i32> %11, <16 x i32> undef, <16 x i32> + %bin.rdx = add <16 x i32> %11, %rdx.shuf + %rdx.shuf15 = shufflevector <16 x i32> %bin.rdx, <16 x i32> undef, <16 x i32> + %bin.rdx16 = add <16 x i32> %bin.rdx, %rdx.shuf15 + %rdx.shuf17 = shufflevector <16 x i32> %bin.rdx16, <16 x i32> undef, <16 x i32> + %bin.rdx18 = add <16 x i32> %bin.rdx16, %rdx.shuf17 + %rdx.shuf19 = shufflevector <16 x i32> %bin.rdx18, <16 x i32> undef, <16 x i32> + %bin.rdx20 = add <16 x i32> %bin.rdx18, %rdx.shuf19 + %13 = extractelement <16 x i32> %bin.rdx20, i32 0 + ret i32 %13 +} -- 2.7.4