K-FAC: Cross Replica Mean for TPU
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 20 Dec 2017 00:31:31 +0000 (16:31 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 20 Dec 2017 02:14:09 +0000 (18:14 -0800)
Adds an op for taking the average of a Tensor across all TPU cores, and uses it
before updating covariance statistics. This is a no-op if TPUs aren't used.

PiperOrigin-RevId: 179620193

tensorflow/contrib/kfac/python/kernel_tests/BUILD
tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
tensorflow/contrib/kfac/python/ops/BUILD
tensorflow/contrib/kfac/python/ops/fisher_factors.py
tensorflow/contrib/kfac/python/ops/utils.py

index 95fba59e3c96ae3c69e0b154740785b0d2bcb3c9..4928bf2c10e5063cec6ce8d7374748239079998c 100644 (file)
@@ -110,6 +110,7 @@ py_test(
     srcs_version = "PY2AND3",
     deps = [
         "//tensorflow/contrib/kfac/python/ops:utils",
+        "//tensorflow/contrib/tpu",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:dtypes",
index 2622fdd0ed43cf8343ff672869850b91133f4541..c8631ed89ba58c5176b4b62344663e5d9e330926 100644 (file)
@@ -22,6 +22,7 @@ import numpy as np
 import numpy.random as npr
 
 from tensorflow.contrib.kfac.python.ops import utils
+from tensorflow.contrib.tpu.python.tpu import tpu_function
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import random_seed
@@ -267,6 +268,25 @@ class UtilsTest(test.TestCase):
       np_inv = np.linalg.inv(x + damp * np.eye(size))
       self.assertAllClose(sess.run(tf_inv), np_inv)
 
+  def testCrossReplicaMean(self):
+    """Ensures that cross_replica_mean() executes only when num_shards > 1."""
+    with ops.Graph().as_default():
+      with tpu_function.tpu_shard_context(4):
+        tensor = array_ops.zeros([], dtype=dtypes.float32)
+        mean = utils.cross_replica_mean(tensor)
+      self.assertNotEqual(mean, tensor)
+
+    with ops.Graph().as_default():
+      with tpu_function.tpu_shard_context(1):
+        tensor = array_ops.zeros([], dtype=dtypes.float32)
+        mean = utils.cross_replica_mean(tensor)
+      self.assertEqual(mean, tensor)
+
+    with ops.Graph().as_default():
+      with self.assertRaises(ValueError):  # Outside of TPU context.
+        tensor = array_ops.zeros([], dtype=dtypes.float32)
+        mean = utils.cross_replica_mean(tensor)
+
 
 if __name__ == '__main__':
   test.main()
index 3d731c7bc206d6f168e9b8f29b66bf4f1dbe8542..9be3d60dc06e0282370fed3882ad994f7e4bc64c 100644 (file)
@@ -196,6 +196,7 @@ py_library(
     srcs = ["utils.py"],
     srcs_version = "PY2AND3",
     deps = [
+        "//tensorflow/contrib/tpu",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:dtypes",
         "//tensorflow/python:framework_ops",
index 5a6d1a93ff217c3922f45a047b4d548086ac5258..e4e81cd13de9de90048683da66dc82d2453866ad 100644 (file)
@@ -267,6 +267,10 @@ class FisherFactor(object):
     new_cov = math_ops.add_n(
         tuple(self._compute_new_cov(idx) for idx in range(self._num_sources)))
 
+    # Synchronize value across all TPU cores.
+    if utils.on_tpu():
+      new_cov = utils.cross_replica_mean(new_cov)
+
     return moving_averages.assign_moving_average(
         self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)
 
index 5faed7c979b0db24798e4554fa11e8e38be074fb..48b191ef501f60118dd4583483d9d27dd5265621 100644 (file)
@@ -20,6 +20,8 @@ from __future__ import print_function
 
 import numpy as np
 
+from tensorflow.contrib.tpu.python.ops import tpu_ops
+from tensorflow.contrib.tpu.python.tpu import tpu_function
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
@@ -313,5 +315,34 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
 
   return dysdx
 
+
+def on_tpu():
+  """Returns True when building a TPU computation."""
+  return tpu_function.get_tpu_context().number_of_shards is not None
+
+
+def cross_replica_mean(tensor, name=None):
+  """Takes mean value of a Tensor across all TPU cores.
+
+  Args:
+    tensor: Tensor to be synchronized.
+    name: None or string. Name of Op.
+
+  Returns:
+    Average of Tensor across all TPU cores.
+
+  Raises:
+    ValueError: If called outside of TPU context.
+  """
+  with ops.name_scope(name, "cross_replica_mean", [tensor]):
+    num_shards = tpu_function.get_tpu_context().number_of_shards
+    if num_shards is None:
+      raise ValueError(
+          "Cannot take cross_replica_mean() outside of TPU Context.")
+    if num_shards == 1:
+      return tensor
+    return tpu_ops.cross_replica_sum(tensor / num_shards)
+
+
 # TODO(b/69623235): Add a function for finding tensors that share gradients
 # to eliminate redundant fisher factor computations.