Make V1 metrics distributed-aware. Also fix a bug where assertAllClose
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 17 May 2018 23:26:24 +0000 (16:26 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 17 May 2018 23:29:38 +0000 (16:29 -0700)
was sometimes ignoring its `msg` parameter.

PiperOrigin-RevId: 197070234

tensorflow/contrib/distribute/python/BUILD
tensorflow/contrib/distribute/python/metrics_v1_test.py [new file with mode: 0644]
tensorflow/python/BUILD
tensorflow/python/framework/test_util.py
tensorflow/python/ops/metrics_impl.py

index 64a77bb..aeeaa0b 100644 (file)
@@ -547,3 +547,22 @@ cuda_py_test(
         "no_pip",
     ],
 )
+
+cuda_py_test(
+    name = "metrics_v1_test",
+    srcs = ["metrics_v1_test.py"],
+    additional_deps = [
+        ":combinations",
+        "@absl_py//absl/testing:parameterized",
+        "//tensorflow/contrib/data/python/ops:batching",
+        "//tensorflow/python:math_ops",
+        "//tensorflow/python:metrics",
+        "//tensorflow/python:variables",
+        "//tensorflow/python/data/ops:dataset_ops",
+        "//tensorflow/python/eager:test",
+    ],
+    tags = [
+        "multi_and_single_gpu",
+        "no_pip",
+    ],
+)
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py
new file mode 100644 (file)
index 0000000..6c6bf14
--- /dev/null
@@ -0,0 +1,438 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for V1 metrics."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import test
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics
+from tensorflow.python.ops import variables
+
+
+def _labeled_dataset_fn():
+  # First four batches of x: labels, predictions -> (labels == predictions)
+  #  0: 0, 0 -> True;   1: 1, 1 -> True;   2: 2, 2 -> True;   3: 3, 0 -> False
+  #  4: 4, 1 -> False;  5: 0, 2 -> False;  6: 1, 0 -> False;  7: 2, 1 -> False
+  #  8: 3, 2 -> False;  9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False
+  # 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True
+  return dataset_ops.Dataset.range(1000).map(
+      lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4)
+
+
+def _boolean_dataset_fn():
+  # First four batches of labels, predictions: {TP, FP, TN, FN}
+  # with a threshold of 0.5:
+  #   T, T -> TP;  F, T -> FP;   T, F -> FN
+  #   F, F -> TN;  T, T -> TP;   F, T -> FP
+  #   T, F -> FN;  F, F -> TN;   T, T -> TP
+  #   F, T -> FP;  T, F -> FN;   F, F -> TN
+  return dataset_ops.Dataset.from_tensor_slices({
+      "labels": [True, False, True, False],
+      "predictions": [True, True, False, False]}).repeat().batch(3)
+
+
+def _threshold_dataset_fn():
+  # First four batches of labels, predictions: {TP, FP, TN, FN}
+  # with a threshold of 0.5:
+  #   True, 1.0 -> TP;  False, .75 -> FP;   True, .25 -> FN
+  #  False, 0.0 -> TN;   True, 1.0 -> TP;  False, .75 -> FP
+  #   True, .25 -> FN;  False, 0.0 -> TN;   True, 1.0 -> TP
+  #  False, .75 -> FP;   True, .25 -> FN;  False, 0.0 -> TN
+  return dataset_ops.Dataset.from_tensor_slices({
+      "labels": [True, False, True, False],
+      "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3)
+
+
+def _regression_dataset_fn():
+  return dataset_ops.Dataset.from_tensor_slices({
+      "labels": [1., .5, 1., 0.],
+      "predictions": [1., .75, .25, 0.]}).repeat()
+
+
+def all_combinations():
+  return combinations.combine(
+      distribution=[combinations.default_strategy,
+                    combinations.one_device_strategy,
+                    combinations.mirrored_strategy_with_gpu_and_cpu,
+                    combinations.mirrored_strategy_with_two_gpus],
+      mode=["graph"])
+
+
+# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k,
+# metrics.precision_at_k
+class MetricsV1Test(test.TestCase, parameterized.TestCase):
+
+  def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
+    with ops.Graph().as_default(), distribution.scope():
+      iterator = distribution.distribute_dataset(
+          dataset_fn).make_one_shot_iterator()
+      value, update = distribution.call_for_each_tower(
+          metric_fn, iterator.get_next())
+      update = distribution.group(update)
+      self.evaluate(variables.local_variables_initializer())
+      # TODO(josh11b): Once we switch to using a global batch size for input,
+      # replace "distribution.num_towers" with "1".
+      batches_per_update = distribution.num_towers
+
+      # Update variables using the first `num_towers` batches.
+      self.evaluate(update)
+      self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value),
+                          0.001, msg="After first update")
+
+      # Update variables using the second `num_towers` batches.
+      self.evaluate(update)
+      self.assertAllClose(expected_fn(2 * batches_per_update),
+                          self.evaluate(value),
+                          0.001,
+                          msg="After second update")
+
+      if batches_per_update == 1:  # Consume 4 input batches
+        self.evaluate(update)
+        self.assertAllClose(expected_fn(3 * batches_per_update),
+                            self.evaluate(value),
+                            0.001,
+                            msg="After third update")
+        self.evaluate(update)
+        self.assertAllClose(expected_fn(4 * batches_per_update),
+                            self.evaluate(value),
+                            0.001,
+                            msg="After fourth update")
+
+  @combinations.generate(all_combinations())
+  def testMean(self, distribution):
+    def _dataset_fn():
+      return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4)
+
+    def _expected_fn(num_batches):
+      # Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc.
+      return num_batches * 2 - 0.5
+
+    self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testAccuracy(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.accuracy(labels, predictions)
+
+    def _expected_fn(num_batches):
+      return [3./4, 3./8, 3./12, 4./16][num_batches - 1]
+
+    self._test_metric(
+        distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testMeanPerClassAccuracy(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.mean_per_class_accuracy(
+          labels, predictions, num_classes=5)
+
+    def _expected_fn(num_batches):
+      mean = lambda x: sum(x) / len(x)
+      return [mean([1., 1., 1., 0., 0.]),
+              mean([0.5, 0.5, 0.5, 0., 0.]),
+              mean([1./3, 1./3, 0.5, 0., 0.]),
+              mean([0.5, 1./3, 1./3, 0., 0.])][num_batches - 1]
+
+    self._test_metric(
+        distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testMeanIOU(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.mean_iou(
+          labels, predictions, num_classes=5)
+
+    def _expected_fn(num_batches):
+      mean = lambda x: sum(x) / len(x)
+      return [mean([1./2, 1./1, 1./1, 0.]),  # no class 4 in first batch
+              mean([1./4, 1./4, 1./3, 0., 0.]),
+              mean([1./6, 1./6, 1./5, 0., 0.]),
+              mean([2./8, 1./7, 1./7, 0., 0.])][num_batches - 1]
+
+    self._test_metric(
+        distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testMeanTensor(self, distribution):
+    def _dataset_fn():
+      dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float)
+      # Want to produce a fixed, known shape, so drop remainder when batching.
+      dataset = dataset.apply(batching.batch_and_drop_remainder(4))
+      return dataset
+
+    def _expected_fn(num_batches):
+      # Mean(0, 4, ..., 4 * num_batches - 4) == 2 * num_batches - 2
+      # Mean(1, 5, ..., 4 * num_batches - 3) == 2 * num_batches - 1
+      # Mean(2, 6, ..., 4 * num_batches - 2) == 2 * num_batches
+      # Mean(3, 7, ..., 4 * num_batches - 1) == 2 * num_batches + 1
+      first = 2. * num_batches - 2.
+      return [first, first + 1., first + 2., first + 3.]
+
+    self._test_metric(
+        distribution, _dataset_fn, metrics.mean_tensor, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testAUCROC(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.auc(labels, predictions, num_thresholds=8, curve="ROC",
+                         summation_method="careful_interpolation")
+
+    def _expected_fn(num_batches):
+      return [0.5, 7./9, 0.8, 0.75][num_batches - 1]
+
+    self._test_metric(
+        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testAUCPR(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.auc(labels, predictions, num_thresholds=8, curve="PR",
+                         summation_method="careful_interpolation")
+
+    def _expected_fn(num_batches):
+      return [0.797267, 0.851238, 0.865411, 0.797267][num_batches - 1]
+
+    self._test_metric(
+        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testFalseNegatives(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.false_negatives(labels, predictions)
+
+    def _expected_fn(num_batches):
+      return [1., 1., 2., 3.][num_batches - 1]
+
+    self._test_metric(
+        distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testFalseNegativesAtThresholds(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.false_negatives_at_thresholds(labels, predictions, [.5])
+
+    def _expected_fn(num_batches):
+      return [[1.], [1.], [2.], [3.]][num_batches - 1]
+
+    self._test_metric(
+        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testTrueNegatives(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.true_negatives(labels, predictions)
+
+    def _expected_fn(num_batches):
+      return [0., 1., 2., 3.][num_batches - 1]
+
+    self._test_metric(
+        distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testTrueNegativesAtThresholds(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.true_negatives_at_thresholds(labels, predictions, [.5])
+
+    def _expected_fn(num_batches):
+      return [[0.], [1.], [2.], [3.]][num_batches - 1]
+
+    self._test_metric(
+        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testFalsePositives(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.false_positives(labels, predictions)
+
+    def _expected_fn(num_batches):
+      return [1., 2., 2., 3.][num_batches - 1]
+
+    self._test_metric(
+        distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testFalsePositivesAtThresholds(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.false_positives_at_thresholds(labels, predictions, [.5])
+
+    def _expected_fn(num_batches):
+      return [[1.], [2.], [2.], [3.]][num_batches - 1]
+
+    self._test_metric(
+        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testTruePositives(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.true_positives(labels, predictions)
+
+    def _expected_fn(num_batches):
+      return [1., 2., 3., 3.][num_batches - 1]
+
+    self._test_metric(
+        distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testTruePositivesAtThresholds(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.true_positives_at_thresholds(labels, predictions, [.5])
+
+    def _expected_fn(num_batches):
+      return [[1.], [2.], [3.], [3.]][num_batches - 1]
+
+    self._test_metric(
+        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testPrecision(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.precision(labels, predictions)
+
+    def _expected_fn(num_batches):
+      return [0.5, 0.5, 0.6, 0.5][num_batches - 1]
+
+    self._test_metric(
+        distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testPrecisionAtThreshold(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.precision_at_thresholds(labels, predictions, [0.5])
+
+    def _expected_fn(num_batches):
+      return [[0.5], [0.5], [0.6], [0.5]][num_batches - 1]
+
+    self._test_metric(
+        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testRecall(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.recall(labels, predictions)
+
+    def _expected_fn(num_batches):
+      return [0.5, 2./3, 0.6, 0.5][num_batches - 1]
+
+    self._test_metric(
+        distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testRecallAtThreshold(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.recall_at_thresholds(labels, predictions, [0.5])
+
+    def _expected_fn(num_batches):
+      return [[0.5], [2./3], [0.6], [0.5]][num_batches - 1]
+
+    self._test_metric(
+        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testMeanSquaredError(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.mean_squared_error(labels, predictions)
+
+    def _expected_fn(num_batches):
+      return [0., 1./32, 0.208333, 0.15625][num_batches - 1]
+
+    self._test_metric(
+        distribution, _regression_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testRootMeanSquaredError(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.root_mean_squared_error(labels, predictions)
+
+    def _expected_fn(num_batches):
+      return [0., 0.176777, 0.456435, 0.395285][num_batches - 1]
+
+    self._test_metric(
+        distribution, _regression_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testSensitivityAtSpecificity(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.sensitivity_at_specificity(labels, predictions, 0.8)
+
+    def _expected_fn(num_batches):
+      return [0.5, 2./3, 0.6, 0.5][num_batches - 1]
+
+    self._test_metric(
+        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+  @combinations.generate(all_combinations())
+  def testSpecificityAtSensitivity(self, distribution):
+    def _metric_fn(x):
+      labels = x["labels"]
+      predictions = x["predictions"]
+      return metrics.specificity_at_sensitivity(labels, predictions, 0.95)
+
+    def _expected_fn(num_batches):
+      return [0., 1./3, 0.5, 0.5][num_batches - 1]
+
+    self._test_metric(
+        distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+
+if __name__ == "__main__":
+  test.main()
index f714d1f..cb72248 100644 (file)
@@ -2461,6 +2461,7 @@ py_library(
         ":check_ops",
         ":confusion_matrix",
         ":control_flow_ops",
+        ":distribute",
         ":framework",
         ":framework_for_generated_wrappers",
         ":math_ops",
index 97cd22e..bf382a2 100644 (file)
@@ -1328,11 +1328,11 @@ class TensorFlowTestCase(googletest.TestCase):
             b,
             rtol=rtol,
             atol=atol,
-            msg="Mismatched value: a%s is different from b%s." % (path_str,
-                                                                  path_str))
+            msg=("Mismatched value: a%s is different from b%s. %s" %
+                 (path_str, path_str, msg)))
       except TypeError as e:
-        msg = "Error: a%s has %s, but b%s has %s" % (path_str, type(a),
-                                                     path_str, type(b))
+        msg = ("Error: a%s has %s, but b%s has %s. %s" %
+               (path_str, type(a), path_str, type(b), msg))
         e.args = ((e.args[0] + " : " + msg,) + e.args[1:])
         raise
 
index 47eea6e..244e28d 100644 (file)
@@ -34,21 +34,54 @@ from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import weights_broadcast_ops
 from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import distribute as distribute_lib
 from tensorflow.python.util.deprecation import deprecated
 from tensorflow.python.util.tf_export import tf_export
 
 
 def metric_variable(shape, dtype, validate_shape=True, name=None):
-  """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections."""
-
-  return variable_scope.variable(
-      lambda: array_ops.zeros(shape, dtype),
-      trainable=False,
-      collections=[
-          ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
-      ],
-      validate_shape=validate_shape,
-      name=name)
+  """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections.
+
+  If running in a `DistributionStrategy` context, the variable will be
+  "tower local". This means:
+
+  *   The returned object will be a container with separate variables
+      per replica/tower of the model.
+
+  *   When writing to the variable, e.g. using `assign_add` in a metric
+      update, the update will be applied to the variable local to the
+      replica/tower.
+
+  *   To get a metric's result value, we need to sum the variable values
+      across the replicas/towers before computing the final answer.
+      Furthermore, the final answer should be computed once instead of
+      in every replica/tower. Both of these are accomplished by
+      running the computation of the final result value inside
+      `tf.contrib.distribute.get_tower_context().merge_call(fn)`.
+      Inside the `merge_call()`, ops are only added to the graph once
+      and access to a tower-local variable in a computation returns
+      the sum across all replicas/towers.
+
+  Args:
+    shape: Shape of the created variable.
+    dtype: Type of the created variable.
+    validate_shape: (Optional) Whether shape validation is enabled for
+      the created variable.
+    name: (Optional) String name of the created variable.
+
+  Returns:
+    A (non-trainable) variable initialized to zero, or if inside a
+    `DistributionStrategy` scope a tower-local variable container.
+  """
+  with distribute_lib.get_tower_context().tower_local_var_scope('sum'):
+    # Note that "tower local" implies trainable=False.
+    return variable_scope.variable(
+        lambda: array_ops.zeros(shape, dtype),
+        collections=[
+            ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
+        ],
+        validate_shape=validate_shape,
+        name=name)
 
 
 def _remove_squeezable_dimensions(predictions, labels, weights):
@@ -333,11 +366,15 @@ def mean(values,
     with ops.control_dependencies([values]):
       update_count_op = state_ops.assign_add(count, num_values)
 
-    mean_t = _safe_div(total, count, 'value')
-    update_op = _safe_div(update_total_op, update_count_op, 'update_op')
+    def aggregate_across_towers(_, t, c):
+      mean_t = _safe_div(t, c, 'value')
+      if metrics_collections:
+        ops.add_to_collections(metrics_collections, mean_t)
+      return mean_t
 
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, mean_t)
+    mean_t = distribute_lib.get_tower_context().merge_call(
+        aggregate_across_towers, total, count)
+    update_op = _safe_div(update_total_op, update_count_op, 'update_op')
 
     if updates_collections:
       ops.add_to_collections(updates_collections, update_op)
@@ -572,6 +609,17 @@ def _confusion_matrix_at_thresholds(labels,
   return values, update_ops
 
 
+def _aggregate_variable(v, collections):
+
+  def f(distribution, value):
+    value = distribution.fetch(value)
+    if collections:
+      ops.add_to_collections(collections, value)
+    return value
+
+  return distribute_lib.get_tower_context().merge_call(f, v)
+
+
 @tf_export('metrics.auc')
 def auc(labels,
         predictions,
@@ -757,14 +805,18 @@ def auc(labels,
         raise ValueError('Invalid summation_method: %s' % summation_method)
 
     # sum up the areas of all the trapeziums
-    auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
-                            values['fp'], 'value')
+    def aggregate_auc(_, values):
+      auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
+                              values['fp'], 'value')
+      if metrics_collections:
+        ops.add_to_collections(metrics_collections, auc_value)
+      return auc_value
+
+    auc_value = distribute_lib.get_tower_context().merge_call(
+        aggregate_auc, values)
     update_op = compute_auc(update_ops['tp'], update_ops['fn'],
                             update_ops['tn'], update_ops['fp'], 'update_op')
 
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, auc_value)
-
     if updates_collections:
       ops.add_to_collections(updates_collections, update_op)
 
@@ -992,15 +1044,18 @@ def mean_per_class_accuracy(labels,
     update_total_op = state_ops.scatter_add(total, labels, ones)
     update_count_op = state_ops.scatter_add(count, labels, is_correct)
 
-    per_class_accuracy = _safe_div(count, total, None)
+    def aggregate_mean_accuracy(_, count, total):
+      per_class_accuracy = _safe_div(count, total, None)
+      mean_accuracy_v = math_ops.reduce_mean(
+          per_class_accuracy, name='mean_accuracy')
+      if metrics_collections:
+        ops.add_to_collections(metrics_collections, mean_accuracy_v)
+      return mean_accuracy_v
 
-    mean_accuracy_v = math_ops.reduce_mean(
-        per_class_accuracy, name='mean_accuracy')
-    update_op = _safe_div(update_count_op, update_total_op, name='update_op')
-
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, mean_accuracy_v)
+    mean_accuracy_v = distribute_lib.get_tower_context().merge_call(
+        aggregate_mean_accuracy, count, total)
 
+    update_op = _safe_div(update_count_op, update_total_op, name='update_op')
     if updates_collections:
       ops.add_to_collections(updates_collections, update_op)
 
@@ -1071,7 +1126,7 @@ def mean_iou(labels,
     total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
                                                       num_classes, weights)
 
-    def compute_mean_iou(name):
+    def compute_mean_iou(total_cm, name):
       """Compute the mean intersection-over-union via the confusion matrix."""
       sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
       sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
@@ -1098,10 +1153,14 @@ def mean_iou(labels,
           math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0)
       return result
 
-    mean_iou_v = compute_mean_iou('mean_iou')
+    def mean_iou_across_towers(_, v):
+      mean_iou_v = compute_mean_iou(v, 'mean_iou')
+      if metrics_collections:
+        ops.add_to_collections(metrics_collections, mean_iou_v)
+      return mean_iou_v
 
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, mean_iou_v)
+    mean_iou_v = distribute_lib.get_tower_context().merge_call(
+        mean_iou_across_towers, total_cm)
 
     if updates_collections:
       ops.add_to_collections(updates_collections, update_op)
@@ -1310,12 +1369,16 @@ def mean_tensor(values,
     with ops.control_dependencies([values]):
       update_count_op = state_ops.assign_add(count, num_values)
 
-    mean_t = _safe_div(total, count, 'value')
-    update_op = _safe_div(update_total_op, update_count_op, 'update_op')
+    def aggregate_across_towers(_, t, c):
+      mean_t = _safe_div(t, c, 'value')
+      if metrics_collections:
+        ops.add_to_collections(metrics_collections, mean_t)
+      return mean_t
 
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, mean_t)
+    mean_t = distribute_lib.get_tower_context().merge_call(
+        aggregate_across_towers, total, count)
 
+    update_op = _safe_div(update_total_op, update_count_op, 'update_op')
     if updates_collections:
       ops.add_to_collections(updates_collections, update_op)
 
@@ -1413,12 +1476,9 @@ def _count_condition(values,
       weights = math_ops.to_float(weights)
       values = math_ops.multiply(values, weights)
 
-  value_tensor = array_ops.identity(count)
-  update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
-
-  if metrics_collections:
-    ops.add_to_collections(metrics_collections, value_tensor)
+  value_tensor = _aggregate_variable(count, metrics_collections)
 
+  update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
   if updates_collections:
     ops.add_to_collections(updates_collections, update_op)
 
@@ -1525,13 +1585,12 @@ def false_negatives_at_thresholds(labels,
     values, update_ops = _confusion_matrix_at_thresholds(
         labels, predictions, thresholds, weights=weights, includes=('fn',))
 
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, values['fn'])
+    fn_value = _aggregate_variable(values['fn'], metrics_collections)
 
     if updates_collections:
       ops.add_to_collections(updates_collections, update_ops['fn'])
 
-    return values['fn'], update_ops['fn']
+    return fn_value, update_ops['fn']
 
 
 @tf_export('metrics.false_positives')
@@ -1635,13 +1694,12 @@ def false_positives_at_thresholds(labels,
     values, update_ops = _confusion_matrix_at_thresholds(
         labels, predictions, thresholds, weights=weights, includes=('fp',))
 
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, values['fp'])
+    fp_value = _aggregate_variable(values['fp'], metrics_collections)
 
     if updates_collections:
       ops.add_to_collections(updates_collections, update_ops['fp'])
 
-    return values['fp'], update_ops['fp']
+    return fp_value, update_ops['fp']
 
 
 @tf_export('metrics.true_negatives')
@@ -1745,13 +1803,12 @@ def true_negatives_at_thresholds(labels,
     values, update_ops = _confusion_matrix_at_thresholds(
         labels, predictions, thresholds, weights=weights, includes=('tn',))
 
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, values['tn'])
+    tn_value = _aggregate_variable(values['tn'], metrics_collections)
 
     if updates_collections:
       ops.add_to_collections(updates_collections, update_ops['tn'])
 
-    return values['tn'], update_ops['tn']
+    return tn_value, update_ops['tn']
 
 
 @tf_export('metrics.true_positives')
@@ -1855,13 +1912,12 @@ def true_positives_at_thresholds(labels,
     values, update_ops = _confusion_matrix_at_thresholds(
         labels, predictions, thresholds, weights=weights, includes=('tp',))
 
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, values['tp'])
+    tp_value = _aggregate_variable(values['tp'], metrics_collections)
 
     if updates_collections:
       ops.add_to_collections(updates_collections, update_ops['tp'])
 
-    return values['tp'], update_ops['tp']
+    return tp_value, update_ops['tp']
 
 
 @tf_export('metrics.precision')
@@ -1945,13 +2001,17 @@ def precision(labels,
       return array_ops.where(
           math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
 
-    p = compute_precision(true_p, false_p, 'value')
-    update_op = compute_precision(true_positives_update_op,
-                                  false_positives_update_op, 'update_op')
+    def once_across_towers(_, true_p, false_p):
+      p = compute_precision(true_p, false_p, 'value')
+      if metrics_collections:
+        ops.add_to_collections(metrics_collections, p)
+      return p
 
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, p)
+    p = distribute_lib.get_tower_context().merge_call(
+        once_across_towers, true_p, false_p)
 
+    update_op = compute_precision(true_positives_update_op,
+                                  false_positives_update_op, 'update_op')
     if updates_collections:
       ops.add_to_collections(updates_collections, update_op)
 
@@ -2025,13 +2085,17 @@ def precision_at_thresholds(labels,
     def compute_precision(tp, fp, name):
       return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
 
-    prec = compute_precision(values['tp'], values['fp'], 'value')
-    update_op = compute_precision(update_ops['tp'], update_ops['fp'],
-                                  'update_op')
+    def precision_across_towers(_, values):
+      prec = compute_precision(values['tp'], values['fp'], 'value')
+      if metrics_collections:
+        ops.add_to_collections(metrics_collections, prec)
+      return prec
 
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, prec)
+    prec = distribute_lib.get_tower_context().merge_call(
+        precision_across_towers, values)
 
+    update_op = compute_precision(update_ops['tp'], update_ops['fp'],
+                                  'update_op')
     if updates_collections:
       ops.add_to_collections(updates_collections, update_op)
 
@@ -2050,7 +2114,7 @@ def recall(labels,
   The `recall` function creates two local variables, `true_positives`
   and `false_negatives`, that are used to compute the recall. This value is
   ultimately returned as `recall`, an idempotent operation that simply divides
-  `true_positives` by the sum of `true_positives`  and `false_negatives`.
+  `true_positives` by the sum of `true_positives` and `false_negatives`.
 
   For estimation of the metric over a stream of data, the function creates an
   `update_op` that updates these variables and returns the `recall`. `update_op`
@@ -2117,13 +2181,17 @@ def recall(labels,
           math_ops.greater(true_p + false_n, 0),
           math_ops.div(true_p, true_p + false_n), 0, name)
 
-    rec = compute_recall(true_p, false_n, 'value')
-    update_op = compute_recall(true_positives_update_op,
-                               false_negatives_update_op, 'update_op')
+    def once_across_towers(_, true_p, false_n):
+      rec = compute_recall(true_p, false_n, 'value')
+      if metrics_collections:
+        ops.add_to_collections(metrics_collections, rec)
+      return rec
 
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, rec)
+    rec = distribute_lib.get_tower_context().merge_call(
+        once_across_towers, true_p, false_n)
 
+    update_op = compute_recall(true_positives_update_op,
+                               false_negatives_update_op, 'update_op')
     if updates_collections:
       ops.add_to_collections(updates_collections, update_op)
 
@@ -2552,11 +2620,17 @@ def recall_at_top_k(labels,
         class_id=class_id,
         weights=weights)
 
-    metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
+    def aggregate_across_towers(_, tp, fn):
+      metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
+      if metrics_collections:
+        ops.add_to_collections(metrics_collections, metric)
+      return metric
+
+    metric = distribute_lib.get_tower_context().merge_call(
+        aggregate_across_towers, tp, fn)
+
     update = math_ops.div(
         tp_update, math_ops.add(tp_update, fn_update), name='update')
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, metric)
     if updates_collections:
       ops.add_to_collections(updates_collections, update)
     return metric, update
@@ -2627,12 +2701,16 @@ def recall_at_thresholds(labels,
     def compute_recall(tp, fn, name):
       return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
 
-    rec = compute_recall(values['tp'], values['fn'], 'value')
-    update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
+    def recall_across_towers(_, values):
+      rec = compute_recall(values['tp'], values['fn'], 'value')
+      if metrics_collections:
+        ops.add_to_collections(metrics_collections, rec)
+      return rec
 
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, rec)
+    rec = distribute_lib.get_tower_context().merge_call(
+        recall_across_towers, values)
 
+    update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
     if updates_collections:
       ops.add_to_collections(updates_collections, update_op)
 
@@ -2698,13 +2776,16 @@ def root_mean_squared_error(labels,
   mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
                                           None, name or
                                           'root_mean_squared_error')
+  def once_across_towers(_, mse):
+    rmse = math_ops.sqrt(mse)
+    if metrics_collections:
+      ops.add_to_collections(metrics_collections, rmse)
+    return rmse
 
-  rmse = math_ops.sqrt(mse)
-  update_rmse_op = math_ops.sqrt(update_mse_op)
-
-  if metrics_collections:
-    ops.add_to_collections(metrics_collections, rmse)
+  rmse = distribute_lib.get_tower_context().merge_call(
+      once_across_towers, mse)
 
+  update_rmse_op = math_ops.sqrt(update_mse_op)
   if updates_collections:
     ops.add_to_collections(updates_collections, update_rmse_op)
 
@@ -2797,15 +2878,19 @@ def sensitivity_at_specificity(labels,
       return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
                           name)
 
-    sensitivity = compute_sensitivity_at_specificity(
-        values['tp'], values['tn'], values['fp'], values['fn'], 'value')
+    def aggregate_across_towers(_, values):
+      sensitivity = compute_sensitivity_at_specificity(
+          values['tp'], values['tn'], values['fp'], values['fn'], 'value')
+      if metrics_collections:
+        ops.add_to_collections(metrics_collections, sensitivity)
+      return sensitivity
+
+    sensitivity = distribute_lib.get_tower_context().merge_call(
+        aggregate_across_towers, values)
+
     update_op = compute_sensitivity_at_specificity(
         update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
         'update_op')
-
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, sensitivity)
-
     if updates_collections:
       ops.add_to_collections(updates_collections, update_op)
 
@@ -3070,11 +3155,16 @@ def _streaming_sparse_average_precision_at_top_k(labels,
       total_update = state_ops.assign_add(total_var, batch_total, name='update')
 
     # Divide total by max to get mean, for both vars and the update ops.
-    mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
-    update = _safe_scalar_div(total_update, max_update, name=scope)
+    def aggregate_across_towers(_, total_var, max_var):
+      mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
+      if metrics_collections:
+        ops.add_to_collections(metrics_collections, mean_average_precision)
+      return mean_average_precision
 
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, mean_average_precision)
+    mean_average_precision = distribute_lib.get_tower_context().merge_call(
+        aggregate_across_towers, total_var, max_var)
+
+    update = _safe_scalar_div(total_update, max_update, name=scope)
     if updates_collections:
       ops.add_to_collections(updates_collections, update)
 
@@ -3351,11 +3441,17 @@ def precision_at_top_k(labels,
         class_id=class_id,
         weights=weights)
 
-    metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
+    def aggregate_across_towers(_, tp, fp):
+      metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
+      if metrics_collections:
+        ops.add_to_collections(metrics_collections, metric)
+      return metric
+
+    metric = distribute_lib.get_tower_context().merge_call(
+        aggregate_across_towers, tp, fp)
+
     update = math_ops.div(
         tp_update, math_ops.add(tp_update, fp_update), name='update')
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, metric)
     if updates_collections:
       ops.add_to_collections(updates_collections, update)
     return metric, update
@@ -3583,15 +3679,19 @@ def specificity_at_sensitivity(labels,
       return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
                           name)
 
-    specificity = compute_specificity_at_sensitivity(
-        values['tp'], values['tn'], values['fp'], values['fn'], 'value')
+    def aggregate_across_towers(_, values):
+      specificity = compute_specificity_at_sensitivity(
+          values['tp'], values['tn'], values['fp'], values['fn'], 'value')
+      if metrics_collections:
+        ops.add_to_collections(metrics_collections, specificity)
+      return specificity
+
+    specificity = distribute_lib.get_tower_context().merge_call(
+        aggregate_across_towers, values)
+
     update_op = compute_specificity_at_sensitivity(
         update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
         'update_op')
-
-    if metrics_collections:
-      ops.add_to_collections(metrics_collections, specificity)
-
     if updates_collections:
       ops.add_to_collections(updates_collections, update_op)