From a6a862e90d1b336570ab67816ca14e191f5acb32 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 8 May 2018 08:07:08 -0700 Subject: [PATCH] [TF:XLA] Fix NaN in StatelessRandomNormal if the underlying uniform distribution returned -1. PiperOrigin-RevId: 195819645 --- tensorflow/compiler/tests/stateless_random_ops_test.py | 9 +++++++++ tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc | 3 ++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py index 4336ebd..b6f8390 100644 --- a/tensorflow/compiler/tests/stateless_random_ops_test.py +++ b/tensorflow/compiler/tests/stateless_random_ops_test.py @@ -86,6 +86,15 @@ class StatelessRandomOpsTest(XLATestCase): # seed were not fixed. self.assertTrue(self._chi_squared(y, 10) < 16.92) + def testRandomNormalIsFinite(self): + with self.test_session() as sess, self.test_scope(): + for dtype in self._random_types(): + seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) + x = stateless.stateless_random_uniform( + shape=[10000], seed=seed_t, dtype=dtype) + y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]}) + self.assertTrue(np.all(np.isfinite(y))) + def _normal_cdf(self, x): """Cumulative distribution function for a standard normal distribution.""" return 0.5 + 0.5 * np.vectorize(math.erf)(x / math.sqrt(2)) diff --git a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc index 6340c22..a99d4dd 100644 --- a/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stateless_random_ops.cc @@ -255,7 +255,8 @@ class StatelessRandomNormalOp : public XlaOpKernel { seed_shape.DebugString())); xla::XlaOp seed = ctx->Input(1); xla::XlaBuilder* builder = ctx->builder(); - auto uniform = RandomUniform(builder, seed, shape, -1.0, 1.0); + auto uniform = + RandomUniform(builder, seed, shape, std::nextafter(-1.0f, 0.0f), 1.0); // Convert uniform distribution to normal distribution by computing // sqrt(2) * erfinv(x) auto normal = builder->Mul(builder->ConstantR0(std::sqrt(2.0)), -- 2.7.4