[AArch64] Compute the Newton series for reciprocals natively
authorEvandro Menezes <e.menezes@samsung.com>
Mon, 14 Nov 2016 23:29:01 +0000 (23:29 +0000)
committerEvandro Menezes <e.menezes@samsung.com>
Mon, 14 Nov 2016 23:29:01 +0000 (23:29 +0000)
Implement the Newton series for square root, its reciprocal and reciprocal
natively using the specialized instructions in AArch64 to perform each
series iteration.

Differential revision: https://reviews.llvm.org/D26518

llvm-svn: 286907

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/lib/Target/AArch64/AArch64InstrInfo.td
llvm/test/CodeGen/AArch64/recp-fastmath.ll
llvm/test/CodeGen/AArch64/sqrt-fastmath.ll

index 803481e..403021e 100644 (file)
@@ -959,8 +959,10 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
   case AArch64ISD::ST4LANEpost:       return "AArch64ISD::ST4LANEpost";
   case AArch64ISD::SMULL:             return "AArch64ISD::SMULL";
   case AArch64ISD::UMULL:             return "AArch64ISD::UMULL";
-  case AArch64ISD::FRSQRTE:           return "AArch64ISD::FRSQRTE";
   case AArch64ISD::FRECPE:            return "AArch64ISD::FRECPE";
+  case AArch64ISD::FRECPS:            return "AArch64ISD::FRECPS";
+  case AArch64ISD::FRSQRTE:           return "AArch64ISD::FRSQRTE";
+  case AArch64ISD::FRSQRTS:           return "AArch64ISD::FRSQRTS";
   }
   return nullptr;
 }
@@ -4653,7 +4655,34 @@ SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand,
       (Enabled == ReciprocalEstimate::Unspecified && Subtarget->useRSqrt()))
     if (SDValue Estimate = getEstimate(Subtarget, AArch64ISD::FRSQRTE, Operand,
                                        DAG, ExtraSteps)) {
-      UseOneConst = true;
+      SDLoc DL(Operand);
+      EVT VT = Operand.getValueType();
+
+      SDNodeFlags Flags;
+      Flags.setUnsafeAlgebra(true);
+
+      // Newton reciprocal square root iteration: E * 0.5 * (3 - X * E^2)
+      // AArch64 reciprocal square root iteration instruction: 0.5 * (3 - M * N)
+      for (int i = ExtraSteps; i > 0; --i) {
+        SDValue Step = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Estimate,
+                                   &Flags);
+        Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, Operand, Step, &Flags);
+        Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, &Flags);
+      }
+
+      if (!Reciprocal) {
+        EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(),
+                                      VT);
+        SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
+        SDValue Eq = DAG.getSetCC(DL, CCVT, Operand, FPZero, ISD::SETEQ);
+
+        Estimate = DAG.getNode(ISD::FMUL, DL, VT, Operand, Estimate, &Flags);
+        // Correct the result if the operand is 0.0.
+        Estimate = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL,
+                               VT, Eq, Operand, Estimate);
+      }
+
+      ExtraSteps = 0;
       return Estimate;
     }
 
@@ -4665,8 +4694,24 @@ SDValue AArch64TargetLowering::getRecipEstimate(SDValue Operand,
                                                 int &ExtraSteps) const {
   if (Enabled == ReciprocalEstimate::Enabled)
     if (SDValue Estimate = getEstimate(Subtarget, AArch64ISD::FRECPE, Operand,
-                                       DAG, ExtraSteps))
+                                       DAG, ExtraSteps)) {
+      SDLoc DL(Operand);
+      EVT VT = Operand.getValueType();
+
+      SDNodeFlags Flags;
+      Flags.setUnsafeAlgebra(true);
+
+      // Newton reciprocal iteration: E * (2 - X * E)
+      // AArch64 reciprocal iteration instruction: (2 - M * N)
+      for (int i = ExtraSteps; i > 0; --i) {
+        SDValue Step = DAG.getNode(AArch64ISD::FRECPS, DL, VT, Operand,
+                                   Estimate, &Flags);
+        Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, &Flags);
+      }
+
+      ExtraSteps = 0;
       return Estimate;
+    }
 
   return SDValue();
 }
index 7b317d6..7867d7c 100644 (file)
@@ -187,9 +187,9 @@ enum NodeType : unsigned {
   SMULL,
   UMULL,
 
-  // Reciprocal estimates.
-  FRECPE,
-  FRSQRTE,
+  // Reciprocal estimates and steps.
+  FRECPE, FRECPS,
+  FRSQRTE, FRSQRTS,
 
   // NEON Load/Store with post-increment base updates
   LD2post = ISD::FIRST_TARGET_MEMORY_OPCODE,
index e8386a6..3bed500 100644 (file)
@@ -287,7 +287,9 @@ def AArch64smull    : SDNode<"AArch64ISD::SMULL", SDT_AArch64mull>;
 def AArch64umull    : SDNode<"AArch64ISD::UMULL", SDT_AArch64mull>;
 
 def AArch64frecpe   : SDNode<"AArch64ISD::FRECPE", SDTFPUnaryOp>;
+def AArch64frecps   : SDNode<"AArch64ISD::FRECPS", SDTFPBinOp>;
 def AArch64frsqrte  : SDNode<"AArch64ISD::FRSQRTE", SDTFPUnaryOp>;
+def AArch64frsqrts  : SDNode<"AArch64ISD::FRSQRTS", SDTFPBinOp>;
 
 def AArch64saddv    : SDNode<"AArch64ISD::SADDV", SDT_AArch64UnaryVec>;
 def AArch64uaddv    : SDNode<"AArch64ISD::UADDV", SDT_AArch64UnaryVec>;
@@ -3422,6 +3424,17 @@ def : Pat<(v1f64 (AArch64frecpe (v1f64 FPR64:$Rn))),
 def : Pat<(v2f64 (AArch64frecpe (v2f64 FPR128:$Rn))),
           (FRECPEv2f64 FPR128:$Rn)>;
 
+def : Pat<(f32 (AArch64frecps (f32 FPR32:$Rn), (f32 FPR32:$Rm))),
+          (FRECPS32 FPR32:$Rn, FPR32:$Rm)>;
+def : Pat<(v2f32 (AArch64frecps (v2f32 V64:$Rn), (v2f32 V64:$Rm))),
+          (FRECPSv2f32 V64:$Rn, V64:$Rm)>;
+def : Pat<(v4f32 (AArch64frecps (v4f32 FPR128:$Rn), (v4f32 FPR128:$Rm))),
+          (FRECPSv4f32 FPR128:$Rn, FPR128:$Rm)>;
+def : Pat<(f64 (AArch64frecps (f64 FPR64:$Rn), (f64 FPR64:$Rm))),
+          (FRECPS64 FPR64:$Rn, FPR64:$Rm)>;
+def : Pat<(v2f64 (AArch64frecps (v2f64 FPR128:$Rn), (v2f64 FPR128:$Rm))),
+          (FRECPSv2f64 FPR128:$Rn, FPR128:$Rm)>;
+
 def : Pat<(f32 (int_aarch64_neon_frecpx (f32 FPR32:$Rn))),
           (FRECPXv1i32 FPR32:$Rn)>;
 def : Pat<(f64 (int_aarch64_neon_frecpx (f64 FPR64:$Rn))),
@@ -3447,6 +3460,17 @@ def : Pat<(v1f64 (AArch64frsqrte (v1f64 FPR64:$Rn))),
 def : Pat<(v2f64 (AArch64frsqrte (v2f64 FPR128:$Rn))),
           (FRSQRTEv2f64 FPR128:$Rn)>;
 
+def : Pat<(f32 (AArch64frsqrts (f32 FPR32:$Rn), (f32 FPR32:$Rm))),
+          (FRSQRTS32 FPR32:$Rn, FPR32:$Rm)>;
+def : Pat<(v2f32 (AArch64frsqrts (v2f32 V64:$Rn), (v2f32 V64:$Rm))),
+          (FRSQRTSv2f32 V64:$Rn, V64:$Rm)>;
+def : Pat<(v4f32 (AArch64frsqrts (v4f32 FPR128:$Rn), (v4f32 FPR128:$Rm))),
+          (FRSQRTSv4f32 FPR128:$Rn, FPR128:$Rm)>;
+def : Pat<(f64 (AArch64frsqrts (f64 FPR64:$Rn), (f64 FPR64:$Rm))),
+          (FRSQRTS64 FPR64:$Rn, FPR64:$Rm)>;
+def : Pat<(v2f64 (AArch64frsqrts (v2f64 FPR128:$Rn), (v2f64 FPR128:$Rm))),
+          (FRSQRTSv2f64 FPR128:$Rn, FPR128:$Rm)>;
+
 // If an integer is about to be converted to a floating point value,
 // just load it on the floating point unit.
 // Here are the patterns for 8 and 16-bits to float.
index 280ef75..38e0fb3 100644 (file)
@@ -16,8 +16,8 @@ define float @frecp1(float %x) #1 {
 
 ; CHECK-LABEL: frecp1:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: frecpe
-; CHECK-NEXT: fmov
+; CHECK-NEXT: frecpe [[R:s[0-7]]]
+; CHECK-NEXT: frecps {{s[0-7](, s[0-7])?}}, [[R]]
 }
 
 define <2 x float> @f2recp0(<2 x float> %x) #0 {
@@ -36,8 +36,8 @@ define <2 x float> @f2recp1(<2 x float> %x) #1 {
 
 ; CHECK-LABEL: f2recp1:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: frecpe
+; CHECK-NEXT: frecpe [[R:v[0-7]\.2s]]
+; CHECK-NEXT: frecps {{v[0-7]\.2s(, v[0-7].2s)?}}, [[R]]
 }
 
 define <4 x float> @f4recp0(<4 x float> %x) #0 {
@@ -56,8 +56,8 @@ define <4 x float> @f4recp1(<4 x float> %x) #1 {
 
 ; CHECK-LABEL: f4recp1:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: frecpe
+; CHECK-NEXT: frecpe [[R:v[0-7]\.4s]]
+; CHECK-NEXT: frecps {{v[0-7]\.4s(, v[0-7].4s)?}}, [[R]]
 }
 
 define <8 x float> @f8recp0(<8 x float> %x) #0 {
@@ -77,9 +77,10 @@ define <8 x float> @f8recp1(<8 x float> %x) #1 {
 
 ; CHECK-LABEL: f8recp1:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: frecpe
-; CHECK: frecpe
+; CHECK-NEXT: frecpe [[RA:v[0-7]\.4s]]
+; CHECK-NEXT: frecpe [[RB:v[0-7]\.4s]]
+; CHECK-NEXT: frecps {{v[0-7]\.4s(, v[0-7].4s)?}}, [[RA]]
+; CHECK: frecps {{v[0-7]\.4s(, v[0-7].4s)?}}, [[RB]]
 }
 
 define double @drecp0(double %x) #0 {
@@ -98,8 +99,8 @@ define double @drecp1(double %x) #1 {
 
 ; CHECK-LABEL: drecp1:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: frecpe
-; CHECK-NEXT: fmov
+; CHECK-NEXT: frecpe [[R:d[0-7]]]
+; CHECK-NEXT: frecps {{d[0-7](, d[0-7])?}}, [[R]]
 }
 
 define <2 x double> @d2recp0(<2 x double> %x) #0 {
@@ -118,8 +119,8 @@ define <2 x double> @d2recp1(<2 x double> %x) #1 {
 
 ; CHECK-LABEL: d2recp1:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: frecpe
+; CHECK-NEXT: frecpe [[R:v[0-7]\.2d]]
+; CHECK-NEXT: frecps {{v[0-7]\.2d(, v[0-7].2d)?}}, [[R]]
 }
 
 define <4 x double> @d4recp0(<4 x double> %x) #0 {
@@ -139,9 +140,10 @@ define <4 x double> @d4recp1(<4 x double> %x) #1 {
 
 ; CHECK-LABEL: d4recp1:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: frecpe
-; CHECK: frecpe
+; CHECK-NEXT: frecpe [[RA:v[0-7]\.2d]]
+; CHECK-NEXT: frecpe [[RB:v[0-7]\.2d]]
+; CHECK-NEXT: frecps {{v[0-7]\.2d(, v[0-7].2d)?}}, [[RA]]
+; CHECK: frecps {{v[0-7]\.2d(, v[0-7].2d)?}}, [[RB]]
 }
 
 attributes #0 = { nounwind "unsafe-fp-math"="true" }
index cd8200b..079562c 100644 (file)
@@ -19,8 +19,10 @@ define float @fsqrt(float %a) #0 {
 
 ; CHECK-LABEL: fsqrt:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: frsqrte
+; CHECK-NEXT: frsqrte [[RA:s[0-7]]]
+; CHECK-NEXT: fmul [[RB:s[0-7]]], [[RA]], [[RA]]
+; CHECK-NEXT: frsqrts {{s[0-7](, s[0-7])?}}, [[RB]]
+; CHECK: fcmp s0, #0
 }
 
 define <2 x float> @f2sqrt(<2 x float> %a) #0 {
@@ -33,9 +35,10 @@ define <2 x float> @f2sqrt(<2 x float> %a) #0 {
 
 ; CHECK-LABEL: f2sqrt:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: mov
-; CHECK-NEXT: frsqrte
+; CHECK-NEXT: frsqrte [[RA:v[0-7]\.2s]]
+; CHECK-NEXT: fmul [[RB:v[0-7]\.2s]], [[RA]], [[RA]]
+; CHECK-NEXT: frsqrts {{v[0-7]\.2s(, v[0-7]\.2s)?}}, [[RB]]
+; CHECK: fcmeq {{v[0-7]\.2s, v0\.2s}}, #0
 }
 
 define <4 x float> @f4sqrt(<4 x float> %a) #0 {
@@ -48,9 +51,10 @@ define <4 x float> @f4sqrt(<4 x float> %a) #0 {
 
 ; CHECK-LABEL: f4sqrt:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: mov
-; CHECK-NEXT: frsqrte
+; CHECK-NEXT: frsqrte [[RA:v[0-7]\.4s]]
+; CHECK-NEXT: fmul [[RB:v[0-7]\.4s]], [[RA]], [[RA]]
+; CHECK-NEXT: frsqrts {{v[0-7]\.4s(, v[0-7]\.4s)?}}, [[RB]]
+; CHECK: fcmeq {{v[0-7]\.4s, v0\.4s}}, #0
 }
 
 define <8 x float> @f8sqrt(<8 x float> %a) #0 {
@@ -64,10 +68,10 @@ define <8 x float> @f8sqrt(<8 x float> %a) #0 {
 
 ; CHECK-LABEL: f8sqrt:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: mov
-; CHECK-NEXT: frsqrte
-; CHECK: frsqrte
+; CHECK-NEXT: frsqrte [[RA:v[0-7]\.4s]]
+; CHECK: fmul [[RB:v[0-7]\.4s]], [[RA]], [[RA]]
+; CHECK: frsqrts {{v[0-7]\.4s(, v[0-7]\.4s)?}}, [[RB]]
+; CHECK: fcmeq {{v[0-7]\.4s, v[0-1]\.4s}}, #0
 }
 
 define double @dsqrt(double %a) #0 {
@@ -80,8 +84,10 @@ define double @dsqrt(double %a) #0 {
 
 ; CHECK-LABEL: dsqrt:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: frsqrte
+; CHECK-NEXT: frsqrte [[RA:d[0-7]]]
+; CHECK-NEXT: fmul [[RB:d[0-7]]], [[RA]], [[RA]]
+; CHECK-NEXT: frsqrts {{d[0-7](, d[0-7])?}}, [[RB]]
+; CHECK: fcmp d0, #0
 }
 
 define <2 x double> @d2sqrt(<2 x double> %a) #0 {
@@ -94,9 +100,10 @@ define <2 x double> @d2sqrt(<2 x double> %a) #0 {
 
 ; CHECK-LABEL: d2sqrt:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: mov
-; CHECK-NEXT: frsqrte
+; CHECK-NEXT: frsqrte [[RA:v[0-7]\.2d]]
+; CHECK-NEXT: fmul [[RB:v[0-7]\.2d]], [[RA]], [[RA]]
+; CHECK-NEXT: frsqrts {{v[0-7]\.2d(, v[0-7]\.2d)?}}, [[RB]]
+; CHECK: fcmeq {{v[0-7]\.2d, v0\.2d}}, #0
 }
 
 define <4 x double> @d4sqrt(<4 x double> %a) #0 {
@@ -110,10 +117,10 @@ define <4 x double> @d4sqrt(<4 x double> %a) #0 {
 
 ; CHECK-LABEL: d4sqrt:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: mov
-; CHECK-NEXT: frsqrte
-; CHECK: frsqrte
+; CHECK-NEXT: frsqrte [[RA:v[0-7]\.2d]]
+; CHECK: fmul [[RB:v[0-7]\.2d]], [[RA]], [[RA]]
+; CHECK: frsqrts {{v[0-7]\.2d(, v[0-7]\.2d)?}}, [[RB]]
+; CHECK: fcmeq {{v[0-7]\.2d, v[0-1]\.2d}}, #0
 }
 
 define float @frsqrt(float %a) #0 {
@@ -127,8 +134,10 @@ define float @frsqrt(float %a) #0 {
 
 ; CHECK-LABEL: frsqrt:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: frsqrte
+; CHECK-NEXT: frsqrte [[RA:s[0-7]]]
+; CHECK-NEXT: fmul [[RB:s[0-7]]], [[RA]], [[RA]]
+; CHECK-NEXT: frsqrts {{s[0-7](, s[0-7])?}}, [[RB]]
+; CHECK-NOT: fcmp {{s[0-7]}}, #0
 }
 
 define <2 x float> @f2rsqrt(<2 x float> %a) #0 {
@@ -142,8 +151,10 @@ define <2 x float> @f2rsqrt(<2 x float> %a) #0 {
 
 ; CHECK-LABEL: f2rsqrt:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: frsqrte
+; CHECK-NEXT: frsqrte [[RA:v[0-7]\.2s]]
+; CHECK-NEXT: fmul [[RB:v[0-7]\.2s]], [[RA]], [[RA]]
+; CHECK-NEXT: frsqrts {{v[0-7]\.2s(, v[0-7]\.2s)?}}, [[RB]]
+; CHECK-NOT: fcmeq {{v[0-7]\.2s, v0\.2s}}, #0
 }
 
 define <4 x float> @f4rsqrt(<4 x float> %a) #0 {
@@ -157,8 +168,10 @@ define <4 x float> @f4rsqrt(<4 x float> %a) #0 {
 
 ; CHECK-LABEL: f4rsqrt:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: frsqrte
+; CHECK-NEXT: frsqrte [[RA:v[0-7]\.4s]]
+; CHECK-NEXT: fmul [[RB:v[0-7]\.4s]], [[RA]], [[RA]]
+; CHECK-NEXT: frsqrts {{v[0-7]\.4s(, v[0-7]\.4s)?}}, [[RB]]
+; CHECK-NOT: fcmeq {{v[0-7]\.4s, v0\.4s}}, #0
 }
 
 define <8 x float> @f8rsqrt(<8 x float> %a) #0 {
@@ -173,9 +186,10 @@ define <8 x float> @f8rsqrt(<8 x float> %a) #0 {
 
 ; CHECK-LABEL: f8rsqrt:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: frsqrte
-; CHECK: frsqrte
+; CHECK-NEXT: frsqrte [[RA:v[0-7]\.4s]]
+; CHECK: fmul [[RB:v[0-7]\.4s]], [[RA]], [[RA]]
+; CHECK: frsqrts {{v[0-7]\.4s(, v[0-7]\.4s)?}}, [[RB]]
+; CHECK-NOT: fcmeq {{v[0-7]\.4s, v0\.4s}}, #0
 }
 
 define double @drsqrt(double %a) #0 {
@@ -189,8 +203,10 @@ define double @drsqrt(double %a) #0 {
 
 ; CHECK-LABEL: drsqrt:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: frsqrte
+; CHECK-NEXT: frsqrte [[RA:d[0-7]]]
+; CHECK-NEXT: fmul [[RB:d[0-7]]], [[RA]], [[RA]]
+; CHECK-NEXT: frsqrts {{d[0-7](, d[0-7])?}}, [[RB]]
+; CHECK-NOT: fcmp d0, #0
 }
 
 define <2 x double> @d2rsqrt(<2 x double> %a) #0 {
@@ -204,8 +220,10 @@ define <2 x double> @d2rsqrt(<2 x double> %a) #0 {
 
 ; CHECK-LABEL: d2rsqrt:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: frsqrte
+; CHECK-NEXT: frsqrte [[RA:v[0-7]\.2d]]
+; CHECK-NEXT: fmul [[RB:v[0-7]\.2d]], [[RA]], [[RA]]
+; CHECK-NEXT: frsqrts {{v[0-7]\.2d(, v[0-7]\.2d)?}}, [[RB]]
+; CHECK-NOT: fcmeq {{v[0-7]\.2d, v0\.2d}}, #0
 }
 
 define <4 x double> @d4rsqrt(<4 x double> %a) #0 {
@@ -220,9 +238,10 @@ define <4 x double> @d4rsqrt(<4 x double> %a) #0 {
 
 ; CHECK-LABEL: d4rsqrt:
 ; CHECK-NEXT: BB#0
-; CHECK-NEXT: fmov
-; CHECK-NEXT: frsqrte
-; CHECK: frsqrte
+; CHECK-NEXT: frsqrte [[RA:v[0-7]\.2d]]
+; CHECK: fmul [[RB:v[0-7]\.2d]], [[RA]], [[RA]]
+; CHECK: frsqrts {{v[0-7]\.2d(, v[0-7]\.2d)?}}, [[RB]]
+; CHECK-NOT: fcmeq {{v[0-7]\.2d, v0\.2d}}, #0
 }
 
 attributes #0 = { nounwind "unsafe-fp-math"="true" }