From 485571dd92532ed5d6989e419b4ee87342c18cbf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 May 2018 16:26:24 -0700 Subject: [PATCH] Make V1 metrics distributed-aware. Also fix a bug where assertAllClose was sometimes ignoring its `msg` parameter. PiperOrigin-RevId: 197070234 --- tensorflow/contrib/distribute/python/BUILD | 19 + .../contrib/distribute/python/metrics_v1_test.py | 438 +++++++++++++++++++++ tensorflow/python/BUILD | 1 + tensorflow/python/framework/test_util.py | 8 +- tensorflow/python/ops/metrics_impl.py | 296 +++++++++----- 5 files changed, 660 insertions(+), 102 deletions(-) create mode 100644 tensorflow/contrib/distribute/python/metrics_v1_test.py diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 64a77bb..aeeaa0b 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -547,3 +547,22 @@ cuda_py_test( "no_pip", ], ) + +cuda_py_test( + name = "metrics_v1_test", + srcs = ["metrics_v1_test.py"], + additional_deps = [ + ":combinations", + "@absl_py//absl/testing:parameterized", + "//tensorflow/contrib/data/python/ops:batching", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:test", + ], + tags = [ + "multi_and_single_gpu", + "no_pip", + ], +) diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py new file mode 100644 index 0000000..6c6bf14 --- /dev/null +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -0,0 +1,438 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for V1 metrics.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.distribute.python import combinations +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics +from tensorflow.python.ops import variables + + +def _labeled_dataset_fn(): + # First four batches of x: labels, predictions -> (labels == predictions) + # 0: 0, 0 -> True; 1: 1, 1 -> True; 2: 2, 2 -> True; 3: 3, 0 -> False + # 4: 4, 1 -> False; 5: 0, 2 -> False; 6: 1, 0 -> False; 7: 2, 1 -> False + # 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False + # 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True + return dataset_ops.Dataset.range(1000).map( + lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4) + + +def _boolean_dataset_fn(): + # First four batches of labels, predictions: {TP, FP, TN, FN} + # with a threshold of 0.5: + # T, T -> TP; F, T -> FP; T, F -> FN + # F, F -> TN; T, T -> TP; F, T -> FP + # T, F -> FN; F, F -> TN; T, T -> TP + # F, T -> FP; T, F -> FN; F, F -> TN + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [True, False, True, False], + "predictions": [True, True, False, False]}).repeat().batch(3) + + +def _threshold_dataset_fn(): + # First four batches of labels, predictions: {TP, FP, TN, FN} + # with a threshold of 0.5: + # True, 1.0 -> TP; False, .75 -> FP; True, .25 -> FN + # False, 0.0 -> TN; True, 1.0 -> TP; False, .75 -> FP + # True, .25 -> FN; False, 0.0 -> TN; True, 1.0 -> TP + # False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [True, False, True, False], + "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3) + + +def _regression_dataset_fn(): + return dataset_ops.Dataset.from_tensor_slices({ + "labels": [1., .5, 1., 0.], + "predictions": [1., .75, .25, 0.]}).repeat() + + +def all_combinations(): + return combinations.combine( + distribution=[combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus], + mode=["graph"]) + + +# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k, +# metrics.precision_at_k +class MetricsV1Test(test.TestCase, parameterized.TestCase): + + def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): + with ops.Graph().as_default(), distribution.scope(): + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() + value, update = distribution.call_for_each_tower( + metric_fn, iterator.get_next()) + update = distribution.group(update) + self.evaluate(variables.local_variables_initializer()) + # TODO(josh11b): Once we switch to using a global batch size for input, + # replace "distribution.num_towers" with "1". + batches_per_update = distribution.num_towers + + # Update variables using the first `num_towers` batches. + self.evaluate(update) + self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value), + 0.001, msg="After first update") + + # Update variables using the second `num_towers` batches. + self.evaluate(update) + self.assertAllClose(expected_fn(2 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After second update") + + if batches_per_update == 1: # Consume 4 input batches + self.evaluate(update) + self.assertAllClose(expected_fn(3 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After third update") + self.evaluate(update) + self.assertAllClose(expected_fn(4 * batches_per_update), + self.evaluate(value), + 0.001, + msg="After fourth update") + + @combinations.generate(all_combinations()) + def testMean(self, distribution): + def _dataset_fn(): + return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4) + + def _expected_fn(num_batches): + # Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc. + return num_batches * 2 - 0.5 + + self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn) + + @combinations.generate(all_combinations()) + def testAccuracy(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.accuracy(labels, predictions) + + def _expected_fn(num_batches): + return [3./4, 3./8, 3./12, 4./16][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanPerClassAccuracy(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_per_class_accuracy( + labels, predictions, num_classes=5) + + def _expected_fn(num_batches): + mean = lambda x: sum(x) / len(x) + return [mean([1., 1., 1., 0., 0.]), + mean([0.5, 0.5, 0.5, 0., 0.]), + mean([1./3, 1./3, 0.5, 0., 0.]), + mean([0.5, 1./3, 1./3, 0., 0.])][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanIOU(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_iou( + labels, predictions, num_classes=5) + + def _expected_fn(num_batches): + mean = lambda x: sum(x) / len(x) + return [mean([1./2, 1./1, 1./1, 0.]), # no class 4 in first batch + mean([1./4, 1./4, 1./3, 0., 0.]), + mean([1./6, 1./6, 1./5, 0., 0.]), + mean([2./8, 1./7, 1./7, 0., 0.])][num_batches - 1] + + self._test_metric( + distribution, _labeled_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanTensor(self, distribution): + def _dataset_fn(): + dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float) + # Want to produce a fixed, known shape, so drop remainder when batching. + dataset = dataset.apply(batching.batch_and_drop_remainder(4)) + return dataset + + def _expected_fn(num_batches): + # Mean(0, 4, ..., 4 * num_batches - 4) == 2 * num_batches - 2 + # Mean(1, 5, ..., 4 * num_batches - 3) == 2 * num_batches - 1 + # Mean(2, 6, ..., 4 * num_batches - 2) == 2 * num_batches + # Mean(3, 7, ..., 4 * num_batches - 1) == 2 * num_batches + 1 + first = 2. * num_batches - 2. + return [first, first + 1., first + 2., first + 3.] + + self._test_metric( + distribution, _dataset_fn, metrics.mean_tensor, _expected_fn) + + @combinations.generate(all_combinations()) + def testAUCROC(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.auc(labels, predictions, num_thresholds=8, curve="ROC", + summation_method="careful_interpolation") + + def _expected_fn(num_batches): + return [0.5, 7./9, 0.8, 0.75][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testAUCPR(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.auc(labels, predictions, num_thresholds=8, curve="PR", + summation_method="careful_interpolation") + + def _expected_fn(num_batches): + return [0.797267, 0.851238, 0.865411, 0.797267][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalseNegatives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_negatives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 1., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalseNegativesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_negatives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [1.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTrueNegatives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_negatives(labels, predictions) + + def _expected_fn(num_batches): + return [0., 1., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTrueNegativesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_negatives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[0.], [1.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalsePositives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_positives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 2., 2., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testFalsePositivesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.false_positives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [2.], [2.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTruePositives(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_positives(labels, predictions) + + def _expected_fn(num_batches): + return [1., 2., 3., 3.][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testTruePositivesAtThresholds(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.true_positives_at_thresholds(labels, predictions, [.5]) + + def _expected_fn(num_batches): + return [[1.], [2.], [3.], [3.]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testPrecision(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.precision(labels, predictions) + + def _expected_fn(num_batches): + return [0.5, 0.5, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testPrecisionAtThreshold(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.precision_at_thresholds(labels, predictions, [0.5]) + + def _expected_fn(num_batches): + return [[0.5], [0.5], [0.6], [0.5]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRecall(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.recall(labels, predictions) + + def _expected_fn(num_batches): + return [0.5, 2./3, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _boolean_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRecallAtThreshold(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.recall_at_thresholds(labels, predictions, [0.5]) + + def _expected_fn(num_batches): + return [[0.5], [2./3], [0.6], [0.5]][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testMeanSquaredError(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.mean_squared_error(labels, predictions) + + def _expected_fn(num_batches): + return [0., 1./32, 0.208333, 0.15625][num_batches - 1] + + self._test_metric( + distribution, _regression_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testRootMeanSquaredError(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.root_mean_squared_error(labels, predictions) + + def _expected_fn(num_batches): + return [0., 0.176777, 0.456435, 0.395285][num_batches - 1] + + self._test_metric( + distribution, _regression_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testSensitivityAtSpecificity(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.sensitivity_at_specificity(labels, predictions, 0.8) + + def _expected_fn(num_batches): + return [0.5, 2./3, 0.6, 0.5][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + @combinations.generate(all_combinations()) + def testSpecificityAtSensitivity(self, distribution): + def _metric_fn(x): + labels = x["labels"] + predictions = x["predictions"] + return metrics.specificity_at_sensitivity(labels, predictions, 0.95) + + def _expected_fn(num_batches): + return [0., 1./3, 0.5, 0.5][num_batches - 1] + + self._test_metric( + distribution, _threshold_dataset_fn, _metric_fn, _expected_fn) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index f714d1f..cb72248 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2461,6 +2461,7 @@ py_library( ":check_ops", ":confusion_matrix", ":control_flow_ops", + ":distribute", ":framework", ":framework_for_generated_wrappers", ":math_ops", diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 97cd22e..bf382a2 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1328,11 +1328,11 @@ class TensorFlowTestCase(googletest.TestCase): b, rtol=rtol, atol=atol, - msg="Mismatched value: a%s is different from b%s." % (path_str, - path_str)) + msg=("Mismatched value: a%s is different from b%s. %s" % + (path_str, path_str, msg))) except TypeError as e: - msg = "Error: a%s has %s, but b%s has %s" % (path_str, type(a), - path_str, type(b)) + msg = ("Error: a%s has %s, but b%s has %s. %s" % + (path_str, type(a), path_str, type(b), msg)) e.args = ((e.args[0] + " : " + msg,) + e.args[1:]) raise diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 47eea6e..244e28d 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -34,21 +34,54 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export def metric_variable(shape, dtype, validate_shape=True, name=None): - """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections.""" - - return variable_scope.variable( - lambda: array_ops.zeros(shape, dtype), - trainable=False, - collections=[ - ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES - ], - validate_shape=validate_shape, - name=name) + """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections. + + If running in a `DistributionStrategy` context, the variable will be + "tower local". This means: + + * The returned object will be a container with separate variables + per replica/tower of the model. + + * When writing to the variable, e.g. using `assign_add` in a metric + update, the update will be applied to the variable local to the + replica/tower. + + * To get a metric's result value, we need to sum the variable values + across the replicas/towers before computing the final answer. + Furthermore, the final answer should be computed once instead of + in every replica/tower. Both of these are accomplished by + running the computation of the final result value inside + `tf.contrib.distribute.get_tower_context().merge_call(fn)`. + Inside the `merge_call()`, ops are only added to the graph once + and access to a tower-local variable in a computation returns + the sum across all replicas/towers. + + Args: + shape: Shape of the created variable. + dtype: Type of the created variable. + validate_shape: (Optional) Whether shape validation is enabled for + the created variable. + name: (Optional) String name of the created variable. + + Returns: + A (non-trainable) variable initialized to zero, or if inside a + `DistributionStrategy` scope a tower-local variable container. + """ + with distribute_lib.get_tower_context().tower_local_var_scope('sum'): + # Note that "tower local" implies trainable=False. + return variable_scope.variable( + lambda: array_ops.zeros(shape, dtype), + collections=[ + ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES + ], + validate_shape=validate_shape, + name=name) def _remove_squeezable_dimensions(predictions, labels, weights): @@ -333,11 +366,15 @@ def mean(values, with ops.control_dependencies([values]): update_count_op = state_ops.assign_add(count, num_values) - mean_t = _safe_div(total, count, 'value') - update_op = _safe_div(update_total_op, update_count_op, 'update_op') + def aggregate_across_towers(_, t, c): + mean_t = _safe_div(t, c, 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, mean_t) + return mean_t - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_t) + mean_t = distribute_lib.get_tower_context().merge_call( + aggregate_across_towers, total, count) + update_op = _safe_div(update_total_op, update_count_op, 'update_op') if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -572,6 +609,17 @@ def _confusion_matrix_at_thresholds(labels, return values, update_ops +def _aggregate_variable(v, collections): + + def f(distribution, value): + value = distribution.fetch(value) + if collections: + ops.add_to_collections(collections, value) + return value + + return distribute_lib.get_tower_context().merge_call(f, v) + + @tf_export('metrics.auc') def auc(labels, predictions, @@ -757,14 +805,18 @@ def auc(labels, raise ValueError('Invalid summation_method: %s' % summation_method) # sum up the areas of all the trapeziums - auc_value = compute_auc(values['tp'], values['fn'], values['tn'], - values['fp'], 'value') + def aggregate_auc(_, values): + auc_value = compute_auc(values['tp'], values['fn'], values['tn'], + values['fp'], 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, auc_value) + return auc_value + + auc_value = distribute_lib.get_tower_context().merge_call( + aggregate_auc, values) update_op = compute_auc(update_ops['tp'], update_ops['fn'], update_ops['tn'], update_ops['fp'], 'update_op') - if metrics_collections: - ops.add_to_collections(metrics_collections, auc_value) - if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -992,15 +1044,18 @@ def mean_per_class_accuracy(labels, update_total_op = state_ops.scatter_add(total, labels, ones) update_count_op = state_ops.scatter_add(count, labels, is_correct) - per_class_accuracy = _safe_div(count, total, None) + def aggregate_mean_accuracy(_, count, total): + per_class_accuracy = _safe_div(count, total, None) + mean_accuracy_v = math_ops.reduce_mean( + per_class_accuracy, name='mean_accuracy') + if metrics_collections: + ops.add_to_collections(metrics_collections, mean_accuracy_v) + return mean_accuracy_v - mean_accuracy_v = math_ops.reduce_mean( - per_class_accuracy, name='mean_accuracy') - update_op = _safe_div(update_count_op, update_total_op, name='update_op') - - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_accuracy_v) + mean_accuracy_v = distribute_lib.get_tower_context().merge_call( + aggregate_mean_accuracy, count, total) + update_op = _safe_div(update_count_op, update_total_op, name='update_op') if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -1071,7 +1126,7 @@ def mean_iou(labels, total_cm, update_op = _streaming_confusion_matrix(labels, predictions, num_classes, weights) - def compute_mean_iou(name): + def compute_mean_iou(total_cm, name): """Compute the mean intersection-over-union via the confusion matrix.""" sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0)) sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1)) @@ -1098,10 +1153,14 @@ def mean_iou(labels, math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0) return result - mean_iou_v = compute_mean_iou('mean_iou') + def mean_iou_across_towers(_, v): + mean_iou_v = compute_mean_iou(v, 'mean_iou') + if metrics_collections: + ops.add_to_collections(metrics_collections, mean_iou_v) + return mean_iou_v - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_iou_v) + mean_iou_v = distribute_lib.get_tower_context().merge_call( + mean_iou_across_towers, total_cm) if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -1310,12 +1369,16 @@ def mean_tensor(values, with ops.control_dependencies([values]): update_count_op = state_ops.assign_add(count, num_values) - mean_t = _safe_div(total, count, 'value') - update_op = _safe_div(update_total_op, update_count_op, 'update_op') + def aggregate_across_towers(_, t, c): + mean_t = _safe_div(t, c, 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, mean_t) + return mean_t - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_t) + mean_t = distribute_lib.get_tower_context().merge_call( + aggregate_across_towers, total, count) + update_op = _safe_div(update_total_op, update_count_op, 'update_op') if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -1413,12 +1476,9 @@ def _count_condition(values, weights = math_ops.to_float(weights) values = math_ops.multiply(values, weights) - value_tensor = array_ops.identity(count) - update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) - - if metrics_collections: - ops.add_to_collections(metrics_collections, value_tensor) + value_tensor = _aggregate_variable(count, metrics_collections) + update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -1525,13 +1585,12 @@ def false_negatives_at_thresholds(labels, values, update_ops = _confusion_matrix_at_thresholds( labels, predictions, thresholds, weights=weights, includes=('fn',)) - if metrics_collections: - ops.add_to_collections(metrics_collections, values['fn']) + fn_value = _aggregate_variable(values['fn'], metrics_collections) if updates_collections: ops.add_to_collections(updates_collections, update_ops['fn']) - return values['fn'], update_ops['fn'] + return fn_value, update_ops['fn'] @tf_export('metrics.false_positives') @@ -1635,13 +1694,12 @@ def false_positives_at_thresholds(labels, values, update_ops = _confusion_matrix_at_thresholds( labels, predictions, thresholds, weights=weights, includes=('fp',)) - if metrics_collections: - ops.add_to_collections(metrics_collections, values['fp']) + fp_value = _aggregate_variable(values['fp'], metrics_collections) if updates_collections: ops.add_to_collections(updates_collections, update_ops['fp']) - return values['fp'], update_ops['fp'] + return fp_value, update_ops['fp'] @tf_export('metrics.true_negatives') @@ -1745,13 +1803,12 @@ def true_negatives_at_thresholds(labels, values, update_ops = _confusion_matrix_at_thresholds( labels, predictions, thresholds, weights=weights, includes=('tn',)) - if metrics_collections: - ops.add_to_collections(metrics_collections, values['tn']) + tn_value = _aggregate_variable(values['tn'], metrics_collections) if updates_collections: ops.add_to_collections(updates_collections, update_ops['tn']) - return values['tn'], update_ops['tn'] + return tn_value, update_ops['tn'] @tf_export('metrics.true_positives') @@ -1855,13 +1912,12 @@ def true_positives_at_thresholds(labels, values, update_ops = _confusion_matrix_at_thresholds( labels, predictions, thresholds, weights=weights, includes=('tp',)) - if metrics_collections: - ops.add_to_collections(metrics_collections, values['tp']) + tp_value = _aggregate_variable(values['tp'], metrics_collections) if updates_collections: ops.add_to_collections(updates_collections, update_ops['tp']) - return values['tp'], update_ops['tp'] + return tp_value, update_ops['tp'] @tf_export('metrics.precision') @@ -1945,13 +2001,17 @@ def precision(labels, return array_ops.where( math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name) - p = compute_precision(true_p, false_p, 'value') - update_op = compute_precision(true_positives_update_op, - false_positives_update_op, 'update_op') + def once_across_towers(_, true_p, false_p): + p = compute_precision(true_p, false_p, 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, p) + return p - if metrics_collections: - ops.add_to_collections(metrics_collections, p) + p = distribute_lib.get_tower_context().merge_call( + once_across_towers, true_p, false_p) + update_op = compute_precision(true_positives_update_op, + false_positives_update_op, 'update_op') if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -2025,13 +2085,17 @@ def precision_at_thresholds(labels, def compute_precision(tp, fp, name): return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name) - prec = compute_precision(values['tp'], values['fp'], 'value') - update_op = compute_precision(update_ops['tp'], update_ops['fp'], - 'update_op') + def precision_across_towers(_, values): + prec = compute_precision(values['tp'], values['fp'], 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, prec) + return prec - if metrics_collections: - ops.add_to_collections(metrics_collections, prec) + prec = distribute_lib.get_tower_context().merge_call( + precision_across_towers, values) + update_op = compute_precision(update_ops['tp'], update_ops['fp'], + 'update_op') if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -2050,7 +2114,7 @@ def recall(labels, The `recall` function creates two local variables, `true_positives` and `false_negatives`, that are used to compute the recall. This value is ultimately returned as `recall`, an idempotent operation that simply divides - `true_positives` by the sum of `true_positives` and `false_negatives`. + `true_positives` by the sum of `true_positives` and `false_negatives`. For estimation of the metric over a stream of data, the function creates an `update_op` that updates these variables and returns the `recall`. `update_op` @@ -2117,13 +2181,17 @@ def recall(labels, math_ops.greater(true_p + false_n, 0), math_ops.div(true_p, true_p + false_n), 0, name) - rec = compute_recall(true_p, false_n, 'value') - update_op = compute_recall(true_positives_update_op, - false_negatives_update_op, 'update_op') + def once_across_towers(_, true_p, false_n): + rec = compute_recall(true_p, false_n, 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, rec) + return rec - if metrics_collections: - ops.add_to_collections(metrics_collections, rec) + rec = distribute_lib.get_tower_context().merge_call( + once_across_towers, true_p, false_n) + update_op = compute_recall(true_positives_update_op, + false_negatives_update_op, 'update_op') if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -2552,11 +2620,17 @@ def recall_at_top_k(labels, class_id=class_id, weights=weights) - metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope) + def aggregate_across_towers(_, tp, fn): + metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope) + if metrics_collections: + ops.add_to_collections(metrics_collections, metric) + return metric + + metric = distribute_lib.get_tower_context().merge_call( + aggregate_across_towers, tp, fn) + update = math_ops.div( tp_update, math_ops.add(tp_update, fn_update), name='update') - if metrics_collections: - ops.add_to_collections(metrics_collections, metric) if updates_collections: ops.add_to_collections(updates_collections, update) return metric, update @@ -2627,12 +2701,16 @@ def recall_at_thresholds(labels, def compute_recall(tp, fn, name): return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name) - rec = compute_recall(values['tp'], values['fn'], 'value') - update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op') + def recall_across_towers(_, values): + rec = compute_recall(values['tp'], values['fn'], 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, rec) + return rec - if metrics_collections: - ops.add_to_collections(metrics_collections, rec) + rec = distribute_lib.get_tower_context().merge_call( + recall_across_towers, values) + update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op') if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -2698,13 +2776,16 @@ def root_mean_squared_error(labels, mse, update_mse_op = mean_squared_error(labels, predictions, weights, None, None, name or 'root_mean_squared_error') + def once_across_towers(_, mse): + rmse = math_ops.sqrt(mse) + if metrics_collections: + ops.add_to_collections(metrics_collections, rmse) + return rmse - rmse = math_ops.sqrt(mse) - update_rmse_op = math_ops.sqrt(update_mse_op) - - if metrics_collections: - ops.add_to_collections(metrics_collections, rmse) + rmse = distribute_lib.get_tower_context().merge_call( + once_across_towers, mse) + update_rmse_op = math_ops.sqrt(update_mse_op) if updates_collections: ops.add_to_collections(updates_collections, update_rmse_op) @@ -2797,15 +2878,19 @@ def sensitivity_at_specificity(labels, return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon, name) - sensitivity = compute_sensitivity_at_specificity( - values['tp'], values['tn'], values['fp'], values['fn'], 'value') + def aggregate_across_towers(_, values): + sensitivity = compute_sensitivity_at_specificity( + values['tp'], values['tn'], values['fp'], values['fn'], 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, sensitivity) + return sensitivity + + sensitivity = distribute_lib.get_tower_context().merge_call( + aggregate_across_towers, values) + update_op = compute_sensitivity_at_specificity( update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 'update_op') - - if metrics_collections: - ops.add_to_collections(metrics_collections, sensitivity) - if updates_collections: ops.add_to_collections(updates_collections, update_op) @@ -3070,11 +3155,16 @@ def _streaming_sparse_average_precision_at_top_k(labels, total_update = state_ops.assign_add(total_var, batch_total, name='update') # Divide total by max to get mean, for both vars and the update ops. - mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean') - update = _safe_scalar_div(total_update, max_update, name=scope) + def aggregate_across_towers(_, total_var, max_var): + mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean') + if metrics_collections: + ops.add_to_collections(metrics_collections, mean_average_precision) + return mean_average_precision - if metrics_collections: - ops.add_to_collections(metrics_collections, mean_average_precision) + mean_average_precision = distribute_lib.get_tower_context().merge_call( + aggregate_across_towers, total_var, max_var) + + update = _safe_scalar_div(total_update, max_update, name=scope) if updates_collections: ops.add_to_collections(updates_collections, update) @@ -3351,11 +3441,17 @@ def precision_at_top_k(labels, class_id=class_id, weights=weights) - metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope) + def aggregate_across_towers(_, tp, fp): + metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope) + if metrics_collections: + ops.add_to_collections(metrics_collections, metric) + return metric + + metric = distribute_lib.get_tower_context().merge_call( + aggregate_across_towers, tp, fp) + update = math_ops.div( tp_update, math_ops.add(tp_update, fp_update), name='update') - if metrics_collections: - ops.add_to_collections(metrics_collections, metric) if updates_collections: ops.add_to_collections(updates_collections, update) return metric, update @@ -3583,15 +3679,19 @@ def specificity_at_sensitivity(labels, return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon, name) - specificity = compute_specificity_at_sensitivity( - values['tp'], values['tn'], values['fp'], values['fn'], 'value') + def aggregate_across_towers(_, values): + specificity = compute_specificity_at_sensitivity( + values['tp'], values['tn'], values['fp'], values['fn'], 'value') + if metrics_collections: + ops.add_to_collections(metrics_collections, specificity) + return specificity + + specificity = distribute_lib.get_tower_context().merge_call( + aggregate_across_towers, values) + update_op = compute_specificity_at_sensitivity( update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 'update_op') - - if metrics_collections: - ops.add_to_collections(metrics_collections, specificity) - if updates_collections: ops.add_to_collections(updates_collections, update_op) -- 2.7.4