Improve type safety around float constants
authorSanjoy Das <sanjoy@google.com>
Wed, 14 Feb 2018 00:55:54 +0000 (16:55 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Feb 2018 00:59:28 +0000 (16:59 -0800)
Instead of passing floating point constants to the vector support library as
compiler-side floats, pass them as APFloats instead.  This reduces the duration
during which these constants are semantically represented as floats on the host
side and are subject to fast-math-like behavior.  This is especially important
in cases where the exact bit representation of the floating point constant is
significant, but also makes progress towards ensuring that e.g. build XLA with
-ffast-math does not change the IR we generate.

PiperOrigin-RevId: 185611301

tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc
tensorflow/compiler/xla/service/cpu/vector_support_library.cc
tensorflow/compiler/xla/service/cpu/vector_support_library.h

index ee213e05d1576317430656409cc4cf5f6e38fea4..2e5cc96098241415b82f225afc81981f3e1069e0 100644 (file)
@@ -63,7 +63,8 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
   CHECK_EQ(input->getType(), vsl.vector_type());
 
   // This implements the same rational interpolant as implemented in Eigen3.
-  llvm::Value* input_clamped = vsl.Clamp(input, /*low=*/-9.0, /*high=*/9.0);
+  llvm::Value* input_clamped =
+      vsl.Clamp(input, /*low=*/GetIeeeF32(-9.0), /*high=*/GetIeeeF32(9.0));
 
   std::array<float, 7> numerator_coeffs{
       -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
@@ -75,16 +76,18 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
       4.89352518554385e-03f};
 
   llvm::Value* input_squared = vsl.Mul(input_clamped, input_clamped);
-  llvm::Value* numerator = vsl.SplatFloat(numerator_coeffs[0]);
+  llvm::Value* numerator = vsl.SplatFloat(GetIeeeF32(numerator_coeffs[0]));
   for (int i = 1; i < numerator_coeffs.size(); i++) {
-    numerator = vsl.MulAdd(input_squared, numerator, numerator_coeffs[i]);
+    numerator =
+        vsl.MulAdd(input_squared, numerator, GetIeeeF32(numerator_coeffs[i]));
   }
 
   numerator = vsl.Mul(input_clamped, numerator);
 
-  llvm::Value* denominator = vsl.SplatFloat(denominator_coeffs[0]);
+  llvm::Value* denominator = vsl.SplatFloat(GetIeeeF32(denominator_coeffs[0]));
   for (int i = 1; i < denominator_coeffs.size(); i++) {
-    denominator = vsl.MulAdd(input_squared, denominator, denominator_coeffs[i]);
+    denominator = vsl.MulAdd(input_squared, denominator,
+                             GetIeeeF32(denominator_coeffs[i]));
   }
 
   llvm::Value* result = vsl.Div(numerator, denominator);
@@ -119,24 +122,27 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module,
 
   // This implements the same polynomial approximation as implemented in Eigen3.
 
-  const float exp_hi = 88.3762626647950;
-  const float exp_lo = -88.3762626647949;
+  const llvm::APFloat half = GetIeeeF32(0.5);
+  const llvm::APFloat one = GetIeeeF32(1.0);
+
+  const llvm::APFloat exp_hi = GetIeeeF32(88.3762626647950);
+  const llvm::APFloat exp_lo = GetIeeeF32(-88.3762626647949);
 
-  const float cephes_LOG2EF = 1.44269504088896341;
-  const float cephes_exp_C1 = 0.693359375;
-  const float cephes_exp_C2 = -2.12194440e-4;
+  const llvm::APFloat cephes_LOG2EF = GetIeeeF32(1.44269504088896341);
+  const llvm::APFloat cephes_exp_C1 = GetIeeeF32(0.693359375);
+  const llvm::APFloat cephes_exp_C2 = GetIeeeF32(-2.12194440e-4);
 
-  const float cephes_exp_p0 = 1.9875691500E-4;
-  const float cephes_exp_p1 = 1.3981999507E-3;
-  const float cephes_exp_p2 = 8.3334519073E-3;
-  const float cephes_exp_p3 = 4.1665795894E-2;
-  const float cephes_exp_p4 = 1.6666665459E-1;
-  const float cephes_exp_p5 = 5.0000001201E-1;
+  const llvm::APFloat cephes_exp_p0 = GetIeeeF32(1.9875691500E-4);
+  const llvm::APFloat cephes_exp_p1 = GetIeeeF32(1.3981999507E-3);
+  const llvm::APFloat cephes_exp_p2 = GetIeeeF32(8.3334519073E-3);
+  const llvm::APFloat cephes_exp_p3 = GetIeeeF32(4.1665795894E-2);
+  const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1);
+  const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1);
 
   llvm::Value* input = &*vector_exp_function->arg_begin();
   llvm::Value* input_clamped =
       vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi);
-  llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, 0.5));
+  llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, half));
   llvm::Value* tmp = vsl.Mul(cephes_exp_C1, fx);
   llvm::Value* z = vsl.Mul(cephes_exp_C2, fx);
   llvm::Value* x = vsl.Sub(input_clamped, tmp);
@@ -149,7 +155,7 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module,
   y = vsl.MulAdd(y, x, cephes_exp_p4);
   y = vsl.MulAdd(y, x, cephes_exp_p5);
   y = vsl.MulAdd(y, z, x);
-  y = vsl.Add(1.0f, y);
+  y = vsl.Add(one, y);
 
   // VectorSupportLibrary (intentionally) can't juggle more than one type at a
   // time so drop down to IRBuilder for this bit.
@@ -198,32 +204,28 @@ llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module,
   llvm::Value* input = &*vector_log_function->arg_begin();
   VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "log_f32");
 
-  const float half = 0.5;
+  const llvm::APFloat half = GetIeeeF32(0.5);
+  const llvm::APFloat one = GetIeeeF32(1.0);
 
   // This implements the same polynomial approximation as implemented in Eigen3.
   // Returns NaN for x < 0, -INF for x = 0
-  const float cephes_SQRTHF = 0.707106781186547524;
-  const float cephes_log_p0 = 7.0376836292E-2;
-  const float cephes_log_p1 = -1.1514610310E-1;
-  const float cephes_log_p2 = 1.1676998740E-1;
-  const float cephes_log_p3 = -1.2420140846E-1;
-  const float cephes_log_p4 = +1.4249322787E-1;
-  const float cephes_log_p5 = -1.6668057665E-1;
-  const float cephes_log_p6 = +2.0000714765E-1;
-  const float cephes_log_p7 = -2.4999993993E-1;
-  const float cephes_log_p8 = +3.3333331174E-1;
-  const float cephes_log_q1 = -2.12194440e-4;
-  const float cephes_log_q2 = 0.693359375;
+  const llvm::APFloat cephes_SQRTHF = GetIeeeF32(0.707106781186547524);
+  const llvm::APFloat cephes_log_p0 = GetIeeeF32(7.0376836292E-2);
+  const llvm::APFloat cephes_log_p1 = GetIeeeF32(-1.1514610310E-1);
+  const llvm::APFloat cephes_log_p2 = GetIeeeF32(1.1676998740E-1);
+  const llvm::APFloat cephes_log_p3 = GetIeeeF32(-1.2420140846E-1);
+  const llvm::APFloat cephes_log_p4 = GetIeeeF32(+1.4249322787E-1);
+  const llvm::APFloat cephes_log_p5 = GetIeeeF32(-1.6668057665E-1);
+  const llvm::APFloat cephes_log_p6 = GetIeeeF32(+2.0000714765E-1);
+  const llvm::APFloat cephes_log_p7 = GetIeeeF32(-2.4999993993E-1);
+  const llvm::APFloat cephes_log_p8 = GetIeeeF32(+3.3333331174E-1);
+  const llvm::APFloat cephes_log_q1 = GetIeeeF32(-2.12194440e-4);
+  const llvm::APFloat cephes_log_q2 = GetIeeeF32(0.693359375);
 
   // The smallest non denormalized float number.
-  const float min_norm_pos = tensorflow::bit_cast<float, int32>(0x00800000);
-  const float minus_inf = tensorflow::bit_cast<float, int32>(0xff800000);
-
-  // NB! This number is denormal and since TF sets the denormals-are-zero flag
-  // (and if TF didn't, -ffast-math would) trying to operate on this float using
-  // C++ operations (including, for instance, implicit conversion to double)
-  // will coerce this to zero.
-  const float inv_mant_mask = tensorflow::bit_cast<float, int32>(~0x7f800000);
+  const llvm::APFloat min_norm_pos = GetIeeeF32FromBitwiseRep(0x00800000);
+  const llvm::APFloat minus_inf = GetIeeeF32FromBitwiseRep(0xff800000);
+  const llvm::APFloat inv_mant_mask = GetIeeeF32FromBitwiseRep(~0x7f800000);
 
   // invalid_mask is set if x is negative or NaN (and therefore output
   // must be NaN).
@@ -251,7 +253,7 @@ llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module,
 
   emm0 = ir_builder.CreateSub(emm0, vector_constant_0x7f);
   llvm::Value* e =
-      vsl.Add(1.0f, ir_builder.CreateSIToFP(emm0, vsl.vector_type()));
+      vsl.Add(one, ir_builder.CreateSIToFP(emm0, vsl.vector_type()));
 
   // part2:
   //   if( x < SQRTHF ) {
@@ -260,8 +262,8 @@ llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module,
   //   } else { x = x - 1.0; }
   llvm::Value* mask = vsl.FCmpOLTMask(input, cephes_SQRTHF);
   llvm::Value* tmp = vsl.FloatAnd(input, mask);
-  input = vsl.Sub(input, 1.0);
-  e = vsl.Sub(e, vsl.FloatAnd(mask, 1.0));
+  input = vsl.Sub(input, one);
+  e = vsl.Sub(e, vsl.FloatAnd(mask, one));
   input = vsl.Add(input, tmp);
 
   llvm::Value* x2 = vsl.Mul(input, input);
index 0596e80df48ef2e926e0de1626fb7e4f22c50bab..150db1cb6edec1af6724a8bca6a5f6272f1a7416 100644 (file)
@@ -103,11 +103,12 @@ llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) {
   }
 }
 
-llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a, float low,
-                                         float high) {
+llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a,
+                                         const llvm::APFloat& low,
+                                         const llvm::APFloat& high) {
   AssertCorrectTypes({a});
   llvm::Type* type = a->getType();
-  CHECK_LT(low, high);
+  CHECK(low.compare(high) == llvm::APFloat::cmpLessThan);
   CHECK(scalar_type_->isFloatingPointTy());
   return llvm_ir::EmitFloatMin(
       llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), ir_builder_),
index 010c82f0cf67eafe5ef7b114de97aaaf1546d4f5..6479bf76aab581ae3ec2923d98dab53720cab203 100644 (file)
@@ -26,6 +26,16 @@ limitations under the License.
 
 namespace xla {
 namespace cpu {
+
+// Simple wrappers around llvm::APFloat::APFloat to make the calling code more
+// obvious.
+
+inline llvm::APFloat GetIeeeF32(float f) { return llvm::APFloat(f); }
+inline llvm::APFloat GetIeeeF32FromBitwiseRep(int32 bitwise_value) {
+  return llvm::APFloat(llvm::APFloat::IEEEsingle(),
+                       llvm::APInt(/*numBits=*/32, /*val=*/bitwise_value));
+}
+
 // A thin wrapper around llvm_util.h to make code generating vector math flow
 // more readable.
 class VectorSupportLibrary {
@@ -41,24 +51,34 @@ class VectorSupportLibrary {
   llvm::Value* Mul(int64 lhs, llvm::Value* rhs) {
     return Mul(ir_builder()->getInt64(lhs), rhs);
   }
-  llvm::Value* Mul(float lhs, llvm::Value* rhs) {
+  llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) {
     return Mul(GetConstantFloat(rhs->getType(), lhs), rhs);
   }
 
+  // If your call resolved to these then you probably wanted the versions taking
+  // APFloat.
+  llvm::Value* Mul(double lhs, llvm::Value* rhs) = delete;
+  llvm::Value* Mul(float lhs, llvm::Value* rhs) = delete;
+
   llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
   llvm::Value* Add(int64 lhs, llvm::Value* rhs) {
     return Add(ir_builder()->getInt64(lhs), rhs);
   }
-  llvm::Value* Add(float lhs, llvm::Value* rhs) {
+  llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) {
     return Add(GetConstantFloat(rhs->getType(), lhs), rhs);
   }
 
+  // If your call resolved to these then you probably wanted the versions taking
+  // APFloat.
+  llvm::Value* Add(double lhs, llvm::Value* rhs) = delete;
+  llvm::Value* Add(float lhs, llvm::Value* rhs) = delete;
+
   llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs);
-  llvm::Value* Sub(llvm::Value* lhs, float rhs) {
+  llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) {
     return Sub(lhs, GetConstantFloat(lhs->getType(), rhs));
   }
   llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs);
-  llvm::Value* Max(float lhs, llvm::Value* rhs) {
+  llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs) {
     return Max(GetConstantFloat(rhs->getType(), lhs), rhs);
   }
   llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs);
@@ -67,19 +87,21 @@ class VectorSupportLibrary {
     return Add(c, Mul(a, b));
   }
 
-  llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, float c) {
+  llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, const llvm::APFloat& c) {
     return Add(GetConstantFloat(vector_type(), c), Mul(a, b));
   }
 
-  llvm::Value* MulAdd(llvm::Value* a, float b, float c) {
+  llvm::Value* MulAdd(llvm::Value* a, const llvm::APFloat& b,
+                      const llvm::APFloat& c) {
     return Add(GetConstantFloat(a->getType(), c),
                Mul(a, GetConstantFloat(a->getType(), b)));
   }
 
   llvm::Value* Floor(llvm::Value* a);
 
-  llvm::Value* Clamp(llvm::Value* a, float low, float high);
-  llvm::Value* SplatFloat(float d) {
+  llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low,
+                     const llvm::APFloat& high);
+  llvm::Value* SplatFloat(const llvm::APFloat& d) {
     return GetConstantFloat(vector_type(), d);
   }
 
@@ -93,7 +115,7 @@ class VectorSupportLibrary {
   llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs);
   llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs);
   llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs);
-  llvm::Value* FCmpOLTMask(llvm::Value* lhs, float rhs) {
+  llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
     return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs));
   }
 
@@ -102,11 +124,11 @@ class VectorSupportLibrary {
   // generating predicates above this type system oddity makes the kernel IR
   // generation code less cluttered.
   llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs);
-  llvm::Value* FloatAnd(llvm::Value* lhs, float rhs) {
+  llvm::Value* FloatAnd(llvm::Value* lhs, const llvm::APFloat& rhs) {
     return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs));
   }
   llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs);
-  llvm::Value* FloatOr(llvm::Value* lhs, float rhs) {
+  llvm::Value* FloatOr(llvm::Value* lhs, const llvm::APFloat& rhs) {
     return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs));
   }
   llvm::Value* FloatNot(llvm::Value* lhs);
@@ -115,7 +137,7 @@ class VectorSupportLibrary {
   }
 
   llvm::Value* BroadcastScalar(llvm::Value* x);
-  llvm::Value* BroadcastScalar(float d) {
+  llvm::Value* BroadcastScalar(const llvm::APFloat& d) {
     return BroadcastScalar(GetConstantFloat(scalar_type(), d));
   }
 
@@ -238,9 +260,8 @@ class VectorSupportLibrary {
 
   llvm::Type* IntegerTypeForFloatSize(bool vector);
   llvm::Value* I1ToFloat(llvm::Value* i1);
-  llvm::Value* GetConstantFloat(llvm::Type* type, float f) {
-    llvm::Constant* scalar_value =
-        llvm::ConstantFP::get(type->getContext(), llvm::APFloat(f));
+  llvm::Value* GetConstantFloat(llvm::Type* type, const llvm::APFloat& f) {
+    llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f);
     if (llvm::isa<llvm::VectorType>(type)) {
       return llvm::ConstantVector::getSplat(vector_size(), scalar_value);
     }