From 0baace53799424bd1e3cfe9068ee254fae0ca677 Mon Sep 17 00:00:00 2001 From: Abinav Puthan Purayil Date: Mon, 2 Aug 2021 16:42:23 +0530 Subject: [PATCH] [DAGCombine] Add node level checks for fp-contract and fp-ninf in visitFMULForFMADistributiveCombine(). Differential Revision: https://reviews.llvm.org/D107551 --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 19 +++++++++++++++++-- llvm/test/CodeGen/AMDGPU/fma.ll | 10 ++++++++++ llvm/test/CodeGen/X86/fma-scalar-combine.ll | 13 +++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index aae9eea..9c8febc 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -13015,6 +13015,20 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) { return DAG.getBuildVector(VT, DL, Ops); } +// Returns true if floating point contraction is allowed on the FMUL-SDValue +// `N` +static bool isContractableFMUL(const TargetOptions &Options, SDValue N) { + assert(N.getOpcode() == ISD::FMUL); + + return Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath || + N->getFlags().hasAllowContract(); +} + +// Return true if `N` can assume no infinities involved in it's computation. +static bool hasNoInfs(const TargetOptions &Options, SDValue N) { + return Options.NoInfsFPMath || N.getNode()->getFlags().hasNoInfs(); +} + /// Try to perform FMA combining on a given FADD node. SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { SDValue N0 = N->getOperand(0); @@ -13557,12 +13571,13 @@ SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) { // The transforms below are incorrect when x == 0 and y == inf, because the // intermediate multiplication produces a nan. - if (!Options.NoInfsFPMath) + SDValue FAdd = N0.getOpcode() == ISD::FADD ? N0 : N1; + if (!hasNoInfs(Options, FAdd)) return SDValue(); // Floating-point multiply-add without intermediate rounding. bool HasFMA = - (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) && + isContractableFMUL(Options, SDValue(N, 0)) && TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) && (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)); diff --git a/llvm/test/CodeGen/AMDGPU/fma.ll b/llvm/test/CodeGen/AMDGPU/fma.ll index 6fadd91..dcf3254 100644 --- a/llvm/test/CodeGen/AMDGPU/fma.ll +++ b/llvm/test/CodeGen/AMDGPU/fma.ll @@ -144,3 +144,13 @@ bb: store float %tmp10, float addrspace(1)* %gep.out ret void } + +; Fold (fmul (fadd x, 1.0), y) -> (fma x, y, y) without FP specific command-line +; options. +; FUNC-LABEL: {{^}}fold_fmul_distributive: +; GFX906: v_fmac_f32_e32 v0, v1, v0 +define float @fold_fmul_distributive(float %x, float %y) { + %fadd = fadd ninf float %y, 1.0 + %fmul = fmul contract float %fadd, %x + ret float %fmul +} diff --git a/llvm/test/CodeGen/X86/fma-scalar-combine.ll b/llvm/test/CodeGen/X86/fma-scalar-combine.ll index 08ae343..7f7e21b 100644 --- a/llvm/test/CodeGen/X86/fma-scalar-combine.ll +++ b/llvm/test/CodeGen/X86/fma-scalar-combine.ll @@ -558,3 +558,16 @@ define float @fma_const_fmul(float %x) { %add1 = fadd contract float %mul1, %mul2 ret float %add1 } + +; Fold (fmul (fadd x, 1.0), y) -> (fma x, y, y) without FP specific command-line +; options. +define float @combine_fmul_distributive(float %x, float %y) { +; CHECK-LABEL: combine_fmul_distributive: +; CHECK: # %bb.0: +; CHECK-NEXT: vfmadd231ss %xmm0, %xmm1, %xmm0 # EVEX TO VEX Compression encoding: [0xc4,0xe2,0x71,0xb9,0xc0] +; CHECK-NEXT: # xmm0 = (xmm1 * xmm0) + xmm0 +; CHECK-NEXT: retq # encoding: [0xc3] + %fadd = fadd ninf float %y, 1.0 + %fmul = fmul contract float %fadd, %x + ret float %fmul +} -- 2.7.4