Raise an error if we try to take the gradient wrt to the initial value of a loop...
authorSkye Wanderman-Milne <skyewm@google.com>
Tue, 8 May 2018 00:21:39 +0000 (17:21 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 8 May 2018 01:26:54 +0000 (18:26 -0700)
Fixes #14101

PiperOrigin-RevId: 195748688

tensorflow/python/kernel_tests/control_flow_ops_py_test.py
tensorflow/python/ops/gradients_impl.py

index 77e6f5f..843759f 100644 (file)
@@ -1847,6 +1847,23 @@ class ControlFlowTest(test.TestCase):
       r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x)
       self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
 
+  def testGradInWhileWrtInitialLoopVal(self):
+    with self.test_session():
+      x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
+      y = x + 1
+
+      def body(i, v):
+        z = v * 2
+        return i + 1, gradients_impl.gradients(z, x)[0]
+
+      with self.assertRaisesRegexp(
+          ValueError,
+          "Cannot compute gradient inside while loop with respect to op 'x'. "
+          "We do not support taking the gradient wrt or through the initial "
+          "value of a loop variable. Gradients can be computed through "
+          "loop invariants or wrt the input parameters to the loop body."):
+        control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y])
+
   def testWhileGradInWhile(self):
     with self.test_session():
       n = ops.convert_to_tensor(1.0, name="n")
index a6b1e6d..069b5a4 100644 (file)
@@ -418,6 +418,30 @@ def _MaybeCompile(scope, op, func, grad_fn):
     return grad_fn()
 
 
+def _RaiseNoGradWrtInitialLoopValError(op, from_ops):
+  """Raises an error if we backprop through a loop var."""
+  # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
+  # message.
+  target_op = None
+  queue = collections.deque([op])
+  visited = set()
+  while queue:
+    curr_op = queue.popleft()
+    if curr_op in visited: continue
+    visited.add(curr_op)
+    if curr_op in from_ops:
+      target_op = curr_op
+      break
+    queue.extend(t.op for t in curr_op.inputs)
+  assert target_op
+  raise ValueError(
+      "Cannot compute gradient inside while loop with respect to op '%s'. "
+      "We do not support taking the gradient wrt or through the initial value "
+      "of a loop variable. Gradients can be computed through loop invariants "
+      "or wrt the input parameters to the loop body."
+      % target_op.name)
+
+
 @tf_export("gradients")
 def gradients(ys,
               xs,
@@ -630,6 +654,21 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
                   (op.name, op.type))
         if loop_state:
           loop_state.EnterGradWhileContext(op, before=False)
+
+        # NOTE(skyewm): We don't support computing gradients wrt a loop variable
+        # unless it's within the context of a single iteration (i.e. the
+        # gradient is wrt to the loop parameter in the body function, not wrt or
+        # through the initial value). This means if we're in a while loop
+        # context, we should never see a switch node from this context.
+        # pylint: disable=protected-access
+        if (control_flow_util.IsSwitch(op) and
+            op._control_flow_context is not None and
+            op._control_flow_context.IsWhileContext() and
+            op._control_flow_context ==
+            ops.get_default_graph()._get_control_flow_context()):
+          _RaiseNoGradWrtInitialLoopValError(op, from_ops)
+        # pylint: enable=protected-access
+
         if (grad_fn or is_func_call) and has_out_grads:
           # NOTE: If _AggregatedGrads didn't compute a value for the i'th
           # output, it means that the cost does not depend on output[i],