Initial Python API for specifying outside_compilation blocks that call out from a...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 9 Apr 2018 16:11:01 +0000 (09:11 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 9 Apr 2018 16:15:35 +0000 (09:15 -0700)
For now outside_compilation cannot occur inside any compiled control flow (while loop or conditional). If the computation is replicated, the outside_compilation ops are also replicated. Both of these restrictions will be lifted in followup CLs.

PiperOrigin-RevId: 192135901

tensorflow/compiler/tf2xla/functionalize_control_flow.cc
tensorflow/contrib/tpu/python/tpu/tpu.py
tensorflow/contrib/tpu/python/tpu/tpu_test.py
tensorflow/python/eager/function.py
tensorflow/python/framework/ops.py
tensorflow/python/ops/control_flow_ops.py
tensorflow/python/ops/gradients_impl.py

index 8b7beef..16b9142 100644 (file)
@@ -901,6 +901,14 @@ FunctionalizeCond::DeterminePredicateSwitchOrder() {
       int src_depth = switch_depth[src_id];
       if (!e->IsControlEdge() || new_switch_depth == src_depth) {
         if (src_depth != new_switch_depth) {
+          // TODO(b/77601805) remove this when outside_compilation supports
+          // control flow.
+          if (str_util::StrContains(src->name(), "outside_compilation") ||
+              str_util::StrContains(n->name(), "outside_compilation")) {
+            return errors::InvalidArgument(
+                "outside_compilation is not yet supported within TensorFlow "
+                "control flow constructs b/77601805");
+          }
           return errors::InvalidArgument(
               "Unable to functionalize control flow in graph: Operand ('",
               src->name(), "') and operator ('", n->name(),
index 3f2db54..a1690da 100644 (file)
@@ -25,6 +25,8 @@ from tensorflow.contrib.tpu.python.ops import tpu_ops
 from tensorflow.contrib.tpu.python.tpu import tpu_function
 
 from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.framework import device as pydev
+from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
@@ -56,6 +58,7 @@ _NOT_IMPLEMENTED_OPS = set([
 _MAX_WARNING_LINES = 5
 
 _TPU_REPLICATE_ATTR = "_tpu_replicate"
+_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation"
 
 
 def _tpu_system_device_name(job):
@@ -121,8 +124,16 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
   outside the replicated computation.
   """
 
-  def __init__(self, name):
+  def __init__(self, name, num_replicas):
     super(TPUReplicateContext, self).__init__()
+    self._num_replicas = num_replicas
+    self._outer_device_function_stack = None
+    self._oc_dev_fn_stack = None
+    self._outside_compilation_cluster = None
+    self._outside_compilation_counter = 0
+    self._in_gradient_colocation = None
+    self._gradient_colocation_stack = []
+    self._host_compute_core = []
     self._name = name
     self._unsupported_ops = []
 
@@ -136,6 +147,143 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
         logging.warning("... and %d more" %
                         (len(self._unsupported_ops) - _MAX_WARNING_LINES))
 
+  def EnterGradientColocation(self, op, gradient_uid):
+    if op is not None:
+      self._gradient_colocation_stack.append(op)
+      if not self._outside_compilation_cluster:
+        try:
+          outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR)
+          if self._in_gradient_colocation:
+            raise NotImplementedError(
+                "Cannot nest gradient colocation operations outside compilation"
+            )
+          if gradient_uid == "__unsupported__":
+            raise NotImplementedError(
+                "No gradient_uid calling gradient within outside_compilation")
+          # When we take the gradient of an op X in an
+          # outside_compilation cluster C in a forward computation we
+          # would like to put the ops corresponding to the gradient of
+          # X into a new outside_compilation cluster C'. However, if
+          # we take the gradient of X twice, the second one should get
+          # yet another new outside_compilation cluster C''.
+          #
+          # The mechanism we adopt is to use a 'root_cluster' which is
+          # the cluster that X was in before we took gradients, and a
+          # 'gradient_uid' which is different for every invocation of
+          # gradients, and put the gradient of X in cluster
+          # 'root_cluster.gradient_uid'.
+          #
+          # When the gradient code adds multiple Ops, it asks them to
+          # be colocated either with the original Op X, or with one of
+          # the preceding Ops that was added to the gradient. In other
+          # words, we want to detect the case where we are colocating
+          # with an Op that is in cluster root_cluster.gradient_uid
+          # and put the new Op in that same cluster if the
+          # gradient_uid is the same (the case that we are in the same
+          # invocation of gradients, and just adding new Ops to the
+          # cluster); and in a different cluster if the gradient_uids
+          # are different (the case that we are in a new invocation of
+          # gradients, taking the gradient of a previously-computed
+          # gradient).
+          self._in_gradient_colocation = op
+          parts = outside_attr.split(".")
+          if len(parts) > 1:
+            uid = parts[-1]
+            if uid == gradient_uid:
+              # Keep using the same cluster
+              cluster = outside_attr
+            else:
+              # We're taking the gradient of a gradient so make a new
+              # cluster attr, adding a new '.uid' on the end to
+              # preserve the invariant that the gradient_uid is the
+              # suffix after the last '.' in the attr.
+              cluster = outside_attr + "." + gradient_uid
+          else:
+            # We're taking the gradient of an Op in the forward pass, so
+            # make a new cluster combining the Op's cluster and the
+            # gradient id.
+            cluster = outside_attr + "." + gradient_uid
+          self._EnterOutsideCompilationScope(cluster=cluster)
+        except ValueError:
+          # The attr was not present: do nothing.
+          pass
+
+  def ExitGradientColocation(self, op, gradient_uid):
+    if op is not None:
+      if not self._gradient_colocation_stack:
+        raise errors.InternalError(
+            op.node_def, op,
+            "Badly nested gradient colocation: empty stack when popping Op " +
+            op.name)
+      last_op = self._gradient_colocation_stack.pop()
+      if op is last_op:
+        if op is self._in_gradient_colocation:
+          self._in_gradient_colocation = None
+          self._ExitOutsideCompilationScope()
+      else:
+        raise errors.InternalError(
+            op.node_def, op, "Badly nested gradient colocation, expected " +
+            last_op + ", got " + op.name)
+
+  def _EnterOutsideCompilationScope(self, cluster=None):
+
+    class FakeOp(object):
+      """A helper class to determine the current device.
+
+      Supports only the device set/get methods needed to run the
+      graph's _apply_device_function method.
+      """
+
+      def __init__(self):
+        self._device = ""
+
+      @property
+      def device(self):
+        return self._device
+
+      def _set_device(self, device):
+        self._device = device.to_string()
+
+    if self._outside_compilation_cluster:
+      raise NotImplementedError("Cannot nest outside_compilation clusters")
+    if cluster:
+      self._outside_compilation_cluster = cluster
+    else:
+      self._outside_compilation_cluster = str(self._outside_compilation_counter)
+      self._outside_compilation_counter += 1
+    graph = ops.get_default_graph()
+    fake_op = FakeOp()
+    graph._apply_device_functions(fake_op)  # pylint: disable=protected-access
+    device = pydev.DeviceSpec.from_string(fake_op.device)
+    if (device.device_type == "TPU_REPLICATED_CORE" and
+        device.device_index is not None):
+      self._host_compute_core.append(self._outside_compilation_cluster + ":" +
+                                     str(device.device_index))
+    self._oc_dev_fn_stack = graph._device_function_stack  # pylint: disable=protected-access
+    graph._device_function_stack = self._outer_device_function_stack  # pylint: disable=protected-access
+
+  def _ExitOutsideCompilationScope(self):
+    if not self._outside_compilation_cluster:
+      raise NotImplementedError(
+          "Attempted to exit outside_compilation scope when not in scope")
+    self._outside_compilation_cluster = None
+    graph = ops.get_default_graph()
+    graph._device_function_stack = self._oc_dev_fn_stack  # pylint: disable=protected-access
+
+  def Enter(self):
+    if not self._outer_device_function_stack:
+      # Capture the device function stack at the time of first entry
+      # since that is the stack that will be used outside_compilation.
+      graph = ops.get_default_graph()
+      self._outer_device_function_stack = list(graph._device_function_stack)  # pylint: disable=protected-access
+    super(TPUReplicateContext, self).Enter()
+
+  def Exit(self):
+    super(TPUReplicateContext, self).Exit()
+
+  def HostComputeCore(self):
+    return self._host_compute_core
+
   def AddOp(self, op):
     self._AddOpInternal(op)
 
@@ -157,9 +305,16 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
       raise ValueError("TPU computations cannot be nested")
     op._set_attr(_TPU_REPLICATE_ATTR,
                  attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
-    # pylint: enable=protected-access
-    op.graph.prevent_feeding(op)
-    op.graph.prevent_fetching(op)
+    if self._outside_compilation_cluster:
+      op._set_attr(
+          _OUTSIDE_COMPILATION_ATTR,
+          attr_value_pb2.AttrValue(
+              s=compat.as_bytes(self._outside_compilation_cluster)))
+    if self._num_replicas > 1 or not self._outside_compilation_cluster:
+      # Prevent feeding or fetching anything that is being compiled,
+      # and any replicated outside_compilation Op.
+      op.graph.prevent_feeding(op)
+      op.graph.prevent_fetching(op)
 
   def AddValue(self, val):
     result = val
@@ -181,6 +336,45 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
     return None
 
 
+def outside_compilation(computation, args=None):
+  """Builds part of a computation outside any current TPU replicate scope.
+
+  Args:
+    computation: A Python function that builds the computation to
+      place on the host.
+    args: Inputs to pass to computation.
+  Returns:
+    The Tensors returned by computation.
+  """
+  graph = ops.get_default_graph()
+
+  # If we are in a TPUReplicateContext, signal that we are now
+  # outside_compilation
+  initial_context = graph._get_control_flow_context()  # pylint: disable=protected-access
+  context = initial_context
+  while context:
+    if isinstance(context, TPUReplicateContext):
+      context._EnterOutsideCompilationScope()  # pylint: disable=protected-access
+    context = context.outer_context
+
+  retval = computation(*args)
+
+  # If we are in a TPUReplicateContext, signal that we are no longer
+  # outside_compilation
+  final_context = graph._get_control_flow_context()  # pylint: disable=protected-access
+  if initial_context is not final_context:
+    raise NotImplementedError(
+        "Control-flow context cannot be different at start and end of an "
+        "outside_compilation scope")
+  context = initial_context
+  while context:
+    if isinstance(context, TPUReplicateContext):
+      context._ExitOutsideCompilationScope()  # pylint: disable=protected-access
+    context = context.outer_context
+
+  return retval
+
+
 def replicate(computation,
               inputs=None,
               infeed_queue=None,
@@ -280,7 +474,8 @@ def replicate(computation,
     computation_inputs.append(
         tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))
 
-  context = TPUReplicateContext(name=graph.unique_name("cluster"))
+  context = TPUReplicateContext(
+      name=graph.unique_name("cluster"), num_replicas=num_replicas)
   try:
     context.Enter()
 
@@ -361,6 +556,12 @@ def replicate(computation,
   finally:
     context.report_unsupported_operations()
     context.Exit()
+    host_compute_core = context.HostComputeCore()
+
+  if host_compute_core:
+    attr_value = attr_value_pb2.AttrValue()
+    attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core])
+    metadata._set_attr("host_compute_core", attr_value)  # pylint: disable=protected-access
 
   # Fan-out: Builds a TPUReplicatedOutput node for each output.
   outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas,
index 336d826..c3882b8 100644 (file)
@@ -37,7 +37,7 @@ class TPUContextTest(test.TestCase):
   def testIsInContext(self):
     """Test that control_flow_util can check that we're in a TPU context."""
     z1 = array_ops.identity(1)
-    context = tpu.TPUReplicateContext(b"context")
+    context = tpu.TPUReplicateContext(b"context", 1)
     context.Enter()
     z2 = array_ops.identity(1)
     context.Exit()
index 61859d6..5168ad3 100644 (file)
@@ -223,6 +223,16 @@ class HelperContext(object):
     else:
       return val
 
+  def EnterGradientColocation(self, op, gradient_uid):
+    """Start building a gradient colocated with an op."""
+    if self._outer_context:
+      self._outer_context.EnterGradientColocation(op, gradient_uid)
+
+  def ExitGradientColocation(self, op, gradient_uid):
+    """Start building a gradient colocated with an op."""
+    if self._outer_context:
+      self._outer_context.ExitGradientColocation(op, gradient_uid)
+
   def __enter__(self):
     # pylint: disable=protected-access
     self._g = ops.get_default_graph()
index 2574fa5..e3ca5a4 100644 (file)
@@ -4180,6 +4180,19 @@ class Graph(object):
     return self._name_stack
 
   @tf_contextlib.contextmanager
+  def _colocate_with_for_gradient(self, op, gradient_uid,
+                                  ignore_existing=False):
+    with self.colocate_with(op, ignore_existing):
+      if gradient_uid is not None and self._control_flow_context is not None:
+        try:
+          self._control_flow_context.EnterGradientColocation(op, gradient_uid)
+          yield
+        finally:
+          self._control_flow_context.ExitGradientColocation(op, gradient_uid)
+      else:
+        yield
+
+  @tf_contextlib.contextmanager
   def colocate_with(self, op, ignore_existing=False):
     """Returns a context manager that specifies an op to colocate with.
 
@@ -4958,8 +4971,7 @@ def container(container_name):
   return get_default_graph().container(container_name)
 
 
-@tf_export("colocate_with")
-def colocate_with(op, ignore_existing=False):
+def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False):
   if context.executing_eagerly():
     if op is not None:
       return device(op.device)
@@ -4973,7 +4985,13 @@ def colocate_with(op, ignore_existing=False):
       else:
         raise ValueError("Encountered an Eager-defined Tensor during graph "
                          "construction, but a function was not being built.")
-    return default_graph.colocate_with(op, ignore_existing)
+    return default_graph._colocate_with_for_gradient(
+        op, gradient_uid=gradient_uid, ignore_existing=ignore_existing)
+
+
+@tf_export("colocate_with")
+def colocate_with(op, ignore_existing=False):
+  return _colocate_with_for_gradient(op, None, ignore_existing=ignore_existing)
 
 
 @tf_export("control_dependencies")
index e56ab93..7be8628 100644 (file)
@@ -1595,6 +1595,16 @@ class ControlFlowContext(object):
     last_context = self._context_stack.pop()
     graph._set_control_flow_context(last_context)
 
+  def EnterGradientColocation(self, op, gradient_uid):
+    """Start building a gradient colocated with an op."""
+    if self._outer_context:
+      self._outer_context.EnterGradientColocation(op, gradient_uid)
+
+  def ExitGradientColocation(self, op, gradient_uid):
+    """Start building a gradient colocated with an op."""
+    if self._outer_context:
+      self._outer_context.ExitGradientColocation(op, gradient_uid)
+
   def ExitResult(self, result):
     """Make a list of tensors available in the outer context."""
     if self._outer_context:
index 44473ec..13420b7 100644 (file)
@@ -208,7 +208,10 @@ def _AsList(x):
   return x if isinstance(x, (list, tuple)) else [x]
 
 
-def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
+def _DefaultGradYs(grad_ys,
+                   ys,
+                   colocate_gradients_with_ops,
+                   gradient_uid="__unsupported__"):
   """Fill in default values for grad_ys.
 
   Args:
@@ -216,6 +219,9 @@ def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
     ys: List of tensors.
     colocate_gradients_with_ops: If True, try colocating gradients with
       the corresponding op.
+    gradient_uid: A unique identifier within the graph indicating
+      which invocation of gradients is being executed. Used to cluster
+      ops for compilation.
 
   Returns:
     A list of gradients to use, without None.
@@ -231,7 +237,7 @@ def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
   for i in xrange(len(grad_ys)):
     grad_y = grad_ys[i]
     y = ys[i]
-    with _maybe_colocate_with(y.op, colocate_gradients_with_ops):
+    with _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops):
       if grad_y is None:
         if y.dtype.is_complex:
           raise TypeError(
@@ -338,10 +344,10 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count):
 
 
 @contextlib.contextmanager
-def _maybe_colocate_with(op, colocate_gradients_with_ops):
+def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):  # pylint: disable=invalid-name
   """Context to colocate with `op` if `colocate_gradients_with_ops`."""
   if colocate_gradients_with_ops:
-    with ops.colocate_with(op):
+    with ops._colocate_with_for_gradient(op, gradient_uid):  # pylint: disable=protected-access
       yield
   else:
     yield
@@ -506,6 +512,9 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
   with ops.name_scope(
       name, "gradients",
       list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
+    # Get a uid for this call to gradients that can be used to help
+    # cluster ops for compilation.
+    gradient_uid = ops.get_default_graph().unique_name("uid")
     ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
     xs = [
         x.handle if resource_variable_ops.is_resource_variable(x) else x
@@ -513,7 +522,8 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
     ]
     xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
         xs, name="x", as_ref=True)
-    grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)
+    grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
+                             gradient_uid)
 
     # The approach we take here is as follows: Create a list of all ops in the
     # subgraph between the ys and xs.  Visit these ops in reverse order of ids
@@ -570,10 +580,11 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
     while queue:
       # generate gradient subgraph for op.
       op = queue.popleft()
-      with _maybe_colocate_with(op, colocate_gradients_with_ops):
+      with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):
         if loop_state:
           loop_state.EnterGradWhileContext(op, before=True)
-        out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method)
+        out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state,
+                                     aggregation_method)
         if loop_state:
           loop_state.ExitGradWhileContext(op, before=True)
 
@@ -633,7 +644,10 @@ def _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
               if gate_gradients and len([x for x in in_grads
                                          if x is not None]) > 1:
                 with ops.device(None):
-                  with ops.colocate_with(None, ignore_existing=True):
+                  with ops._colocate_with_for_gradient(  # pylint: disable=protected-access
+                      None,
+                      gradient_uid,
+                      ignore_existing=True):
                     in_grads = control_flow_ops.tuple(in_grads)
           _LogOpGradients(op, out_grads, in_grads)
         else:
@@ -789,7 +803,7 @@ def _LogOpGradients(op, out_grads, in_grads):
                ", ".join([x.name for x in in_grads if _FilterGrad(x)]))
 
 
-def _MultiDeviceAddN(tensor_list):
+def _MultiDeviceAddN(tensor_list, gradient_uid):
   """Adds tensors from potentially multiple devices."""
   # Basic function structure comes from control_flow_ops.group().
   # Sort tensors according to their devices.
@@ -808,7 +822,10 @@ def _MultiDeviceAddN(tensor_list):
 
   for dev in sorted(six.iterkeys(tensors_on_device), key=DeviceKey):
     tensors = tensors_on_device[dev]
-    with ops.colocate_with(tensors[0].op, ignore_existing=True):
+    with ops._colocate_with_for_gradient(  # pylint: disable=protected-access
+        tensors[0].op,
+        gradient_uid,
+        ignore_existing=True):
       summands.append(math_ops.add_n(tensors))
 
   return math_ops.add_n(summands)
@@ -834,12 +851,19 @@ class AggregationMethod(object):
   EXPERIMENTAL_ACCUMULATE_N = 2
 
 
-def _AggregatedGrads(grads, op, loop_state, aggregation_method=None):
+def _AggregatedGrads(grads,
+                     op,
+                     gradient_uid,
+                     loop_state,
+                     aggregation_method=None):
   """Get the aggregated gradients for op.
 
   Args:
     grads: The map of memoized gradients.
     op: The op to get gradients for.
+    gradient_uid: A unique identifier within the graph indicating
+      which invocation of gradients is being executed. Used to cluster
+      ops for compilation.
     loop_state: An object for maintaining the state of the while loops in the
                 graph. It is of type ControlFlowState. None if the graph
                 contains no while loops.
@@ -916,7 +940,7 @@ def _AggregatedGrads(grads, op, loop_state, aggregation_method=None):
             out_grads[i] = running_sum
         else:
           used = "add_n"
-          out_grads[i] = _MultiDeviceAddN(out_grad)
+          out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid)
         logging.vlog(2, "  _AggregatedGrads %d x %s using %s", len(out_grad),
                      tensor_shape, used)
       else: