From 0fd579e6b7e17e207effb5d931128e10e2e0d6fc Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Tue, 15 May 2018 23:00:32 -0700 Subject: [PATCH] [TF:XLA] Make softplus more accurate The softplus function computes log(exp(x) + 1). We computed it this way but with special cases to handle underflow and overflow. This was done by comparing the input against a quantity with the magnitude 13.94238515. Note that this quantity is not representable as a single precision float and is instead rounded to 13.9423847. If softplus would overflow, it will be approximated as x. If softplus would underflow, it will be approximated as exp(x). Unfortunately, this can provide inaccurate results for negative floats close to the threshold. For example: consider x = -13.9274826049805. softmax(x) is ~8.94068849e-7; rounded to the nearest single precision float, this is 8.940689e-7. In this case, x is quite close to the underflow threshold but not close enough to be approximated by exp(x) == 8.94069273e-7. Rather, it gets calculated using the canonical definition of softmax and comes to 8.34464686e-7. This result comes out to be wrong by 1,048,568 ULPs. Instead, we can compute it the way one would compute LogSumExp(x, 0): max(x, 0) + log(exp(x - max(x, 0)) + exp(0 - max(x, 0))) When x is positive, this is: x + log(exp(0) + exp(-x)) When x is negative, this is: log(exp(x) + exp(0)) When x is 0, this is: log(exp(0) + exp(0)) exp(0) evaluates to 1 which gives us: if x is positive, x + log(1 + exp(-x)) if x is negative, log(exp(x) + 1) if x is zero, log(2) These three cases can be combined like so: max(x, 0) + log(exp(-abs(x)) + 1) Further, we can increase the fidelity of the log calculation by using log1p: max(x, 0) + log1p(exp(-abs(x))) This computation naturally handles underflow and overflow while also providing more numerically accurate results for a few small, positive, floating point values. PiperOrigin-RevId: 196782814 --- tensorflow/compiler/tests/unary_ops_test.py | 4 +++- tensorflow/compiler/tf2xla/kernels/unary_ops.cc | 29 ++++++++++--------------- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 57a1d9b..52633f6 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -793,7 +793,9 @@ class UnaryOpsTest(XLATestCase): zero = np.asarray(0).astype(dtype) expected = np.logaddexp(zero, features) self._assertOpOutputMatchesExpected( - nn_ops.softplus, features, expected=expected) + nn_ops.softplus, features, expected=expected, + rtol=1e-6, + atol=9.1e-6) def testSoftplus(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 3f6e218..71a9fd0 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -158,24 +158,17 @@ XLAJIT_MAKE_UNARY(Sinh, b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))), XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -static xla::XlaOp Softplus(xla::XlaBuilder* b, DataType dtype, - const xla::XlaOp& features) { - xla::XlaOp threshold = b->Add(b->Log(XlaHelpers::Epsilon(b, dtype)), - XlaHelpers::FloatLiteral(b, dtype, 2.0)); - // Value above which exp(x) may overflow, but softplus(x) == x - // is within machine epsilon. - xla::XlaOp too_large = b->Gt(features, b->Neg(threshold)); - // Value below which exp(x) may underflow, but softplus(x) == exp(x) - // is within machine epsilon. - xla::XlaOp too_small = b->Lt(features, threshold); - xla::XlaOp features_exp = b->Exp(features); - xla::XlaOp output = b->Select( - too_large, features, - b->Select(too_small, features_exp, - b->Log(b->Add(features_exp, XlaHelpers::One(b, dtype))))); - return output; -} -XLAJIT_MAKE_UNARY(Softplus, Softplus(b, input_type(0), x)); +// softplus(x) = log(1 + exp(x)) +// +// This is not numerically stable when x is large, it can easily overflow. +// However, we can compute it as LogSumExp(x, 0): +// max(x, 0) + log(exp(x - max(x, 0)) + exp(0 - max(x, 0))) +// +// This is equivalent to: +// max(x, 0) + log1p(exp(-abs(x))) +XLAJIT_MAKE_UNARY(Softplus, + b->Add(b->Max(x, XlaHelpers::Zero(b, input_type(0))), + b->Log1p(b->Exp(b->Neg(b->Abs(x)))))); // softsign(x) = x / (abs(x) + 1) XLAJIT_MAKE_UNARY(Softsign, -- 2.7.4