From 3dc861e754ffb86286038ef9c78327f59384eaad Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 27 Mar 2018 17:13:22 -0700 Subject: [PATCH] K-FAC: Bugfixes for TPU compatibility with covariance update ops. PiperOrigin-RevId: 190699635 --- tensorflow/contrib/kfac/python/ops/fisher_factors.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index 353e1c6..0d40d26 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -336,12 +336,16 @@ class FisherFactor(object): new_cov = math_ops.add_n(new_cov_contribs) / float(self._num_towers) - # I have no idea if the TPU code below is still correct since I don't know - # what it actually does. Also, this code is not present in some of the - # other versions of make_covariance_update_op. Does it matter? - # Synchronize value across all TPU cores. + # Compute average of 'new_cov' across all TPU cores. On a TPU, each + # instance of 'new_cov' will be based on a different minibatch. This ensures + # that by the end of assign_moving_average(), all TPU cores see the same + # value for self._cov. + # + # Other implementations of make_covariance_update_op() that accumulate + # statistics in other variables should mimic this behavior. if utils.on_tpu(): new_cov = utils.cross_replica_mean(new_cov) + return moving_averages.assign_moving_average( self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) @@ -1398,6 +1402,10 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor): new_cov_dt1 = (math_ops.add_n(new_cov_dt1_contribs) / float(self._num_towers)) + # See comments in FisherFactor.make_covariance_update_op() for details. + if utils.on_tpu(): + new_cov_dt1 = utils.cross_replica_mean(new_cov_dt1) + op2 = moving_averages.assign_moving_average( self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) -- 2.7.4