r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x)
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
+ def testGradInWhileWrtInitialLoopVal(self):
+ with self.test_session():
+ x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
+ y = x + 1
+
+ def body(i, v):
+ z = v * 2
+ return i + 1, gradients_impl.gradients(z, x)[0]
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Cannot compute gradient inside while loop with respect to op 'x'. "
+ "We do not support taking the gradient wrt or through the initial "
+ "value of a loop variable. Gradients can be computed through "
+ "loop invariants or wrt the input parameters to the loop body."):
+ control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y])
+
def testWhileGradInWhile(self):
with self.test_session():
n = ops.convert_to_tensor(1.0, name="n")
return grad_fn()
+def _RaiseNoGradWrtInitialLoopValError(op, from_ops):
+ """Raises an error if we backprop through a loop var."""
+ # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
+ # message.
+ target_op = None
+ queue = collections.deque([op])
+ visited = set()
+ while queue:
+ curr_op = queue.popleft()
+ if curr_op in visited: continue
+ visited.add(curr_op)
+ if curr_op in from_ops:
+ target_op = curr_op
+ break
+ queue.extend(t.op for t in curr_op.inputs)
+ assert target_op
+ raise ValueError(
+ "Cannot compute gradient inside while loop with respect to op '%s'. "
+ "We do not support taking the gradient wrt or through the initial value "
+ "of a loop variable. Gradients can be computed through loop invariants "
+ "or wrt the input parameters to the loop body."
+ % target_op.name)
+
+
@tf_export("gradients")
def gradients(ys,
xs,
(op.name, op.type))
if loop_state:
loop_state.EnterGradWhileContext(op, before=False)
+
+ # NOTE(skyewm): We don't support computing gradients wrt a loop variable
+ # unless it's within the context of a single iteration (i.e. the
+ # gradient is wrt to the loop parameter in the body function, not wrt or
+ # through the initial value). This means if we're in a while loop
+ # context, we should never see a switch node from this context.
+ # pylint: disable=protected-access
+ if (control_flow_util.IsSwitch(op) and
+ op._control_flow_context is not None and
+ op._control_flow_context.IsWhileContext() and
+ op._control_flow_context ==
+ ops.get_default_graph()._get_control_flow_context()):
+ _RaiseNoGradWrtInitialLoopValError(op, from_ops)
+ # pylint: enable=protected-access
+
if (grad_fn or is_func_call) and has_out_grads:
# NOTE: If _AggregatedGrads didn't compute a value for the i'th
# output, it means that the cost does not depend on output[i],