From 48adc7ba73177f2a9331918b160bc3d0775985b8 Mon Sep 17 00:00:00 2001 From: Surya Bhupatiraju Date: Mon, 19 Mar 2018 18:51:06 -0700 Subject: [PATCH] Make L2 norm computation more stable. Avoids the potentially numerically instable square root in the linalg_ops.norm() function because we 'undo' that operation with a math_ops.square() operation anyway. PiperOrigin-RevId: 189677716 --- .../contrib/gan/python/eval/python/classifier_metrics_impl.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index 323cbe6..7e86d10 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -563,7 +563,8 @@ def mean_only_frechet_classifier_distance_from_activations( m_w = math_ops.reduce_mean(generated_activations, 0) # Next the distance between means. - mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. mofid = mean if activations_dtype != dtypes.float64: mofid = math_ops.cast(mofid, activations_dtype) @@ -637,7 +638,8 @@ def diagonal_only_frechet_classifier_distance_from_activations( (var + var_w) - 2.0 * math_ops.sqrt(math_ops.multiply(var, var_w))) # Next the distance between means. - mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. dofid = trace + mean if activations_dtype != dtypes.float64: dofid = math_ops.cast(dofid, activations_dtype) @@ -718,7 +720,8 @@ def frechet_classifier_distance_from_activations(real_activations, trace = math_ops.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component # Next the distance between means. - mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. fid = trace + mean if activations_dtype != dtypes.float64: fid = math_ops.cast(fid, activations_dtype) -- 2.7.4