DAG: Add function context to isFMAFasterThanFMulAndFAdd
authorMatt Arsenault <Matthew.Arsenault@amd.com>
Tue, 29 Oct 2019 00:38:44 +0000 (17:38 -0700)
committerMatt Arsenault <arsenm2@gmail.com>
Tue, 19 Nov 2019 13:55:26 +0000 (19:25 +0530)
AMDGPU needs to know the FP mode for the function to answer this
correctly when this is removed from the subtarget.

AArch64 had to make this more complicated by using this from an IR
hook, so add an IR typed overload.

19 files changed:
llvm/include/llvm/CodeGen/TargetLowering.h
llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/lib/Target/AMDGPU/SIISelLowering.cpp
llvm/lib/Target/AMDGPU/SIISelLowering.h
llvm/lib/Target/ARM/ARMISelLowering.cpp
llvm/lib/Target/ARM/ARMISelLowering.h
llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
llvm/lib/Target/Hexagon/HexagonISelLowering.h
llvm/lib/Target/NVPTX/NVPTXISelLowering.h
llvm/lib/Target/PowerPC/PPCISelLowering.cpp
llvm/lib/Target/PowerPC/PPCISelLowering.h
llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
llvm/lib/Target/SystemZ/SystemZISelLowering.h
llvm/lib/Target/X86/X86ISelLowering.cpp
llvm/lib/Target/X86/X86ISelLowering.h

index 7fe8ffb..e4adbf4 100644 (file)
@@ -2528,7 +2528,13 @@ public:
   /// not legal, but should return true if those types will eventually legalize
   /// to types that support FMAs. After legalization, it will only be called on
   /// types that support FMAs (via Legal or Custom actions)
-  virtual bool isFMAFasterThanFMulAndFAdd(EVT) const {
+  virtual bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
+                                          EVT) const {
+    return false;
+  }
+
+  /// IR version
+  virtual bool isFMAFasterThanFMulAndFAdd(const Function &F, Type *) const {
     return false;
   }
 
@@ -3763,7 +3769,7 @@ public:
   /// Should SelectionDAG lower an atomic store of the given kind as a normal
   /// StoreSDNode (as opposed to an AtomicSDNode)?  NOTE: The intention is to
   /// eventually migrate all targets to the using StoreSDNodes, but porting is
-  /// being done target at a time.  
+  /// being done target at a time.
   virtual bool lowerAtomicStoreAsStoreSDNode(const StoreInst &SI) const {
     assert(SI.isAtomic() && "violated precondition");
     return false;
index 3f2826c..5b0ad08 100644 (file)
@@ -1417,7 +1417,8 @@ bool IRTranslator::translateKnownIntrinsic(const CallInst &CI, Intrinsic::ID ID,
     Register Op1 = getOrCreateVReg(*CI.getArgOperand(1));
     Register Op2 = getOrCreateVReg(*CI.getArgOperand(2));
     if (TM.Options.AllowFPOpFusion != FPOpFusion::Strict &&
-        TLI.isFMAFasterThanFMulAndFAdd(TLI.getValueType(*DL, CI.getType()))) {
+        TLI.isFMAFasterThanFMulAndFAdd(*MF,
+                                       TLI.getValueType(*DL, CI.getType()))) {
       // TODO: Revisit this to see if we should move this part of the
       // lowering to the combiner.
       MIRBuilder.buildInstr(TargetOpcode::G_FMA, {Dst}, {Op0, Op1, Op2},
index 9780b69..2db3027 100644 (file)
@@ -11337,7 +11337,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) {
 
   // Floating-point multiply-add without intermediate rounding.
   bool HasFMA =
-      TLI.isFMAFasterThanFMulAndFAdd(VT) &&
+      TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
 
   // No valid opcode, do not combine.
@@ -11554,7 +11554,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) {
 
   // Floating-point multiply-add without intermediate rounding.
   bool HasFMA =
-      TLI.isFMAFasterThanFMulAndFAdd(VT) &&
+      TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
 
   // No valid opcode, do not combine.
@@ -11860,7 +11860,7 @@ SDValue DAGCombiner::visitFMULForFMADistributiveCombine(SDNode *N) {
   // Floating-point multiply-add without intermediate rounding.
   bool HasFMA =
       (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) &&
-      TLI.isFMAFasterThanFMulAndFAdd(VT) &&
+      TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT) &&
       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT));
 
   // Floating-point multiply-add with intermediate rounding. This can result
index 1a42150..3f41a24 100644 (file)
@@ -6169,7 +6169,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
   case Intrinsic::fmuladd: {
     EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType());
     if (TM.Options.AllowFPOpFusion != FPOpFusion::Strict &&
-        TLI.isFMAFasterThanFMulAndFAdd(VT)) {
+        TLI.isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT)) {
       setValue(&I, DAG.getNode(ISD::FMA, sdl,
                                getValue(I.getArgOperand(0)).getValueType(),
                                getValue(I.getArgOperand(0)),
index a9471a7..9e8df33 100644 (file)
@@ -8546,11 +8546,12 @@ bool AArch64TargetLowering::isProfitableToHoist(Instruction *I) const {
     return true;
 
   const TargetOptions &Options = getTargetMachine().Options;
-  const DataLayout &DL = I->getModule()->getDataLayout();
-  EVT VT = getValueType(DL, User->getOperand(0)->getType());
+  const Function *F = I->getFunction();
+  const DataLayout &DL = F->getParent()->getDataLayout();
+  Type *Ty = User->getOperand(0)->getType();
 
-  return !(isFMAFasterThanFMulAndFAdd(VT) &&
-           isOperationLegalOrCustom(ISD::FMA, VT) &&
+  return !(isFMAFasterThanFMulAndFAdd(*F, Ty) &&
+           isOperationLegalOrCustom(ISD::FMA, getValueType(DL, Ty)) &&
            (Options.AllowFPOpFusion == FPOpFusion::Fast ||
             Options.UnsafeFPMath));
 }
@@ -9207,7 +9208,8 @@ int AArch64TargetLowering::getScalingFactorCost(const DataLayout &DL,
   return -1;
 }
 
-bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(EVT VT) const {
+bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(
+    const MachineFunction &MF, EVT VT) const {
   VT = VT.getScalarType();
 
   if (!VT.isSimple())
@@ -9224,6 +9226,17 @@ bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(EVT VT) const {
   return false;
 }
 
+bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(const Function &F,
+                                                       Type *Ty) const {
+  switch (Ty->getScalarType()->getTypeID()) {
+  case Type::FloatTyID:
+  case Type::DoubleTyID:
+    return true;
+  default:
+    return false;
+  }
+}
+
 const MCPhysReg *
 AArch64TargetLowering::getScratchRegisters(CallingConv::ID) const {
   // LR is a callee-save register, but we must treat it as clobbered by any call
index 5a76f0c..384c7b4 100644 (file)
@@ -396,7 +396,9 @@ public:
   /// Return true if an FMA operation is faster than a pair of fmul and fadd
   /// instructions. fmuladd intrinsics will be expanded to FMAs when this method
   /// returns true, otherwise fmuladd is expanded to fmul + fadd.
-  bool isFMAFasterThanFMulAndFAdd(EVT VT) const override;
+  bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
+                                  EVT VT) const override;
+  bool isFMAFasterThanFMulAndFAdd(const Function &F, Type *Ty) const override;
 
   const MCPhysReg *getScratchRegisters(CallingConv::ID CC) const override;
 
index 85af397..1a02037 100644 (file)
@@ -3920,7 +3920,8 @@ MVT SITargetLowering::getScalarShiftAmountTy(const DataLayout &, EVT VT) const {
 // however does not support denormals, so we do report fma as faster if we have
 // a fast fma device and require denormals.
 //
-bool SITargetLowering::isFMAFasterThanFMulAndFAdd(EVT VT) const {
+bool SITargetLowering::isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
+                                                  EVT VT) const {
   VT = VT.getScalarType();
 
   switch (VT.getSimpleVT().SimpleTy) {
@@ -9461,7 +9462,7 @@ unsigned SITargetLowering::getFusedOpcode(const SelectionDAG &DAG,
   if ((Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath ||
        (N0->getFlags().hasAllowContract() &&
         N1->getFlags().hasAllowContract())) &&
-      isFMAFasterThanFMulAndFAdd(VT)) {
+      isFMAFasterThanFMulAndFAdd(DAG.getMachineFunction(), VT)) {
     return ISD::FMA;
   }
 
index c99904c..b2c2e40 100644 (file)
@@ -349,7 +349,8 @@ public:
   EVT getSetCCResultType(const DataLayout &DL, LLVMContext &Context,
                          EVT VT) const override;
   MVT getScalarShiftAmountTy(const DataLayout &, EVT) const override;
-  bool isFMAFasterThanFMulAndFAdd(EVT VT) const override;
+  bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
+                                  EVT VT) const override;
   bool isFMADLegalForFAddFSub(const SelectionDAG &DAG,
                               const SDNode *N) const override;
 
index c931400..a33535e 100644 (file)
@@ -14826,7 +14826,8 @@ int ARMTargetLowering::getScalingFactorCost(const DataLayout &DL,
 ///
 /// For MVE, we set this to true as it helps simplify the need for some
 /// patterns (and we don't have the non-fused floating point instruction).
-bool ARMTargetLowering::isFMAFasterThanFMulAndFAdd(EVT VT) const {
+bool ARMTargetLowering::isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
+                                                   EVT VT) const {
   if (!Subtarget->hasMVEFloatOps())
     return false;
 
index 0aee61f..367a40b 100644 (file)
@@ -738,7 +738,8 @@ class VectorType;
     SDValue BuildSDIVPow2(SDNode *N, const APInt &Divisor, SelectionDAG &DAG,
                           SmallVectorImpl<SDNode *> &Created) const override;
 
-    bool isFMAFasterThanFMulAndFAdd(EVT VT) const override;
+    bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
+                                    EVT VT) const override;
 
     SDValue ReconstructShuffle(SDValue Op, SelectionDAG &DAG) const;
 
index 7345100..1d7aa2c 100644 (file)
@@ -1909,7 +1909,8 @@ bool HexagonTargetLowering::isTruncateFree(EVT VT1, EVT VT2) const {
   return VT1.getSimpleVT() == MVT::i64 && VT2.getSimpleVT() == MVT::i32;
 }
 
-bool HexagonTargetLowering::isFMAFasterThanFMulAndFAdd(EVT VT) const {
+bool HexagonTargetLowering::isFMAFasterThanFMulAndFAdd(
+    const MachineFunction &MF, EVT VT) const {
   return isOperationLegalOrCustom(ISD::FMA, VT);
 }
 
index 75f553b..ed207a7 100644 (file)
@@ -137,7 +137,8 @@ namespace HexagonISD {
     /// instructions. fmuladd intrinsics will be expanded to FMAs when this
     /// method returns true (and FMAs are legal), otherwise fmuladd is
     /// expanded to mul + add.
-    bool isFMAFasterThanFMulAndFAdd(EVT) const override;
+    bool isFMAFasterThanFMulAndFAdd(const MachineFunction &,
+                                    EVT) const override;
 
     // Should we expand the build vector with shuffles?
     bool shouldExpandBuildVectorWithShuffles(EVT VT,
index ef645fc..546fe49 100644 (file)
@@ -538,7 +538,10 @@ public:
   bool allowFMA(MachineFunction &MF, CodeGenOpt::Level OptLevel) const;
   bool allowUnsafeFPMath(MachineFunction &MF) const;
 
-  bool isFMAFasterThanFMulAndFAdd(EVT) const override { return true; }
+  bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
+                                  EVT) const override {
+    return true;
+  }
 
   bool enableAggressiveFMAFusion(EVT VT) const override { return true; }
 
index a2ffe9e..313d6b8 100644 (file)
@@ -14948,7 +14948,8 @@ bool PPCTargetLowering::allowsMisalignedMemoryAccesses(EVT VT,
   return true;
 }
 
-bool PPCTargetLowering::isFMAFasterThanFMulAndFAdd(EVT VT) const {
+bool PPCTargetLowering::isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
+                                                   EVT VT) const {
   VT = VT.getScalarType();
 
   if (!VT.isSimple())
index a011343..77b19b2 100644 (file)
@@ -907,7 +907,8 @@ namespace llvm {
     /// than a pair of fmul and fadd instructions. fmuladd intrinsics will be
     /// expanded to FMAs when this method returns true, otherwise fmuladd is
     /// expanded to fmul + fadd.
-    bool isFMAFasterThanFMulAndFAdd(EVT VT) const override;
+    bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
+                                    EVT VT) const override;
 
     const MCPhysReg *getScratchRegisters(CallingConv::ID CC) const override;
 
index daef108..42c1880 100644 (file)
@@ -643,7 +643,8 @@ EVT SystemZTargetLowering::getSetCCResultType(const DataLayout &DL,
   return VT.changeVectorElementTypeToInteger();
 }
 
-bool SystemZTargetLowering::isFMAFasterThanFMulAndFAdd(EVT VT) const {
+bool SystemZTargetLowering::isFMAFasterThanFMulAndFAdd(
+    const MachineFunction &MF, EVT VT) const {
   VT = VT.getScalarType();
 
   if (!VT.isSimple())
index 7391365..f774b8a 100644 (file)
@@ -404,7 +404,8 @@ public:
   bool isCheapToSpeculateCtlz() const override { return true; }
   EVT getSetCCResultType(const DataLayout &DL, LLVMContext &,
                          EVT) const override;
-  bool isFMAFasterThanFMulAndFAdd(EVT VT) const override;
+  bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
+                                  EVT VT) const override;
   bool isFPImmLegal(const APFloat &Imm, EVT VT,
                     bool ForCodeSize) const override;
   bool isLegalICmpImmediate(int64_t Imm) const override;
index 6bb2d1e..bcb091e 100644 (file)
@@ -29115,8 +29115,8 @@ bool X86TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const {
   return true;
 }
 
-bool
-X86TargetLowering::isFMAFasterThanFMulAndFAdd(EVT VT) const {
+bool X86TargetLowering::isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
+                                                   EVT VT) const {
   if (!Subtarget.hasAnyFMA())
     return false;
 
index 6f7e900..184983d 100644 (file)
@@ -1056,7 +1056,8 @@ namespace llvm {
     /// Return true if an FMA operation is faster than a pair of fmul and fadd
     /// instructions. fmuladd intrinsics will be expanded to FMAs when this
     /// method returns true, otherwise fmuladd is expanded to fmul + fadd.
-    bool isFMAFasterThanFMulAndFAdd(EVT VT) const override;
+    bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
+                                    EVT VT) const override;
 
     /// Return true if it's profitable to narrow
     /// operations of type VT1 to VT2. e.g. on x86, it's profitable to narrow