From 0cde87e83883c1b98b28b41ed175922b845e650b Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Thu, 17 May 2018 15:28:33 -0700 Subject: [PATCH] [XLA] Use Expm1 in Elu/Selu exp(x) - 1 is best executed using the composed Expm1 operation as it is better behaved when exp(x) is near 1. PiperOrigin-RevId: 197061826 --- tensorflow/compiler/tests/unary_ops_test.py | 17 +++++++++++------ tensorflow/compiler/tf2xla/kernels/elu_op.cc | 6 ++---- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 52633f6..689a4a1 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -210,8 +210,7 @@ class UnaryOpsTest(XLATestCase): math_ops.expm1, np.array([[-1, 1]], dtype=dtype), expected=np.array([[-0.63212056, 1.71828183]], dtype=dtype), - rtol=1e-5, - atol=1e-6) + rtol=1e-5) self._assertOpOutputMatchesExpected( math_ops.floor, @@ -335,13 +334,19 @@ class UnaryOpsTest(XLATestCase): self._assertOpOutputMatchesExpected( nn_ops.elu, - np.array([[-1, 0, 1]], dtype=dtype), - expected=np.array([[-0.63212056, 0, 1]], dtype=dtype)) + np.array([[-1, 0, 1, -1e-6]], dtype=dtype), + expected=np.array([[-0.63212056, 0, 1, -9.999995e-07]], dtype=dtype), + rtol=1e-5, + atol=1e-6) self._assertOpOutputMatchesExpected( nn_ops.selu, - np.array([[-1, 0, 1]], dtype=dtype), - expected=np.array([[-1.11133074, 0., 1.05070099]], dtype=dtype)) + np.array([[-1, 0, 1, -1e-5]], dtype=dtype), + expected=np.array( + [[-1.11133074, 0., 1.05070099, -1.758090550379974e-05]], + dtype=dtype), + rtol=1e-5, + atol=1e-6) self._assertOpOutputMatchesExpected( nn_ops.relu, diff --git a/tensorflow/compiler/tf2xla/kernels/elu_op.cc b/tensorflow/compiler/tf2xla/kernels/elu_op.cc index ed7462c..493781a 100644 --- a/tensorflow/compiler/tf2xla/kernels/elu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/elu_op.cc @@ -34,9 +34,8 @@ class EluOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto one = XlaHelpers::One(b, input_type(0)); const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + const auto expm1 = b->Expm1(ctx->Input(0)); ctx->SetOutput(0, b->Select(pred, ctx->Input(0), expm1)); } }; @@ -68,13 +67,12 @@ class SeluOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::XlaBuilder* b = ctx->builder(); const auto zero = XlaHelpers::Zero(b, input_type(0)); - const auto one = XlaHelpers::One(b, input_type(0)); const auto scale = XlaHelpers::FloatLiteral(b, input_type(0), 1.0507009873554804934193349852946); const auto scale_alpha = XlaHelpers::FloatLiteral(b, input_type(0), 1.7580993408473768599402175208123); const auto pred = b->Gt(ctx->Input(0), zero); - const auto expm1 = b->Sub(b->Exp(ctx->Input(0)), one); + const auto expm1 = b->Expm1(ctx->Input(0)); ctx->SetOutput(0, b->Select(pred, b->Mul(scale, ctx->Input(0)), b->Mul(scale_alpha, expm1))); } -- 2.7.4