[RISCV] Add an implementation of isFMAFasterThanFMulAndFAdd
authorCraig Topper <craig.topper@sifive.com>
Wed, 25 Nov 2020 23:07:34 +0000 (15:07 -0800)
committerCraig Topper <craig.topper@sifive.com>
Wed, 25 Nov 2020 23:07:34 +0000 (15:07 -0800)
Start with an assumption that FMA is faster than Fmul+FAdd. If thats not true
on some particular implementation we can add a tuning parameter in the future.

I've update the fmuladd test cases and added new test cases for fast math flag
based contraction.

Differential Revision: https://reviews.llvm.org/D91987

llvm/lib/Target/RISCV/RISCVISelLowering.cpp
llvm/lib/Target/RISCV/RISCVISelLowering.h
llvm/test/CodeGen/RISCV/double-arith.ll
llvm/test/CodeGen/RISCV/double-intrinsics.ll
llvm/test/CodeGen/RISCV/float-arith.ll
llvm/test/CodeGen/RISCV/float-intrinsics.ll

index c182a34..e855b77 100644 (file)
@@ -3365,6 +3365,25 @@ Value *RISCVTargetLowering::emitMaskedAtomicCmpXchgIntrinsic(
   return Result;
 }
 
+bool RISCVTargetLowering::isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
+                                                     EVT VT) const {
+  VT = VT.getScalarType();
+
+  if (!VT.isSimple())
+    return false;
+
+  switch (VT.getSimpleVT().SimpleTy) {
+  case MVT::f32:
+    return Subtarget.hasStdExtF();
+  case MVT::f64:
+    return Subtarget.hasStdExtD();
+  default:
+    break;
+  }
+
+  return false;
+}
+
 Register RISCVTargetLowering::getExceptionPointerRegister(
     const Constant *PersonalityFn) const {
   return RISCV::X10;
index 39aa360..065b358 100644 (file)
@@ -146,6 +146,9 @@ public:
   Instruction *emitTrailingFence(IRBuilder<> &Builder, Instruction *Inst,
                                  AtomicOrdering Ord) const override;
 
+  bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
+                                  EVT VT) const override;
+
   ISD::NodeType getExtendForAtomicOps() const override {
     return ISD::SIGN_EXTEND;
   }
index 5f9d350..e33d2e4 100644 (file)
@@ -644,3 +644,160 @@ define double @fnmsub_d_2(double %a, double %b, double %c) nounwind {
   %1 = call double @llvm.fma.f64(double %a, double %negb, double %c)
   ret double %1
 }
+
+define double @fmadd_d_contract(double %a, double %b, double %c) nounwind {
+; RV32IFD-LABEL: fmadd_d_contract:
+; RV32IFD:       # %bb.0:
+; RV32IFD-NEXT:    addi sp, sp, -16
+; RV32IFD-NEXT:    sw a4, 8(sp)
+; RV32IFD-NEXT:    sw a5, 12(sp)
+; RV32IFD-NEXT:    fld ft0, 8(sp)
+; RV32IFD-NEXT:    sw a2, 8(sp)
+; RV32IFD-NEXT:    sw a3, 12(sp)
+; RV32IFD-NEXT:    fld ft1, 8(sp)
+; RV32IFD-NEXT:    sw a0, 8(sp)
+; RV32IFD-NEXT:    sw a1, 12(sp)
+; RV32IFD-NEXT:    fld ft2, 8(sp)
+; RV32IFD-NEXT:    fmadd.d ft0, ft2, ft1, ft0
+; RV32IFD-NEXT:    fsd ft0, 8(sp)
+; RV32IFD-NEXT:    lw a0, 8(sp)
+; RV32IFD-NEXT:    lw a1, 12(sp)
+; RV32IFD-NEXT:    addi sp, sp, 16
+; RV32IFD-NEXT:    ret
+;
+; RV64IFD-LABEL: fmadd_d_contract:
+; RV64IFD:       # %bb.0:
+; RV64IFD-NEXT:    fmv.d.x ft0, a2
+; RV64IFD-NEXT:    fmv.d.x ft1, a1
+; RV64IFD-NEXT:    fmv.d.x ft2, a0
+; RV64IFD-NEXT:    fmadd.d ft0, ft2, ft1, ft0
+; RV64IFD-NEXT:    fmv.x.d a0, ft0
+; RV64IFD-NEXT:    ret
+  %1 = fmul contract double %a, %b
+  %2 = fadd contract double %1, %c
+  ret double %2
+}
+
+define double @fmsub_d_contract(double %a, double %b, double %c) nounwind {
+; RV32IFD-LABEL: fmsub_d_contract:
+; RV32IFD:       # %bb.0:
+; RV32IFD-NEXT:    addi sp, sp, -16
+; RV32IFD-NEXT:    sw a2, 8(sp)
+; RV32IFD-NEXT:    sw a3, 12(sp)
+; RV32IFD-NEXT:    fld ft0, 8(sp)
+; RV32IFD-NEXT:    sw a0, 8(sp)
+; RV32IFD-NEXT:    sw a1, 12(sp)
+; RV32IFD-NEXT:    fld ft1, 8(sp)
+; RV32IFD-NEXT:    sw a4, 8(sp)
+; RV32IFD-NEXT:    sw a5, 12(sp)
+; RV32IFD-NEXT:    fld ft2, 8(sp)
+; RV32IFD-NEXT:    fcvt.d.w ft3, zero
+; RV32IFD-NEXT:    fadd.d ft2, ft2, ft3
+; RV32IFD-NEXT:    fmsub.d ft0, ft1, ft0, ft2
+; RV32IFD-NEXT:    fsd ft0, 8(sp)
+; RV32IFD-NEXT:    lw a0, 8(sp)
+; RV32IFD-NEXT:    lw a1, 12(sp)
+; RV32IFD-NEXT:    addi sp, sp, 16
+; RV32IFD-NEXT:    ret
+;
+; RV64IFD-LABEL: fmsub_d_contract:
+; RV64IFD:       # %bb.0:
+; RV64IFD-NEXT:    fmv.d.x ft0, a1
+; RV64IFD-NEXT:    fmv.d.x ft1, a0
+; RV64IFD-NEXT:    fmv.d.x ft2, a2
+; RV64IFD-NEXT:    fmv.d.x ft3, zero
+; RV64IFD-NEXT:    fadd.d ft2, ft2, ft3
+; RV64IFD-NEXT:    fmsub.d ft0, ft1, ft0, ft2
+; RV64IFD-NEXT:    fmv.x.d a0, ft0
+; RV64IFD-NEXT:    ret
+  %c_ = fadd double 0.0, %c ; avoid negation using xor
+  %1 = fmul contract double %a, %b
+  %2 = fsub contract double %1, %c_
+  ret double %2
+}
+
+define double @fnmadd_d_contract(double %a, double %b, double %c) nounwind {
+; RV32IFD-LABEL: fnmadd_d_contract:
+; RV32IFD:       # %bb.0:
+; RV32IFD-NEXT:    addi sp, sp, -16
+; RV32IFD-NEXT:    sw a4, 8(sp)
+; RV32IFD-NEXT:    sw a5, 12(sp)
+; RV32IFD-NEXT:    fld ft0, 8(sp)
+; RV32IFD-NEXT:    sw a2, 8(sp)
+; RV32IFD-NEXT:    sw a3, 12(sp)
+; RV32IFD-NEXT:    fld ft1, 8(sp)
+; RV32IFD-NEXT:    sw a0, 8(sp)
+; RV32IFD-NEXT:    sw a1, 12(sp)
+; RV32IFD-NEXT:    fld ft2, 8(sp)
+; RV32IFD-NEXT:    fcvt.d.w ft3, zero
+; RV32IFD-NEXT:    fadd.d ft2, ft2, ft3
+; RV32IFD-NEXT:    fadd.d ft1, ft1, ft3
+; RV32IFD-NEXT:    fadd.d ft0, ft0, ft3
+; RV32IFD-NEXT:    fnmadd.d ft0, ft2, ft1, ft0
+; RV32IFD-NEXT:    fsd ft0, 8(sp)
+; RV32IFD-NEXT:    lw a0, 8(sp)
+; RV32IFD-NEXT:    lw a1, 12(sp)
+; RV32IFD-NEXT:    addi sp, sp, 16
+; RV32IFD-NEXT:    ret
+;
+; RV64IFD-LABEL: fnmadd_d_contract:
+; RV64IFD:       # %bb.0:
+; RV64IFD-NEXT:    fmv.d.x ft0, a2
+; RV64IFD-NEXT:    fmv.d.x ft1, a1
+; RV64IFD-NEXT:    fmv.d.x ft2, a0
+; RV64IFD-NEXT:    fmv.d.x ft3, zero
+; RV64IFD-NEXT:    fadd.d ft2, ft2, ft3
+; RV64IFD-NEXT:    fadd.d ft1, ft1, ft3
+; RV64IFD-NEXT:    fadd.d ft0, ft0, ft3
+; RV64IFD-NEXT:    fnmadd.d ft0, ft2, ft1, ft0
+; RV64IFD-NEXT:    fmv.x.d a0, ft0
+; RV64IFD-NEXT:    ret
+  %a_ = fadd double 0.0, %a ; avoid negation using xor
+  %b_ = fadd double 0.0, %b ; avoid negation using xor
+  %c_ = fadd double 0.0, %c ; avoid negation using xor
+  %1 = fmul contract double %a_, %b_
+  %2 = fneg double %1
+  %3 = fsub contract double %2, %c_
+  ret double %3
+}
+
+define double @fnmsub_d_contract(double %a, double %b, double %c) nounwind {
+; RV32IFD-LABEL: fnmsub_d_contract:
+; RV32IFD:       # %bb.0:
+; RV32IFD-NEXT:    addi sp, sp, -16
+; RV32IFD-NEXT:    sw a4, 8(sp)
+; RV32IFD-NEXT:    sw a5, 12(sp)
+; RV32IFD-NEXT:    fld ft0, 8(sp)
+; RV32IFD-NEXT:    sw a2, 8(sp)
+; RV32IFD-NEXT:    sw a3, 12(sp)
+; RV32IFD-NEXT:    fld ft1, 8(sp)
+; RV32IFD-NEXT:    sw a0, 8(sp)
+; RV32IFD-NEXT:    sw a1, 12(sp)
+; RV32IFD-NEXT:    fld ft2, 8(sp)
+; RV32IFD-NEXT:    fcvt.d.w ft3, zero
+; RV32IFD-NEXT:    fadd.d ft2, ft2, ft3
+; RV32IFD-NEXT:    fadd.d ft1, ft1, ft3
+; RV32IFD-NEXT:    fnmsub.d ft0, ft2, ft1, ft0
+; RV32IFD-NEXT:    fsd ft0, 8(sp)
+; RV32IFD-NEXT:    lw a0, 8(sp)
+; RV32IFD-NEXT:    lw a1, 12(sp)
+; RV32IFD-NEXT:    addi sp, sp, 16
+; RV32IFD-NEXT:    ret
+;
+; RV64IFD-LABEL: fnmsub_d_contract:
+; RV64IFD:       # %bb.0:
+; RV64IFD-NEXT:    fmv.d.x ft0, a2
+; RV64IFD-NEXT:    fmv.d.x ft1, a1
+; RV64IFD-NEXT:    fmv.d.x ft2, a0
+; RV64IFD-NEXT:    fmv.d.x ft3, zero
+; RV64IFD-NEXT:    fadd.d ft2, ft2, ft3
+; RV64IFD-NEXT:    fadd.d ft1, ft1, ft3
+; RV64IFD-NEXT:    fnmsub.d ft0, ft2, ft1, ft0
+; RV64IFD-NEXT:    fmv.x.d a0, ft0
+; RV64IFD-NEXT:    ret
+  %a_ = fadd double 0.0, %a ; avoid negation using xor
+  %b_ = fadd double 0.0, %b ; avoid negation using xor
+  %1 = fmul contract double %a_, %b_
+  %2 = fsub contract double %c, %1
+  ret double %2
+}
index b887388..4126c01 100644 (file)
@@ -341,7 +341,6 @@ define double @fma_f64(double %a, double %b, double %c) nounwind {
 declare double @llvm.fmuladd.f64(double, double, double)
 
 define double @fmuladd_f64(double %a, double %b, double %c) nounwind {
-; Use of fmadd depends on TargetLowering::isFMAFasterthanFMulAndFAdd
 ; RV32IFD-LABEL: fmuladd_f64:
 ; RV32IFD:       # %bb.0:
 ; RV32IFD-NEXT:    addi sp, sp, -16
@@ -354,8 +353,7 @@ define double @fmuladd_f64(double %a, double %b, double %c) nounwind {
 ; RV32IFD-NEXT:    sw a0, 8(sp)
 ; RV32IFD-NEXT:    sw a1, 12(sp)
 ; RV32IFD-NEXT:    fld ft2, 8(sp)
-; RV32IFD-NEXT:    fmul.d ft1, ft2, ft1
-; RV32IFD-NEXT:    fadd.d ft0, ft1, ft0
+; RV32IFD-NEXT:    fmadd.d ft0, ft2, ft1, ft0
 ; RV32IFD-NEXT:    fsd ft0, 8(sp)
 ; RV32IFD-NEXT:    lw a0, 8(sp)
 ; RV32IFD-NEXT:    lw a1, 12(sp)
@@ -367,8 +365,7 @@ define double @fmuladd_f64(double %a, double %b, double %c) nounwind {
 ; RV64IFD-NEXT:    fmv.d.x ft0, a2
 ; RV64IFD-NEXT:    fmv.d.x ft1, a1
 ; RV64IFD-NEXT:    fmv.d.x ft2, a0
-; RV64IFD-NEXT:    fmul.d ft1, ft2, ft1
-; RV64IFD-NEXT:    fadd.d ft0, ft1, ft0
+; RV64IFD-NEXT:    fmadd.d ft0, ft2, ft1, ft0
 ; RV64IFD-NEXT:    fmv.x.d a0, ft0
 ; RV64IFD-NEXT:    ret
   %1 = call double @llvm.fmuladd.f64(double %a, double %b, double %c)
index 9f25cde..d302dd5 100644 (file)
@@ -483,3 +483,120 @@ define float @fnmsub_s_2(float %a, float %b, float %c) nounwind {
   %1 = call float @llvm.fma.f32(float %a, float %negb, float %c)
   ret float %1
 }
+
+define float @fmadd_s_contract(float %a, float %b, float %c) nounwind {
+; RV32IF-LABEL: fmadd_s_contract:
+; RV32IF:       # %bb.0:
+; RV32IF-NEXT:    fmv.w.x ft0, a2
+; RV32IF-NEXT:    fmv.w.x ft1, a1
+; RV32IF-NEXT:    fmv.w.x ft2, a0
+; RV32IF-NEXT:    fmadd.s ft0, ft2, ft1, ft0
+; RV32IF-NEXT:    fmv.x.w a0, ft0
+; RV32IF-NEXT:    ret
+;
+; RV64IF-LABEL: fmadd_s_contract:
+; RV64IF:       # %bb.0:
+; RV64IF-NEXT:    fmv.w.x ft0, a2
+; RV64IF-NEXT:    fmv.w.x ft1, a1
+; RV64IF-NEXT:    fmv.w.x ft2, a0
+; RV64IF-NEXT:    fmadd.s ft0, ft2, ft1, ft0
+; RV64IF-NEXT:    fmv.x.w a0, ft0
+; RV64IF-NEXT:    ret
+  %1 = fmul contract float %a, %b
+  %2 = fadd contract float %1, %c
+  ret float %2
+}
+
+define float @fmsub_s_contract(float %a, float %b, float %c) nounwind {
+; RV32IF-LABEL: fmsub_s_contract:
+; RV32IF:       # %bb.0:
+; RV32IF-NEXT:    fmv.w.x ft0, a1
+; RV32IF-NEXT:    fmv.w.x ft1, a0
+; RV32IF-NEXT:    fmv.w.x ft2, a2
+; RV32IF-NEXT:    fmv.w.x ft3, zero
+; RV32IF-NEXT:    fadd.s ft2, ft2, ft3
+; RV32IF-NEXT:    fmsub.s ft0, ft1, ft0, ft2
+; RV32IF-NEXT:    fmv.x.w a0, ft0
+; RV32IF-NEXT:    ret
+;
+; RV64IF-LABEL: fmsub_s_contract:
+; RV64IF:       # %bb.0:
+; RV64IF-NEXT:    fmv.w.x ft0, a1
+; RV64IF-NEXT:    fmv.w.x ft1, a0
+; RV64IF-NEXT:    fmv.w.x ft2, a2
+; RV64IF-NEXT:    fmv.w.x ft3, zero
+; RV64IF-NEXT:    fadd.s ft2, ft2, ft3
+; RV64IF-NEXT:    fmsub.s ft0, ft1, ft0, ft2
+; RV64IF-NEXT:    fmv.x.w a0, ft0
+; RV64IF-NEXT:    ret
+  %c_ = fadd float 0.0, %c ; avoid negation using xor
+  %1 = fmul contract float %a, %b
+  %2 = fsub contract float %1, %c_
+  ret float %2
+}
+
+define float @fnmadd_s_contract(float %a, float %b, float %c) nounwind {
+; RV32IF-LABEL: fnmadd_s_contract:
+; RV32IF:       # %bb.0:
+; RV32IF-NEXT:    fmv.w.x ft0, a2
+; RV32IF-NEXT:    fmv.w.x ft1, a1
+; RV32IF-NEXT:    fmv.w.x ft2, a0
+; RV32IF-NEXT:    fmv.w.x ft3, zero
+; RV32IF-NEXT:    fadd.s ft2, ft2, ft3
+; RV32IF-NEXT:    fadd.s ft1, ft1, ft3
+; RV32IF-NEXT:    fadd.s ft0, ft0, ft3
+; RV32IF-NEXT:    fnmadd.s ft0, ft2, ft1, ft0
+; RV32IF-NEXT:    fmv.x.w a0, ft0
+; RV32IF-NEXT:    ret
+;
+; RV64IF-LABEL: fnmadd_s_contract:
+; RV64IF:       # %bb.0:
+; RV64IF-NEXT:    fmv.w.x ft0, a2
+; RV64IF-NEXT:    fmv.w.x ft1, a1
+; RV64IF-NEXT:    fmv.w.x ft2, a0
+; RV64IF-NEXT:    fmv.w.x ft3, zero
+; RV64IF-NEXT:    fadd.s ft2, ft2, ft3
+; RV64IF-NEXT:    fadd.s ft1, ft1, ft3
+; RV64IF-NEXT:    fadd.s ft0, ft0, ft3
+; RV64IF-NEXT:    fnmadd.s ft0, ft2, ft1, ft0
+; RV64IF-NEXT:    fmv.x.w a0, ft0
+; RV64IF-NEXT:    ret
+  %a_ = fadd float 0.0, %a ; avoid negation using xor
+  %b_ = fadd float 0.0, %b ; avoid negation using xor
+  %c_ = fadd float 0.0, %c ; avoid negation using xor
+  %1 = fmul contract float %a_, %b_
+  %2 = fneg float %1
+  %3 = fsub contract float %2, %c_
+  ret float %3
+}
+
+define float @fnmsub_s_contract(float %a, float %b, float %c) nounwind {
+; RV32IF-LABEL: fnmsub_s_contract:
+; RV32IF:       # %bb.0:
+; RV32IF-NEXT:    fmv.w.x ft0, a2
+; RV32IF-NEXT:    fmv.w.x ft1, a1
+; RV32IF-NEXT:    fmv.w.x ft2, a0
+; RV32IF-NEXT:    fmv.w.x ft3, zero
+; RV32IF-NEXT:    fadd.s ft2, ft2, ft3
+; RV32IF-NEXT:    fadd.s ft1, ft1, ft3
+; RV32IF-NEXT:    fnmsub.s ft0, ft2, ft1, ft0
+; RV32IF-NEXT:    fmv.x.w a0, ft0
+; RV32IF-NEXT:    ret
+;
+; RV64IF-LABEL: fnmsub_s_contract:
+; RV64IF:       # %bb.0:
+; RV64IF-NEXT:    fmv.w.x ft0, a2
+; RV64IF-NEXT:    fmv.w.x ft1, a1
+; RV64IF-NEXT:    fmv.w.x ft2, a0
+; RV64IF-NEXT:    fmv.w.x ft3, zero
+; RV64IF-NEXT:    fadd.s ft2, ft2, ft3
+; RV64IF-NEXT:    fadd.s ft1, ft1, ft3
+; RV64IF-NEXT:    fnmsub.s ft0, ft2, ft1, ft0
+; RV64IF-NEXT:    fmv.x.w a0, ft0
+; RV64IF-NEXT:    ret
+  %a_ = fadd float 0.0, %a ; avoid negation using xor
+  %b_ = fadd float 0.0, %b ; avoid negation using xor
+  %1 = fmul contract float %a_, %b_
+  %2 = fsub contract float %c, %1
+  ret float %2
+}
index 68bb95b..85d1564 100644 (file)
@@ -319,14 +319,12 @@ define float @fma_f32(float %a, float %b, float %c) nounwind {
 declare float @llvm.fmuladd.f32(float, float, float)
 
 define float @fmuladd_f32(float %a, float %b, float %c) nounwind {
-; Use of fmadd depends on TargetLowering::isFMAFasterthanFMulAndFAdd
 ; RV32IF-LABEL: fmuladd_f32:
 ; RV32IF:       # %bb.0:
 ; RV32IF-NEXT:    fmv.w.x ft0, a2
 ; RV32IF-NEXT:    fmv.w.x ft1, a1
 ; RV32IF-NEXT:    fmv.w.x ft2, a0
-; RV32IF-NEXT:    fmul.s ft1, ft2, ft1
-; RV32IF-NEXT:    fadd.s ft0, ft1, ft0
+; RV32IF-NEXT:    fmadd.s ft0, ft2, ft1, ft0
 ; RV32IF-NEXT:    fmv.x.w a0, ft0
 ; RV32IF-NEXT:    ret
 ;
@@ -335,8 +333,7 @@ define float @fmuladd_f32(float %a, float %b, float %c) nounwind {
 ; RV64IF-NEXT:    fmv.w.x ft0, a2
 ; RV64IF-NEXT:    fmv.w.x ft1, a1
 ; RV64IF-NEXT:    fmv.w.x ft2, a0
-; RV64IF-NEXT:    fmul.s ft1, ft2, ft1
-; RV64IF-NEXT:    fadd.s ft0, ft1, ft0
+; RV64IF-NEXT:    fmadd.s ft0, ft2, ft1, ft0
 ; RV64IF-NEXT:    fmv.x.w a0, ft0
 ; RV64IF-NEXT:    ret
   %1 = call float @llvm.fmuladd.f32(float %a, float %b, float %c)