[AMDGPU] Move V_FMA_MIX pattern matching into tablegen. NFC
authorJustin Bogner <mail@justinbogner.com>
Thu, 23 Feb 2023 01:45:58 +0000 (17:45 -0800)
committerJustin Bogner <mail@justinbogner.com>
Thu, 23 Feb 2023 18:23:34 +0000 (10:23 -0800)
The matching for V_FMA_MIX was partially implemented with a C++
matcher (for fmas with 32 bit results and 16 bit inputs) and partially
in tablegen (for fmas with 16 bit results). Move the C++ matcher logic
into tablegen to make this more consistent and so we can remove the
duplication between SDAG and GISel.

Differential Revision: https://reviews.llvm.org/D144612

llvm/lib/Target/AMDGPU/AMDGPUGISel.td
llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.h
llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp
llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h
llvm/lib/Target/AMDGPU/SIInstrInfo.td
llvm/lib/Target/AMDGPU/VOP3PInstructions.td

index 9ef59a4..0f3e3c0 100644 (file)
@@ -153,6 +153,10 @@ def gi_vop3_mad_mix_mods :
     GIComplexOperandMatcher<s64, "selectVOP3PMadMixMods">,
     GIComplexPatternEquiv<VOP3PMadMixMods>;
 
+def gi_vop3_mad_mix_mods_ext :
+    GIComplexOperandMatcher<s64, "selectVOP3PMadMixModsExt">,
+    GIComplexPatternEquiv<VOP3PMadMixModsExt>;
+
 // Separate load nodes are defined to glue m0 initialization in
 // SelectionDAG. The GISel selector can just insert m0 initialization
 // directly before selecting a glue-less load, so hide this
index 382eeeb..28c26b2 100644 (file)
@@ -663,10 +663,6 @@ void AMDGPUDAGToDAGISel::Select(SDNode *N) {
   case ISD::BRCOND:
     SelectBRCOND(N);
     return;
-  case ISD::FMAD:
-  case ISD::FMA:
-    SelectFMAD_FMA(N);
-    return;
   case AMDGPUISD::CVT_PKRTZ_F16_F32:
   case AMDGPUISD::CVT_PKNORM_I16_F32:
   case AMDGPUISD::CVT_PKNORM_U16_F32:
@@ -2283,52 +2279,6 @@ void AMDGPUDAGToDAGISel::SelectBRCOND(SDNode *N) {
                        VCC.getValue(0));
 }
 
-void AMDGPUDAGToDAGISel::SelectFMAD_FMA(SDNode *N) {
-  MVT VT = N->getSimpleValueType(0);
-  bool IsFMA = N->getOpcode() == ISD::FMA;
-  if (VT != MVT::f32 || (!Subtarget->hasMadMixInsts() &&
-                         !Subtarget->hasFmaMixInsts()) ||
-      ((IsFMA && Subtarget->hasMadMixInsts()) ||
-       (!IsFMA && Subtarget->hasFmaMixInsts()))) {
-    SelectCode(N);
-    return;
-  }
-
-  SDValue Src0 = N->getOperand(0);
-  SDValue Src1 = N->getOperand(1);
-  SDValue Src2 = N->getOperand(2);
-  unsigned Src0Mods, Src1Mods, Src2Mods;
-
-  // Avoid using v_mad_mix_f32/v_fma_mix_f32 unless there is actually an operand
-  // using the conversion from f16.
-  bool Sel0 = SelectVOP3PMadMixModsImpl(Src0, Src0, Src0Mods);
-  bool Sel1 = SelectVOP3PMadMixModsImpl(Src1, Src1, Src1Mods);
-  bool Sel2 = SelectVOP3PMadMixModsImpl(Src2, Src2, Src2Mods);
-
-  assert((IsFMA || !Mode.allFP32Denormals()) &&
-         "fmad selected with denormals enabled");
-  // TODO: We can select this with f32 denormals enabled if all the sources are
-  // converted from f16 (in which case fmad isn't legal).
-
-  if (Sel0 || Sel1 || Sel2) {
-    // For dummy operands.
-    SDValue Zero = CurDAG->getTargetConstant(0, SDLoc(), MVT::i32);
-    SDValue Ops[] = {
-      CurDAG->getTargetConstant(Src0Mods, SDLoc(), MVT::i32), Src0,
-      CurDAG->getTargetConstant(Src1Mods, SDLoc(), MVT::i32), Src1,
-      CurDAG->getTargetConstant(Src2Mods, SDLoc(), MVT::i32), Src2,
-      CurDAG->getTargetConstant(0, SDLoc(), MVT::i1),
-      Zero, Zero
-    };
-
-    CurDAG->SelectNodeTo(N,
-                         IsFMA ? AMDGPU::V_FMA_MIX_F32 : AMDGPU::V_MAD_MIX_F32,
-                         MVT::f32, Ops);
-  } else {
-    SelectCode(N);
-  }
-}
-
 void AMDGPUDAGToDAGISel::SelectDSAppendConsume(SDNode *N, unsigned IntrID) {
   // The address is assumed to be uniform, so if it ends up in a VGPR, it will
   // be copied to an SGPR with readfirstlane.
@@ -2883,6 +2833,15 @@ bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsImpl(SDValue In, SDValue &Src,
   return false;
 }
 
+bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixModsExt(SDValue In, SDValue &Src,
+                                                  SDValue &SrcMods) const {
+  unsigned Mods = 0;
+  if (!SelectVOP3PMadMixModsImpl(In, Src, Mods))
+    return false;
+  SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
+  return true;
+}
+
 bool AMDGPUDAGToDAGISel::SelectVOP3PMadMixMods(SDValue In, SDValue &Src,
                                                SDValue &SrcMods) const {
   unsigned Mods = 0;
index 8c4e378..12912b7 100644 (file)
@@ -248,6 +248,8 @@ private:
   bool SelectVOP3OpSelMods(SDValue In, SDValue &Src, SDValue &SrcMods) const;
   bool SelectVOP3PMadMixModsImpl(SDValue In, SDValue &Src,
                                  unsigned &Mods) const;
+  bool SelectVOP3PMadMixModsExt(SDValue In, SDValue &Src,
+                                SDValue &SrcMods) const;
   bool SelectVOP3PMadMixMods(SDValue In, SDValue &Src, SDValue &SrcMods) const;
 
   SDValue getHi16Elt(SDValue In) const;
index 1749070..7d3536d 100644 (file)
@@ -523,60 +523,6 @@ bool AMDGPUInstructionSelector::selectG_EXTRACT(MachineInstr &I) const {
   return true;
 }
 
-bool AMDGPUInstructionSelector::selectG_FMA_FMAD(MachineInstr &I) const {
-  assert(I.getOpcode() == AMDGPU::G_FMA || I.getOpcode() == AMDGPU::G_FMAD);
-
-  // Try to manually select MAD_MIX/FMA_MIX.
-  Register Dst = I.getOperand(0).getReg();
-  LLT ResultTy = MRI->getType(Dst);
-  bool IsFMA = I.getOpcode() == AMDGPU::G_FMA;
-  if (ResultTy != LLT::scalar(32) ||
-      (IsFMA ? !Subtarget->hasFmaMixInsts() : !Subtarget->hasMadMixInsts()))
-    return false;
-
-  // Avoid using v_mad_mix_f32/v_fma_mix_f32 unless there is actually an operand
-  // using the conversion from f16.
-  bool MatchedSrc0, MatchedSrc1, MatchedSrc2;
-  auto [Src0, Src0Mods] =
-      selectVOP3PMadMixModsImpl(I.getOperand(1), MatchedSrc0);
-  auto [Src1, Src1Mods] =
-      selectVOP3PMadMixModsImpl(I.getOperand(2), MatchedSrc1);
-  auto [Src2, Src2Mods] =
-      selectVOP3PMadMixModsImpl(I.getOperand(3), MatchedSrc2);
-
-#ifndef NDEBUG
-  const SIMachineFunctionInfo *MFI =
-      I.getMF()->getInfo<SIMachineFunctionInfo>();
-  SIModeRegisterDefaults Mode = MFI->getMode();
-  assert((IsFMA || !Mode.allFP32Denormals()) &&
-         "fmad selected with denormals enabled");
-#endif
-
-  // TODO: We can select this with f32 denormals enabled if all the sources are
-  // converted from f16 (in which case fmad isn't legal).
-  if (!MatchedSrc0 && !MatchedSrc1 && !MatchedSrc2)
-    return false;
-
-  const unsigned OpC = IsFMA ? AMDGPU::V_FMA_MIX_F32 : AMDGPU::V_MAD_MIX_F32;
-  MachineInstr *MixInst =
-      BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(OpC), Dst)
-          .addImm(Src0Mods)
-          .addReg(copyToVGPRIfSrcFolded(Src0, Src0Mods, I.getOperand(1), &I))
-          .addImm(Src1Mods)
-          .addReg(copyToVGPRIfSrcFolded(Src1, Src1Mods, I.getOperand(2), &I))
-          .addImm(Src2Mods)
-          .addReg(copyToVGPRIfSrcFolded(Src2, Src2Mods, I.getOperand(3), &I))
-          .addImm(0)
-          .addImm(0)
-          .addImm(0);
-
-  if (!constrainSelectedInstRegOperands(*MixInst, TII, TRI, RBI))
-    return false;
-
-  I.eraseFromParent();
-  return true;
-}
-
 bool AMDGPUInstructionSelector::selectG_MERGE_VALUES(MachineInstr &MI) const {
   MachineBasicBlock *BB = MI.getParent();
   Register DstReg = MI.getOperand(0).getReg();
@@ -3405,11 +3351,6 @@ bool AMDGPUInstructionSelector::select(MachineInstr &I) {
     return selectG_FABS(I);
   case TargetOpcode::G_EXTRACT:
     return selectG_EXTRACT(I);
-  case TargetOpcode::G_FMA:
-  case TargetOpcode::G_FMAD:
-    if (selectG_FMA_FMAD(I))
-      return true;
-    return selectImpl(I, *CoverageInfo);
   case TargetOpcode::G_MERGE_VALUES:
   case TargetOpcode::G_CONCAT_VECTORS:
     return selectG_MERGE_VALUES(I);
@@ -4988,6 +4929,22 @@ AMDGPUInstructionSelector::selectVOP3PMadMixModsImpl(MachineOperand &Root,
 }
 
 InstructionSelector::ComplexRendererFns
+AMDGPUInstructionSelector::selectVOP3PMadMixModsExt(
+    MachineOperand &Root) const {
+  Register Src;
+  unsigned Mods;
+  bool Matched;
+  std::tie(Src, Mods) = selectVOP3PMadMixModsImpl(Root, Matched);
+  if (!Matched)
+    return {};
+
+  return {{
+      [=](MachineInstrBuilder &MIB) { MIB.addReg(Src); },
+      [=](MachineInstrBuilder &MIB) { MIB.addImm(Mods); } // src_mods
+  }};
+}
+
+InstructionSelector::ComplexRendererFns
 AMDGPUInstructionSelector::selectVOP3PMadMixMods(MachineOperand &Root) const {
   Register Src;
   unsigned Mods;
index 99af9dd..0ccf02b 100644 (file)
@@ -297,6 +297,7 @@ private:
 
   std::pair<Register, unsigned> selectVOP3PMadMixModsImpl(MachineOperand &Root,
                                                           bool &Matched) const;
+  ComplexRendererFns selectVOP3PMadMixModsExt(MachineOperand &Root) const;
   ComplexRendererFns selectVOP3PMadMixMods(MachineOperand &Root) const;
 
   void renderTruncImm32(MachineInstrBuilder &MIB, const MachineInstr &MI,
index 8253641..e0fea7d 100644 (file)
@@ -1511,7 +1511,8 @@ def VOP3OpSel  : ComplexPattern<untyped, 2, "SelectVOP3OpSel">;
 
 def VOP3OpSelMods  : ComplexPattern<untyped, 2, "SelectVOP3OpSelMods">;
 
-def VOP3PMadMixMods  : ComplexPattern<untyped, 2, "SelectVOP3PMadMixMods">;
+def VOP3PMadMixModsExt : ComplexPattern<untyped, 2, "SelectVOP3PMadMixModsExt">;
+def VOP3PMadMixMods : ComplexPattern<untyped, 2, "SelectVOP3PMadMixMods">;
 
 def VINTERPMods  : ComplexPattern<untyped, 2, "SelectVINTERPMods">;
 def VINTERPModsHi  : ComplexPattern<untyped, 2, "SelectVINTERPModsHi">;
index 2c7888e..8f8c448 100644 (file)
@@ -142,9 +142,34 @@ def : VOP3PSatPat<usubsat, V_PK_SUB_U16>;
 def : VOP3PSatPat<ssubsat, V_PK_SUB_I16>;
 } // End SubtargetPredicate = HasVOP3PInsts
 
+// TODO: Make sure we're doing the right thing with denormals. Note
+// that FMA and MAD will differ.
 multiclass MadFmaMixPats<SDPatternOperator fma_like,
+                         Instruction mix_inst,
                          Instruction mixlo_inst,
                          Instruction mixhi_inst> {
+  // At least one of the operands needs to be an fpextend of an f16
+  // for this to be worthwhile, so we need three patterns here.
+  // TODO: Could we use a predicate to inspect src1/2/3 instead?
+  def : GCNPat <
+    (f32 (fma_like (f32 (VOP3PMadMixModsExt f16:$src0, i32:$src0_mods)),
+                   (f32 (VOP3PMadMixMods f16:$src1, i32:$src1_mods)),
+                   (f32 (VOP3PMadMixMods f16:$src2, i32:$src2_mods)))),
+    (mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
+              DSTCLAMP.NONE)>;
+  def : GCNPat <
+    (f32 (fma_like (f32 (VOP3PMadMixMods f16:$src0, i32:$src0_mods)),
+                   (f32 (VOP3PMadMixModsExt f16:$src1, i32:$src1_mods)),
+                   (f32 (VOP3PMadMixMods f32:$src2, i32:$src2_mods)))),
+    (mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
+              DSTCLAMP.NONE)>;
+  def : GCNPat <
+    (f32 (fma_like (f32 (VOP3PMadMixMods f16:$src0, i32:$src0_mods)),
+                   (f32 (VOP3PMadMixMods f32:$src1, i32:$src1_mods)),
+                   (f32 (VOP3PMadMixModsExt f16:$src2, i32:$src2_mods)))),
+    (mix_inst $src0_mods, $src0, $src1_mods, $src1, $src2_mods, $src2,
+              DSTCLAMP.NONE)>;
+
   def : GCNPat <
     (f16 (fpround (fma_like (f32 (VOP3PMadMixMods f16:$src0, i32:$src0_modifiers)),
                             (f32 (VOP3PMadMixMods f16:$src1, i32:$src1_modifiers)),
@@ -222,7 +247,7 @@ defm V_MAD_MIXHI_F16 : VOP3_VOP3PInst<"v_mad_mixhi_f16", VOP3P_Mix_Profile<VOP_F
 } // End FPDPRounding = 1
 }
 
-defm : MadFmaMixPats<fmad, V_MAD_MIXLO_F16, V_MAD_MIXHI_F16>;
+defm : MadFmaMixPats<fmad, V_MAD_MIX_F32, V_MAD_MIXLO_F16, V_MAD_MIXHI_F16>;
 } // End SubtargetPredicate = HasMadMixInsts
 
 
@@ -243,7 +268,7 @@ defm V_FMA_MIXHI_F16 : VOP3_VOP3PInst<"v_fma_mixhi_f16", VOP3P_Mix_Profile<VOP_F
 } // End FPDPRounding = 1
 }
 
-defm : MadFmaMixPats<fma, V_FMA_MIXLO_F16, V_FMA_MIXHI_F16>;
+defm : MadFmaMixPats<fma, V_FMA_MIX_F32, V_FMA_MIXLO_F16, V_FMA_MIXHI_F16>;
 }
 
 // Defines patterns that extract signed 4bit from each Idx[0].