CHECK_EQ(input->getType(), vsl.vector_type());
// This implements the same rational interpolant as implemented in Eigen3.
- llvm::Value* input_clamped = vsl.Clamp(input, /*low=*/-9.0, /*high=*/9.0);
+ llvm::Value* input_clamped =
+ vsl.Clamp(input, /*low=*/GetIeeeF32(-9.0), /*high=*/GetIeeeF32(9.0));
std::array<float, 7> numerator_coeffs{
-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
4.89352518554385e-03f};
llvm::Value* input_squared = vsl.Mul(input_clamped, input_clamped);
- llvm::Value* numerator = vsl.SplatFloat(numerator_coeffs[0]);
+ llvm::Value* numerator = vsl.SplatFloat(GetIeeeF32(numerator_coeffs[0]));
for (int i = 1; i < numerator_coeffs.size(); i++) {
- numerator = vsl.MulAdd(input_squared, numerator, numerator_coeffs[i]);
+ numerator =
+ vsl.MulAdd(input_squared, numerator, GetIeeeF32(numerator_coeffs[i]));
}
numerator = vsl.Mul(input_clamped, numerator);
- llvm::Value* denominator = vsl.SplatFloat(denominator_coeffs[0]);
+ llvm::Value* denominator = vsl.SplatFloat(GetIeeeF32(denominator_coeffs[0]));
for (int i = 1; i < denominator_coeffs.size(); i++) {
- denominator = vsl.MulAdd(input_squared, denominator, denominator_coeffs[i]);
+ denominator = vsl.MulAdd(input_squared, denominator,
+ GetIeeeF32(denominator_coeffs[i]));
}
llvm::Value* result = vsl.Div(numerator, denominator);
// This implements the same polynomial approximation as implemented in Eigen3.
- const float exp_hi = 88.3762626647950;
- const float exp_lo = -88.3762626647949;
+ const llvm::APFloat half = GetIeeeF32(0.5);
+ const llvm::APFloat one = GetIeeeF32(1.0);
+
+ const llvm::APFloat exp_hi = GetIeeeF32(88.3762626647950);
+ const llvm::APFloat exp_lo = GetIeeeF32(-88.3762626647949);
- const float cephes_LOG2EF = 1.44269504088896341;
- const float cephes_exp_C1 = 0.693359375;
- const float cephes_exp_C2 = -2.12194440e-4;
+ const llvm::APFloat cephes_LOG2EF = GetIeeeF32(1.44269504088896341);
+ const llvm::APFloat cephes_exp_C1 = GetIeeeF32(0.693359375);
+ const llvm::APFloat cephes_exp_C2 = GetIeeeF32(-2.12194440e-4);
- const float cephes_exp_p0 = 1.9875691500E-4;
- const float cephes_exp_p1 = 1.3981999507E-3;
- const float cephes_exp_p2 = 8.3334519073E-3;
- const float cephes_exp_p3 = 4.1665795894E-2;
- const float cephes_exp_p4 = 1.6666665459E-1;
- const float cephes_exp_p5 = 5.0000001201E-1;
+ const llvm::APFloat cephes_exp_p0 = GetIeeeF32(1.9875691500E-4);
+ const llvm::APFloat cephes_exp_p1 = GetIeeeF32(1.3981999507E-3);
+ const llvm::APFloat cephes_exp_p2 = GetIeeeF32(8.3334519073E-3);
+ const llvm::APFloat cephes_exp_p3 = GetIeeeF32(4.1665795894E-2);
+ const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1);
+ const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1);
llvm::Value* input = &*vector_exp_function->arg_begin();
llvm::Value* input_clamped =
vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi);
- llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, 0.5));
+ llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, half));
llvm::Value* tmp = vsl.Mul(cephes_exp_C1, fx);
llvm::Value* z = vsl.Mul(cephes_exp_C2, fx);
llvm::Value* x = vsl.Sub(input_clamped, tmp);
y = vsl.MulAdd(y, x, cephes_exp_p4);
y = vsl.MulAdd(y, x, cephes_exp_p5);
y = vsl.MulAdd(y, z, x);
- y = vsl.Add(1.0f, y);
+ y = vsl.Add(one, y);
// VectorSupportLibrary (intentionally) can't juggle more than one type at a
// time so drop down to IRBuilder for this bit.
llvm::Value* input = &*vector_log_function->arg_begin();
VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "log_f32");
- const float half = 0.5;
+ const llvm::APFloat half = GetIeeeF32(0.5);
+ const llvm::APFloat one = GetIeeeF32(1.0);
// This implements the same polynomial approximation as implemented in Eigen3.
// Returns NaN for x < 0, -INF for x = 0
- const float cephes_SQRTHF = 0.707106781186547524;
- const float cephes_log_p0 = 7.0376836292E-2;
- const float cephes_log_p1 = -1.1514610310E-1;
- const float cephes_log_p2 = 1.1676998740E-1;
- const float cephes_log_p3 = -1.2420140846E-1;
- const float cephes_log_p4 = +1.4249322787E-1;
- const float cephes_log_p5 = -1.6668057665E-1;
- const float cephes_log_p6 = +2.0000714765E-1;
- const float cephes_log_p7 = -2.4999993993E-1;
- const float cephes_log_p8 = +3.3333331174E-1;
- const float cephes_log_q1 = -2.12194440e-4;
- const float cephes_log_q2 = 0.693359375;
+ const llvm::APFloat cephes_SQRTHF = GetIeeeF32(0.707106781186547524);
+ const llvm::APFloat cephes_log_p0 = GetIeeeF32(7.0376836292E-2);
+ const llvm::APFloat cephes_log_p1 = GetIeeeF32(-1.1514610310E-1);
+ const llvm::APFloat cephes_log_p2 = GetIeeeF32(1.1676998740E-1);
+ const llvm::APFloat cephes_log_p3 = GetIeeeF32(-1.2420140846E-1);
+ const llvm::APFloat cephes_log_p4 = GetIeeeF32(+1.4249322787E-1);
+ const llvm::APFloat cephes_log_p5 = GetIeeeF32(-1.6668057665E-1);
+ const llvm::APFloat cephes_log_p6 = GetIeeeF32(+2.0000714765E-1);
+ const llvm::APFloat cephes_log_p7 = GetIeeeF32(-2.4999993993E-1);
+ const llvm::APFloat cephes_log_p8 = GetIeeeF32(+3.3333331174E-1);
+ const llvm::APFloat cephes_log_q1 = GetIeeeF32(-2.12194440e-4);
+ const llvm::APFloat cephes_log_q2 = GetIeeeF32(0.693359375);
// The smallest non denormalized float number.
- const float min_norm_pos = tensorflow::bit_cast<float, int32>(0x00800000);
- const float minus_inf = tensorflow::bit_cast<float, int32>(0xff800000);
-
- // NB! This number is denormal and since TF sets the denormals-are-zero flag
- // (and if TF didn't, -ffast-math would) trying to operate on this float using
- // C++ operations (including, for instance, implicit conversion to double)
- // will coerce this to zero.
- const float inv_mant_mask = tensorflow::bit_cast<float, int32>(~0x7f800000);
+ const llvm::APFloat min_norm_pos = GetIeeeF32FromBitwiseRep(0x00800000);
+ const llvm::APFloat minus_inf = GetIeeeF32FromBitwiseRep(0xff800000);
+ const llvm::APFloat inv_mant_mask = GetIeeeF32FromBitwiseRep(~0x7f800000);
// invalid_mask is set if x is negative or NaN (and therefore output
// must be NaN).
emm0 = ir_builder.CreateSub(emm0, vector_constant_0x7f);
llvm::Value* e =
- vsl.Add(1.0f, ir_builder.CreateSIToFP(emm0, vsl.vector_type()));
+ vsl.Add(one, ir_builder.CreateSIToFP(emm0, vsl.vector_type()));
// part2:
// if( x < SQRTHF ) {
// } else { x = x - 1.0; }
llvm::Value* mask = vsl.FCmpOLTMask(input, cephes_SQRTHF);
llvm::Value* tmp = vsl.FloatAnd(input, mask);
- input = vsl.Sub(input, 1.0);
- e = vsl.Sub(e, vsl.FloatAnd(mask, 1.0));
+ input = vsl.Sub(input, one);
+ e = vsl.Sub(e, vsl.FloatAnd(mask, one));
input = vsl.Add(input, tmp);
llvm::Value* x2 = vsl.Mul(input, input);
namespace xla {
namespace cpu {
+
+// Simple wrappers around llvm::APFloat::APFloat to make the calling code more
+// obvious.
+
+inline llvm::APFloat GetIeeeF32(float f) { return llvm::APFloat(f); }
+inline llvm::APFloat GetIeeeF32FromBitwiseRep(int32 bitwise_value) {
+ return llvm::APFloat(llvm::APFloat::IEEEsingle(),
+ llvm::APInt(/*numBits=*/32, /*val=*/bitwise_value));
+}
+
// A thin wrapper around llvm_util.h to make code generating vector math flow
// more readable.
class VectorSupportLibrary {
llvm::Value* Mul(int64 lhs, llvm::Value* rhs) {
return Mul(ir_builder()->getInt64(lhs), rhs);
}
- llvm::Value* Mul(float lhs, llvm::Value* rhs) {
+ llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) {
return Mul(GetConstantFloat(rhs->getType(), lhs), rhs);
}
+ // If your call resolved to these then you probably wanted the versions taking
+ // APFloat.
+ llvm::Value* Mul(double lhs, llvm::Value* rhs) = delete;
+ llvm::Value* Mul(float lhs, llvm::Value* rhs) = delete;
+
llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* Add(int64 lhs, llvm::Value* rhs) {
return Add(ir_builder()->getInt64(lhs), rhs);
}
- llvm::Value* Add(float lhs, llvm::Value* rhs) {
+ llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) {
return Add(GetConstantFloat(rhs->getType(), lhs), rhs);
}
+ // If your call resolved to these then you probably wanted the versions taking
+ // APFloat.
+ llvm::Value* Add(double lhs, llvm::Value* rhs) = delete;
+ llvm::Value* Add(float lhs, llvm::Value* rhs) = delete;
+
llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs);
- llvm::Value* Sub(llvm::Value* lhs, float rhs) {
+ llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) {
return Sub(lhs, GetConstantFloat(lhs->getType(), rhs));
}
llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs);
- llvm::Value* Max(float lhs, llvm::Value* rhs) {
+ llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs) {
return Max(GetConstantFloat(rhs->getType(), lhs), rhs);
}
llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs);
return Add(c, Mul(a, b));
}
- llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, float c) {
+ llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, const llvm::APFloat& c) {
return Add(GetConstantFloat(vector_type(), c), Mul(a, b));
}
- llvm::Value* MulAdd(llvm::Value* a, float b, float c) {
+ llvm::Value* MulAdd(llvm::Value* a, const llvm::APFloat& b,
+ const llvm::APFloat& c) {
return Add(GetConstantFloat(a->getType(), c),
Mul(a, GetConstantFloat(a->getType(), b)));
}
llvm::Value* Floor(llvm::Value* a);
- llvm::Value* Clamp(llvm::Value* a, float low, float high);
- llvm::Value* SplatFloat(float d) {
+ llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low,
+ const llvm::APFloat& high);
+ llvm::Value* SplatFloat(const llvm::APFloat& d) {
return GetConstantFloat(vector_type(), d);
}
llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs);
- llvm::Value* FCmpOLTMask(llvm::Value* lhs, float rhs) {
+ llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs));
}
// generating predicates above this type system oddity makes the kernel IR
// generation code less cluttered.
llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs);
- llvm::Value* FloatAnd(llvm::Value* lhs, float rhs) {
+ llvm::Value* FloatAnd(llvm::Value* lhs, const llvm::APFloat& rhs) {
return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs));
}
llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs);
- llvm::Value* FloatOr(llvm::Value* lhs, float rhs) {
+ llvm::Value* FloatOr(llvm::Value* lhs, const llvm::APFloat& rhs) {
return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs));
}
llvm::Value* FloatNot(llvm::Value* lhs);
}
llvm::Value* BroadcastScalar(llvm::Value* x);
- llvm::Value* BroadcastScalar(float d) {
+ llvm::Value* BroadcastScalar(const llvm::APFloat& d) {
return BroadcastScalar(GetConstantFloat(scalar_type(), d));
}
llvm::Type* IntegerTypeForFloatSize(bool vector);
llvm::Value* I1ToFloat(llvm::Value* i1);
- llvm::Value* GetConstantFloat(llvm::Type* type, float f) {
- llvm::Constant* scalar_value =
- llvm::ConstantFP::get(type->getContext(), llvm::APFloat(f));
+ llvm::Value* GetConstantFloat(llvm::Type* type, const llvm::APFloat& f) {
+ llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f);
if (llvm::isa<llvm::VectorType>(type)) {
return llvm::ConstantVector::getSplat(vector_size(), scalar_value);
}