setTargetDAGCombine(ISD::SELECT);
setTargetDAGCombine(ISD::SELECT_CC);
}
+ if (Subtarget->hasMVEFloatOps()) {
+ setTargetDAGCombine(ISD::FADD);
+ }
if (!Subtarget->hasFP64()) {
// When targeting a floating-point unit with only single-precision
return FixConv;
}
+static SDValue PerformFAddVSelectCombine(SDNode *N, SelectionDAG &DAG,
+ const ARMSubtarget *Subtarget) {
+ if (!Subtarget->hasMVEFloatOps())
+ return SDValue();
+
+ // Turn (fadd x, (vselect c, y, -0.0)) into (vselect c, (fadd x, y), x)
+ // The second form can be more easily turned into a predicated vadd, and
+ // possibly combined into a fma to become a predicated vfma.
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ EVT VT = N->getValueType(0);
+ SDLoc DL(N);
+
+ // The identity element for a fadd is -0.0, which these VMOV's represent.
+ auto isNegativeZeroSplat = [&](SDValue Op) {
+ if (Op.getOpcode() != ISD::BITCAST ||
+ Op.getOperand(0).getOpcode() != ARMISD::VMOVIMM)
+ return false;
+ if (VT == MVT::v4f32 && Op.getOperand(0).getConstantOperandVal(0) == 1664)
+ return true;
+ if (VT == MVT::v8f16 && Op.getOperand(0).getConstantOperandVal(0) == 2688)
+ return true;
+ return false;
+ };
+
+ if (Op0.getOpcode() == ISD::VSELECT && Op1.getOpcode() != ISD::VSELECT)
+ std::swap(Op0, Op1);
+
+ if (Op1.getOpcode() != ISD::VSELECT ||
+ !isNegativeZeroSplat(Op1.getOperand(2)))
+ return SDValue();
+ SDValue FAdd =
+ DAG.getNode(ISD::FADD, DL, VT, Op0, Op1.getOperand(1), N->getFlags());
+ return DAG.getNode(ISD::VSELECT, DL, VT, Op1.getOperand(0), FAdd, Op0);
+}
+
/// PerformVDIVCombine - VCVT (fixed-point to floating-point, Advanced SIMD)
/// can replace combinations of VCVT (integer to floating-point) and VDIV
/// when the VDIV has a constant operand that is a power of 2.
case ISD::FP_TO_SINT:
case ISD::FP_TO_UINT:
return PerformVCVTCombine(N, DCI.DAG, Subtarget);
+ case ISD::FADD:
+ return PerformFAddVSelectCombine(N, DCI.DAG, Subtarget);
case ISD::FDIV:
return PerformVDIVCombine(N, DCI.DAG, Subtarget);
case ISD::INTRINSIC_WO_CHAIN:
define arm_aapcs_vfpcc <4 x float> @fma_v4f32_x(<4 x float> %x, <4 x float> %y, <4 x float> %z, i32 %n) {
; CHECK-LABEL: fma_v4f32_x:
; CHECK: @ %bb.0: @ %entry
-; CHECK-NEXT: vmul.f32 q1, q1, q2
; CHECK-NEXT: vctp.32 r0
; CHECK-NEXT: vpst
-; CHECK-NEXT: vaddt.f32 q0, q0, q1
+; CHECK-NEXT: vfmat.f32 q0, q1, q2
; CHECK-NEXT: bx lr
entry:
%c = call <4 x i1> @llvm.arm.mve.vctp32(i32 %n)
define arm_aapcs_vfpcc <8 x half> @fma_v8f16_x(<8 x half> %x, <8 x half> %y, <8 x half> %z, i32 %n) {
; CHECK-LABEL: fma_v8f16_x:
; CHECK: @ %bb.0: @ %entry
-; CHECK-NEXT: vmul.f16 q1, q1, q2
; CHECK-NEXT: vctp.16 r0
; CHECK-NEXT: vpst
-; CHECK-NEXT: vaddt.f16 q0, q0, q1
+; CHECK-NEXT: vfmat.f16 q0, q1, q2
; CHECK-NEXT: bx lr
entry:
%c = call <8 x i1> @llvm.arm.mve.vctp16(i32 %n)
; CHECK-NEXT: vctp.32 r0
; CHECK-NEXT: vdup.32 q1, r1
; CHECK-NEXT: vpst
-; CHECK-NEXT: vaddt.f32 q1, q1, q0
+; CHECK-NEXT: vaddt.f32 q1, q0, r1
; CHECK-NEXT: vmov q0, q1
; CHECK-NEXT: bx lr
entry:
; CHECK-NEXT: vctp.16 r0
; CHECK-NEXT: vdup.16 q1, r1
; CHECK-NEXT: vpst
-; CHECK-NEXT: vaddt.f16 q1, q1, q0
+; CHECK-NEXT: vaddt.f16 q1, q0, r1
; CHECK-NEXT: vmov q0, q1
; CHECK-NEXT: bx lr
entry: