--- /dev/null
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for V1 metrics."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import test
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics
+from tensorflow.python.ops import variables
+
+
+def _labeled_dataset_fn():
+ # First four batches of x: labels, predictions -> (labels == predictions)
+ # 0: 0, 0 -> True; 1: 1, 1 -> True; 2: 2, 2 -> True; 3: 3, 0 -> False
+ # 4: 4, 1 -> False; 5: 0, 2 -> False; 6: 1, 0 -> False; 7: 2, 1 -> False
+ # 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False
+ # 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True
+ return dataset_ops.Dataset.range(1000).map(
+ lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4)
+
+
+def _boolean_dataset_fn():
+ # First four batches of labels, predictions: {TP, FP, TN, FN}
+ # with a threshold of 0.5:
+ # T, T -> TP; F, T -> FP; T, F -> FN
+ # F, F -> TN; T, T -> TP; F, T -> FP
+ # T, F -> FN; F, F -> TN; T, T -> TP
+ # F, T -> FP; T, F -> FN; F, F -> TN
+ return dataset_ops.Dataset.from_tensor_slices({
+ "labels": [True, False, True, False],
+ "predictions": [True, True, False, False]}).repeat().batch(3)
+
+
+def _threshold_dataset_fn():
+ # First four batches of labels, predictions: {TP, FP, TN, FN}
+ # with a threshold of 0.5:
+ # True, 1.0 -> TP; False, .75 -> FP; True, .25 -> FN
+ # False, 0.0 -> TN; True, 1.0 -> TP; False, .75 -> FP
+ # True, .25 -> FN; False, 0.0 -> TN; True, 1.0 -> TP
+ # False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN
+ return dataset_ops.Dataset.from_tensor_slices({
+ "labels": [True, False, True, False],
+ "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3)
+
+
+def _regression_dataset_fn():
+ return dataset_ops.Dataset.from_tensor_slices({
+ "labels": [1., .5, 1., 0.],
+ "predictions": [1., .75, .25, 0.]}).repeat()
+
+
+def all_combinations():
+ return combinations.combine(
+ distribution=[combinations.default_strategy,
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus],
+ mode=["graph"])
+
+
+# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k,
+# metrics.precision_at_k
+class MetricsV1Test(test.TestCase, parameterized.TestCase):
+
+ def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
+ with ops.Graph().as_default(), distribution.scope():
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
+ value, update = distribution.call_for_each_tower(
+ metric_fn, iterator.get_next())
+ update = distribution.group(update)
+ self.evaluate(variables.local_variables_initializer())
+ # TODO(josh11b): Once we switch to using a global batch size for input,
+ # replace "distribution.num_towers" with "1".
+ batches_per_update = distribution.num_towers
+
+ # Update variables using the first `num_towers` batches.
+ self.evaluate(update)
+ self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value),
+ 0.001, msg="After first update")
+
+ # Update variables using the second `num_towers` batches.
+ self.evaluate(update)
+ self.assertAllClose(expected_fn(2 * batches_per_update),
+ self.evaluate(value),
+ 0.001,
+ msg="After second update")
+
+ if batches_per_update == 1: # Consume 4 input batches
+ self.evaluate(update)
+ self.assertAllClose(expected_fn(3 * batches_per_update),
+ self.evaluate(value),
+ 0.001,
+ msg="After third update")
+ self.evaluate(update)
+ self.assertAllClose(expected_fn(4 * batches_per_update),
+ self.evaluate(value),
+ 0.001,
+ msg="After fourth update")
+
+ @combinations.generate(all_combinations())
+ def testMean(self, distribution):
+ def _dataset_fn():
+ return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4)
+
+ def _expected_fn(num_batches):
+ # Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc.
+ return num_batches * 2 - 0.5
+
+ self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testAccuracy(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.accuracy(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [3./4, 3./8, 3./12, 4./16][num_batches - 1]
+
+ self._test_metric(
+ distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testMeanPerClassAccuracy(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.mean_per_class_accuracy(
+ labels, predictions, num_classes=5)
+
+ def _expected_fn(num_batches):
+ mean = lambda x: sum(x) / len(x)
+ return [mean([1., 1., 1., 0., 0.]),
+ mean([0.5, 0.5, 0.5, 0., 0.]),
+ mean([1./3, 1./3, 0.5, 0., 0.]),
+ mean([0.5, 1./3, 1./3, 0., 0.])][num_batches - 1]
+
+ self._test_metric(
+ distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testMeanIOU(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.mean_iou(
+ labels, predictions, num_classes=5)
+
+ def _expected_fn(num_batches):
+ mean = lambda x: sum(x) / len(x)
+ return [mean([1./2, 1./1, 1./1, 0.]), # no class 4 in first batch
+ mean([1./4, 1./4, 1./3, 0., 0.]),
+ mean([1./6, 1./6, 1./5, 0., 0.]),
+ mean([2./8, 1./7, 1./7, 0., 0.])][num_batches - 1]
+
+ self._test_metric(
+ distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testMeanTensor(self, distribution):
+ def _dataset_fn():
+ dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float)
+ # Want to produce a fixed, known shape, so drop remainder when batching.
+ dataset = dataset.apply(batching.batch_and_drop_remainder(4))
+ return dataset
+
+ def _expected_fn(num_batches):
+ # Mean(0, 4, ..., 4 * num_batches - 4) == 2 * num_batches - 2
+ # Mean(1, 5, ..., 4 * num_batches - 3) == 2 * num_batches - 1
+ # Mean(2, 6, ..., 4 * num_batches - 2) == 2 * num_batches
+ # Mean(3, 7, ..., 4 * num_batches - 1) == 2 * num_batches + 1
+ first = 2. * num_batches - 2.
+ return [first, first + 1., first + 2., first + 3.]
+
+ self._test_metric(
+ distribution, _dataset_fn, metrics.mean_tensor, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testAUCROC(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.auc(labels, predictions, num_thresholds=8, curve="ROC",
+ summation_method="careful_interpolation")
+
+ def _expected_fn(num_batches):
+ return [0.5, 7./9, 0.8, 0.75][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testAUCPR(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.auc(labels, predictions, num_thresholds=8, curve="PR",
+ summation_method="careful_interpolation")
+
+ def _expected_fn(num_batches):
+ return [0.797267, 0.851238, 0.865411, 0.797267][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testFalseNegatives(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.false_negatives(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [1., 1., 2., 3.][num_batches - 1]
+
+ self._test_metric(
+ distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testFalseNegativesAtThresholds(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.false_negatives_at_thresholds(labels, predictions, [.5])
+
+ def _expected_fn(num_batches):
+ return [[1.], [1.], [2.], [3.]][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testTrueNegatives(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.true_negatives(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [0., 1., 2., 3.][num_batches - 1]
+
+ self._test_metric(
+ distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testTrueNegativesAtThresholds(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.true_negatives_at_thresholds(labels, predictions, [.5])
+
+ def _expected_fn(num_batches):
+ return [[0.], [1.], [2.], [3.]][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testFalsePositives(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.false_positives(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [1., 2., 2., 3.][num_batches - 1]
+
+ self._test_metric(
+ distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testFalsePositivesAtThresholds(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.false_positives_at_thresholds(labels, predictions, [.5])
+
+ def _expected_fn(num_batches):
+ return [[1.], [2.], [2.], [3.]][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testTruePositives(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.true_positives(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [1., 2., 3., 3.][num_batches - 1]
+
+ self._test_metric(
+ distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testTruePositivesAtThresholds(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.true_positives_at_thresholds(labels, predictions, [.5])
+
+ def _expected_fn(num_batches):
+ return [[1.], [2.], [3.], [3.]][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testPrecision(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.precision(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [0.5, 0.5, 0.6, 0.5][num_batches - 1]
+
+ self._test_metric(
+ distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testPrecisionAtThreshold(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.precision_at_thresholds(labels, predictions, [0.5])
+
+ def _expected_fn(num_batches):
+ return [[0.5], [0.5], [0.6], [0.5]][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testRecall(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.recall(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [0.5, 2./3, 0.6, 0.5][num_batches - 1]
+
+ self._test_metric(
+ distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testRecallAtThreshold(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.recall_at_thresholds(labels, predictions, [0.5])
+
+ def _expected_fn(num_batches):
+ return [[0.5], [2./3], [0.6], [0.5]][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testMeanSquaredError(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.mean_squared_error(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [0., 1./32, 0.208333, 0.15625][num_batches - 1]
+
+ self._test_metric(
+ distribution, _regression_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testRootMeanSquaredError(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.root_mean_squared_error(labels, predictions)
+
+ def _expected_fn(num_batches):
+ return [0., 0.176777, 0.456435, 0.395285][num_batches - 1]
+
+ self._test_metric(
+ distribution, _regression_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testSensitivityAtSpecificity(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.sensitivity_at_specificity(labels, predictions, 0.8)
+
+ def _expected_fn(num_batches):
+ return [0.5, 2./3, 0.6, 0.5][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+ @combinations.generate(all_combinations())
+ def testSpecificityAtSensitivity(self, distribution):
+ def _metric_fn(x):
+ labels = x["labels"]
+ predictions = x["predictions"]
+ return metrics.specificity_at_sensitivity(labels, predictions, 0.95)
+
+ def _expected_fn(num_batches):
+ return [0., 1./3, 0.5, 0.5][num_batches - 1]
+
+ self._test_metric(
+ distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
+
+
+if __name__ == "__main__":
+ test.main()
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
def metric_variable(shape, dtype, validate_shape=True, name=None):
- """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections."""
-
- return variable_scope.variable(
- lambda: array_ops.zeros(shape, dtype),
- trainable=False,
- collections=[
- ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
- ],
- validate_shape=validate_shape,
- name=name)
+ """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections.
+
+ If running in a `DistributionStrategy` context, the variable will be
+ "tower local". This means:
+
+ * The returned object will be a container with separate variables
+ per replica/tower of the model.
+
+ * When writing to the variable, e.g. using `assign_add` in a metric
+ update, the update will be applied to the variable local to the
+ replica/tower.
+
+ * To get a metric's result value, we need to sum the variable values
+ across the replicas/towers before computing the final answer.
+ Furthermore, the final answer should be computed once instead of
+ in every replica/tower. Both of these are accomplished by
+ running the computation of the final result value inside
+ `tf.contrib.distribute.get_tower_context().merge_call(fn)`.
+ Inside the `merge_call()`, ops are only added to the graph once
+ and access to a tower-local variable in a computation returns
+ the sum across all replicas/towers.
+
+ Args:
+ shape: Shape of the created variable.
+ dtype: Type of the created variable.
+ validate_shape: (Optional) Whether shape validation is enabled for
+ the created variable.
+ name: (Optional) String name of the created variable.
+
+ Returns:
+ A (non-trainable) variable initialized to zero, or if inside a
+ `DistributionStrategy` scope a tower-local variable container.
+ """
+ with distribute_lib.get_tower_context().tower_local_var_scope('sum'):
+ # Note that "tower local" implies trainable=False.
+ return variable_scope.variable(
+ lambda: array_ops.zeros(shape, dtype),
+ collections=[
+ ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
+ ],
+ validate_shape=validate_shape,
+ name=name)
def _remove_squeezable_dimensions(predictions, labels, weights):
with ops.control_dependencies([values]):
update_count_op = state_ops.assign_add(count, num_values)
- mean_t = _safe_div(total, count, 'value')
- update_op = _safe_div(update_total_op, update_count_op, 'update_op')
+ def aggregate_across_towers(_, t, c):
+ mean_t = _safe_div(t, c, 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, mean_t)
+ return mean_t
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_t)
+ mean_t = distribute_lib.get_tower_context().merge_call(
+ aggregate_across_towers, total, count)
+ update_op = _safe_div(update_total_op, update_count_op, 'update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
return values, update_ops
+def _aggregate_variable(v, collections):
+
+ def f(distribution, value):
+ value = distribution.fetch(value)
+ if collections:
+ ops.add_to_collections(collections, value)
+ return value
+
+ return distribute_lib.get_tower_context().merge_call(f, v)
+
+
@tf_export('metrics.auc')
def auc(labels,
predictions,
raise ValueError('Invalid summation_method: %s' % summation_method)
# sum up the areas of all the trapeziums
- auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
- values['fp'], 'value')
+ def aggregate_auc(_, values):
+ auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
+ values['fp'], 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, auc_value)
+ return auc_value
+
+ auc_value = distribute_lib.get_tower_context().merge_call(
+ aggregate_auc, values)
update_op = compute_auc(update_ops['tp'], update_ops['fn'],
update_ops['tn'], update_ops['fp'], 'update_op')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, auc_value)
-
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
update_total_op = state_ops.scatter_add(total, labels, ones)
update_count_op = state_ops.scatter_add(count, labels, is_correct)
- per_class_accuracy = _safe_div(count, total, None)
+ def aggregate_mean_accuracy(_, count, total):
+ per_class_accuracy = _safe_div(count, total, None)
+ mean_accuracy_v = math_ops.reduce_mean(
+ per_class_accuracy, name='mean_accuracy')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, mean_accuracy_v)
+ return mean_accuracy_v
- mean_accuracy_v = math_ops.reduce_mean(
- per_class_accuracy, name='mean_accuracy')
- update_op = _safe_div(update_count_op, update_total_op, name='update_op')
-
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_accuracy_v)
+ mean_accuracy_v = distribute_lib.get_tower_context().merge_call(
+ aggregate_mean_accuracy, count, total)
+ update_op = _safe_div(update_count_op, update_total_op, name='update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
num_classes, weights)
- def compute_mean_iou(name):
+ def compute_mean_iou(total_cm, name):
"""Compute the mean intersection-over-union via the confusion matrix."""
sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0)
return result
- mean_iou_v = compute_mean_iou('mean_iou')
+ def mean_iou_across_towers(_, v):
+ mean_iou_v = compute_mean_iou(v, 'mean_iou')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, mean_iou_v)
+ return mean_iou_v
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_iou_v)
+ mean_iou_v = distribute_lib.get_tower_context().merge_call(
+ mean_iou_across_towers, total_cm)
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
with ops.control_dependencies([values]):
update_count_op = state_ops.assign_add(count, num_values)
- mean_t = _safe_div(total, count, 'value')
- update_op = _safe_div(update_total_op, update_count_op, 'update_op')
+ def aggregate_across_towers(_, t, c):
+ mean_t = _safe_div(t, c, 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, mean_t)
+ return mean_t
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_t)
+ mean_t = distribute_lib.get_tower_context().merge_call(
+ aggregate_across_towers, total, count)
+ update_op = _safe_div(update_total_op, update_count_op, 'update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
weights = math_ops.to_float(weights)
values = math_ops.multiply(values, weights)
- value_tensor = array_ops.identity(count)
- update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
-
- if metrics_collections:
- ops.add_to_collections(metrics_collections, value_tensor)
+ value_tensor = _aggregate_variable(count, metrics_collections)
+ update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
values, update_ops = _confusion_matrix_at_thresholds(
labels, predictions, thresholds, weights=weights, includes=('fn',))
- if metrics_collections:
- ops.add_to_collections(metrics_collections, values['fn'])
+ fn_value = _aggregate_variable(values['fn'], metrics_collections)
if updates_collections:
ops.add_to_collections(updates_collections, update_ops['fn'])
- return values['fn'], update_ops['fn']
+ return fn_value, update_ops['fn']
@tf_export('metrics.false_positives')
values, update_ops = _confusion_matrix_at_thresholds(
labels, predictions, thresholds, weights=weights, includes=('fp',))
- if metrics_collections:
- ops.add_to_collections(metrics_collections, values['fp'])
+ fp_value = _aggregate_variable(values['fp'], metrics_collections)
if updates_collections:
ops.add_to_collections(updates_collections, update_ops['fp'])
- return values['fp'], update_ops['fp']
+ return fp_value, update_ops['fp']
@tf_export('metrics.true_negatives')
values, update_ops = _confusion_matrix_at_thresholds(
labels, predictions, thresholds, weights=weights, includes=('tn',))
- if metrics_collections:
- ops.add_to_collections(metrics_collections, values['tn'])
+ tn_value = _aggregate_variable(values['tn'], metrics_collections)
if updates_collections:
ops.add_to_collections(updates_collections, update_ops['tn'])
- return values['tn'], update_ops['tn']
+ return tn_value, update_ops['tn']
@tf_export('metrics.true_positives')
values, update_ops = _confusion_matrix_at_thresholds(
labels, predictions, thresholds, weights=weights, includes=('tp',))
- if metrics_collections:
- ops.add_to_collections(metrics_collections, values['tp'])
+ tp_value = _aggregate_variable(values['tp'], metrics_collections)
if updates_collections:
ops.add_to_collections(updates_collections, update_ops['tp'])
- return values['tp'], update_ops['tp']
+ return tp_value, update_ops['tp']
@tf_export('metrics.precision')
return array_ops.where(
math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
- p = compute_precision(true_p, false_p, 'value')
- update_op = compute_precision(true_positives_update_op,
- false_positives_update_op, 'update_op')
+ def once_across_towers(_, true_p, false_p):
+ p = compute_precision(true_p, false_p, 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, p)
+ return p
- if metrics_collections:
- ops.add_to_collections(metrics_collections, p)
+ p = distribute_lib.get_tower_context().merge_call(
+ once_across_towers, true_p, false_p)
+ update_op = compute_precision(true_positives_update_op,
+ false_positives_update_op, 'update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
def compute_precision(tp, fp, name):
return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
- prec = compute_precision(values['tp'], values['fp'], 'value')
- update_op = compute_precision(update_ops['tp'], update_ops['fp'],
- 'update_op')
+ def precision_across_towers(_, values):
+ prec = compute_precision(values['tp'], values['fp'], 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, prec)
+ return prec
- if metrics_collections:
- ops.add_to_collections(metrics_collections, prec)
+ prec = distribute_lib.get_tower_context().merge_call(
+ precision_across_towers, values)
+ update_op = compute_precision(update_ops['tp'], update_ops['fp'],
+ 'update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
The `recall` function creates two local variables, `true_positives`
and `false_negatives`, that are used to compute the recall. This value is
ultimately returned as `recall`, an idempotent operation that simply divides
- `true_positives` by the sum of `true_positives` and `false_negatives`.
+ `true_positives` by the sum of `true_positives` and `false_negatives`.
For estimation of the metric over a stream of data, the function creates an
`update_op` that updates these variables and returns the `recall`. `update_op`
math_ops.greater(true_p + false_n, 0),
math_ops.div(true_p, true_p + false_n), 0, name)
- rec = compute_recall(true_p, false_n, 'value')
- update_op = compute_recall(true_positives_update_op,
- false_negatives_update_op, 'update_op')
+ def once_across_towers(_, true_p, false_n):
+ rec = compute_recall(true_p, false_n, 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, rec)
+ return rec
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rec)
+ rec = distribute_lib.get_tower_context().merge_call(
+ once_across_towers, true_p, false_n)
+ update_op = compute_recall(true_positives_update_op,
+ false_negatives_update_op, 'update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
class_id=class_id,
weights=weights)
- metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
+ def aggregate_across_towers(_, tp, fn):
+ metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, metric)
+ return metric
+
+ metric = distribute_lib.get_tower_context().merge_call(
+ aggregate_across_towers, tp, fn)
+
update = math_ops.div(
tp_update, math_ops.add(tp_update, fn_update), name='update')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, metric)
if updates_collections:
ops.add_to_collections(updates_collections, update)
return metric, update
def compute_recall(tp, fn, name):
return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
- rec = compute_recall(values['tp'], values['fn'], 'value')
- update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
+ def recall_across_towers(_, values):
+ rec = compute_recall(values['tp'], values['fn'], 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, rec)
+ return rec
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rec)
+ rec = distribute_lib.get_tower_context().merge_call(
+ recall_across_towers, values)
+ update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
None, name or
'root_mean_squared_error')
+ def once_across_towers(_, mse):
+ rmse = math_ops.sqrt(mse)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, rmse)
+ return rmse
- rmse = math_ops.sqrt(mse)
- update_rmse_op = math_ops.sqrt(update_mse_op)
-
- if metrics_collections:
- ops.add_to_collections(metrics_collections, rmse)
+ rmse = distribute_lib.get_tower_context().merge_call(
+ once_across_towers, mse)
+ update_rmse_op = math_ops.sqrt(update_mse_op)
if updates_collections:
ops.add_to_collections(updates_collections, update_rmse_op)
return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
name)
- sensitivity = compute_sensitivity_at_specificity(
- values['tp'], values['tn'], values['fp'], values['fn'], 'value')
+ def aggregate_across_towers(_, values):
+ sensitivity = compute_sensitivity_at_specificity(
+ values['tp'], values['tn'], values['fp'], values['fn'], 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, sensitivity)
+ return sensitivity
+
+ sensitivity = distribute_lib.get_tower_context().merge_call(
+ aggregate_across_towers, values)
+
update_op = compute_sensitivity_at_specificity(
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
'update_op')
-
- if metrics_collections:
- ops.add_to_collections(metrics_collections, sensitivity)
-
if updates_collections:
ops.add_to_collections(updates_collections, update_op)
total_update = state_ops.assign_add(total_var, batch_total, name='update')
# Divide total by max to get mean, for both vars and the update ops.
- mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
- update = _safe_scalar_div(total_update, max_update, name=scope)
+ def aggregate_across_towers(_, total_var, max_var):
+ mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, mean_average_precision)
+ return mean_average_precision
- if metrics_collections:
- ops.add_to_collections(metrics_collections, mean_average_precision)
+ mean_average_precision = distribute_lib.get_tower_context().merge_call(
+ aggregate_across_towers, total_var, max_var)
+
+ update = _safe_scalar_div(total_update, max_update, name=scope)
if updates_collections:
ops.add_to_collections(updates_collections, update)
class_id=class_id,
weights=weights)
- metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
+ def aggregate_across_towers(_, tp, fp):
+ metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, metric)
+ return metric
+
+ metric = distribute_lib.get_tower_context().merge_call(
+ aggregate_across_towers, tp, fp)
+
update = math_ops.div(
tp_update, math_ops.add(tp_update, fp_update), name='update')
- if metrics_collections:
- ops.add_to_collections(metrics_collections, metric)
if updates_collections:
ops.add_to_collections(updates_collections, update)
return metric, update
return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
name)
- specificity = compute_specificity_at_sensitivity(
- values['tp'], values['tn'], values['fp'], values['fn'], 'value')
+ def aggregate_across_towers(_, values):
+ specificity = compute_specificity_at_sensitivity(
+ values['tp'], values['tn'], values['fp'], values['fn'], 'value')
+ if metrics_collections:
+ ops.add_to_collections(metrics_collections, specificity)
+ return specificity
+
+ specificity = distribute_lib.get_tower_context().merge_call(
+ aggregate_across_towers, values)
+
update_op = compute_specificity_at_sensitivity(
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
'update_op')
-
- if metrics_collections:
- ops.add_to_collections(metrics_collections, specificity)
-
if updates_collections:
ops.add_to_collections(updates_collections, update_op)