From c9096fd166a9d7fdb62c6cb747a74edb73630b0c Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Tue, 16 Jan 2018 16:17:36 -0800 Subject: [PATCH] [TF] Fix XLA Control Flow gradient stacks max_size creation. Stack creation uses tf.while_loop's maximum_iterations iff the while_loop was created inside an XLA/TPU context. Added several error checks to ensure this provides useful error messages if the limited use case is not supported. PiperOrigin-RevId: 182128135 --- .../kernel_tests/control_flow_ops_py_test.py | 162 +++++++++++++++++- tensorflow/python/ops/control_flow_ops.py | 120 ++++++++++--- tensorflow/python/ops/control_flow_util.py | 9 +- 3 files changed, 253 insertions(+), 38 deletions(-) diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 7f2c2545dc..6e18ed132c 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -747,18 +747,162 @@ class ControlFlowTest(test.TestCase): maximum_iterations=1) self.assertEqual(1, r.eval()) - def testInvalidMaximumIterationsContext(self): - def outer_body(i, r): - r = control_flow_ops.while_loop(lambda i: i < 3, lambda i: i + 1, [0], - maximum_iterations=r.shape[0]) - return i, r + def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self): + v = constant_op.constant(1.0) + def training_loop_with_gradient(i): + out = control_flow_ops.while_loop( + lambda i_, _: i_ < 3, + lambda i_, j: [i_ + 1, j * v], + [0, 1.0], + maximum_iterations=i) + g = gradients_impl.gradients(out, v) + with ops.control_dependencies(g): + return i + 1 + + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + # Create training loop, ensure we can call gradient() of + # while_loop inside the training loop. + loop = control_flow_ops.while_loop( + lambda i: i < 3, training_loop_with_gradient, [0]) + xla_context.Exit() + + loop_execute = array_ops.identity(loop) # Because loop is not fetchable. + + # Should execute without issue. + self.assertEqual(3, self.evaluate(loop_execute)) + + def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self): + v = constant_op.constant(1.0) + def inner_body(i, x): + out = control_flow_ops.while_loop( + lambda i, _: i < 3, + lambda i, j: [i + 1, j * v], + [0, x], + maximum_iterations=i) + return out + + def create_while_loop(maximum_iterations=None): + return control_flow_ops.while_loop( + lambda i, _: i < 3, inner_body, [0, 1.0], + maximum_iterations=maximum_iterations) + + loop_no_xla = create_while_loop(maximum_iterations=5) + # maximum_iterations is fine outside of an XLA scope + gs = gradients_impl.gradients(loop_no_xla, v) + self.evaluate(gs) # This should execute without error. + + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + loop_no_maxiter = create_while_loop() + loop_with_maxiter = create_while_loop(maximum_iterations=2) + xla_context.Exit() with self.assertRaisesRegexp( ValueError, - "maximum_iterations tensor cannot be declared in tf.cond or " - "tf.while_loop"): - control_flow_ops.while_loop(lambda i, r: i < 3, outer_body, - [0, constant_op.constant([1])]) + r"Cannot create a gradient accumulator for tensor '.+' inside " + r"XLA while_loop because maximum_iterations was not passed to " + r"the tf.while_loop call \('.+'\)."): + _ = gradients_impl.gradients(loop_no_maxiter, v) + + with self.assertRaisesRegexp( + ValueError, + r"Cannot create a gradient accumulator for tensor '.+' inside XLA " + r"while_loop. maximum_iterations tensor '.+' for while_loop context " + r"'.+' must be statically known \(e.g. a constant value or known " + r"shape dimension\), or be defined at or outside the while loop " + r"context '.*' \(currently defined in '.*'\)"): + _ = gradients_impl.gradients(loop_with_maxiter, v) + + def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self): + v = constant_op.constant(1.0) + + def create_while_loop(): + max_iter_holder = [] + def create_mi(): + max_iter_holder.append(array_ops.placeholder(dtypes.int32, shape=())) + return 1.0 + _ = control_flow_ops.cond(constant_op.constant(True), + create_mi, create_mi) + + return control_flow_ops.while_loop( + lambda i, _: i < 3, lambda i, x: (i + 1, v * x), (0, 1.0), + maximum_iterations=max_iter_holder[0]) + + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + loop = create_while_loop() + xla_context.Exit() + + with self.assertRaisesRegexp( + ValueError, + r"Cannot create a gradient accumulator for tensor '.+' inside XLA " + r"while_loop. maximum_iterations tensor '.*Placeholder:0' for " + r"while_loop context '.+' must be statically known \(e.g. a constant " + r"value or known shape dimension\), or be defined at or outside the " + r"while loop context '' \(currently defined in 'cond/.+'\)"): + _ = gradients_impl.gradients(loop, v) + + def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self): + v = constant_op.constant(1.0) + + p = array_ops.placeholder(dtype=dtypes.int32) + + def mid_body_builder(iterations): + def mid_body(i, x): + r = control_flow_ops.while_loop( + lambda *_: True, + lambda i, x: (i + 1, v * x), + (0, x), + maximum_iterations=iterations, name="inner") + return (i + 1, gradients_impl.gradients(x + r[1], v)[0]) + return mid_body + + def outer_body(i, x): + iterations = array_ops.size(p, name="iterations") + return ( + i + 1, + x + control_flow_ops.while_loop( + lambda *_: True, mid_body_builder(iterations), (0, x), + maximum_iterations=iterations, name="mid")[1]) + + def create_while_loop(): + with ops.device("/cpu:0"): + r = control_flow_ops.while_loop( + lambda *_: True, outer_body, (0, 1.0), + maximum_iterations=5, name="outer") + return array_ops.identity(r[1]) + + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + final_with_xla_context = create_while_loop() + xla_context.Exit() + + final_without_xla_context = create_while_loop() + + with self.test_session(use_gpu=False) as sess: + opts = config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE) + run_metadata = config_pb2.RunMetadata() + + final_value_without_xla_context = sess.run( + final_without_xla_context, + feed_dict={p: [0, 0, 0]}) + + final_value_with_xla_context = sess.run( + final_with_xla_context, + feed_dict={p: [0, 0, 0]}, + options=opts, run_metadata=run_metadata) + + node_stats = run_metadata.step_stats.dev_stats[0].node_stats + stack_push_count = len( + [x for x in node_stats if x.node_name.endswith("StackPushV2")]) + # Pushes to the stack = product of maximum_iterations values; + # the last two "3"s comes from size(p), when p == [0, 0, 0]. + self.assertEqual(stack_push_count, 5 * 3 * 3) + + self.assertAllClose( + final_value_with_xla_context, final_value_without_xla_context) # Have more than 10 parallel iterations and hence exercise k-bound # most of the time. diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 3fca3f522f..86941a7f2a 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -681,6 +681,78 @@ def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True): return v +def GetMaxSizeFromNestedMaximumIterations(value, while_ctxt): + """Calculate a max_size for use by stack ops inside an XLA while_loop. + + Args: + value: The value inside the while_loop forward context. Used for printing + error messages. + while_ctxt: The forward context inside which value resides. This does + not always match the value's immediate context, as `value` may be + inside e.g. a cond context inside the while_loop. + + Returns: + A tensor containing the `max_size` to feed to a Stack initializer. + + Raises: + ValueError: If `value` is nested inside a `while_loop` that either + lacks a `maximum_iterations` parameter, or the `maximum_iterations` + parameter: + + - is inside a `while_loop` that is a parent of the calling context, and + - cannot be evaluated at graph build time to a constant. + """ + value_name = value.name + # curr_ctxt is the context that tf.gradients was called in. + curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access + + curr_ctxt_name = curr_ctxt.name if curr_ctxt is not None else "" + max_size = constant_op.constant(1) + + # Loop through all containing while contexts between value and the + # current context, multiplying together each context's + # max_iterations to get the maximum stack size. + while while_ctxt not in (None, curr_ctxt): + max_iter = while_ctxt.maximum_iterations + if max_iter is None: + raise ValueError( + "Cannot create a gradient accumulator for tensor '%s' inside " + "XLA while_loop because maximum_iterations was not passed to " + "the tf.while_loop call ('%s')." + % (value_name, while_ctxt.name)) + + # pylint: disable=protected-access + max_iter_ctxt = max_iter.op._get_control_flow_context() + # pylint: enable=protected-access + + # If max_iter_ctxt (non-strictly) contains curr_ctxt, then it's OK to use. + if util.IsContainingContext(curr_ctxt, max_iter_ctxt): + max_size *= max_iter + else: + # We cannot use max_iter because it's defined in a nested while + # or cond context, so will fail if we try to use it as input to + # any ops in curr_ctxt (e.g. max_size or the final accumulator + # stack). Attempt to get a constant value out to use instead. + const_max_iter = tensor_util.constant_value(max_iter) + if const_max_iter is None: + raise ValueError( + "Cannot create a gradient accumulator for tensor '%s' inside XLA " + "while_loop. maximum_iterations tensor '%s' for while_loop context " + "'%s' must be statically known (e.g. a constant value or known " + "shape dimension), or be defined at or outside the while loop " + "context '%s' (currently defined in '%s')." % ( + value_name, max_iter.name, while_ctxt.name, + curr_ctxt_name, max_iter_ctxt.name)) + max_size *= const_max_iter + + # Find the next outer WhileContext (or stop if we reach the + # tf.gradient's context). + while_ctxt = util.GetContainingWhileContext( + while_ctxt.outer_context, stop_ctxt=curr_ctxt) + + return max_size + + class GradLoopState(object): """The state used for constructing the gradient graph for a while loop. @@ -893,17 +965,24 @@ class GradLoopState(object): Raises: TypeError: For internal errors involving the value condition context. + ValueError: If `value` is inside a XLA scope and a valid max size + for the stack can't be found. """ - curr_ctxt = ops.get_default_graph()._get_control_flow_context() + # curr_ctxt is the context that tf.gradients was called in. + curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access with ops.control_dependencies(None): if curr_ctxt: curr_ctxt.Enter() with ops.colocate_with(value): - maximum_iterations = self.forward_context.maximum_iterations - if maximum_iterations is None: - maximum_iterations = constant_op.constant(-1, dtypes.int32) + # We only need to pass maximum_iterations to the stack if + # we're inside an XLA context. + if not util.IsInXLAContext(value.op): + max_size = constant_op.constant(-1, dtypes.int32) + else: + max_size = GetMaxSizeFromNestedMaximumIterations( + value, self.forward_context) # pylint: disable=protected-access acc = gen_data_flow_ops._stack_v2( - max_size=maximum_iterations, + max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") # pylint: enable=protected-access @@ -2902,27 +2981,6 @@ def while_loop(cond, body, loop_vars, shape_invariants=None, raise ValueError("maximum_iterations must be a scalar, saw shape: %s" % maximum_iterations.shape) - # If/when we generated the gradient for this while loop, the - # maximum_iterations tensor will be used as the input to any generated - # stack ops. It's likely the stacks will be outside any control flow - # context (i.e. if gradients() is called outside any control flow - # context), which will result in the maximum_iterations tensor being an - # illegal input (see control_flow_util.CheckInputFromValidContext). - # - # NOTE(skyewm): we could technically allow tensors from CondContexts, but - # that will be error-prone and hard to reason about for users. - # - # TODO(skyewm): make this work (it's tricky). - if (context.in_graph_mode() and - (util.IsInWhileLoop(maximum_iterations.op) or - util.IsInCond(maximum_iterations.op))): - raise ValueError( - "maximum_iterations tensor cannot be declared in tf.cond or " - "tf.while_loop. Please file an issue at " - "https://github.com/tensorflow/tensorflow/issues if you require " - "this functionality. (Control flow context: %s)" % - maximum_iterations.op._get_control_flow_context().name) # pylint: disable=protected-access - counter = constant_op.constant( 0, dtype=maximum_iterations.dtype, name="iteration_counter") orig_cond = cond @@ -3384,9 +3442,19 @@ def case(pred_fn_pairs, class XLAControlFlowContext(ControlFlowContext): """Base class for XLA and TPU control flow contexts.""" + def __init__(self): + super(XLAControlFlowContext, self).__init__() + self._name = "XLAControlFlowContext" + def IsXLAContext(self): return True + def AddOp(self, _): + pass + + def AddValue(self, x): + return x + ops.register_proto_function(ops.GraphKeys.COND_CONTEXT, proto_type=control_flow_pb2.CondContextDef, diff --git a/tensorflow/python/ops/control_flow_util.py b/tensorflow/python/ops/control_flow_util.py index 247c9f7299..eee31102db 100644 --- a/tensorflow/python/ops/control_flow_util.py +++ b/tensorflow/python/ops/control_flow_util.py @@ -96,7 +96,7 @@ def GetOutputContext(op): return ctxt -def GetContainingWhileContext(ctxt): +def GetContainingWhileContext(ctxt, stop_ctxt=None): """Returns the first ancestor WhileContext of `ctxt`. Returns `ctxt` if `ctxt` is a WhileContext, or None if `ctxt` is not in a @@ -104,13 +104,16 @@ def GetContainingWhileContext(ctxt): Args: ctxt: ControlFlowContext + stop_ctxt: ControlFlowContext, optional. If provided, the search will end + if it sees stop_ctxt. Returns: `ctxt` if `ctxt` is a WhileContext, the most nested WhileContext containing - `ctxt`, or None if `ctxt` is not in a while loop. + `ctxt`, or None if `ctxt` is not in a while loop. If `stop_ctxt` is not + `None`, this returns `ctxt` if it matches `stop_ctxt` in its traversal. """ while ctxt: - if ctxt.IsWhileContext(): return ctxt + if ctxt.IsWhileContext() or ctxt == stop_ctxt: return ctxt ctxt = ctxt.outer_context return None -- 2.34.1