[DAGCombiner] allow more folding of fadd + fmul into fma
authorSanjay Patel <spatel@rotateright.com>
Tue, 9 Jun 2020 14:04:16 +0000 (10:04 -0400)
committerSanjay Patel <spatel@rotateright.com>
Tue, 9 Jun 2020 14:41:27 +0000 (10:41 -0400)
If fmul and fadd are separated by an fma, we can fold them together
to save an instruction:
fadd (fma A, B, (fmul C, D)), N1 --> fma(A, B, fma(C, D, N1))

The fold implemented here is actually a specialization - we should
be able to peek through >1 fma to find this pattern. That's another
patch if we want to try that enhancement though.

This transform was guarded by the TLI hook enableAggressiveFMAFusion(),
so it was done for some in-tree targets like PowerPC, but not AArch64
or x86. The hook is protecting against forming a potentially more
expensive computation when fma takes longer to execute than a single
fadd. That hook may be needed for other transforms, but in this case,
we are replacing fmul+fadd with fma, and the fma should never take
longer than the 2 individual instructions.

'contract' FMF is all we need to allow this transform. That flag
corresponds to -ffp-contract=fast in Clang, so we are allowed to form
fma ops freely across expressions.

Differential Revision: https://reviews.llvm.org/D80801

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/test/CodeGen/AArch64/fadd-combines.ll
llvm/test/CodeGen/X86/fma_patterns.ll

index e3275ca..1352a7d 100644 (file)
@@ -11919,6 +11919,29 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
                        N1.getOperand(0), N1.getOperand(1), N0, Flags);
   }
 
+  // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
+  // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
+  SDValue FMA, E;
+  if (CanFuse && N0.getOpcode() == PreferredFusedOpcode &&
+      N0.getOperand(2).getOpcode() == ISD::FMUL && N0.hasOneUse() &&
+      N0.getOperand(2).hasOneUse()) {
+    FMA = N0;
+    E = N1;
+  } else if (CanFuse && N1.getOpcode() == PreferredFusedOpcode &&
+             N1.getOperand(2).getOpcode() == ISD::FMUL && N1.hasOneUse() &&
+             N1.getOperand(2).hasOneUse()) {
+    FMA = N1;
+    E = N0;
+  }
+  if (FMA && E) {
+    SDValue A = FMA.getOperand(0);
+    SDValue B = FMA.getOperand(1);
+    SDValue C = FMA.getOperand(2).getOperand(0);
+    SDValue D = FMA.getOperand(2).getOperand(1);
+    SDValue CDE = DAG.getNode(PreferredFusedOpcode, SL, VT, C, D, E, Flags);
+    return DAG.getNode(PreferredFusedOpcode, SL, VT, A, B, CDE, Flags);
+  }
+
   // Look through FP_EXTEND nodes to do more combining.
 
   // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z)
@@ -11952,29 +11975,6 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
 
   // More folding opportunities when target permits.
   if (Aggressive) {
-    // fadd (fma A, B, (fmul C, D)), E --> fma A, B, (fma C, D, E)
-    // fadd E, (fma A, B, (fmul C, D)) --> fma A, B, (fma C, D, E)
-    SDValue FMA, E;
-    if (CanFuse && N0.getOpcode() == PreferredFusedOpcode &&
-        N0.getOperand(2).getOpcode() == ISD::FMUL && N0.hasOneUse() &&
-        N0.getOperand(2).hasOneUse()) {
-      FMA = N0;
-      E = N1;
-    } else if (CanFuse && N1.getOpcode() == PreferredFusedOpcode &&
-               N1.getOperand(2).getOpcode() == ISD::FMUL && N1.hasOneUse() &&
-               N1.getOperand(2).hasOneUse()) {
-      FMA = N1;
-      E = N0;
-    }
-    if (FMA && E) {
-      SDValue A = FMA.getOperand(0);
-      SDValue B = FMA.getOperand(1);
-      SDValue C = FMA.getOperand(2).getOperand(0);
-      SDValue D = FMA.getOperand(2).getOperand(1);
-      SDValue CDE = DAG.getNode(PreferredFusedOpcode, SL, VT, C, D, E, Flags);
-      return DAG.getNode(PreferredFusedOpcode, SL, VT, A, B, CDE, Flags);
-    }
-
     // fold (fadd (fma x, y, (fpext (fmul u, v))), z)
     //   -> (fma x, y, (fma (fpext u), (fpext v), z))
     auto FoldFAddFMAFPExtFMul = [&] (
index 3702cc5..61d0ecc 100644 (file)
@@ -197,9 +197,8 @@ define <2 x double> @fmul2_negated_vec(<2 x double> %a, <2 x double> %b, <2 x do
 define double @fadd_fma_fmul_1(double %a, double %b, double %c, double %d, double %n1) nounwind {
 ; CHECK-LABEL: fadd_fma_fmul_1:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fmul d2, d2, d3
+; CHECK-NEXT:    fmadd d2, d2, d3, d4
 ; CHECK-NEXT:    fmadd d0, d0, d1, d2
-; CHECK-NEXT:    fadd d0, d0, d4
 ; CHECK-NEXT:    ret
   %m1 = fmul fast double %a, %b
   %m2 = fmul fast double %c, %d
@@ -213,9 +212,8 @@ define double @fadd_fma_fmul_1(double %a, double %b, double %c, double %d, doubl
 define float @fadd_fma_fmul_2(float %a, float %b, float %c, float %d, float %n0) nounwind {
 ; CHECK-LABEL: fadd_fma_fmul_2:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    fmul s2, s2, s3
+; CHECK-NEXT:    fmadd s2, s2, s3, s4
 ; CHECK-NEXT:    fmadd s0, s0, s1, s2
-; CHECK-NEXT:    fadd s0, s4, s0
 ; CHECK-NEXT:    ret
   %m1 = fmul float %a, %b
   %m2 = fmul float %c, %d
@@ -230,10 +228,10 @@ define <2 x double> @fadd_fma_fmul_3(<2 x double> %x1, <2 x double> %x2, <2 x do
 ; CHECK-LABEL: fadd_fma_fmul_3:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    fmul v2.2d, v2.2d, v3.2d
-; CHECK-NEXT:    fmul v3.2d, v6.2d, v7.2d
 ; CHECK-NEXT:    fmla v2.2d, v1.2d, v0.2d
-; CHECK-NEXT:    fmla v3.2d, v5.2d, v4.2d
-; CHECK-NEXT:    fadd v0.2d, v2.2d, v3.2d
+; CHECK-NEXT:    fmla v2.2d, v7.2d, v6.2d
+; CHECK-NEXT:    fmla v2.2d, v5.2d, v4.2d
+; CHECK-NEXT:    mov v0.16b, v2.16b
 ; CHECK-NEXT:    ret
   %m1 = fmul fast <2 x double> %x1, %x2
   %m2 = fmul fast <2 x double> %x3, %x4
@@ -245,6 +243,8 @@ define <2 x double> @fadd_fma_fmul_3(<2 x double> %x1, <2 x double> %x2, <2 x do
   ret <2 x double> %a3
 }
 
+; negative test
+
 define float @fadd_fma_fmul_extra_use_1(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind {
 ; CHECK-LABEL: fadd_fma_fmul_extra_use_1:
 ; CHECK:       // %bb.0:
@@ -261,6 +261,8 @@ define float @fadd_fma_fmul_extra_use_1(float %a, float %b, float %c, float %d,
   ret float %a2
 }
 
+; negative test
+
 define float @fadd_fma_fmul_extra_use_2(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind {
 ; CHECK-LABEL: fadd_fma_fmul_extra_use_2:
 ; CHECK:       // %bb.0:
@@ -277,6 +279,8 @@ define float @fadd_fma_fmul_extra_use_2(float %a, float %b, float %c, float %d,
   ret float %a2
 }
 
+; negative test
+
 define float @fadd_fma_fmul_extra_use_3(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind {
 ; CHECK-LABEL: fadd_fma_fmul_extra_use_3:
 ; CHECK:       // %bb.0:
index 7b9d511..a11ddf1 100644 (file)
@@ -1799,23 +1799,20 @@ define <4 x double> @test_v4f64_fneg_fmul_no_nsz(<4 x double> %x, <4 x double> %
 define double @fadd_fma_fmul_1(double %a, double %b, double %c, double %d, double %n1) nounwind {
 ; FMA-LABEL: fadd_fma_fmul_1:
 ; FMA:       # %bb.0:
-; FMA-NEXT:    vmulsd %xmm3, %xmm2, %xmm2
-; FMA-NEXT:    vfmadd231sd {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
-; FMA-NEXT:    vaddsd %xmm4, %xmm2, %xmm0
+; FMA-NEXT:    vfmadd213sd {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4
+; FMA-NEXT:    vfmadd213sd {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2
 ; FMA-NEXT:    retq
 ;
 ; FMA4-LABEL: fadd_fma_fmul_1:
 ; FMA4:       # %bb.0:
-; FMA4-NEXT:    vmulsd %xmm3, %xmm2, %xmm2
+; FMA4-NEXT:    vfmaddsd {{.*#+}} xmm2 = (xmm2 * xmm3) + xmm4
 ; FMA4-NEXT:    vfmaddsd {{.*#+}} xmm0 = (xmm0 * xmm1) + xmm2
-; FMA4-NEXT:    vaddsd %xmm4, %xmm0, %xmm0
 ; FMA4-NEXT:    retq
 ;
 ; AVX512-LABEL: fadd_fma_fmul_1:
 ; AVX512:       # %bb.0:
-; AVX512-NEXT:    vmulsd %xmm3, %xmm2, %xmm2
-; AVX512-NEXT:    vfmadd231sd {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
-; AVX512-NEXT:    vaddsd %xmm4, %xmm2, %xmm0
+; AVX512-NEXT:    vfmadd213sd {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4
+; AVX512-NEXT:    vfmadd213sd {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2
 ; AVX512-NEXT:    retq
   %m1 = fmul fast double %a, %b
   %m2 = fmul fast double %c, %d
@@ -1829,23 +1826,20 @@ define double @fadd_fma_fmul_1(double %a, double %b, double %c, double %d, doubl
 define float @fadd_fma_fmul_2(float %a, float %b, float %c, float %d, float %n0) nounwind {
 ; FMA-LABEL: fadd_fma_fmul_2:
 ; FMA:       # %bb.0:
-; FMA-NEXT:    vmulss %xmm3, %xmm2, %xmm2
-; FMA-NEXT:    vfmadd231ss {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
-; FMA-NEXT:    vaddss %xmm2, %xmm4, %xmm0
+; FMA-NEXT:    vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4
+; FMA-NEXT:    vfmadd213ss {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2
 ; FMA-NEXT:    retq
 ;
 ; FMA4-LABEL: fadd_fma_fmul_2:
 ; FMA4:       # %bb.0:
-; FMA4-NEXT:    vmulss %xmm3, %xmm2, %xmm2
+; FMA4-NEXT:    vfmaddss {{.*#+}} xmm2 = (xmm2 * xmm3) + xmm4
 ; FMA4-NEXT:    vfmaddss {{.*#+}} xmm0 = (xmm0 * xmm1) + xmm2
-; FMA4-NEXT:    vaddss %xmm0, %xmm4, %xmm0
 ; FMA4-NEXT:    retq
 ;
 ; AVX512-LABEL: fadd_fma_fmul_2:
 ; AVX512:       # %bb.0:
-; AVX512-NEXT:    vmulss %xmm3, %xmm2, %xmm2
-; AVX512-NEXT:    vfmadd231ss {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
-; AVX512-NEXT:    vaddss %xmm2, %xmm4, %xmm0
+; AVX512-NEXT:    vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + xmm4
+; AVX512-NEXT:    vfmadd213ss {{.*#+}} xmm0 = (xmm1 * xmm0) + xmm2
 ; AVX512-NEXT:    retq
   %m1 = fmul float %a, %b
   %m2 = fmul float %c, %d
@@ -1860,28 +1854,27 @@ define <2 x double> @fadd_fma_fmul_3(<2 x double> %x1, <2 x double> %x2, <2 x do
 ; FMA-LABEL: fadd_fma_fmul_3:
 ; FMA:       # %bb.0:
 ; FMA-NEXT:    vmulpd %xmm3, %xmm2, %xmm2
-; FMA-NEXT:    vmulpd %xmm7, %xmm6, %xmm3
 ; FMA-NEXT:    vfmadd231pd {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
-; FMA-NEXT:    vfmadd231pd {{.*#+}} xmm3 = (xmm5 * xmm4) + xmm3
-; FMA-NEXT:    vaddpd %xmm3, %xmm2, %xmm0
+; FMA-NEXT:    vfmadd231pd {{.*#+}} xmm2 = (xmm7 * xmm6) + xmm2
+; FMA-NEXT:    vfmadd231pd {{.*#+}} xmm2 = (xmm5 * xmm4) + xmm2
+; FMA-NEXT:    vmovapd %xmm2, %xmm0
 ; FMA-NEXT:    retq
 ;
 ; FMA4-LABEL: fadd_fma_fmul_3:
 ; FMA4:       # %bb.0:
 ; FMA4-NEXT:    vmulpd %xmm3, %xmm2, %xmm2
-; FMA4-NEXT:    vmulpd %xmm7, %xmm6, %xmm3
 ; FMA4-NEXT:    vfmaddpd {{.*#+}} xmm0 = (xmm0 * xmm1) + xmm2
-; FMA4-NEXT:    vfmaddpd {{.*#+}} xmm1 = (xmm4 * xmm5) + xmm3
-; FMA4-NEXT:    vaddpd %xmm1, %xmm0, %xmm0
+; FMA4-NEXT:    vfmaddpd {{.*#+}} xmm0 = (xmm6 * xmm7) + xmm0
+; FMA4-NEXT:    vfmaddpd {{.*#+}} xmm0 = (xmm4 * xmm5) + xmm0
 ; FMA4-NEXT:    retq
 ;
 ; AVX512-LABEL: fadd_fma_fmul_3:
 ; AVX512:       # %bb.0:
 ; AVX512-NEXT:    vmulpd %xmm3, %xmm2, %xmm2
-; AVX512-NEXT:    vmulpd %xmm7, %xmm6, %xmm3
 ; AVX512-NEXT:    vfmadd231pd {{.*#+}} xmm2 = (xmm1 * xmm0) + xmm2
-; AVX512-NEXT:    vfmadd231pd {{.*#+}} xmm3 = (xmm5 * xmm4) + xmm3
-; AVX512-NEXT:    vaddpd %xmm3, %xmm2, %xmm0
+; AVX512-NEXT:    vfmadd231pd {{.*#+}} xmm2 = (xmm7 * xmm6) + xmm2
+; AVX512-NEXT:    vfmadd231pd {{.*#+}} xmm2 = (xmm5 * xmm4) + xmm2
+; AVX512-NEXT:    vmovapd %xmm2, %xmm0
 ; AVX512-NEXT:    retq
   %m1 = fmul fast <2 x double> %x1, %x2
   %m2 = fmul fast <2 x double> %x3, %x4
@@ -1893,6 +1886,8 @@ define <2 x double> @fadd_fma_fmul_3(<2 x double> %x1, <2 x double> %x2, <2 x do
   ret <2 x double> %a3
 }
 
+; negative test
+
 define float @fadd_fma_fmul_extra_use_1(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind {
 ; FMA-LABEL: fadd_fma_fmul_extra_use_1:
 ; FMA:       # %bb.0:
@@ -1925,6 +1920,8 @@ define float @fadd_fma_fmul_extra_use_1(float %a, float %b, float %c, float %d,
   ret float %a2
 }
 
+; negative test
+
 define float @fadd_fma_fmul_extra_use_2(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind {
 ; FMA-LABEL: fadd_fma_fmul_extra_use_2:
 ; FMA:       # %bb.0:
@@ -1957,6 +1954,8 @@ define float @fadd_fma_fmul_extra_use_2(float %a, float %b, float %c, float %d,
   ret float %a2
 }
 
+; negative test
+
 define float @fadd_fma_fmul_extra_use_3(float %a, float %b, float %c, float %d, float %n0, float* %p) nounwind {
 ; FMA-LABEL: fadd_fma_fmul_extra_use_3:
 ; FMA:       # %bb.0: