From: Skye Wanderman-Milne Date: Sat, 19 May 2018 01:29:54 +0000 (-0700) Subject: Add 'src_graph' argument to gradients_impl._GradientsHelper. X-Git-Tag: upstream/v1.9.0_rc1~56^2~2^2~54 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=92cdc99c98da1763fcba65477303b46ecee628a6;p=platform%2Fupstream%2Ftensorflow.git Add 'src_graph' argument to gradients_impl._GradientsHelper. This allows the gradient graph to be built in a _FuncGraph separate from the forward graph (a _FuncGraph is necessary to capture needed tensors from the forward graph. It's up to the capturing logic what how to feed the forward tensors to the gradient graph). PiperOrigin-RevId: 197230736 --- diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 5ebdb19..ee024ce 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1192,20 +1192,18 @@ class ControlFlowState(object): to backprop. """ loop_exits = [] - for _, grad_state in self._map.items(): - # pylint: disable=protected-access + for grad_state in self._map.values(): for y in grad_state.forward_loop_exits: - if pending_count[y.op._id] == 0: + if pending_count[y.op] == 0: grad_state.pending_exits_count -= 1 - if y.op._id not in to_ops_set: + if y.op not in to_ops_set: grad_state.unused_exits.append(y) if grad_state.pending_exits_count == 0: loop_exits.extend(grad_state.unused_exits) # Need to include Enters in backprop for higher-order gradients. for y in grad_state.forward_context.loop_enters: - if pending_count[y.op._id] == 0: - pending_count[y.op._id] = 1 - # pylint: enable=protected-access + if pending_count[y.op] == 0: + pending_count[y.op] = 1 return loop_exits def EnterGradWhileContext(self, op, before): @@ -1243,8 +1241,8 @@ class ControlFlowState(object): # We need to include all exits of a loop for backprop. for loop_exit in grad_state.forward_loop_exits: - if not between_ops[loop_exit.op._id]: - between_ops[loop_exit.op._id] = True + if loop_exit.op not in between_ops: + between_ops.add(loop_exit.op) between_op_list.append(loop_exit.op) def ZerosLikeForExit(self, val): diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 069b5a4..716b54f 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -112,14 +112,14 @@ def _MarkReachedOps(from_ops, reached_ops): Args: from_ops: list of Operations. - reached_ops: list of booleans, indexed by operation id. + reached_ops: set of Operations. """ queue = collections.deque() queue.extend(from_ops) while queue: op = queue.popleft() - if not reached_ops[op._id]: - reached_ops[op._id] = True + if op not in reached_ops: + reached_ops.add(op) for output in op.outputs: if _IsBackpropagatable(output): queue.extend(output.consumers()) @@ -130,7 +130,7 @@ def _GatherInputs(to_ops, reached_ops): Args: to_ops: list of Operations. - reached_ops: list of booleans, indexed by operation id. + reached_ops: set of Operations. Returns: The list of all inputs of to_ops that are in reached_ops. @@ -142,58 +142,57 @@ def _GatherInputs(to_ops, reached_ops): while queue: op = queue.popleft() # We are interested in this op. - if reached_ops[op._id]: + if op in reached_ops: inputs.append(op) # Clear the boolean so we won't add the inputs again. - reached_ops[op._id] = False + reached_ops.remove(op) for inp in op.inputs: queue.append(inp.op) return inputs -def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops): +def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops): """Initialize the pending count for ops between two lists of Operations. - 'pending_count[op._id]' indicates the number of backprop inputs + 'pending_count[op]' indicates the number of backprop inputs to this operation. Args: - graph: a Graph. to_ops: list of Operations. from_ops: list of Operations. colocate_gradients_with_ops: Python bool. See docstring of gradients(). Returns: - A tuple containing: (1) the subset of to_ops ids reachable from from_ops - by a path of zero or more backpropagatable tensors, (2) a list of integers - indexed by operation id, indicating the number of backprop inputs to this - operation, and (3) a ControlFlowState object which is not None if the ops - between from_ops and to_ops contain control flow loops. + A tuple containing: (1) the subset of to_ops reachable from from_ops by a + path of zero or more backpropagatable tensors, (2) a mapping from operation + to the number of backprop inputs to that op, and (3) a ControlFlowState + object which is not None if the ops between from_ops and to_ops contain + control flow loops. """ # Mark reachable ops from from_ops. - reached_ops = [False] * (graph._last_id + 1) + reached_ops = set() _MarkReachedOps(from_ops, reached_ops) - # reached_ops[X] iff X is reachable from from_ops by a path of zero or more + # X in reached_ops iff X is reachable from from_ops by a path of zero or more # backpropagatable tensors. - reachable_to_ops = set(op._id for op in to_ops if reached_ops[op._id]) # pylint: disable=protected-access + reachable_to_ops = set(op for op in to_ops if op in reached_ops) # Mark between ops. - between_ops = [False] * (graph._last_id + 1) + between_ops = set() between_op_list = [] queue = collections.deque() queue.extend(to_ops) while queue: op = queue.popleft() # We are interested in this op. - if reached_ops[op._id]: - between_ops[op._id] = True + if op in reached_ops: + between_ops.add(op) between_op_list.append(op) # Clear the boolean so we won't add the inputs again. - reached_ops[op._id] = False + reached_ops.remove(op) for inp in op.inputs: queue.append(inp.op) - # between_ops[X] iff X is on a path of zero or more backpropagatable tensors + # X in between_ops iff X is on a path of zero or more backpropagatable tensors # between from_ops and to_ops # 'loop_state' is None if there are no while loops. @@ -201,11 +200,11 @@ def _PendingCount(graph, to_ops, from_ops, colocate_gradients_with_ops): between_op_list, between_ops, colocate_gradients_with_ops) # Initialize pending count for between ops. - pending_count = [0] * (graph._last_id + 1) + pending_count = collections.defaultdict(int) for op in between_op_list: for x in op.inputs: - if between_ops[x.op._id]: - pending_count[x.op._id] += 1 + if x.op in between_ops: + pending_count[x.op] += 1 return reachable_to_ops, pending_count, loop_state @@ -331,15 +330,15 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count): should stop. Operations in the returned set will not be differentiated. This set is defined as the subset of `from_ops` containing ops that have no predecessor in `from_ops`. `pending_count` is the result of - `_PendingCount(g, xs, from_ops)`. An 'op' has predecessors in `from_ops` - iff pending_count[op._id] > 0. + `_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops` + iff pending_count[op] > 0. In addition, none of `stop_gradient_ops` will be differentiated. Args: from_ops: list of Operations. stop_gradient_ops: list of Operations never to backprop through. - pending_count: List of integers, indexed by operation id. + pending_count: mapping from operation to number of backprop inputs. Returns: The set of operations. @@ -348,12 +347,12 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count): for op in from_ops: is_stop_op = True for inp in op.inputs: - if pending_count[inp.op._id] > 0: + if pending_count[inp.op] > 0: is_stop_op = False break if is_stop_op: - stop_ops.add(op._id) - stop_ops.update(op._id for op in stop_gradient_ops) # pylint: disable=protected-access + stop_ops.add(op) + stop_ops.update(op for op in stop_gradient_ops) return stop_ops @@ -375,9 +374,7 @@ def _SymGrad(op, out_grads): f.name = op.type for k in op.node_def.attr: f.attr[k].CopyFrom(op.node_def.attr[k]) - # pylint: disable=protected-access in_grads = functional_ops.symbolic_gradient(input=f_in, Tout=f_types, f=f) - # pylint: enable=protected-access return in_grads @@ -535,13 +532,23 @@ def gradients(ys, gate_gradients, aggregation_method, stop_gradients) -def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, - gate_gradients, aggregation_method, stop_gradients): +def _GradientsHelper(ys, + xs, + grad_ys=None, + name="gradients", + colocate_gradients_with_ops=False, + gate_gradients=False, + aggregation_method=None, + stop_gradients=None, + src_graph=None): """Implementation of gradients().""" if context.executing_eagerly(): raise RuntimeError("tf.gradients not supported when eager execution " "is enabled. Use tf.contrib.eager.GradientTape " "instead.") + if src_graph is None: + src_graph = ops.get_default_graph() + ys = _AsList(ys) xs = _AsList(xs) stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients) @@ -581,7 +588,7 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, from_ops = [t.op for t in xs] stop_gradient_ops = [t.op for t in stop_gradients] reachable_to_ops, pending_count, loop_state = _PendingCount( - ops.get_default_graph(), to_ops, from_ops, colocate_gradients_with_ops) + to_ops, from_ops, colocate_gradients_with_ops) # Iterate over the collected ops. # @@ -603,12 +610,10 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, for op in to_ops: # 'ready' handles the case where one output gradient relies on # another output's gradient. - # pylint: disable=protected-access - ready = (pending_count[op._id] == 0) - if ready and op._id not in to_ops_set and op._id in reachable_to_ops: - to_ops_set.add(op._id) + ready = (pending_count[op] == 0) + if ready and op not in to_ops_set and op in reachable_to_ops: + to_ops_set.add(op) queue.append(op) - # pylint: enable=protected-access if loop_state: loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set) @@ -632,12 +637,12 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, grad_fn = None func_call = None # pylint: disable=protected-access - is_func_call = ops.get_default_graph()._is_function(op.type) + is_func_call = src_graph._is_function(op.type) # pylint: enable=protected-access has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) - if has_out_grads and (op._id not in stop_ops): + if has_out_grads and (op not in stop_ops): if is_func_call: - func_call = ops.get_default_graph()._get_function(op.type) + func_call = src_graph._get_function(op.type) # pylint: disable=protected-access # Note that __defun is not set if the graph is # imported. If it's set, we prefer to access the original # defun. @@ -687,7 +692,7 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i) with ops.name_scope(op.name + "_grad"): # pylint: disable=protected-access - with ops.get_default_graph()._original_op(op): + with src_graph._original_op(op): # pylint: enable=protected-access if grad_fn: # If grad_fn was found, do not use SymbolicGradient even for @@ -754,13 +759,10 @@ def _HasAnyNotNoneGrads(grads, op): def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state): """Update pending count for the inputs of op and enqueue ready ops.""" for x in op.inputs: - # pylint: disable=protected-access - pending_count[x.op._id] -= 1 - ready = (pending_count[x.op._id] == 0) + pending_count[x.op] -= 1 + ready = (pending_count[x.op] == 0) if loop_state and not ready: - ready = ( - pending_count[x.op._id] > 0 and control_flow_util.IsLoopSwitch(x.op)) - # pylint: enable=protected-access + ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op) if ready: if control_flow_util.IsLoopExit(x.op): # if x is an exit without real gradient, defer processing them. diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 096d0ce..70d500a 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -57,11 +57,10 @@ from tensorflow.python.ops.nn_ops import bias_add from tensorflow.python.platform import googletest -def _OpsBetween(graph, to_ops, from_ops): +def _OpsBetween(to_ops, from_ops): """Build the list of operations between two lists of Operations. Args: - graph: a Graph. to_ops: list of Operations. from_ops: list of Operations. @@ -72,13 +71,12 @@ def _OpsBetween(graph, to_ops, from_ops): TODO(touts): Think about returning an empty list if from_ops are not reachable from to_ops. Presently it returns to_ops in that case. """ - # List of booleans, indexed by operation id, indicating if - # an op is reached from the output of "input_ops". - reached_ops = [False] * (graph._last_id + 1) + # Ops that are reachable from the output of "input_ops". + reached_ops = set() # We only care to reach up to "output_ops" so we mark the # output ops as reached to avoid recursing past them. for op in to_ops: - reached_ops[op._id] = True + reached_ops.add(op) gradients_impl._MarkReachedOps(from_ops, reached_ops) between_ops = gradients_impl._GatherInputs(to_ops, reached_ops) between_ops.sort(key=lambda x: -x._id) @@ -95,18 +93,18 @@ class GradientsTest(test_util.TensorFlowTestCase): self.assertEquals(self._OpNames(ops1), self._OpNames(ops2)) def testOpsBetweenSimple(self): - with ops.Graph().as_default() as g: + with ops.Graph().as_default(): t1 = constant(1.0) t2 = constant(2.0) t3 = array_ops.stack([t1, t2]) # Full graph self._assertOpListEqual([t3.op, t2.op, t1.op], - _OpsBetween(g, [t3.op], [t1.op, t2.op])) + _OpsBetween([t3.op], [t1.op, t2.op])) # Only t1, t3. - self._assertOpListEqual([t3.op, t1.op], _OpsBetween(g, [t3.op], [t1.op])) + self._assertOpListEqual([t3.op, t1.op], _OpsBetween([t3.op], [t1.op])) def testOpsBetweenUnreachable(self): - with ops.Graph().as_default() as g: + with ops.Graph().as_default(): t1 = constant(1.0) t2 = constant(2.0) _ = array_ops.stack([t1, t2]) @@ -114,10 +112,10 @@ class GradientsTest(test_util.TensorFlowTestCase): t5 = constant(2.0) t6 = array_ops.stack([t4, t5]) # Elements of to_ops are always listed. - self._assertOpListEqual([t6.op], _OpsBetween(g, [t6.op], [t1.op])) + self._assertOpListEqual([t6.op], _OpsBetween([t6.op], [t1.op])) def testOpsBetweenCut(self): - with ops.Graph().as_default() as g: + with ops.Graph().as_default(): t1 = constant(1.0) t2 = constant(2.0) t3 = array_ops.stack([t1, t2]) @@ -126,10 +124,10 @@ class GradientsTest(test_util.TensorFlowTestCase): t6 = constant([2.0]) t7 = array_ops.concat([t5, t6], 0) self._assertOpListEqual([t7.op, t5.op, t4.op], - _OpsBetween(g, [t7.op], [t4.op])) + _OpsBetween([t7.op], [t4.op])) def testOpsBetweenCycle(self): - with ops.Graph().as_default() as g: + with ops.Graph().as_default(): t1 = constant(1.0) t2 = constant(2.0) t3 = array_ops.stack([t1, t2]) @@ -138,11 +136,11 @@ class GradientsTest(test_util.TensorFlowTestCase): t6 = array_ops.concat([t4, t5], 0) t7 = array_ops.concat([t6, t3], 0) self._assertOpListEqual([t6.op, t4.op, t3.op], - _OpsBetween(g, [t6.op], [t3.op])) + _OpsBetween([t6.op], [t3.op])) self._assertOpListEqual([t7.op, t6.op, t5.op, t4.op, t3.op, t1.op], - _OpsBetween(g, [t7.op], [t1.op, t5.op])) + _OpsBetween([t7.op], [t1.op, t5.op])) self._assertOpListEqual([t6.op, t5.op, t4.op, t3.op, t2.op], - _OpsBetween(g, [t6.op], [t2.op, t5.op])) + _OpsBetween([t6.op], [t2.op, t5.op])) def testGradients(self): with ops.Graph().as_default():