From: Surya Bhupatiraju Date: Tue, 20 Mar 2018 01:51:06 +0000 (-0700) Subject: Make L2 norm computation more stable. X-Git-Tag: tflite-v0.1.7~145^2^2~35 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=48adc7ba73177f2a9331918b160bc3d0775985b8;p=platform%2Fupstream%2Ftensorflow.git Make L2 norm computation more stable. Avoids the potentially numerically instable square root in the linalg_ops.norm() function because we 'undo' that operation with a math_ops.square() operation anyway. PiperOrigin-RevId: 189677716 --- diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index 323cbe6..7e86d10 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -563,7 +563,8 @@ def mean_only_frechet_classifier_distance_from_activations( m_w = math_ops.reduce_mean(generated_activations, 0) # Next the distance between means. - mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. mofid = mean if activations_dtype != dtypes.float64: mofid = math_ops.cast(mofid, activations_dtype) @@ -637,7 +638,8 @@ def diagonal_only_frechet_classifier_distance_from_activations( (var + var_w) - 2.0 * math_ops.sqrt(math_ops.multiply(var, var_w))) # Next the distance between means. - mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. dofid = trace + mean if activations_dtype != dtypes.float64: dofid = math_ops.cast(dofid, activations_dtype) @@ -718,7 +720,8 @@ def frechet_classifier_distance_from_activations(real_activations, trace = math_ops.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component # Next the distance between means. - mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. fid = trace + mean if activations_dtype != dtypes.float64: fid = math_ops.cast(fid, activations_dtype)