m_w = math_ops.reduce_mean(generated_activations, 0)
# Next the distance between means.
- mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm.
+ mean = math_ops.reduce_sum(
+ math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable.
mofid = mean
if activations_dtype != dtypes.float64:
mofid = math_ops.cast(mofid, activations_dtype)
(var + var_w) - 2.0 * math_ops.sqrt(math_ops.multiply(var, var_w)))
# Next the distance between means.
- mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm.
+ mean = math_ops.reduce_sum(
+ math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable.
dofid = trace + mean
if activations_dtype != dtypes.float64:
dofid = math_ops.cast(dofid, activations_dtype)
trace = math_ops.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component
# Next the distance between means.
- mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm.
+ mean = math_ops.reduce_sum(
+ math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable.
fid = trace + mean
if activations_dtype != dtypes.float64:
fid = math_ops.cast(fid, activations_dtype)