From 4321469f1db7a6ff220c2415c63f433df6e7161d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 20 Mar 2018 17:00:33 -0700 Subject: [PATCH] Fixing bug in MultitaskOptimizerWrapper where types of tensors were mismatching. PiperOrigin-RevId: 189837743 --- .../contrib/opt/python/training/multitask_optimizer_wrapper.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper.py b/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper.py index cb6c77a..9076cc9 100644 --- a/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper.py +++ b/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper.py @@ -22,6 +22,7 @@ import types import six from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops @@ -40,8 +41,10 @@ def _get_wrapper(fn, opt): def wrapper(self, grad, *args, **kwargs): # pylint: disable=unused-argument all_zeros = _is_all_zeros(grad) - return control_flow_ops.cond(all_zeros, control_flow_ops.no_op, - lambda: fn(grad, *args, **kwargs)) + def call_fn(): + with ops.control_dependencies([fn(grad, *args, **kwargs)]): + return control_flow_ops.no_op() + return control_flow_ops.cond(all_zeros, control_flow_ops.no_op, call_fn) wrapper = types.MethodType(wrapper, opt) return wrapper -- 2.7.4