Fix a bug in tf.metrics.mean_tensor for case that the weights are very small.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 21 Feb 2018 20:42:51 +0000 (12:42 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Feb 2018 20:51:15 +0000 (12:51 -0800)
We have renamed metrics_test.MeanTensorTest.testWeighted1d as metrics_test.MeanTensorTest.testBinaryWeighted1d, since the weights on the instances are zeros and ones.

We have added a new metrics_test.MeanTensorTest.testWeighted1d that has small weights. It was failing for the previous implementation, but passes now.

Now the code for mean_tensor() and mean() now use the same _safe_div method. Previously, mean_tensor() used a different means to ensure that we don't divide by zero. This set the denominator to max(1., sum(weights)), which was inaccurate when sum(weights) is non-zero, but less than one.

PiperOrigin-RevId: 186503714

tensorflow/python/kernel_tests/metrics_test.py
tensorflow/python/ops/metrics_impl.py

index fd78c026c273da1ffecf9e1dfe8c9e6042a4be69..59e7afa2dcb1e02ed9c66e5cf75753f96552b4e0 100644 (file)
@@ -417,7 +417,7 @@ class MeanTensorTest(test.TestCase):
 
       self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean), 5)
 
-  def testWeighted1d(self):
+  def testBinaryWeighted1d(self):
     with self.test_session() as sess:
       # Create the queue that populates the values.
       values_queue = data_flow_ops.FIFOQueue(
@@ -444,6 +444,33 @@ class MeanTensorTest(test.TestCase):
         sess.run(update_op)
       self.assertAllClose([[3.25, 0.5]], sess.run(mean), 5)
 
+  def testWeighted1d(self):
+    with self.test_session() as sess:
+      # Create the queue that populates the values.
+      values_queue = data_flow_ops.FIFOQueue(
+          4, dtypes=dtypes_lib.float32, shapes=(1, 2))
+      _enqueue_vector(sess, values_queue, [0, 1])
+      _enqueue_vector(sess, values_queue, [-4.2, 9.1])
+      _enqueue_vector(sess, values_queue, [6.5, 0])
+      _enqueue_vector(sess, values_queue, [-3.2, 4.0])
+      values = values_queue.dequeue()
+
+      # Create the queue that populates the weights.
+      weights_queue = data_flow_ops.FIFOQueue(
+          4, dtypes=dtypes_lib.float32, shapes=(1, 1))
+      _enqueue_vector(sess, weights_queue, [[0.0025]])
+      _enqueue_vector(sess, weights_queue, [[0.005]])
+      _enqueue_vector(sess, weights_queue, [[0.01]])
+      _enqueue_vector(sess, weights_queue, [[0.0075]])
+      weights = weights_queue.dequeue()
+
+      mean, update_op = metrics.mean_tensor(values, weights)
+
+      sess.run(variables.local_variables_initializer())
+      for _ in range(4):
+        sess.run(update_op)
+      self.assertAllClose([[0.8, 3.52]], sess.run(mean), 5)
+
   def testWeighted2d_1(self):
     with self.test_session() as sess:
       # Create the queue that populates the values.
index 44c2f304cf9245539e42da2ce54260990de980e0..043c0e30cd8476b1a91e136df60edfbedf85ab24 100644 (file)
@@ -1247,13 +1247,8 @@ def mean_tensor(values,
     with ops.control_dependencies([values]):
       update_count_op = state_ops.assign_add(count, num_values)
 
-    def compute_mean(total, count, name):
-      non_zero_count = math_ops.maximum(
-          count, array_ops.ones_like(count), name=name)
-      return math_ops.truediv(total, non_zero_count, name=name)
-
-    mean_t = compute_mean(total, count, 'value')
-    update_op = compute_mean(update_total_op, update_count_op, 'update_op')
+    mean_t = _safe_div(total, count, 'value')
+    update_op = _safe_div(update_total_op, update_count_op, 'update_op')
 
     if metrics_collections:
       ops.add_to_collections(metrics_collections, mean_t)