[ARM] Fold (fadd x, (vselect c, y, -1.0)) into (vselect c, (fadd x, y), x)
authorDavid Green <david.green@arm.com>
Wed, 24 Nov 2021 10:41:00 +0000 (10:41 +0000)
committerDavid Green <david.green@arm.com>
Wed, 24 Nov 2021 10:41:00 +0000 (10:41 +0000)
This is similar to D113574, but as a DAG combine, not tablegen patterns.
Doing the fold as a DAG combine allows the fadd to be folded with a
fmul, finally producing a predicated vfma. It performs the same fold of
fadd(x, vselect(p, y, -0.0)) to vselect p, (fadd x, y), x) using -0.0 as
the identity value of a fadd.

Differential Revision: https://reviews.llvm.org/D113584

llvm/lib/Target/ARM/ARMISelLowering.cpp
llvm/test/CodeGen/Thumb2/mve-pred-selectop3.ll

index 239e2270966f2d99dfe645f8b6596a0f039b2e42..23c1a6e8cf2122f44292fb9cfa59f7d4b40f2258 100644 (file)
@@ -1017,6 +1017,9 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
     setTargetDAGCombine(ISD::SELECT);
     setTargetDAGCombine(ISD::SELECT_CC);
   }
+  if (Subtarget->hasMVEFloatOps()) {
+    setTargetDAGCombine(ISD::FADD);
+  }
 
   if (!Subtarget->hasFP64()) {
     // When targeting a floating-point unit with only single-precision
@@ -16407,6 +16410,42 @@ static SDValue PerformVCVTCombine(SDNode *N, SelectionDAG &DAG,
   return FixConv;
 }
 
+static SDValue PerformFAddVSelectCombine(SDNode *N, SelectionDAG &DAG,
+                                         const ARMSubtarget *Subtarget) {
+  if (!Subtarget->hasMVEFloatOps())
+    return SDValue();
+
+  // Turn (fadd x, (vselect c, y, -0.0)) into (vselect c, (fadd x, y), x)
+  // The second form can be more easily turned into a predicated vadd, and
+  // possibly combined into a fma to become a predicated vfma.
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  EVT VT = N->getValueType(0);
+  SDLoc DL(N);
+
+  // The identity element for a fadd is -0.0, which these VMOV's represent.
+  auto isNegativeZeroSplat = [&](SDValue Op) {
+    if (Op.getOpcode() != ISD::BITCAST ||
+        Op.getOperand(0).getOpcode() != ARMISD::VMOVIMM)
+      return false;
+    if (VT == MVT::v4f32 && Op.getOperand(0).getConstantOperandVal(0) == 1664)
+      return true;
+    if (VT == MVT::v8f16 && Op.getOperand(0).getConstantOperandVal(0) == 2688)
+      return true;
+    return false;
+  };
+
+  if (Op0.getOpcode() == ISD::VSELECT && Op1.getOpcode() != ISD::VSELECT)
+    std::swap(Op0, Op1);
+
+  if (Op1.getOpcode() != ISD::VSELECT ||
+      !isNegativeZeroSplat(Op1.getOperand(2)))
+    return SDValue();
+  SDValue FAdd =
+      DAG.getNode(ISD::FADD, DL, VT, Op0, Op1.getOperand(1), N->getFlags());
+  return DAG.getNode(ISD::VSELECT, DL, VT, Op1.getOperand(0), FAdd, Op0);
+}
+
 /// PerformVDIVCombine - VCVT (fixed-point to floating-point, Advanced SIMD)
 /// can replace combinations of VCVT (integer to floating-point) and VDIV
 /// when the VDIV has a constant operand that is a power of 2.
@@ -18201,6 +18240,8 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::FP_TO_SINT:
   case ISD::FP_TO_UINT:
     return PerformVCVTCombine(N, DCI.DAG, Subtarget);
+  case ISD::FADD:
+    return PerformFAddVSelectCombine(N, DCI.DAG, Subtarget);
   case ISD::FDIV:
     return PerformVDIVCombine(N, DCI.DAG, Subtarget);
   case ISD::INTRINSIC_WO_CHAIN:
index 238133c21e20762cd31a0ea21da0528299fa1b83..e3e23f6524ba094c99cc16fe15df24fd35b86e66 100644 (file)
@@ -470,10 +470,9 @@ entry:
 define arm_aapcs_vfpcc <4 x float> @fma_v4f32_x(<4 x float> %x, <4 x float> %y, <4 x float> %z, i32 %n) {
 ; CHECK-LABEL: fma_v4f32_x:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmul.f32 q1, q1, q2
 ; CHECK-NEXT:    vctp.32 r0
 ; CHECK-NEXT:    vpst
-; CHECK-NEXT:    vaddt.f32 q0, q0, q1
+; CHECK-NEXT:    vfmat.f32 q0, q1, q2
 ; CHECK-NEXT:    bx lr
 entry:
   %c = call <4 x i1> @llvm.arm.mve.vctp32(i32 %n)
@@ -486,10 +485,9 @@ entry:
 define arm_aapcs_vfpcc <8 x half> @fma_v8f16_x(<8 x half> %x, <8 x half> %y, <8 x half> %z, i32 %n) {
 ; CHECK-LABEL: fma_v8f16_x:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    vmul.f16 q1, q1, q2
 ; CHECK-NEXT:    vctp.16 r0
 ; CHECK-NEXT:    vpst
-; CHECK-NEXT:    vaddt.f16 q0, q0, q1
+; CHECK-NEXT:    vfmat.f16 q0, q1, q2
 ; CHECK-NEXT:    bx lr
 entry:
   %c = call <8 x i1> @llvm.arm.mve.vctp16(i32 %n)
@@ -2422,7 +2420,7 @@ define arm_aapcs_vfpcc <4 x float> @faddqr_v4f32_y(<4 x float> %x, float %y, i32
 ; CHECK-NEXT:    vctp.32 r0
 ; CHECK-NEXT:    vdup.32 q1, r1
 ; CHECK-NEXT:    vpst
-; CHECK-NEXT:    vaddt.f32 q1, q1, q0
+; CHECK-NEXT:    vaddt.f32 q1, q0, r1
 ; CHECK-NEXT:    vmov q0, q1
 ; CHECK-NEXT:    bx lr
 entry:
@@ -2441,7 +2439,7 @@ define arm_aapcs_vfpcc <8 x half> @faddqr_v8f16_y(<8 x half> %x, half %y, i32 %n
 ; CHECK-NEXT:    vctp.16 r0
 ; CHECK-NEXT:    vdup.16 q1, r1
 ; CHECK-NEXT:    vpst
-; CHECK-NEXT:    vaddt.f16 q1, q1, q0
+; CHECK-NEXT:    vaddt.f16 q1, q0, r1
 ; CHECK-NEXT:    vmov q0, q1
 ; CHECK-NEXT:    bx lr
 entry: