self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean), 5)
- def testWeighted1d(self):
+ def testBinaryWeighted1d(self):
with self.test_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
sess.run(update_op)
self.assertAllClose([[3.25, 0.5]], sess.run(mean), 5)
+ def testWeighted1d(self):
+ with self.test_session() as sess:
+ # Create the queue that populates the values.
+ values_queue = data_flow_ops.FIFOQueue(
+ 4, dtypes=dtypes_lib.float32, shapes=(1, 2))
+ _enqueue_vector(sess, values_queue, [0, 1])
+ _enqueue_vector(sess, values_queue, [-4.2, 9.1])
+ _enqueue_vector(sess, values_queue, [6.5, 0])
+ _enqueue_vector(sess, values_queue, [-3.2, 4.0])
+ values = values_queue.dequeue()
+
+ # Create the queue that populates the weights.
+ weights_queue = data_flow_ops.FIFOQueue(
+ 4, dtypes=dtypes_lib.float32, shapes=(1, 1))
+ _enqueue_vector(sess, weights_queue, [[0.0025]])
+ _enqueue_vector(sess, weights_queue, [[0.005]])
+ _enqueue_vector(sess, weights_queue, [[0.01]])
+ _enqueue_vector(sess, weights_queue, [[0.0075]])
+ weights = weights_queue.dequeue()
+
+ mean, update_op = metrics.mean_tensor(values, weights)
+
+ sess.run(variables.local_variables_initializer())
+ for _ in range(4):
+ sess.run(update_op)
+ self.assertAllClose([[0.8, 3.52]], sess.run(mean), 5)
+
def testWeighted2d_1(self):
with self.test_session() as sess:
# Create the queue that populates the values.
with ops.control_dependencies([values]):
update_count_op = state_ops.assign_add(count, num_values)
- def compute_mean(total, count, name):
- non_zero_count = math_ops.maximum(
- count, array_ops.ones_like(count), name=name)
- return math_ops.truediv(total, non_zero_count, name=name)
-
- mean_t = compute_mean(total, count, 'value')
- update_op = compute_mean(update_total_op, update_count_op, 'update_op')
+ mean_t = _safe_div(total, count, 'value')
+ update_op = _safe_div(update_total_op, update_count_op, 'update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, mean_t)