From 307cfe7ab7e2c475b2741fc2a2f7663b46223e6d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 26 Mar 2018 16:19:50 -0700 Subject: [PATCH] Save the last loss reduction method (for future use). PiperOrigin-RevId: 190543066 --- tensorflow/python/framework/ops.py | 3 +++ tensorflow/python/ops/losses/losses_impl.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index e579289..25a951a 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2788,6 +2788,9 @@ class Graph(object): # being called inside function definitions behave as if they were seeing the # actual outside graph). self._graph_key = "grap-key-%d/" % (uid(),) + # A string with the last reduction method passed to + # losses.compute_weighted_loss(), or None. + self._last_loss_reduction = None self._container = "" self._registered_ops = op_def_registry.get_registered_ops() diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index 0840760..34ca1ad 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -194,6 +194,11 @@ def compute_weighted_loss( """ Reduction.validate(reduction) with ops.name_scope(scope, "weighted_loss", (losses, weights)): + # Save the `reduction` argument for loss normalization when distributing + # to multiple towers. + # TODO(josh11b): Associate it with the returned op for more precision. + ops.get_default_graph()._last_loss_reduction = reduction # pylint: disable=protected-access + with ops.control_dependencies(( weights_broadcast_ops.assert_broadcastable(weights, losses),)): losses = ops.convert_to_tensor(losses) -- 2.7.4