[TF:XLA] Make softplus more accurate
authorDavid Majnemer <majnemer@google.com>
Wed, 16 May 2018 06:00:32 +0000 (23:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 16 May 2018 06:03:01 +0000 (23:03 -0700)
commit0fd579e6b7e17e207effb5d931128e10e2e0d6fc
tree36cd68526244d4e16919a8ad54b654368bc0c40c
parenta09c0c8858e3539a80bbb20677153b1950f64fb7
[TF:XLA] Make softplus more accurate

The softplus function computes log(exp(x) + 1).

We computed it this way but with special cases to handle underflow and
overflow.
This was done by comparing the input against a quantity with the
magnitude 13.94238515. Note that this quantity is not representable as a single
precision float and is instead rounded to 13.9423847.

If softplus would overflow, it will be approximated as x.
If softplus would underflow, it will be approximated as exp(x).

Unfortunately, this can provide inaccurate results for negative floats close to
the threshold.

For example: consider x = -13.9274826049805. softmax(x) is ~8.94068849e-7;
rounded to the nearest single precision float, this is 8.940689e-7.

In this case, x is quite close to the underflow threshold but not close enough
to be approximated by exp(x) == 8.94069273e-7.
Rather, it gets calculated using the canonical definition of softmax and comes
to 8.34464686e-7.

This result comes out to be wrong by 1,048,568 ULPs.

Instead, we can compute it the way one would compute LogSumExp(x, 0):
  max(x, 0) + log(exp(x - max(x, 0)) + exp(0 - max(x, 0)))

When x is positive, this is:
  x + log(exp(0) + exp(-x))

When x is negative, this is:
  log(exp(x) + exp(0))

When x is 0, this is:
  log(exp(0) + exp(0))

exp(0) evaluates to 1 which gives us:
  if x is positive, x + log(1 + exp(-x))
  if x is negative, log(exp(x) + 1)
  if x is zero,     log(2)

These three cases can be combined like so:
  max(x, 0) + log(exp(-abs(x)) + 1)

Further, we can increase the fidelity of the log calculation by using log1p:
  max(x, 0) + log1p(exp(-abs(x)))

This computation naturally handles underflow and overflow while also providing
more numerically accurate results for a few small, positive, floating point
values.

PiperOrigin-RevId: 196782814
tensorflow/compiler/tests/unary_ops_test.py
tensorflow/compiler/tf2xla/kernels/unary_ops.cc