Make L2 norm computation more stable.
authorSurya Bhupatiraju <sbhupatiraju@google.com>
Tue, 20 Mar 2018 01:51:06 +0000 (18:51 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Mar 2018 01:55:52 +0000 (18:55 -0700)
Avoids the potentially numerically instable square root in the linalg_ops.norm() function because we 'undo' that operation with a math_ops.square() operation anyway.

PiperOrigin-RevId: 189677716

tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py

index 323cbe6..7e86d10 100644 (file)
@@ -563,7 +563,8 @@ def mean_only_frechet_classifier_distance_from_activations(
   m_w = math_ops.reduce_mean(generated_activations, 0)
 
   # Next the distance between means.
-  mean = math_ops.square(linalg_ops.norm(m - m_w))  # This uses the L2 norm.
+  mean = math_ops.reduce_sum(
+      math_ops.squared_difference(m, m_w))  # Equivalent to L2 but more stable.
   mofid = mean
   if activations_dtype != dtypes.float64:
     mofid = math_ops.cast(mofid, activations_dtype)
@@ -637,7 +638,8 @@ def diagonal_only_frechet_classifier_distance_from_activations(
       (var + var_w) - 2.0 * math_ops.sqrt(math_ops.multiply(var, var_w)))
 
   # Next the distance between means.
-  mean = math_ops.square(linalg_ops.norm(m - m_w))  # This uses the L2 norm.
+  mean = math_ops.reduce_sum(
+      math_ops.squared_difference(m, m_w))  # Equivalent to L2 but more stable.
   dofid = trace + mean
   if activations_dtype != dtypes.float64:
     dofid = math_ops.cast(dofid, activations_dtype)
@@ -718,7 +720,8 @@ def frechet_classifier_distance_from_activations(real_activations,
   trace = math_ops.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component
 
   # Next the distance between means.
-  mean = math_ops.square(linalg_ops.norm(m - m_w))  # This uses the L2 norm.
+  mean = math_ops.reduce_sum(
+      math_ops.squared_difference(m, m_w))  # Equivalent to L2 but more stable.
   fid = trace + mean
   if activations_dtype != dtypes.float64:
     fid = math_ops.cast(fid, activations_dtype)