[DAG] Fix and expand fmin/fmax reassociation fold.
authorDavid Green <david.green@arm.com>
Fri, 23 Jun 2023 13:45:14 +0000 (14:45 +0100)
committerDavid Green <david.green@arm.com>
Fri, 23 Jun 2023 13:45:14 +0000 (14:45 +0100)
This call to reassociateReduction is used by both fminnum/fmaxnum and
fminimum/fmaximum. In adding support for fminimum/fmaximum we appear to be
fixing the use of an incorrect reduction type, which should have only applied
to minnum/maxnum.

I also believe that it doesn't need nsz and reassoc to perform the
reassociation. For float min/max it should always be valid.

Differential Revision: https://reviews.llvm.org/D153247

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/test/CodeGen/AArch64/double_reduct.ll
llvm/test/CodeGen/AArch64/sve-doublereduct.ll
llvm/test/CodeGen/RISCV/double_reduct.ll
llvm/test/CodeGen/Thumb2/mve-doublereduct.ll

index 357b0c6..77b05a3 100644 (file)
@@ -17232,13 +17232,12 @@ SDValue DAGCombiner::visitFMinMax(SDNode *N) {
     }
   }
 
-  const TargetOptions &Options = DAG.getTarget().Options;
-  if ((Options.UnsafeFPMath && Options.NoSignedZerosFPMath) ||
-      (Flags.hasAllowReassociation() && Flags.hasNoSignedZeros()))
-    if (SDValue SD = reassociateReduction(IsMin ? ISD::VECREDUCE_FMIN
-                                                : ISD::VECREDUCE_FMAX,
-                                          Opc, SDLoc(N), VT, N0, N1, Flags))
-      return SD;
+  if (SDValue SD = reassociateReduction(
+          PropagatesNaN
+              ? (IsMin ? ISD::VECREDUCE_FMINIMUM : ISD::VECREDUCE_FMAXIMUM)
+              : (IsMin ? ISD::VECREDUCE_FMIN : ISD::VECREDUCE_FMAX),
+          Opc, SDLoc(N), VT, N0, N1, Flags))
+    return SD;
 
   return SDValue();
 }
index 1bde11d..cb2e7a3 100644 (file)
@@ -34,9 +34,8 @@ define float @fmin_f32(<8 x float> %a, <4 x float> %b) {
 ; CHECK-LABEL: fmin_f32:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    fminnm v0.4s, v0.4s, v1.4s
-; CHECK-NEXT:    fminnmv s2, v2.4s
+; CHECK-NEXT:    fminnm v0.4s, v0.4s, v2.4s
 ; CHECK-NEXT:    fminnmv s0, v0.4s
-; CHECK-NEXT:    fminnm s0, s0, s2
 ; CHECK-NEXT:    ret
   %r1 = call float @llvm.vector.reduce.fmin.v8f32(<8 x float> %a)
   %r2 = call float @llvm.vector.reduce.fmin.v4f32(<4 x float> %b)
@@ -48,9 +47,8 @@ define float @fmax_f32(<8 x float> %a, <4 x float> %b) {
 ; CHECK-LABEL: fmax_f32:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    fmaxnm v0.4s, v0.4s, v1.4s
-; CHECK-NEXT:    fmaxnmv s2, v2.4s
+; CHECK-NEXT:    fmaxnm v0.4s, v0.4s, v2.4s
 ; CHECK-NEXT:    fmaxnmv s0, v0.4s
-; CHECK-NEXT:    fmaxnm s0, s0, s2
 ; CHECK-NEXT:    ret
   %r1 = call float @llvm.vector.reduce.fmax.v8f32(<8 x float> %a)
   %r2 = call float @llvm.vector.reduce.fmax.v4f32(<4 x float> %b)
@@ -62,9 +60,8 @@ define float @fminimum_f32(<8 x float> %a, <4 x float> %b) {
 ; CHECK-LABEL: fminimum_f32:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    fmin v0.4s, v0.4s, v1.4s
-; CHECK-NEXT:    fminv s2, v2.4s
+; CHECK-NEXT:    fmin v0.4s, v0.4s, v2.4s
 ; CHECK-NEXT:    fminv s0, v0.4s
-; CHECK-NEXT:    fmin s0, s0, s2
 ; CHECK-NEXT:    ret
   %r1 = call float @llvm.vector.reduce.fminimum.v8f32(<8 x float> %a)
   %r2 = call float @llvm.vector.reduce.fminimum.v4f32(<4 x float> %b)
@@ -76,9 +73,8 @@ define float @fmaximum_f32(<8 x float> %a, <4 x float> %b) {
 ; CHECK-LABEL: fmaximum_f32:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    fmax v0.4s, v0.4s, v1.4s
-; CHECK-NEXT:    fmaxv s2, v2.4s
+; CHECK-NEXT:    fmax v0.4s, v0.4s, v2.4s
 ; CHECK-NEXT:    fmaxv s0, v0.4s
-; CHECK-NEXT:    fmax s0, s0, s2
 ; CHECK-NEXT:    ret
   %r1 = call float @llvm.vector.reduce.fmaximum.v8f32(<8 x float> %a)
   %r2 = call float @llvm.vector.reduce.fmaximum.v4f32(<4 x float> %b)
index 6a06d38..bfb296b 100644 (file)
@@ -28,9 +28,9 @@ define float @fmin_f32(<vscale x 8 x float> %a, <vscale x 4 x float> %b) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ptrue p0.s
 ; CHECK-NEXT:    fminnm z0.s, p0/m, z0.s, z1.s
-; CHECK-NEXT:    fminnmv s2, p0, z2.s
+; CHECK-NEXT:    fminnm z0.s, p0/m, z0.s, z2.s
 ; CHECK-NEXT:    fminnmv s0, p0, z0.s
-; CHECK-NEXT:    fminnm s0, s0, s2
+; CHECK-NEXT:    // kill: def $s0 killed $s0 killed $z0
 ; CHECK-NEXT:    ret
   %r1 = call fast float @llvm.vector.reduce.fmin.nxv8f32(<vscale x 8 x float> %a)
   %r2 = call fast float @llvm.vector.reduce.fmin.nxv4f32(<vscale x 4 x float> %b)
@@ -43,9 +43,9 @@ define float @fmax_f32(<vscale x 8 x float> %a, <vscale x 4 x float> %b) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    ptrue p0.s
 ; CHECK-NEXT:    fmaxnm z0.s, p0/m, z0.s, z1.s
-; CHECK-NEXT:    fmaxnmv s2, p0, z2.s
+; CHECK-NEXT:    fmaxnm z0.s, p0/m, z0.s, z2.s
 ; CHECK-NEXT:    fmaxnmv s0, p0, z0.s
-; CHECK-NEXT:    fmaxnm s0, s0, s2
+; CHECK-NEXT:    // kill: def $s0 killed $s0 killed $z0
 ; CHECK-NEXT:    ret
   %r1 = call fast float @llvm.vector.reduce.fmax.nxv8f32(<vscale x 8 x float> %a)
   %r2 = call fast float @llvm.vector.reduce.fmax.nxv4f32(<vscale x 4 x float> %b)
index 40d1180..2de827f 100644 (file)
@@ -44,15 +44,12 @@ define float @fmul_f32(<4 x float> %a, <4 x float> %b) {
 define float @fmin_f32(<4 x float> %a, <4 x float> %b) {
 ; CHECK-LABEL: fmin_f32:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    lui a0, %hi(.LCPI2_0)
-; CHECK-NEXT:    flw fa5, %lo(.LCPI2_0)(a0)
+; CHECK-NEXT:    lui a0, 523264
 ; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; CHECK-NEXT:    vfmv.s.f v10, fa5
+; CHECK-NEXT:    vmv.s.x v10, a0
+; CHECK-NEXT:    vfmin.vv v8, v8, v9
 ; CHECK-NEXT:    vfredmin.vs v8, v8, v10
-; CHECK-NEXT:    vfmv.f.s fa5, v8
-; CHECK-NEXT:    vfredmin.vs v8, v9, v10
-; CHECK-NEXT:    vfmv.f.s fa4, v8
-; CHECK-NEXT:    fmin.s fa0, fa5, fa4
+; CHECK-NEXT:    vfmv.f.s fa0, v8
 ; CHECK-NEXT:    ret
   %r1 = call fast float @llvm.vector.reduce.fmin.v4f32(<4 x float> %a)
   %r2 = call fast float @llvm.vector.reduce.fmin.v4f32(<4 x float> %b)
@@ -63,15 +60,12 @@ define float @fmin_f32(<4 x float> %a, <4 x float> %b) {
 define float @fmax_f32(<4 x float> %a, <4 x float> %b) {
 ; CHECK-LABEL: fmax_f32:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    lui a0, %hi(.LCPI3_0)
-; CHECK-NEXT:    flw fa5, %lo(.LCPI3_0)(a0)
+; CHECK-NEXT:    lui a0, 1047552
 ; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; CHECK-NEXT:    vfmv.s.f v10, fa5
+; CHECK-NEXT:    vmv.s.x v10, a0
+; CHECK-NEXT:    vfmax.vv v8, v8, v9
 ; CHECK-NEXT:    vfredmax.vs v8, v8, v10
-; CHECK-NEXT:    vfmv.f.s fa5, v8
-; CHECK-NEXT:    vfredmax.vs v8, v9, v10
-; CHECK-NEXT:    vfmv.f.s fa4, v8
-; CHECK-NEXT:    fmax.s fa0, fa5, fa4
+; CHECK-NEXT:    vfmv.f.s fa0, v8
 ; CHECK-NEXT:    ret
   %r1 = call fast float @llvm.vector.reduce.fmax.v4f32(<4 x float> %a)
   %r2 = call fast float @llvm.vector.reduce.fmax.v4f32(<4 x float> %b)
index fad110d..3b85a4f 100644 (file)
@@ -35,13 +35,10 @@ define float @fmin_f32(<8 x float> %a, <4 x float> %b) {
 ; CHECK-LABEL: fmin_f32:
 ; CHECK:       @ %bb.0:
 ; CHECK-NEXT:    vminnm.f32 q0, q0, q1
-; CHECK-NEXT:    vminnm.f32 s4, s8, s9
+; CHECK-NEXT:    vminnm.f32 q0, q0, q2
 ; CHECK-NEXT:    vminnm.f32 s2, s2, s3
 ; CHECK-NEXT:    vminnm.f32 s0, s0, s1
 ; CHECK-NEXT:    vminnm.f32 s0, s0, s2
-; CHECK-NEXT:    vminnm.f32 s2, s10, s11
-; CHECK-NEXT:    vminnm.f32 s2, s4, s2
-; CHECK-NEXT:    vminnm.f32 s0, s0, s2
 ; CHECK-NEXT:    bx lr
   %r1 = call fast float @llvm.vector.reduce.fmin.v8f32(<8 x float> %a)
   %r2 = call fast float @llvm.vector.reduce.fmin.v4f32(<4 x float> %b)
@@ -53,13 +50,10 @@ define float @fmax_f32(<8 x float> %a, <4 x float> %b) {
 ; CHECK-LABEL: fmax_f32:
 ; CHECK:       @ %bb.0:
 ; CHECK-NEXT:    vmaxnm.f32 q0, q0, q1
-; CHECK-NEXT:    vmaxnm.f32 s4, s8, s9
+; CHECK-NEXT:    vmaxnm.f32 q0, q0, q2
 ; CHECK-NEXT:    vmaxnm.f32 s2, s2, s3
 ; CHECK-NEXT:    vmaxnm.f32 s0, s0, s1
 ; CHECK-NEXT:    vmaxnm.f32 s0, s0, s2
-; CHECK-NEXT:    vmaxnm.f32 s2, s10, s11
-; CHECK-NEXT:    vmaxnm.f32 s2, s4, s2
-; CHECK-NEXT:    vmaxnm.f32 s0, s0, s2
 ; CHECK-NEXT:    bx lr
   %r1 = call fast float @llvm.vector.reduce.fmax.v8f32(<8 x float> %a)
   %r2 = call fast float @llvm.vector.reduce.fmax.v4f32(<4 x float> %b)