From 38d1ac1e4f5b2a6e88eee43d332292898e0afc41 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 6 Apr 2018 17:31:43 -0700 Subject: [PATCH] Initial Python API for specifying outside_compilation blocks that call out from a TPU computation. For now outside_compilation cannot occur inside any compiled control flow (while loop or conditional). If the computation is replicated, the outside_compilation ops are also replicated. Both of these restrictions will be lifted in followup CLs. PiperOrigin-RevId: 191963758 --- .../compiler/tf2xla/functionalize_control_flow.cc | 8 + tensorflow/contrib/tpu/python/tpu/tpu.py | 211 ++++++++++++++++++++- tensorflow/contrib/tpu/python/tpu/tpu_test.py | 2 +- tensorflow/python/eager/function.py | 10 + tensorflow/python/framework/ops.py | 24 ++- tensorflow/python/ops/control_flow_ops.py | 10 + tensorflow/python/ops/gradients_impl.py | 48 +++-- 7 files changed, 292 insertions(+), 21 deletions(-) diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 8b7beef..16b9142 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -901,6 +901,14 @@ FunctionalizeCond::DeterminePredicateSwitchOrder() { int src_depth = switch_depth[src_id]; if (!e->IsControlEdge() || new_switch_depth == src_depth) { if (src_depth != new_switch_depth) { + // TODO(b/77601805) remove this when outside_compilation supports + // control flow. + if (str_util::StrContains(src->name(), "outside_compilation") || + str_util::StrContains(n->name(), "outside_compilation")) { + return errors::InvalidArgument( + "outside_compilation is not yet supported within TensorFlow " + "control flow constructs b/77601805"); + } return errors::InvalidArgument( "Unable to functionalize control flow in graph: Operand ('", src->name(), "') and operator ('", n->name(), diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index 3f2db54..a1690da 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -25,6 +25,8 @@ from tensorflow.contrib.tpu.python.ops import tpu_ops from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.framework import device as pydev +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -56,6 +58,7 @@ _NOT_IMPLEMENTED_OPS = set([ _MAX_WARNING_LINES = 5 _TPU_REPLICATE_ATTR = "_tpu_replicate" +_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation" def _tpu_system_device_name(job): @@ -121,8 +124,16 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): outside the replicated computation. """ - def __init__(self, name): + def __init__(self, name, num_replicas): super(TPUReplicateContext, self).__init__() + self._num_replicas = num_replicas + self._outer_device_function_stack = None + self._oc_dev_fn_stack = None + self._outside_compilation_cluster = None + self._outside_compilation_counter = 0 + self._in_gradient_colocation = None + self._gradient_colocation_stack = [] + self._host_compute_core = [] self._name = name self._unsupported_ops = [] @@ -136,6 +147,143 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): logging.warning("... and %d more" % (len(self._unsupported_ops) - _MAX_WARNING_LINES)) + def EnterGradientColocation(self, op, gradient_uid): + if op is not None: + self._gradient_colocation_stack.append(op) + if not self._outside_compilation_cluster: + try: + outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR) + if self._in_gradient_colocation: + raise NotImplementedError( + "Cannot nest gradient colocation operations outside compilation" + ) + if gradient_uid == "__unsupported__": + raise NotImplementedError( + "No gradient_uid calling gradient within outside_compilation") + # When we take the gradient of an op X in an + # outside_compilation cluster C in a forward computation we + # would like to put the ops corresponding to the gradient of + # X into a new outside_compilation cluster C'. However, if + # we take the gradient of X twice, the second one should get + # yet another new outside_compilation cluster C''. + # + # The mechanism we adopt is to use a 'root_cluster' which is + # the cluster that X was in before we took gradients, and a + # 'gradient_uid' which is different for every invocation of + # gradients, and put the gradient of X in cluster + # 'root_cluster.gradient_uid'. + # + # When the gradient code adds multiple Ops, it asks them to + # be colocated either with the original Op X, or with one of + # the preceding Ops that was added to the gradient. In other + # words, we want to detect the case where we are colocating + # with an Op that is in cluster root_cluster.gradient_uid + # and put the new Op in that same cluster if the + # gradient_uid is the same (the case that we are in the same + # invocation of gradients, and just adding new Ops to the + # cluster); and in a different cluster if the gradient_uids + # are different (the case that we are in a new invocation of + # gradients, taking the gradient of a previously-computed + # gradient). + self._in_gradient_colocation = op + parts = outside_attr.split(".") + if len(parts) > 1: + uid = parts[-1] + if uid == gradient_uid: + # Keep using the same cluster + cluster = outside_attr + else: + # We're taking the gradient of a gradient so make a new + # cluster attr, adding a new '.uid' on the end to + # preserve the invariant that the gradient_uid is the + # suffix after the last '.' in the attr. + cluster = outside_attr + "." + gradient_uid + else: + # We're taking the gradient of an Op in the forward pass, so + # make a new cluster combining the Op's cluster and the + # gradient id. + cluster = outside_attr + "." + gradient_uid + self._EnterOutsideCompilationScope(cluster=cluster) + except ValueError: + # The attr was not present: do nothing. + pass + + def ExitGradientColocation(self, op, gradient_uid): + if op is not None: + if not self._gradient_colocation_stack: + raise errors.InternalError( + op.node_def, op, + "Badly nested gradient colocation: empty stack when popping Op " + + op.name) + last_op = self._gradient_colocation_stack.pop() + if op is last_op: + if op is self._in_gradient_colocation: + self._in_gradient_colocation = None + self._ExitOutsideCompilationScope() + else: + raise errors.InternalError( + op.node_def, op, "Badly nested gradient colocation, expected " + + last_op + ", got " + op.name) + + def _EnterOutsideCompilationScope(self, cluster=None): + + class FakeOp(object): + """A helper class to determine the current device. + + Supports only the device set/get methods needed to run the + graph's _apply_device_function method. + """ + + def __init__(self): + self._device = "" + + @property + def device(self): + return self._device + + def _set_device(self, device): + self._device = device.to_string() + + if self._outside_compilation_cluster: + raise NotImplementedError("Cannot nest outside_compilation clusters") + if cluster: + self._outside_compilation_cluster = cluster + else: + self._outside_compilation_cluster = str(self._outside_compilation_counter) + self._outside_compilation_counter += 1 + graph = ops.get_default_graph() + fake_op = FakeOp() + graph._apply_device_functions(fake_op) # pylint: disable=protected-access + device = pydev.DeviceSpec.from_string(fake_op.device) + if (device.device_type == "TPU_REPLICATED_CORE" and + device.device_index is not None): + self._host_compute_core.append(self._outside_compilation_cluster + ":" + + str(device.device_index)) + self._oc_dev_fn_stack = graph._device_function_stack # pylint: disable=protected-access + graph._device_function_stack = self._outer_device_function_stack # pylint: disable=protected-access + + def _ExitOutsideCompilationScope(self): + if not self._outside_compilation_cluster: + raise NotImplementedError( + "Attempted to exit outside_compilation scope when not in scope") + self._outside_compilation_cluster = None + graph = ops.get_default_graph() + graph._device_function_stack = self._oc_dev_fn_stack # pylint: disable=protected-access + + def Enter(self): + if not self._outer_device_function_stack: + # Capture the device function stack at the time of first entry + # since that is the stack that will be used outside_compilation. + graph = ops.get_default_graph() + self._outer_device_function_stack = list(graph._device_function_stack) # pylint: disable=protected-access + super(TPUReplicateContext, self).Enter() + + def Exit(self): + super(TPUReplicateContext, self).Exit() + + def HostComputeCore(self): + return self._host_compute_core + def AddOp(self, op): self._AddOpInternal(op) @@ -157,9 +305,16 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): raise ValueError("TPU computations cannot be nested") op._set_attr(_TPU_REPLICATE_ATTR, attr_value_pb2.AttrValue(s=compat.as_bytes(self._name))) - # pylint: enable=protected-access - op.graph.prevent_feeding(op) - op.graph.prevent_fetching(op) + if self._outside_compilation_cluster: + op._set_attr( + _OUTSIDE_COMPILATION_ATTR, + attr_value_pb2.AttrValue( + s=compat.as_bytes(self._outside_compilation_cluster))) + if self._num_replicas > 1 or not self._outside_compilation_cluster: + # Prevent feeding or fetching anything that is being compiled, + # and any replicated outside_compilation Op. + op.graph.prevent_feeding(op) + op.graph.prevent_fetching(op) def AddValue(self, val): result = val @@ -181,6 +336,45 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext): return None +def outside_compilation(computation, args=None): + """Builds part of a computation outside any current TPU replicate scope. + + Args: + computation: A Python function that builds the computation to + place on the host. + args: Inputs to pass to computation. + Returns: + The Tensors returned by computation. + """ + graph = ops.get_default_graph() + + # If we are in a TPUReplicateContext, signal that we are now + # outside_compilation + initial_context = graph._get_control_flow_context() # pylint: disable=protected-access + context = initial_context + while context: + if isinstance(context, TPUReplicateContext): + context._EnterOutsideCompilationScope() # pylint: disable=protected-access + context = context.outer_context + + retval = computation(*args) + + # If we are in a TPUReplicateContext, signal that we are no longer + # outside_compilation + final_context = graph._get_control_flow_context() # pylint: disable=protected-access + if initial_context is not final_context: + raise NotImplementedError( + "Control-flow context cannot be different at start and end of an " + "outside_compilation scope") + context = initial_context + while context: + if isinstance(context, TPUReplicateContext): + context._ExitOutsideCompilationScope() # pylint: disable=protected-access + context = context.outer_context + + return retval + + def replicate(computation, inputs=None, infeed_queue=None, @@ -280,7 +474,8 @@ def replicate(computation, computation_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) - context = TPUReplicateContext(name=graph.unique_name("cluster")) + context = TPUReplicateContext( + name=graph.unique_name("cluster"), num_replicas=num_replicas) try: context.Enter() @@ -361,6 +556,12 @@ def replicate(computation, finally: context.report_unsupported_operations() context.Exit() + host_compute_core = context.HostComputeCore() + + if host_compute_core: + attr_value = attr_value_pb2.AttrValue() + attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core]) + metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access # Fan-out: Builds a TPUReplicatedOutput node for each output. outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas, diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_test.py index 336d826..c3882b8 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_test.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_test.py @@ -37,7 +37,7 @@ class TPUContextTest(test.TestCase): def testIsInContext(self): """Test that control_flow_util can check that we're in a TPU context.""" z1 = array_ops.identity(1) - context = tpu.TPUReplicateContext(b"context") + context = tpu.TPUReplicateContext(b"context", 1) context.Enter() z2 = array_ops.identity(1) context.Exit() diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 61859d6..5168ad3 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -223,6 +223,16 @@ class HelperContext(object): else: return val + def EnterGradientColocation(self, op, gradient_uid): + """Start building a gradient colocated with an op.""" + if self._outer_context: + self._outer_context.EnterGradientColocation(op, gradient_uid) + + def ExitGradientColocation(self, op, gradient_uid): + """Start building a gradient colocated with an op.""" + if self._outer_context: + self._outer_context.ExitGradientColocation(op, gradient_uid) + def __enter__(self): # pylint: disable=protected-access self._g = ops.get_default_graph() diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 2574fa5..e3ca5a4 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -4180,6 +4180,19 @@ class Graph(object): return self._name_stack @tf_contextlib.contextmanager + def _colocate_with_for_gradient(self, op, gradient_uid, + ignore_existing=False): + with self.colocate_with(op, ignore_existing): + if gradient_uid is not None and self._control_flow_context is not None: + try: + self._control_flow_context.EnterGradientColocation(op, gradient_uid) + yield + finally: + self._control_flow_context.ExitGradientColocation(op, gradient_uid) + else: + yield + + @tf_contextlib.contextmanager def colocate_with(self, op, ignore_existing=False): """Returns a context manager that specifies an op to colocate with. @@ -4958,8 +4971,7 @@ def container(container_name): return get_default_graph().container(container_name) -@tf_export("colocate_with") -def colocate_with(op, ignore_existing=False): +def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False): if context.executing_eagerly(): if op is not None: return device(op.device) @@ -4973,7 +4985,13 @@ def colocate_with(op, ignore_existing=False): else: raise ValueError("Encountered an Eager-defined Tensor during graph " "construction, but a function was not being built.") - return default_graph.colocate_with(op, ignore_existing) + return default_graph._colocate_with_for_gradient( + op, gradient_uid=gradient_uid, ignore_existing=ignore_existing) + + +@tf_export("colocate_with") +def colocate_with(op, ignore_existing=False): + return _colocate_with_for_gradient(op, None, ignore_existing=ignore_existing) @tf_export("control_dependencies") diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index e56ab93..7be8628 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1595,6 +1595,16 @@ class ControlFlowContext(object): last_context = self._context_stack.pop() graph._set_control_flow_context(last_context) + def EnterGradientColocation(self, op, gradient_uid): + """Start building a gradient colocated with an op.""" + if self._outer_context: + self._outer_context.EnterGradientColocation(op, gradient_uid) + + def ExitGradientColocation(self, op, gradient_uid): + """Start building a gradient colocated with an op.""" + if self._outer_context: + self._outer_context.ExitGradientColocation(op, gradient_uid) + def ExitResult(self, result): """Make a list of tensors available in the outer context.""" if self._outer_context: diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 44473ec..13420b7 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -208,7 +208,10 @@ def _AsList(x): return x if isinstance(x, (list, tuple)) else [x] -def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops): +def _DefaultGradYs(grad_ys, + ys, + colocate_gradients_with_ops, + gradient_uid="__unsupported__"): """Fill in default values for grad_ys. Args: @@ -216,6 +219,9 @@ def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops): ys: List of tensors. colocate_gradients_with_ops: If True, try colocating gradients with the corresponding op. + gradient_uid: A unique identifier within the graph indicating + which invocation of gradients is being executed. Used to cluster + ops for compilation. Returns: A list of gradients to use, without None. @@ -231,7 +237,7 @@ def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops): for i in xrange(len(grad_ys)): grad_y = grad_ys[i] y = ys[i] - with _maybe_colocate_with(y.op, colocate_gradients_with_ops): + with _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops): if grad_y is None: if y.dtype.is_complex: raise TypeError( @@ -338,10 +344,10 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count): @contextlib.contextmanager -def _maybe_colocate_with(op, colocate_gradients_with_ops): +def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): # pylint: disable=invalid-name """Context to colocate with `op` if `colocate_gradients_with_ops`.""" if colocate_gradients_with_ops: - with ops.colocate_with(op): + with ops._colocate_with_for_gradient(op, gradient_uid): # pylint: disable=protected-access yield else: yield @@ -506,6 +512,9 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, with ops.name_scope( name, "gradients", list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope: + # Get a uid for this call to gradients that can be used to help + # cluster ops for compilation. + gradient_uid = ops.get_default_graph().unique_name("uid") ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y") xs = [ x.handle if resource_variable_ops.is_resource_variable(x) else x @@ -513,7 +522,8 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, ] xs = ops.internal_convert_n_to_tensor_or_indexed_slices( xs, name="x", as_ref=True) - grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops) + grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, + gradient_uid) # The approach we take here is as follows: Create a list of all ops in the # subgraph between the ys and xs. Visit these ops in reverse order of ids @@ -570,10 +580,11 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, while queue: # generate gradient subgraph for op. op = queue.popleft() - with _maybe_colocate_with(op, colocate_gradients_with_ops): + with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): if loop_state: loop_state.EnterGradWhileContext(op, before=True) - out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method) + out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state, + aggregation_method) if loop_state: loop_state.ExitGradWhileContext(op, before=True) @@ -633,7 +644,10 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, if gate_gradients and len([x for x in in_grads if x is not None]) > 1: with ops.device(None): - with ops.colocate_with(None, ignore_existing=True): + with ops._colocate_with_for_gradient( # pylint: disable=protected-access + None, + gradient_uid, + ignore_existing=True): in_grads = control_flow_ops.tuple(in_grads) _LogOpGradients(op, out_grads, in_grads) else: @@ -789,7 +803,7 @@ def _LogOpGradients(op, out_grads, in_grads): ", ".join([x.name for x in in_grads if _FilterGrad(x)])) -def _MultiDeviceAddN(tensor_list): +def _MultiDeviceAddN(tensor_list, gradient_uid): """Adds tensors from potentially multiple devices.""" # Basic function structure comes from control_flow_ops.group(). # Sort tensors according to their devices. @@ -808,7 +822,10 @@ def _MultiDeviceAddN(tensor_list): for dev in sorted(six.iterkeys(tensors_on_device), key=DeviceKey): tensors = tensors_on_device[dev] - with ops.colocate_with(tensors[0].op, ignore_existing=True): + with ops._colocate_with_for_gradient( # pylint: disable=protected-access + tensors[0].op, + gradient_uid, + ignore_existing=True): summands.append(math_ops.add_n(tensors)) return math_ops.add_n(summands) @@ -834,12 +851,19 @@ class AggregationMethod(object): EXPERIMENTAL_ACCUMULATE_N = 2 -def _AggregatedGrads(grads, op, loop_state, aggregation_method=None): +def _AggregatedGrads(grads, + op, + gradient_uid, + loop_state, + aggregation_method=None): """Get the aggregated gradients for op. Args: grads: The map of memoized gradients. op: The op to get gradients for. + gradient_uid: A unique identifier within the graph indicating + which invocation of gradients is being executed. Used to cluster + ops for compilation. loop_state: An object for maintaining the state of the while loops in the graph. It is of type ControlFlowState. None if the graph contains no while loops. @@ -916,7 +940,7 @@ def _AggregatedGrads(grads, op, loop_state, aggregation_method=None): out_grads[i] = running_sum else: used = "add_n" - out_grads[i] = _MultiDeviceAddN(out_grad) + out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid) logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad), tensor_shape, used) else: -- 2.7.4