[TF:XLA] Make softplus more accurate
The softplus function computes log(exp(x) + 1).
We computed it this way but with special cases to handle underflow and
overflow.
This was done by comparing the input against a quantity with the
magnitude 13.
94238515. Note that this quantity is not representable as a single
precision float and is instead rounded to 13.9423847.
If softplus would overflow, it will be approximated as x.
If softplus would underflow, it will be approximated as exp(x).
Unfortunately, this can provide inaccurate results for negative floats close to
the threshold.
For example: consider x = -13.
9274826049805. softmax(x) is ~8.
94068849e-7;
rounded to the nearest single precision float, this is 8.940689e-7.
In this case, x is quite close to the underflow threshold but not close enough
to be approximated by exp(x) == 8.
94069273e-7.
Rather, it gets calculated using the canonical definition of softmax and comes
to 8.
34464686e-7.
This result comes out to be wrong by 1,048,568 ULPs.
Instead, we can compute it the way one would compute LogSumExp(x, 0):
max(x, 0) + log(exp(x - max(x, 0)) + exp(0 - max(x, 0)))
When x is positive, this is:
x + log(exp(0) + exp(-x))
When x is negative, this is:
log(exp(x) + exp(0))
When x is 0, this is:
log(exp(0) + exp(0))
exp(0) evaluates to 1 which gives us:
if x is positive, x + log(1 + exp(-x))
if x is negative, log(exp(x) + 1)
if x is zero, log(2)
These three cases can be combined like so:
max(x, 0) + log(exp(-abs(x)) + 1)
Further, we can increase the fidelity of the log calculation by using log1p:
max(x, 0) + log1p(exp(-abs(x)))
This computation naturally handles underflow and overflow while also providing
more numerically accurate results for a few small, positive, floating point
values.
PiperOrigin-RevId:
196782814