srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/kfac/python/ops:utils",
+ "//tensorflow/contrib/tpu",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
import numpy.random as npr
from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
np_inv = np.linalg.inv(x + damp * np.eye(size))
self.assertAllClose(sess.run(tf_inv), np_inv)
+ def testCrossReplicaMean(self):
+ """Ensures that cross_replica_mean() executes only when num_shards > 1."""
+ with ops.Graph().as_default():
+ with tpu_function.tpu_shard_context(4):
+ tensor = array_ops.zeros([], dtype=dtypes.float32)
+ mean = utils.cross_replica_mean(tensor)
+ self.assertNotEqual(mean, tensor)
+
+ with ops.Graph().as_default():
+ with tpu_function.tpu_shard_context(1):
+ tensor = array_ops.zeros([], dtype=dtypes.float32)
+ mean = utils.cross_replica_mean(tensor)
+ self.assertEqual(mean, tensor)
+
+ with ops.Graph().as_default():
+ with self.assertRaises(ValueError): # Outside of TPU context.
+ tensor = array_ops.zeros([], dtype=dtypes.float32)
+ mean = utils.cross_replica_mean(tensor)
+
if __name__ == '__main__':
test.main()
srcs = ["utils.py"],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/contrib/tpu",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
new_cov = math_ops.add_n(
tuple(self._compute_new_cov(idx) for idx in range(self._num_sources)))
+ # Synchronize value across all TPU cores.
+ if utils.on_tpu():
+ new_cov = utils.cross_replica_mean(new_cov)
+
return moving_averages.assign_moving_average(
self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)
import numpy as np
+from tensorflow.contrib.tpu.python.ops import tpu_ops
+from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
return dysdx
+
+def on_tpu():
+ """Returns True when building a TPU computation."""
+ return tpu_function.get_tpu_context().number_of_shards is not None
+
+
+def cross_replica_mean(tensor, name=None):
+ """Takes mean value of a Tensor across all TPU cores.
+
+ Args:
+ tensor: Tensor to be synchronized.
+ name: None or string. Name of Op.
+
+ Returns:
+ Average of Tensor across all TPU cores.
+
+ Raises:
+ ValueError: If called outside of TPU context.
+ """
+ with ops.name_scope(name, "cross_replica_mean", [tensor]):
+ num_shards = tpu_function.get_tpu_context().number_of_shards
+ if num_shards is None:
+ raise ValueError(
+ "Cannot take cross_replica_mean() outside of TPU Context.")
+ if num_shards == 1:
+ return tensor
+ return tpu_ops.cross_replica_sum(tensor / num_shards)
+
+
# TODO(b/69623235): Add a function for finding tensors that share gradients
# to eliminate redundant fisher factor computations.