[DAGCombine] Add hook to allow target specific test for sqrt input
authorQingShan Zhang <qshanz@cn.ibm.com>
Wed, 25 Nov 2020 05:37:15 +0000 (05:37 +0000)
committerQingShan Zhang <qshanz@cn.ibm.com>
Wed, 25 Nov 2020 05:37:15 +0000 (05:37 +0000)
PowerPC has instruction ftsqrt/xstsqrtdp etc to do the input test for software square root.
LLVM now tests it with smallest normalized value using abs + setcc. We should add hook to
target that has test instructions.

Reviewed By: Spatel, Chen Zheng, Qiu Chao Fang

Differential Revision: https://reviews.llvm.org/D80706

llvm/include/llvm/CodeGen/TargetLowering.h
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/Target/PowerPC/PPCISelLowering.cpp
llvm/lib/Target/PowerPC/PPCISelLowering.h
llvm/lib/Target/PowerPC/PPCInstrFormats.td
llvm/lib/Target/PowerPC/PPCInstrInfo.td
llvm/lib/Target/PowerPC/PPCInstrVSX.td
llvm/test/CodeGen/PowerPC/fma-mutate.ll
llvm/test/CodeGen/PowerPC/recipest.ll

index 164cbd710713206d2e191f0826e8e3cbc197fb56..16580a9160b9a1e7b00a28ddffb7f27328bf673a 100644 (file)
@@ -4277,6 +4277,15 @@ public:
     return SDValue();
   }
 
+  /// Return a target-dependent comparison result if the input operand is
+  /// suitable for use with a square root estimate calculation. For example, the
+  /// comparison may check if the operand is NAN, INF, zero, normal, etc. The
+  /// result should be used as the condition operand for a select or branch.
+  virtual SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG,
+                                   const DenormalMode &Mode) const {
+    return SDValue();
+  }
+
   //===--------------------------------------------------------------------===//
   // Legalization utility functions
   //
index cae602d166d15e4c1e682b1384c3eec73bb586bd..4ac1743d2d3428f4ecb15a07d8ca480c10e8a42a 100644 (file)
@@ -22056,26 +22056,31 @@ SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags Flags,
         // possibly a denormal. Force the answer to 0.0 for those cases.
         SDLoc DL(Op);
         EVT CCVT = getSetCCResultType(VT);
-        ISD::NodeType SelOpcode = VT.isVector() ? ISD::VSELECT : ISD::SELECT;
+        SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
         DenormalMode DenormMode = DAG.getDenormalMode(VT);
-        if (DenormMode.Input == DenormalMode::IEEE) {
-          // This is specifically a check for the handling of denormal inputs,
-          // not the result.
-
-          // fabs(X) < SmallestNormal ? 0.0 : Est
-          const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT);
-          APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
-          SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
-          SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
-          SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
-          SDValue IsDenorm = DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
-          Est = DAG.getNode(SelOpcode, DL, VT, IsDenorm, FPZero, Est);
-        } else {
-          // X == 0.0 ? 0.0 : Est
-          SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
-          SDValue IsZero = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
-          Est = DAG.getNode(SelOpcode, DL, VT, IsZero, FPZero, Est);
+        // Try the target specific test first.
+        SDValue Test = TLI.getSqrtInputTest(Op, DAG, DenormMode);
+        if (!Test) {
+          // If no test provided by target, testing it with denormal inputs to
+          // avoid wrong estimate.
+          if (DenormMode.Input == DenormalMode::IEEE) {
+            // This is specifically a check for the handling of denormal inputs,
+            // not the result.
+
+            // Test = fabs(X) < SmallestNormal
+            const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT);
+            APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem);
+            SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT);
+            SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op);
+            Test = DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT);
+          } else
+            // Test = X == 0.0
+            Test = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
         }
+        // Test ? 0.0 : Est
+        Est = DAG.getNode(Test.getValueType().isVector() ? ISD::VSELECT
+                                                         : ISD::SELECT,
+                          DL, VT, Test, FPZero, Est);
       }
     }
     return Est;
index 10aecf97fcdf1a84491ba8dcc833de0603a6fd4d..d19fbd477d77e08cfd7f117cce699f5366126937 100644 (file)
@@ -1447,6 +1447,8 @@ const char *PPCTargetLowering::getTargetNodeName(unsigned Opcode) const {
                                 return "PPCISD::FP_TO_SINT_IN_VSR";
   case PPCISD::FRE:             return "PPCISD::FRE";
   case PPCISD::FRSQRTE:         return "PPCISD::FRSQRTE";
+  case PPCISD::FTSQRT:
+    return "PPCISD::FTSQRT";
   case PPCISD::STFIWX:          return "PPCISD::STFIWX";
   case PPCISD::VPERM:           return "PPCISD::VPERM";
   case PPCISD::XXSPLT:          return "PPCISD::XXSPLT";
@@ -12758,6 +12760,33 @@ static int getEstimateRefinementSteps(EVT VT, const PPCSubtarget &Subtarget) {
   return RefinementSteps;
 }
 
+SDValue PPCTargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
+                                            const DenormalMode &Mode) const {
+  // TODO - add support for v2f64/v4f32
+  EVT VT = Op.getValueType();
+  if (VT != MVT::f64)
+    return SDValue();
+
+  SDLoc DL(Op);
+  // The output register of FTSQRT is CR field.
+  SDValue FTSQRT = DAG.getNode(PPCISD::FTSQRT, DL, MVT::i32, Op);
+  // ftsqrt BF,FRB
+  // Let e_b be the unbiased exponent of the double-precision
+  // floating-point operand in register FRB.
+  // fe_flag is set to 1 if either of the following conditions occurs.
+  //   - The double-precision floating-point operand in register FRB is a zero,
+  //     a NaN, or an infinity, or a negative value.
+  //   - e_b is less than or equal to -970.
+  // Otherwise fe_flag is set to 0.
+  // Both VSX and non-VSX versions would set EQ bit in the CR if the number is
+  // not eligible for iteration. (zero/negative/infinity/nan or unbiased
+  // exponent is less than -970)
+  SDValue SRIdxVal = DAG.getTargetConstant(PPC::sub_eq, DL, MVT::i32);
+  return SDValue(DAG.getMachineNode(TargetOpcode::EXTRACT_SUBREG, DL, MVT::i1,
+                                    FTSQRT, SRIdxVal),
+                 0);
+}
+
 SDValue PPCTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
                                            int Enabled, int &RefinementSteps,
                                            bool &UseOneConstNR,
index 414a355264f834c8538c4b1853d72d14b966fd04..6c4899fae22cbf79c03445b724e25b659b03058e 100644 (file)
@@ -89,6 +89,9 @@ namespace llvm {
     FRE,
     FRSQRTE,
 
+    /// Test instruction for software square root.
+    FTSQRT,
+
     /// VPERM - The PPC VPERM Instruction.
     ///
     VPERM,
@@ -1283,6 +1286,8 @@ namespace llvm {
                             bool Reciprocal) const override;
     SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
                              int &RefinementSteps) const override;
+    SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG,
+                             const DenormalMode &Mode) const override;
     unsigned combineRepeatedFPDivisors() const override;
 
     SDValue
index 5ff5fc78326ba8bd8ccc1d7dfecf02aca02982c7..646efe64a22c7cc53241e5c9342e179845eede6b 100644 (file)
@@ -637,9 +637,10 @@ class XForm_17<bits<6> opcode, bits<10> xo, dag OOL, dag IOL, string asmstr,
 }
 
 class XForm_17a<bits<6> opcode, bits<10> xo, dag OOL, dag IOL, string asmstr,
-               InstrItinClass itin>
+               InstrItinClass itin, list<dag> pattern>
   : XForm_17<opcode, xo, OOL, IOL, asmstr, itin > {
   let FRA = 0;
+  let Pattern = pattern;
 }
 
 class XForm_18<bits<6> opcode, bits<10> xo, dag OOL, dag IOL, string asmstr,
index 2e77d04d4a79e202ec0dcc037dac0d3b56fd4dbb..de9ae99adac731a0e64eb7cc134bfec531a8dcd4 100644 (file)
@@ -74,6 +74,9 @@ def SDT_PPCcondbr : SDTypeProfile<0, 3, [
   SDTCisVT<0, i32>, SDTCisVT<2, OtherVT>
 ]>;
 
+def SDT_PPCFtsqrt : SDTypeProfile<1, 1, [
+  SDTCisVT<0, i32>]>;
+
 def SDT_PPClbrx : SDTypeProfile<1, 2, [
   SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisVT<2, OtherVT>
 ]>;
@@ -124,6 +127,7 @@ def SDT_PPCFPMinMax : SDTypeProfile<1, 2, [
 
 def PPCfre    : SDNode<"PPCISD::FRE",     SDTFPUnaryOp, []>;
 def PPCfrsqrte: SDNode<"PPCISD::FRSQRTE", SDTFPUnaryOp, []>;
+def PPCftsqrt : SDNode<"PPCISD::FTSQRT",  SDT_PPCFtsqrt,[]>;
 
 def PPCfcfid  : SDNode<"PPCISD::FCFID",   SDTFPUnaryOp, []>;
 def PPCfcfidu : SDNode<"PPCISD::FCFIDU",  SDTFPUnaryOp, []>;
@@ -2643,7 +2647,8 @@ let isCompare = 1, mayRaiseFPException = 1, hasSideEffects = 0 in {
 def FTDIV: XForm_17<63, 128, (outs crrc:$crD), (ins f8rc:$fA, f8rc:$fB),
                       "ftdiv $crD, $fA, $fB", IIC_FPCompare>;
 def FTSQRT: XForm_17a<63, 160, (outs crrc:$crD), (ins f8rc:$fB),
-                      "ftsqrt $crD, $fB", IIC_FPCompare>;
+                      "ftsqrt $crD, $fB", IIC_FPCompare,
+                      [(set i32:$crD, (PPCftsqrt f64:$fB))]>;
 
 let mayRaiseFPException = 1, hasSideEffects = 0 in {
   let Interpretation64Bit = 1, isCodeGenOnly = 1 in
index 1ffbd405d87aa2c80dd1dadb467877227938d2ae..b023c0596063975db48d42b0721157b8cf91bbd7 100644 (file)
@@ -629,7 +629,8 @@ let hasSideEffects = 0 in {
                          "xstdivdp $crD, $XA, $XB", IIC_FPCompare, []>;
   def XSTSQRTDP : XX2Form_1<60, 106,
                           (outs crrc:$crD), (ins vsfrc:$XB),
-                          "xstsqrtdp $crD, $XB", IIC_FPCompare, []>;
+                          "xstsqrtdp $crD, $XB", IIC_FPCompare,
+                          [(set i32:$crD, (PPCftsqrt f64:$XB))]>;
   def XVTDIVDP : XX3Form_1<60, 125,
                          (outs crrc:$crD), (ins vsrc:$XA, vsrc:$XB),
                          "xvtdivdp $crD, $XA, $XB", IIC_FPCompare, []>;
index a1e3473edf222a68bac15120663562ef28fdfacd..62cce7362c682ebed1fb31ae956786b6bf43cf1b 100644 (file)
@@ -9,12 +9,9 @@ declare double @llvm.sqrt.f64(double)
 define double @foo3_fmf(double %a) nounwind {
 ; CHECK-LABEL: foo3_fmf:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    xsabsdp 0, 1
-; CHECK-NEXT:    addis 3, 2, .LCPI0_2@toc@ha
-; CHECK-NEXT:    lfd 2, .LCPI0_2@toc@l(3)
-; CHECK-NEXT:    xscmpudp 0, 0, 2
+; CHECK-NEXT:    xstsqrtdp 0, 1
 ; CHECK-NEXT:    xxlxor 0, 0, 0
-; CHECK-NEXT:    blt 0, .LBB0_2
+; CHECK-NEXT:    bc 12, 2, .LBB0_2
 ; CHECK-NEXT:  # %bb.1:
 ; CHECK-NEXT:    xsrsqrtedp 0, 1
 ; CHECK-NEXT:    addis 3, 2, .LCPI0_0@toc@ha
index e3894bcd23f5a3a4ee9b4d3914a7ee4a10af9237..cd8520b35ffad555aeb4ce33dc5b1fb29a4f1550 100644 (file)
@@ -749,11 +749,8 @@ define <4 x float> @hoo2_safe(<4 x float> %a, <4 x float> %b) nounwind {
 define double @foo3_fmf(double %a) nounwind {
 ; CHECK-P7-LABEL: foo3_fmf:
 ; CHECK-P7:       # %bb.0:
-; CHECK-P7-NEXT:    fabs 0, 1
-; CHECK-P7-NEXT:    addis 3, 2, .LCPI20_2@toc@ha
-; CHECK-P7-NEXT:    lfd 2, .LCPI20_2@toc@l(3)
-; CHECK-P7-NEXT:    fcmpu 0, 0, 2
-; CHECK-P7-NEXT:    blt 0, .LBB20_2
+; CHECK-P7-NEXT:    ftsqrt 0, 1
+; CHECK-P7-NEXT:    bc 12, 2, .LBB20_2
 ; CHECK-P7-NEXT:  # %bb.1:
 ; CHECK-P7-NEXT:    frsqrte 0, 1
 ; CHECK-P7-NEXT:    addis 3, 2, .LCPI20_0@toc@ha
@@ -770,18 +767,15 @@ define double @foo3_fmf(double %a) nounwind {
 ; CHECK-P7-NEXT:    fmul 1, 1, 0
 ; CHECK-P7-NEXT:    blr
 ; CHECK-P7-NEXT:  .LBB20_2:
-; CHECK-P7-NEXT:    addis 3, 2, .LCPI20_3@toc@ha
-; CHECK-P7-NEXT:    lfs 1, .LCPI20_3@toc@l(3)
+; CHECK-P7-NEXT:    addis 3, 2, .LCPI20_2@toc@ha
+; CHECK-P7-NEXT:    lfs 1, .LCPI20_2@toc@l(3)
 ; CHECK-P7-NEXT:    blr
 ;
 ; CHECK-P8-LABEL: foo3_fmf:
 ; CHECK-P8:       # %bb.0:
-; CHECK-P8-NEXT:    xsabsdp 0, 1
-; CHECK-P8-NEXT:    addis 3, 2, .LCPI20_2@toc@ha
-; CHECK-P8-NEXT:    lfd 2, .LCPI20_2@toc@l(3)
-; CHECK-P8-NEXT:    xscmpudp 0, 0, 2
+; CHECK-P8-NEXT:    xstsqrtdp 0, 1
 ; CHECK-P8-NEXT:    xxlxor 0, 0, 0
-; CHECK-P8-NEXT:    blt 0, .LBB20_2
+; CHECK-P8-NEXT:    bc 12, 2, .LBB20_2
 ; CHECK-P8-NEXT:  # %bb.1:
 ; CHECK-P8-NEXT:    xsrsqrtedp 0, 1
 ; CHECK-P8-NEXT:    addis 3, 2, .LCPI20_0@toc@ha
@@ -803,12 +797,9 @@ define double @foo3_fmf(double %a) nounwind {
 ;
 ; CHECK-P9-LABEL: foo3_fmf:
 ; CHECK-P9:       # %bb.0:
-; CHECK-P9-NEXT:    addis 3, 2, .LCPI20_2@toc@ha
-; CHECK-P9-NEXT:    xsabsdp 0, 1
-; CHECK-P9-NEXT:    lfd 2, .LCPI20_2@toc@l(3)
-; CHECK-P9-NEXT:    xscmpudp 0, 0, 2
+; CHECK-P9-NEXT:    xstsqrtdp 0, 1
 ; CHECK-P9-NEXT:    xxlxor 0, 0, 0
-; CHECK-P9-NEXT:    blt 0, .LBB20_2
+; CHECK-P9-NEXT:    bc 12, 2, .LBB20_2
 ; CHECK-P9-NEXT:  # %bb.1:
 ; CHECK-P9-NEXT:    xsrsqrtedp 0, 1
 ; CHECK-P9-NEXT:    addis 3, 2, .LCPI20_0@toc@ha
@@ -1038,18 +1029,18 @@ define <2 x double> @hoo4_fmf(<2 x double> %a) #1 {
 ; CHECK-P7-LABEL: hoo4_fmf:
 ; CHECK-P7:       # %bb.0:
 ; CHECK-P7-NEXT:    addis 3, 2, .LCPI26_2@toc@ha
+; CHECK-P7-NEXT:    ftsqrt 0, 1
 ; CHECK-P7-NEXT:    fmr 3, 1
-; CHECK-P7-NEXT:    addis 4, 2, .LCPI26_1@toc@ha
+; CHECK-P7-NEXT:    addis 4, 2, .LCPI26_0@toc@ha
 ; CHECK-P7-NEXT:    lfs 0, .LCPI26_2@toc@l(3)
-; CHECK-P7-NEXT:    addis 3, 2, .LCPI26_0@toc@ha
-; CHECK-P7-NEXT:    lfs 4, .LCPI26_1@toc@l(4)
-; CHECK-P7-NEXT:    lfs 5, .LCPI26_0@toc@l(3)
-; CHECK-P7-NEXT:    fcmpu 0, 1, 0
+; CHECK-P7-NEXT:    addis 3, 2, .LCPI26_1@toc@ha
+; CHECK-P7-NEXT:    lfs 5, .LCPI26_0@toc@l(4)
+; CHECK-P7-NEXT:    lfs 4, .LCPI26_1@toc@l(3)
 ; CHECK-P7-NEXT:    fmr 1, 0
-; CHECK-P7-NEXT:    bne 0, .LBB26_3
+; CHECK-P7-NEXT:    bc 4, 2, .LBB26_3
 ; CHECK-P7-NEXT:  # %bb.1:
-; CHECK-P7-NEXT:    fcmpu 0, 2, 0
-; CHECK-P7-NEXT:    bne 0, .LBB26_4
+; CHECK-P7-NEXT:    ftsqrt 0, 2
+; CHECK-P7-NEXT:    bc 4, 2, .LBB26_4
 ; CHECK-P7-NEXT:  .LBB26_2:
 ; CHECK-P7-NEXT:    fmr 2, 0
 ; CHECK-P7-NEXT:    blr
@@ -1063,8 +1054,8 @@ define <2 x double> @hoo4_fmf(<2 x double> %a) #1 {
 ; CHECK-P7-NEXT:    fmadd 1, 3, 1, 5
 ; CHECK-P7-NEXT:    fmul 3, 3, 4
 ; CHECK-P7-NEXT:    fmul 1, 3, 1
-; CHECK-P7-NEXT:    fcmpu 0, 2, 0
-; CHECK-P7-NEXT:    beq 0, .LBB26_2
+; CHECK-P7-NEXT:    ftsqrt 0, 2
+; CHECK-P7-NEXT:    bc 12, 2, .LBB26_2
 ; CHECK-P7-NEXT:  .LBB26_4:
 ; CHECK-P7-NEXT:    frsqrte 0, 2
 ; CHECK-P7-NEXT:    fmul 3, 2, 0