Args:
from_ops: list of Operations.
- reached_ops: list of booleans, indexed by operation id.
+ reached_ops: set of Operations.
"""
queue = collections.deque()
queue.extend(from_ops)
while queue:
op = queue.popleft()
- if not reached_ops[op._id]:
- reached_ops[op._id] = True
+ if op not in reached_ops:
+ reached_ops.add(op)
for output in op.outputs:
if _IsBackpropagatable(output):
queue.extend(output.consumers())
Args:
to_ops: list of Operations.
- reached_ops: list of booleans, indexed by operation id.
+ reached_ops: set of Operations.
Returns:
The list of all inputs of to_ops that are in reached_ops.
while queue:
op = queue.popleft()
# We are interested in this op.
- if reached_ops[op._id]:
+ if op in reached_ops:
inputs.append(op)
# Clear the boolean so we won't add the inputs again.
- reached_ops[op._id] = False
+ reached_ops.remove(op)
for inp in op.inputs:
queue.append(inp.op)
return inputs
-def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops):
+def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops):
"""Initialize the pending count for ops between two lists of Operations.
- 'pending_count[op._id]' indicates the number of backprop inputs
+ 'pending_count[op]' indicates the number of backprop inputs
to this operation.
Args:
- graph: a Graph.
to_ops: list of Operations.
from_ops: list of Operations.
colocate_gradients_with_ops: Python bool. See docstring of gradients().
Returns:
- A tuple containing: (1) the subset of to_ops ids reachable from from_ops
- by a path of zero or more backpropagatable tensors, (2) a list of integers
- indexed by operation id, indicating the number of backprop inputs to this
- operation, and (3) a ControlFlowState object which is not None if the ops
- between from_ops and to_ops contain control flow loops.
+ A tuple containing: (1) the subset of to_ops reachable from from_ops by a
+ path of zero or more backpropagatable tensors, (2) a mapping from operation
+ to the number of backprop inputs to that op, and (3) a ControlFlowState
+ object which is not None if the ops between from_ops and to_ops contain
+ control flow loops.
"""
# Mark reachable ops from from_ops.
- reached_ops = [False] * (graph._last_id + 1)
+ reached_ops = set()
_MarkReachedOps(from_ops, reached_ops)
- # reached_ops[X] iff X is reachable from from_ops by a path of zero or more
+ # X in reached_ops iff X is reachable from from_ops by a path of zero or more
# backpropagatable tensors.
- reachable_to_ops = set(op._id for op in to_ops if reached_ops[op._id]) # pylint: disable=protected-access
+ reachable_to_ops = set(op for op in to_ops if op in reached_ops)
# Mark between ops.
- between_ops = [False] * (graph._last_id + 1)
+ between_ops = set()
between_op_list = []
queue = collections.deque()
queue.extend(to_ops)
while queue:
op = queue.popleft()
# We are interested in this op.
- if reached_ops[op._id]:
- between_ops[op._id] = True
+ if op in reached_ops:
+ between_ops.add(op)
between_op_list.append(op)
# Clear the boolean so we won't add the inputs again.
- reached_ops[op._id] = False
+ reached_ops.remove(op)
for inp in op.inputs:
queue.append(inp.op)
- # between_ops[X] iff X is on a path of zero or more backpropagatable tensors
+ # X in between_ops iff X is on a path of zero or more backpropagatable tensors
# between from_ops and to_ops
# 'loop_state' is None if there are no while loops.
between_op_list, between_ops, colocate_gradients_with_ops)
# Initialize pending count for between ops.
- pending_count = [0] * (graph._last_id + 1)
+ pending_count = collections.defaultdict(int)
for op in between_op_list:
for x in op.inputs:
- if between_ops[x.op._id]:
- pending_count[x.op._id] += 1
+ if x.op in between_ops:
+ pending_count[x.op] += 1
return reachable_to_ops, pending_count, loop_state
should stop. Operations in the returned set will not be differentiated.
This set is defined as the subset of `from_ops` containing ops that have
no predecessor in `from_ops`. `pending_count` is the result of
- `_PendingCount(g, xs, from_ops)`. An 'op' has predecessors in `from_ops`
- iff pending_count[op._id] > 0.
+ `_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops`
+ iff pending_count[op] > 0.
In addition, none of `stop_gradient_ops` will be differentiated.
Args:
from_ops: list of Operations.
stop_gradient_ops: list of Operations never to backprop through.
- pending_count: List of integers, indexed by operation id.
+ pending_count: mapping from operation to number of backprop inputs.
Returns:
The set of operations.
for op in from_ops:
is_stop_op = True
for inp in op.inputs:
- if pending_count[inp.op._id] > 0:
+ if pending_count[inp.op] > 0:
is_stop_op = False
break
if is_stop_op:
- stop_ops.add(op._id)
- stop_ops.update(op._id for op in stop_gradient_ops) # pylint: disable=protected-access
+ stop_ops.add(op)
+ stop_ops.update(op for op in stop_gradient_ops)
return stop_ops
f.name = op.type
for k in op.node_def.attr:
f.attr[k].CopyFrom(op.node_def.attr[k])
- # pylint: disable=protected-access
in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f)
- # pylint: enable=protected-access
return in_grads
gate_gradients, aggregation_method, stop_gradients)
-def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
- gate_gradients, aggregation_method, stop_gradients):
+def _GradientsHelper(ys,
+ xs,
+ grad_ys=None,
+ name="gradients",
+ colocate_gradients_with_ops=False,
+ gate_gradients=False,
+ aggregation_method=None,
+ stop_gradients=None,
+ src_graph=None):
"""Implementation of gradients()."""
if context.executing_eagerly():
raise RuntimeError("tf.gradients not supported when eager execution "
"is enabled. Use tf.contrib.eager.GradientTape "
"instead.")
+ if src_graph is None:
+ src_graph = ops.get_default_graph()
+
ys = _AsList(ys)
xs = _AsList(xs)
stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
from_ops = [t.op for t in xs]
stop_gradient_ops = [t.op for t in stop_gradients]
reachable_to_ops, pending_count, loop_state = _PendingCount(
- ops.get_default_graph(), to_ops, from_ops, colocate_gradients_with_ops)
+ to_ops, from_ops, colocate_gradients_with_ops)
# Iterate over the collected ops.
#
for op in to_ops:
# 'ready' handles the case where one output gradient relies on
# another output's gradient.
- # pylint: disable=protected-access
- ready = (pending_count[op._id] == 0)
- if ready and op._id not in to_ops_set and op._id in reachable_to_ops:
- to_ops_set.add(op._id)
+ ready = (pending_count[op] == 0)
+ if ready and op not in to_ops_set and op in reachable_to_ops:
+ to_ops_set.add(op)
queue.append(op)
- # pylint: enable=protected-access
if loop_state:
loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
grad_fn = None
func_call = None
# pylint: disable=protected-access
- is_func_call = ops.get_default_graph()._is_function(op.type)
+ is_func_call = src_graph._is_function(op.type)
# pylint: enable=protected-access
has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
- if has_out_grads and (op._id not in stop_ops):
+ if has_out_grads and (op not in stop_ops):
if is_func_call:
- func_call = ops.get_default_graph()._get_function(op.type)
+ func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
# Note that __defun is not set if the graph is
# imported. If it's set, we prefer to access the original
# defun.
out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i)
with ops.name_scope(op.name + "_grad"):
# pylint: disable=protected-access
- with ops.get_default_graph()._original_op(op):
+ with src_graph._original_op(op):
# pylint: enable=protected-access
if grad_fn:
# If grad_fn was found, do not use SymbolicGradient even for
def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state):
"""Update pending count for the inputs of op and enqueue ready ops."""
for x in op.inputs:
- # pylint: disable=protected-access
- pending_count[x.op._id] -= 1
- ready = (pending_count[x.op._id] == 0)
+ pending_count[x.op] -= 1
+ ready = (pending_count[x.op] == 0)
if loop_state and not ready:
- ready = (
- pending_count[x.op._id] > 0 and control_flow_util.IsLoopSwitch(x.op))
- # pylint: enable=protected-access
+ ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op)
if ready:
if control_flow_util.IsLoopExit(x.op):
# if x is an exit without real gradient, defer processing them.
from tensorflow.python.platform import googletest
-def _OpsBetween(graph, to_ops, from_ops):
+def _OpsBetween(to_ops, from_ops):
"""Build the list of operations between two lists of Operations.
Args:
- graph: a Graph.
to_ops: list of Operations.
from_ops: list of Operations.
TODO(touts): Think about returning an empty list if from_ops are not
reachable from to_ops. Presently it returns to_ops in that case.
"""
- # List of booleans, indexed by operation id, indicating if
- # an op is reached from the output of "input_ops".
- reached_ops = [False] * (graph._last_id + 1)
+ # Ops that are reachable from the output of "input_ops".
+ reached_ops = set()
# We only care to reach up to "output_ops" so we mark the
# output ops as reached to avoid recursing past them.
for op in to_ops:
- reached_ops[op._id] = True
+ reached_ops.add(op)
gradients_impl._MarkReachedOps(from_ops, reached_ops)
between_ops = gradients_impl._GatherInputs(to_ops, reached_ops)
between_ops.sort(key=lambda x: -x._id)
self.assertEquals(self._OpNames(ops1), self._OpNames(ops2))
def testOpsBetweenSimple(self):
- with ops.Graph().as_default() as g:
+ with ops.Graph().as_default():
t1 = constant(1.0)
t2 = constant(2.0)
t3 = array_ops.stack([t1, t2])
# Full graph
self._assertOpListEqual([t3.op, t2.op, t1.op],
- _OpsBetween(g, [t3.op], [t1.op, t2.op]))
+ _OpsBetween([t3.op], [t1.op, t2.op]))
# Only t1, t3.
- self._assertOpListEqual([t3.op, t1.op], _OpsBetween(g, [t3.op], [t1.op]))
+ self._assertOpListEqual([t3.op, t1.op], _OpsBetween([t3.op], [t1.op]))
def testOpsBetweenUnreachable(self):
- with ops.Graph().as_default() as g:
+ with ops.Graph().as_default():
t1 = constant(1.0)
t2 = constant(2.0)
_ = array_ops.stack([t1, t2])
t5 = constant(2.0)
t6 = array_ops.stack([t4, t5])
# Elements of to_ops are always listed.
- self._assertOpListEqual([t6.op], _OpsBetween(g, [t6.op], [t1.op]))
+ self._assertOpListEqual([t6.op], _OpsBetween([t6.op], [t1.op]))
def testOpsBetweenCut(self):
- with ops.Graph().as_default() as g:
+ with ops.Graph().as_default():
t1 = constant(1.0)
t2 = constant(2.0)
t3 = array_ops.stack([t1, t2])
t6 = constant([2.0])
t7 = array_ops.concat([t5, t6], 0)
self._assertOpListEqual([t7.op, t5.op, t4.op],
- _OpsBetween(g, [t7.op], [t4.op]))
+ _OpsBetween([t7.op], [t4.op]))
def testOpsBetweenCycle(self):
- with ops.Graph().as_default() as g:
+ with ops.Graph().as_default():
t1 = constant(1.0)
t2 = constant(2.0)
t3 = array_ops.stack([t1, t2])
t6 = array_ops.concat([t4, t5], 0)
t7 = array_ops.concat([t6, t3], 0)
self._assertOpListEqual([t6.op, t4.op, t3.op],
- _OpsBetween(g, [t6.op], [t3.op]))
+ _OpsBetween([t6.op], [t3.op]))
self._assertOpListEqual([t7.op, t6.op, t5.op, t4.op, t3.op, t1.op],
- _OpsBetween(g, [t7.op], [t1.op, t5.op]))
+ _OpsBetween([t7.op], [t1.op, t5.op]))
self._assertOpListEqual([t6.op, t5.op, t4.op, t3.op, t2.op],
- _OpsBetween(g, [t6.op], [t2.op, t5.op]))
+ _OpsBetween([t6.op], [t2.op, t5.op]))
def testGradients(self):
with ops.Graph().as_default():