from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.framework import device as pydev
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
_MAX_WARNING_LINES = 5
_TPU_REPLICATE_ATTR = "_tpu_replicate"
+_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation"
def _tpu_system_device_name(job):
outside the replicated computation.
"""
- def __init__(self, name):
+ def __init__(self, name, num_replicas):
super(TPUReplicateContext, self).__init__()
+ self._num_replicas = num_replicas
+ self._outer_device_function_stack = None
+ self._oc_dev_fn_stack = None
+ self._outside_compilation_cluster = None
+ self._outside_compilation_counter = 0
+ self._in_gradient_colocation = None
+ self._gradient_colocation_stack = []
+ self._host_compute_core = []
self._name = name
self._unsupported_ops = []
logging.warning("... and %d more" %
(len(self._unsupported_ops) - _MAX_WARNING_LINES))
+ def EnterGradientColocation(self, op, gradient_uid):
+ if op is not None:
+ self._gradient_colocation_stack.append(op)
+ if not self._outside_compilation_cluster:
+ try:
+ outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR)
+ if self._in_gradient_colocation:
+ raise NotImplementedError(
+ "Cannot nest gradient colocation operations outside compilation"
+ )
+ if gradient_uid == "__unsupported__":
+ raise NotImplementedError(
+ "No gradient_uid calling gradient within outside_compilation")
+ # When we take the gradient of an op X in an
+ # outside_compilation cluster C in a forward computation we
+ # would like to put the ops corresponding to the gradient of
+ # X into a new outside_compilation cluster C'. However, if
+ # we take the gradient of X twice, the second one should get
+ # yet another new outside_compilation cluster C''.
+ #
+ # The mechanism we adopt is to use a 'root_cluster' which is
+ # the cluster that X was in before we took gradients, and a
+ # 'gradient_uid' which is different for every invocation of
+ # gradients, and put the gradient of X in cluster
+ # 'root_cluster.gradient_uid'.
+ #
+ # When the gradient code adds multiple Ops, it asks them to
+ # be colocated either with the original Op X, or with one of
+ # the preceding Ops that was added to the gradient. In other
+ # words, we want to detect the case where we are colocating
+ # with an Op that is in cluster root_cluster.gradient_uid
+ # and put the new Op in that same cluster if the
+ # gradient_uid is the same (the case that we are in the same
+ # invocation of gradients, and just adding new Ops to the
+ # cluster); and in a different cluster if the gradient_uids
+ # are different (the case that we are in a new invocation of
+ # gradients, taking the gradient of a previously-computed
+ # gradient).
+ self._in_gradient_colocation = op
+ parts = outside_attr.split(".")
+ if len(parts) > 1:
+ uid = parts[-1]
+ if uid == gradient_uid:
+ # Keep using the same cluster
+ cluster = outside_attr
+ else:
+ # We're taking the gradient of a gradient so make a new
+ # cluster attr, adding a new '.uid' on the end to
+ # preserve the invariant that the gradient_uid is the
+ # suffix after the last '.' in the attr.
+ cluster = outside_attr + "." + gradient_uid
+ else:
+ # We're taking the gradient of an Op in the forward pass, so
+ # make a new cluster combining the Op's cluster and the
+ # gradient id.
+ cluster = outside_attr + "." + gradient_uid
+ self._EnterOutsideCompilationScope(cluster=cluster)
+ except ValueError:
+ # The attr was not present: do nothing.
+ pass
+
+ def ExitGradientColocation(self, op, gradient_uid):
+ if op is not None:
+ if not self._gradient_colocation_stack:
+ raise errors.InternalError(
+ op.node_def, op,
+ "Badly nested gradient colocation: empty stack when popping Op " +
+ op.name)
+ last_op = self._gradient_colocation_stack.pop()
+ if op is last_op:
+ if op is self._in_gradient_colocation:
+ self._in_gradient_colocation = None
+ self._ExitOutsideCompilationScope()
+ else:
+ raise errors.InternalError(
+ op.node_def, op, "Badly nested gradient colocation, expected " +
+ last_op + ", got " + op.name)
+
+ def _EnterOutsideCompilationScope(self, cluster=None):
+
+ class FakeOp(object):
+ """A helper class to determine the current device.
+
+ Supports only the device set/get methods needed to run the
+ graph's _apply_device_function method.
+ """
+
+ def __init__(self):
+ self._device = ""
+
+ @property
+ def device(self):
+ return self._device
+
+ def _set_device(self, device):
+ self._device = device.to_string()
+
+ if self._outside_compilation_cluster:
+ raise NotImplementedError("Cannot nest outside_compilation clusters")
+ if cluster:
+ self._outside_compilation_cluster = cluster
+ else:
+ self._outside_compilation_cluster = str(self._outside_compilation_counter)
+ self._outside_compilation_counter += 1
+ graph = ops.get_default_graph()
+ fake_op = FakeOp()
+ graph._apply_device_functions(fake_op) # pylint: disable=protected-access
+ device = pydev.DeviceSpec.from_string(fake_op.device)
+ if (device.device_type == "TPU_REPLICATED_CORE" and
+ device.device_index is not None):
+ self._host_compute_core.append(self._outside_compilation_cluster + ":" +
+ str(device.device_index))
+ self._oc_dev_fn_stack = graph._device_function_stack # pylint: disable=protected-access
+ graph._device_function_stack = self._outer_device_function_stack # pylint: disable=protected-access
+
+ def _ExitOutsideCompilationScope(self):
+ if not self._outside_compilation_cluster:
+ raise NotImplementedError(
+ "Attempted to exit outside_compilation scope when not in scope")
+ self._outside_compilation_cluster = None
+ graph = ops.get_default_graph()
+ graph._device_function_stack = self._oc_dev_fn_stack # pylint: disable=protected-access
+
+ def Enter(self):
+ if not self._outer_device_function_stack:
+ # Capture the device function stack at the time of first entry
+ # since that is the stack that will be used outside_compilation.
+ graph = ops.get_default_graph()
+ self._outer_device_function_stack = list(graph._device_function_stack) # pylint: disable=protected-access
+ super(TPUReplicateContext, self).Enter()
+
+ def Exit(self):
+ super(TPUReplicateContext, self).Exit()
+
+ def HostComputeCore(self):
+ return self._host_compute_core
+
def AddOp(self, op):
self._AddOpInternal(op)
raise ValueError("TPU computations cannot be nested")
op._set_attr(_TPU_REPLICATE_ATTR,
attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
- # pylint: enable=protected-access
- op.graph.prevent_feeding(op)
- op.graph.prevent_fetching(op)
+ if self._outside_compilation_cluster:
+ op._set_attr(
+ _OUTSIDE_COMPILATION_ATTR,
+ attr_value_pb2.AttrValue(
+ s=compat.as_bytes(self._outside_compilation_cluster)))
+ if self._num_replicas > 1 or not self._outside_compilation_cluster:
+ # Prevent feeding or fetching anything that is being compiled,
+ # and any replicated outside_compilation Op.
+ op.graph.prevent_feeding(op)
+ op.graph.prevent_fetching(op)
def AddValue(self, val):
result = val
return None
+def outside_compilation(computation, args=None):
+ """Builds part of a computation outside any current TPU replicate scope.
+
+ Args:
+ computation: A Python function that builds the computation to
+ place on the host.
+ args: Inputs to pass to computation.
+ Returns:
+ The Tensors returned by computation.
+ """
+ graph = ops.get_default_graph()
+
+ # If we are in a TPUReplicateContext, signal that we are now
+ # outside_compilation
+ initial_context = graph._get_control_flow_context() # pylint: disable=protected-access
+ context = initial_context
+ while context:
+ if isinstance(context, TPUReplicateContext):
+ context._EnterOutsideCompilationScope() # pylint: disable=protected-access
+ context = context.outer_context
+
+ retval = computation(*args)
+
+ # If we are in a TPUReplicateContext, signal that we are no longer
+ # outside_compilation
+ final_context = graph._get_control_flow_context() # pylint: disable=protected-access
+ if initial_context is not final_context:
+ raise NotImplementedError(
+ "Control-flow context cannot be different at start and end of an "
+ "outside_compilation scope")
+ context = initial_context
+ while context:
+ if isinstance(context, TPUReplicateContext):
+ context._ExitOutsideCompilationScope() # pylint: disable=protected-access
+ context = context.outer_context
+
+ return retval
+
+
def replicate(computation,
inputs=None,
infeed_queue=None,
computation_inputs.append(
tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))
- context = TPUReplicateContext(name=graph.unique_name("cluster"))
+ context = TPUReplicateContext(
+ name=graph.unique_name("cluster"), num_replicas=num_replicas)
try:
context.Enter()
finally:
context.report_unsupported_operations()
context.Exit()
+ host_compute_core = context.HostComputeCore()
+
+ if host_compute_core:
+ attr_value = attr_value_pb2.AttrValue()
+ attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core])
+ metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access
# Fan-out: Builds a TPUReplicatedOutput node for each output.
outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas,
return x if isinstance(x, (list, tuple)) else [x]
-def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
+def _DefaultGradYs(grad_ys,
+ ys,
+ colocate_gradients_with_ops,
+ gradient_uid="__unsupported__"):
"""Fill in default values for grad_ys.
Args:
ys: List of tensors.
colocate_gradients_with_ops: If True, try colocating gradients with
the corresponding op.
+ gradient_uid: A unique identifier within the graph indicating
+ which invocation of gradients is being executed. Used to cluster
+ ops for compilation.
Returns:
A list of gradients to use, without None.
for i in xrange(len(grad_ys)):
grad_y = grad_ys[i]
y = ys[i]
- with _maybe_colocate_with(y.op, colocate_gradients_with_ops):
+ with _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops):
if grad_y is None:
if y.dtype.is_complex:
raise TypeError(
@contextlib.contextmanager
-def _maybe_colocate_with(op, colocate_gradients_with_ops):
+def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): # pylint: disable=invalid-name
"""Context to colocate with `op` if `colocate_gradients_with_ops`."""
if colocate_gradients_with_ops:
- with ops.colocate_with(op):
+ with ops._colocate_with_for_gradient(op, gradient_uid): # pylint: disable=protected-access
yield
else:
yield
with ops.name_scope(
name, "gradients",
list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
+ # Get a uid for this call to gradients that can be used to help
+ # cluster ops for compilation.
+ gradient_uid = ops.get_default_graph().unique_name("uid")
ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
xs = [
x.handle if resource_variable_ops.is_resource_variable(x) else x
]
xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
xs, name="x", as_ref=True)
- grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)
+ grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
+ gradient_uid)
# The approach we take here is as follows: Create a list of all ops in the
# subgraph between the ys and xs. Visit these ops in reverse order of ids
while queue:
# generate gradient subgraph for op.
op = queue.popleft()
- with _maybe_colocate_with(op, colocate_gradients_with_ops):
+ with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):
if loop_state:
loop_state.EnterGradWhileContext(op, before=True)
- out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method)
+ out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state,
+ aggregation_method)
if loop_state:
loop_state.ExitGradWhileContext(op, before=True)
if gate_gradients and len([x for x in in_grads
if x is not None]) > 1:
with ops.device(None):
- with ops.colocate_with(None, ignore_existing=True):
+ with ops._colocate_with_for_gradient( # pylint: disable=protected-access
+ None,
+ gradient_uid,
+ ignore_existing=True):
in_grads = control_flow_ops.tuple(in_grads)
_LogOpGradients(op, out_grads, in_grads)
else:
", ".join([x.name for x in in_grads if _FilterGrad(x)]))
-def _MultiDeviceAddN(tensor_list):
+def _MultiDeviceAddN(tensor_list, gradient_uid):
"""Adds tensors from potentially multiple devices."""
# Basic function structure comes from control_flow_ops.group().
# Sort tensors according to their devices.
for dev in sorted(six.iterkeys(tensors_on_device), key=DeviceKey):
tensors = tensors_on_device[dev]
- with ops.colocate_with(tensors[0].op, ignore_existing=True):
+ with ops._colocate_with_for_gradient( # pylint: disable=protected-access
+ tensors[0].op,
+ gradient_uid,
+ ignore_existing=True):
summands.append(math_ops.add_n(tensors))
return math_ops.add_n(summands)
EXPERIMENTAL_ACCUMULATE_N = 2
-def _AggregatedGrads(grads, op, loop_state, aggregation_method=None):
+def _AggregatedGrads(grads,
+ op,
+ gradient_uid,
+ loop_state,
+ aggregation_method=None):
"""Get the aggregated gradients for op.
Args:
grads: The map of memoized gradients.
op: The op to get gradients for.
+ gradient_uid: A unique identifier within the graph indicating
+ which invocation of gradients is being executed. Used to cluster
+ ops for compilation.
loop_state: An object for maintaining the state of the while loops in the
graph. It is of type ControlFlowState. None if the graph
contains no while loops.
out_grads[i] = running_sum
else:
used = "add_n"
- out_grads[i] = _MultiDeviceAddN(out_grad)
+ out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid)
logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad),
tensor_shape, used)
else: