Add 'src_graph' argument to gradients_impl._GradientsHelper.
authorSkye Wanderman-Milne <skyewm@google.com>
Sat, 19 May 2018 01:29:54 +0000 (18:29 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 19 May 2018 01:32:35 +0000 (18:32 -0700)
This allows the gradient graph to be built in a _FuncGraph separate
from the forward graph (a _FuncGraph is necessary to capture needed
tensors from the forward graph. It's up to the capturing logic what
how to feed the forward tensors to the gradient graph).

PiperOrigin-RevId: 197230736

tensorflow/python/ops/control_flow_ops.py
tensorflow/python/ops/gradients_impl.py
tensorflow/python/ops/gradients_test.py

index 5ebdb19..ee024ce 100644 (file)
@@ -1192,20 +1192,18 @@ class ControlFlowState(object):
       to backprop.
     """
     loop_exits = []
-    for _, grad_state in self._map.items():
-      # pylint: disable=protected-access
+    for grad_state in self._map.values():
       for y in grad_state.forward_loop_exits:
-        if pending_count[y.op._id] == 0:
+        if pending_count[y.op] == 0:
           grad_state.pending_exits_count -= 1
-          if y.op._id not in to_ops_set:
+          if y.op not in to_ops_set:
             grad_state.unused_exits.append(y)
           if grad_state.pending_exits_count == 0:
             loop_exits.extend(grad_state.unused_exits)
       # Need to include Enters in backprop for higher-order gradients.
       for y in grad_state.forward_context.loop_enters:
-        if pending_count[y.op._id] == 0:
-          pending_count[y.op._id] = 1
-      # pylint: enable=protected-access
+        if pending_count[y.op] == 0:
+          pending_count[y.op] = 1
     return loop_exits
 
   def EnterGradWhileContext(self, op, before):
@@ -1243,8 +1241,8 @@ class ControlFlowState(object):
 
       # We need to include all exits of a loop for backprop.
       for loop_exit in grad_state.forward_loop_exits:
-        if not between_ops[loop_exit.op._id]:
-          between_ops[loop_exit.op._id] = True
+        if loop_exit.op not in between_ops:
+          between_ops.add(loop_exit.op)
           between_op_list.append(loop_exit.op)
 
   def ZerosLikeForExit(self, val):
index 069b5a4..716b54f 100644 (file)
@@ -112,14 +112,14 @@ def _MarkReachedOps(from_ops, reached_ops):
 
   Args:
     from_ops: list of Operations.
-    reached_ops: list of booleans, indexed by operation id.
+    reached_ops: set of Operations.
   """
   queue = collections.deque()
   queue.extend(from_ops)
   while queue:
     op = queue.popleft()
-    if not reached_ops[op._id]:
-      reached_ops[op._id] = True
+    if op not in reached_ops:
+      reached_ops.add(op)
       for output in op.outputs:
         if _IsBackpropagatable(output):
           queue.extend(output.consumers())
@@ -130,7 +130,7 @@ def _GatherInputs(to_ops, reached_ops):
 
   Args:
     to_ops: list of Operations.
-    reached_ops: list of booleans, indexed by operation id.
+    reached_ops: set of Operations.
 
   Returns:
     The list of all inputs of to_ops that are in reached_ops.
@@ -142,58 +142,57 @@ def _GatherInputs(to_ops, reached_ops):
   while queue:
     op = queue.popleft()
     # We are interested in this op.
-    if reached_ops[op._id]:
+    if op in reached_ops:
       inputs.append(op)
       # Clear the boolean so we won't add the inputs again.
-      reached_ops[op._id] = False
+      reached_ops.remove(op)
       for inp in op.inputs:
         queue.append(inp.op)
   return inputs
 
 
-def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops):
+def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
   """Initialize the pending count for ops between two lists of Operations.
 
-  'pending_count[op._id]' indicates the number of backprop inputs
+  'pending_count[op]' indicates the number of backprop inputs
   to this operation.
 
   Args:
-    graph: a Graph.
     to_ops: list of Operations.
     from_ops: list of Operations.
     colocate_gradients_with_ops: Python bool.  See docstring of gradients().
 
   Returns:
-    A tuple containing: (1) the subset of to_ops ids reachable from from_ops
-    by a path of zero or more backpropagatable tensors, (2) a list of integers
-    indexed by operation id, indicating the number of backprop inputs to this
-    operation, and (3) a ControlFlowState object which is not None if the ops
-    between from_ops and to_ops contain control flow loops.
+    A tuple containing: (1) the subset of to_ops reachable from from_ops by a
+    path of zero or more backpropagatable tensors, (2) a mapping from operation
+    to the number of backprop inputs to that op, and (3) a ControlFlowState
+    object which is not None if the ops between from_ops and to_ops contain
+    control flow loops.
   """
   # Mark reachable ops from from_ops.
-  reached_ops = [False] * (graph._last_id + 1)
+  reached_ops = set()
   _MarkReachedOps(from_ops, reached_ops)
-  # reached_ops[X] iff X is reachable from from_ops by a path of zero or more
+  # X in reached_ops iff X is reachable from from_ops by a path of zero or more
   # backpropagatable tensors.
 
-  reachable_to_ops = set(op._id for op in to_ops if reached_ops[op._id])  # pylint: disable=protected-access
+  reachable_to_ops = set(op for op in to_ops if op in reached_ops)
 
   # Mark between ops.
-  between_ops = [False] * (graph._last_id + 1)
+  between_ops = set()
   between_op_list = []
   queue = collections.deque()
   queue.extend(to_ops)
   while queue:
     op = queue.popleft()
     # We are interested in this op.
-    if reached_ops[op._id]:
-      between_ops[op._id] = True
+    if op in reached_ops:
+      between_ops.add(op)
       between_op_list.append(op)
       # Clear the boolean so we won't add the inputs again.
-      reached_ops[op._id] = False
+      reached_ops.remove(op)
       for inp in op.inputs:
         queue.append(inp.op)
-  # between_ops[X] iff X is on a path of zero or more backpropagatable tensors
+  # X in between_ops iff X is on a path of zero or more backpropagatable tensors
   # between from_ops and to_ops
 
   # 'loop_state' is None if there are no while loops.
@@ -201,11 +200,11 @@ def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops):
       between_op_list, between_ops, colocate_gradients_with_ops)
 
   # Initialize pending count for between ops.
-  pending_count = [0] * (graph._last_id + 1)
+  pending_count = collections.defaultdict(int)
   for op in between_op_list:
     for x in op.inputs:
-      if between_ops[x.op._id]:
-        pending_count[x.op._id] += 1
+      if x.op in between_ops:
+        pending_count[x.op] += 1
 
   return reachable_to_ops, pending_count, loop_state
 
@@ -331,15 +330,15 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count):
   should stop. Operations in the returned set will not be differentiated.
   This set is defined as the subset of `from_ops` containing ops that have
   no predecessor in `from_ops`. `pending_count` is the result of
-  `_PendingCount(g, xs, from_ops)`. An 'op' has predecessors in `from_ops`
-  iff pending_count[op._id] > 0.
+  `_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops`
+  iff pending_count[op] > 0.
 
   In addition, none of `stop_gradient_ops` will be differentiated.
 
   Args:
     from_ops: list of Operations.
     stop_gradient_ops: list of Operations never to backprop through.
-    pending_count: List of integers, indexed by operation id.
+    pending_count: mapping from operation to number of backprop inputs.
 
   Returns:
     The set of operations.
@@ -348,12 +347,12 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count):
   for op in from_ops:
     is_stop_op = True
     for inp in op.inputs:
-      if pending_count[inp.op._id] > 0:
+      if pending_count[inp.op] > 0:
         is_stop_op = False
         break
     if is_stop_op:
-      stop_ops.add(op._id)
-  stop_ops.update(op._id for op in stop_gradient_ops)  # pylint: disable=protected-access
+      stop_ops.add(op)
+  stop_ops.update(op for op in stop_gradient_ops)
   return stop_ops
 
 
@@ -375,9 +374,7 @@ def _SymGrad(op, out_grads):
   f.name = op.type
   for k in op.node_def.attr:
     f.attr[k].CopyFrom(op.node_def.attr[k])
-  # pylint: disable=protected-access
   in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f)
-  # pylint: enable=protected-access
   return in_grads
 
 
@@ -535,13 +532,23 @@ def gradients(ys,
                             gate_gradients, aggregation_method, stop_gradients)
 
 
-def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
-                     gate_gradients, aggregation_method, stop_gradients):
+def _GradientsHelper(ys,
+                     xs,
+                     grad_ys=None,
+                     name="gradients",
+                     colocate_gradients_with_ops=False,
+                     gate_gradients=False,
+                     aggregation_method=None,
+                     stop_gradients=None,
+                     src_graph=None):
   """Implementation of gradients()."""
   if context.executing_eagerly():
     raise RuntimeError("tf.gradients not supported when eager execution "
                        "is enabled. Use tf.contrib.eager.GradientTape "
                        "instead.")
+  if src_graph is None:
+    src_graph = ops.get_default_graph()
+
   ys = _AsList(ys)
   xs = _AsList(xs)
   stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
@@ -581,7 +588,7 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
     from_ops = [t.op for t in xs]
     stop_gradient_ops = [t.op for t in stop_gradients]
     reachable_to_ops, pending_count, loop_state = _PendingCount(
-        ops.get_default_graph(), to_ops, from_ops, colocate_gradients_with_ops)
+        to_ops, from_ops, colocate_gradients_with_ops)
 
     # Iterate over the collected ops.
     #
@@ -603,12 +610,10 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
     for op in to_ops:
       # 'ready' handles the case where one output gradient relies on
       # another output's gradient.
-      # pylint: disable=protected-access
-      ready = (pending_count[op._id] == 0)
-      if ready and op._id not in to_ops_set and op._id in reachable_to_ops:
-        to_ops_set.add(op._id)
+      ready = (pending_count[op] == 0)
+      if ready and op not in to_ops_set and op in reachable_to_ops:
+        to_ops_set.add(op)
         queue.append(op)
-      # pylint: enable=protected-access
 
     if loop_state:
       loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
@@ -632,12 +637,12 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
         grad_fn = None
         func_call = None
         # pylint: disable=protected-access
-        is_func_call = ops.get_default_graph()._is_function(op.type)
+        is_func_call = src_graph._is_function(op.type)
         # pylint: enable=protected-access
         has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
-        if has_out_grads and (op._id not in stop_ops):
+        if has_out_grads and (op not in stop_ops):
           if is_func_call:
-            func_call = ops.get_default_graph()._get_function(op.type)
+            func_call = src_graph._get_function(op.type)  # pylint: disable=protected-access
             # Note that __defun is not set if the graph is
             # imported. If it's set, we prefer to access the original
             # defun.
@@ -687,7 +692,7 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
                 out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i)
           with ops.name_scope(op.name + "_grad"):
             # pylint: disable=protected-access
-            with ops.get_default_graph()._original_op(op):
+            with src_graph._original_op(op):
               # pylint: enable=protected-access
               if grad_fn:
                 # If grad_fn was found, do not use SymbolicGradient even for
@@ -754,13 +759,10 @@ def _HasAnyNotNoneGrads(grads, op):
 def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state):
   """Update pending count for the inputs of op and enqueue ready ops."""
   for x in op.inputs:
-    # pylint: disable=protected-access
-    pending_count[x.op._id] -= 1
-    ready = (pending_count[x.op._id] == 0)
+    pending_count[x.op] -= 1
+    ready = (pending_count[x.op] == 0)
     if loop_state and not ready:
-      ready = (
-          pending_count[x.op._id] > 0 and control_flow_util.IsLoopSwitch(x.op))
-    # pylint: enable=protected-access
+      ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op)
     if ready:
       if control_flow_util.IsLoopExit(x.op):
         # if x is an exit without real gradient, defer processing them.
index 096d0ce..70d500a 100644 (file)
@@ -57,11 +57,10 @@ from tensorflow.python.ops.nn_ops import bias_add
 from tensorflow.python.platform import googletest
 
 
-def _OpsBetween(graph, to_ops, from_ops):
+def _OpsBetween(to_ops, from_ops):
   """Build the list of operations between two lists of Operations.
 
   Args:
-    graph: a Graph.
     to_ops: list of Operations.
     from_ops: list of Operations.
 
@@ -72,13 +71,12 @@ def _OpsBetween(graph, to_ops, from_ops):
     TODO(touts): Think about returning an empty list if from_ops are not
     reachable from to_ops.  Presently it returns to_ops in that case.
   """
-  # List of booleans, indexed by operation id, indicating if
-  # an op is reached from the output of "input_ops".
-  reached_ops = [False] * (graph._last_id + 1)
+  # Ops that are reachable from the output of "input_ops".
+  reached_ops = set()
   # We only care to reach up to "output_ops" so we mark the
   # output ops as reached to avoid recursing past them.
   for op in to_ops:
-    reached_ops[op._id] = True
+    reached_ops.add(op)
   gradients_impl._MarkReachedOps(from_ops, reached_ops)
   between_ops = gradients_impl._GatherInputs(to_ops, reached_ops)
   between_ops.sort(key=lambda x: -x._id)
@@ -95,18 +93,18 @@ class GradientsTest(test_util.TensorFlowTestCase):
     self.assertEquals(self._OpNames(ops1), self._OpNames(ops2))
 
   def testOpsBetweenSimple(self):
-    with ops.Graph().as_default() as g:
+    with ops.Graph().as_default():
       t1 = constant(1.0)
       t2 = constant(2.0)
       t3 = array_ops.stack([t1, t2])
     # Full graph
     self._assertOpListEqual([t3.op, t2.op, t1.op],
-                            _OpsBetween(g, [t3.op], [t1.op, t2.op]))
+                            _OpsBetween([t3.op], [t1.op, t2.op]))
     # Only t1, t3.
-    self._assertOpListEqual([t3.op, t1.op], _OpsBetween(g, [t3.op], [t1.op]))
+    self._assertOpListEqual([t3.op, t1.op], _OpsBetween([t3.op], [t1.op]))
 
   def testOpsBetweenUnreachable(self):
-    with ops.Graph().as_default() as g:
+    with ops.Graph().as_default():
       t1 = constant(1.0)
       t2 = constant(2.0)
       _ = array_ops.stack([t1, t2])
@@ -114,10 +112,10 @@ class GradientsTest(test_util.TensorFlowTestCase):
       t5 = constant(2.0)
       t6 = array_ops.stack([t4, t5])
     # Elements of to_ops are always listed.
-    self._assertOpListEqual([t6.op], _OpsBetween(g, [t6.op], [t1.op]))
+    self._assertOpListEqual([t6.op], _OpsBetween([t6.op], [t1.op]))
 
   def testOpsBetweenCut(self):
-    with ops.Graph().as_default() as g:
+    with ops.Graph().as_default():
       t1 = constant(1.0)
       t2 = constant(2.0)
       t3 = array_ops.stack([t1, t2])
@@ -126,10 +124,10 @@ class GradientsTest(test_util.TensorFlowTestCase):
       t6 = constant([2.0])
       t7 = array_ops.concat([t5, t6], 0)
     self._assertOpListEqual([t7.op, t5.op, t4.op],
-                            _OpsBetween(g, [t7.op], [t4.op]))
+                            _OpsBetween([t7.op], [t4.op]))
 
   def testOpsBetweenCycle(self):
-    with ops.Graph().as_default() as g:
+    with ops.Graph().as_default():
       t1 = constant(1.0)
       t2 = constant(2.0)
       t3 = array_ops.stack([t1, t2])
@@ -138,11 +136,11 @@ class GradientsTest(test_util.TensorFlowTestCase):
       t6 = array_ops.concat([t4, t5], 0)
       t7 = array_ops.concat([t6, t3], 0)
     self._assertOpListEqual([t6.op, t4.op, t3.op],
-                            _OpsBetween(g, [t6.op], [t3.op]))
+                            _OpsBetween([t6.op], [t3.op]))
     self._assertOpListEqual([t7.op, t6.op, t5.op, t4.op, t3.op, t1.op],
-                            _OpsBetween(g, [t7.op], [t1.op, t5.op]))
+                            _OpsBetween([t7.op], [t1.op, t5.op]))
     self._assertOpListEqual([t6.op, t5.op, t4.op, t3.op, t2.op],
-                            _OpsBetween(g, [t6.op], [t2.op, t5.op]))
+                            _OpsBetween([t6.op], [t2.op, t5.op]))
 
   def testGradients(self):
     with ops.Graph().as_default():