from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.core.framework import attr_value_pb2
-from tensorflow.python.framework import device as pydev
-from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
_MAX_WARNING_LINES = 5
_TPU_REPLICATE_ATTR = "_tpu_replicate"
-_OUTSIDE_COMPILATION_ATTR = "_xla_outside_compilation"
def _tpu_system_device_name(job):
outside the replicated computation.
"""
- def __init__(self, name, num_replicas):
+ def __init__(self, name):
super(TPUReplicateContext, self).__init__()
- self._num_replicas = num_replicas
- self._outer_device_function_stack = None
- self._oc_dev_fn_stack = None
- self._outside_compilation_cluster = None
- self._outside_compilation_counter = 0
- self._in_gradient_colocation = None
- self._gradient_colocation_stack = []
- self._host_compute_core = []
self._name = name
self._unsupported_ops = []
logging.warning("... and %d more" %
(len(self._unsupported_ops) - _MAX_WARNING_LINES))
- def EnterGradientColocation(self, op, gradient_uid):
- if op is not None:
- self._gradient_colocation_stack.append(op)
- if not self._outside_compilation_cluster:
- try:
- outside_attr = op.get_attr(_OUTSIDE_COMPILATION_ATTR)
- if self._in_gradient_colocation:
- raise NotImplementedError(
- "Cannot nest gradient colocation operations outside compilation"
- )
- if gradient_uid == "__unsupported__":
- raise NotImplementedError(
- "No gradient_uid calling gradient within outside_compilation")
- # When we take the gradient of an op X in an
- # outside_compilation cluster C in a forward computation we
- # would like to put the ops corresponding to the gradient of
- # X into a new outside_compilation cluster C'. However, if
- # we take the gradient of X twice, the second one should get
- # yet another new outside_compilation cluster C''.
- #
- # The mechanism we adopt is to use a 'root_cluster' which is
- # the cluster that X was in before we took gradients, and a
- # 'gradient_uid' which is different for every invocation of
- # gradients, and put the gradient of X in cluster
- # 'root_cluster.gradient_uid'.
- #
- # When the gradient code adds multiple Ops, it asks them to
- # be colocated either with the original Op X, or with one of
- # the preceding Ops that was added to the gradient. In other
- # words, we want to detect the case where we are colocating
- # with an Op that is in cluster root_cluster.gradient_uid
- # and put the new Op in that same cluster if the
- # gradient_uid is the same (the case that we are in the same
- # invocation of gradients, and just adding new Ops to the
- # cluster); and in a different cluster if the gradient_uids
- # are different (the case that we are in a new invocation of
- # gradients, taking the gradient of a previously-computed
- # gradient).
- self._in_gradient_colocation = op
- parts = outside_attr.split(".")
- if len(parts) > 1:
- uid = parts[-1]
- if uid == gradient_uid:
- # Keep using the same cluster
- cluster = outside_attr
- else:
- # We're taking the gradient of a gradient so make a new
- # cluster attr, adding a new '.uid' on the end to
- # preserve the invariant that the gradient_uid is the
- # suffix after the last '.' in the attr.
- cluster = outside_attr + "." + gradient_uid
- else:
- # We're taking the gradient of an Op in the forward pass, so
- # make a new cluster combining the Op's cluster and the
- # gradient id.
- cluster = outside_attr + "." + gradient_uid
- self._EnterOutsideCompilationScope(cluster=cluster)
- except ValueError:
- # The attr was not present: do nothing.
- pass
-
- def ExitGradientColocation(self, op, gradient_uid):
- if op is not None:
- if not self._gradient_colocation_stack:
- raise errors.InternalError(
- op.node_def, op,
- "Badly nested gradient colocation: empty stack when popping Op " +
- op.name)
- last_op = self._gradient_colocation_stack.pop()
- if op is last_op:
- if op is self._in_gradient_colocation:
- self._in_gradient_colocation = None
- self._ExitOutsideCompilationScope()
- else:
- raise errors.InternalError(
- op.node_def, op, "Badly nested gradient colocation, expected " +
- last_op + ", got " + op.name)
-
- def _EnterOutsideCompilationScope(self, cluster=None):
-
- class FakeOp(object):
- """A helper class to determine the current device.
-
- Supports only the device set/get methods needed to run the
- graph's _apply_device_function method.
- """
-
- def __init__(self):
- self._device = ""
-
- @property
- def device(self):
- return self._device
-
- def _set_device(self, device):
- self._device = device.to_string()
-
- if self._outside_compilation_cluster:
- raise NotImplementedError("Cannot nest outside_compilation clusters")
- if cluster:
- self._outside_compilation_cluster = cluster
- else:
- self._outside_compilation_cluster = str(self._outside_compilation_counter)
- self._outside_compilation_counter += 1
- graph = ops.get_default_graph()
- fake_op = FakeOp()
- graph._apply_device_functions(fake_op) # pylint: disable=protected-access
- device = pydev.DeviceSpec.from_string(fake_op.device)
- if (device.device_type == "TPU_REPLICATED_CORE" and
- device.device_index is not None):
- self._host_compute_core.append(self._outside_compilation_cluster + ":" +
- str(device.device_index))
- self._oc_dev_fn_stack = graph._device_function_stack # pylint: disable=protected-access
- graph._device_function_stack = self._outer_device_function_stack # pylint: disable=protected-access
-
- def _ExitOutsideCompilationScope(self):
- if not self._outside_compilation_cluster:
- raise NotImplementedError(
- "Attempted to exit outside_compilation scope when not in scope")
- self._outside_compilation_cluster = None
- graph = ops.get_default_graph()
- graph._device_function_stack = self._oc_dev_fn_stack # pylint: disable=protected-access
-
- def Enter(self):
- if not self._outer_device_function_stack:
- # Capture the device function stack at the time of first entry
- # since that is the stack that will be used outside_compilation.
- graph = ops.get_default_graph()
- self._outer_device_function_stack = list(graph._device_function_stack) # pylint: disable=protected-access
- super(TPUReplicateContext, self).Enter()
-
- def Exit(self):
- super(TPUReplicateContext, self).Exit()
-
- def HostComputeCore(self):
- return self._host_compute_core
-
def AddOp(self, op):
self._AddOpInternal(op)
raise ValueError("TPU computations cannot be nested")
op._set_attr(_TPU_REPLICATE_ATTR,
attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
- if self._outside_compilation_cluster:
- op._set_attr(
- _OUTSIDE_COMPILATION_ATTR,
- attr_value_pb2.AttrValue(
- s=compat.as_bytes(self._outside_compilation_cluster)))
- if self._num_replicas > 1 or not self._outside_compilation_cluster:
- # Prevent feeding or fetching anything that is being compiled,
- # and any replicated outside_compilation Op.
- op.graph.prevent_feeding(op)
- op.graph.prevent_fetching(op)
+ # pylint: enable=protected-access
+ op.graph.prevent_feeding(op)
+ op.graph.prevent_fetching(op)
def AddValue(self, val):
result = val
return None
-def outside_compilation(computation, args=None):
- """Builds part of a computation outside any current TPU replicate scope.
-
- Args:
- computation: A Python function that builds the computation to
- place on the host.
- args: Inputs to pass to computation.
- Returns:
- The Tensors returned by computation.
- """
- graph = ops.get_default_graph()
-
- # If we are in a TPUReplicateContext, signal that we are now
- # outside_compilation
- initial_context = graph._get_control_flow_context() # pylint: disable=protected-access
- context = initial_context
- while context:
- if isinstance(context, TPUReplicateContext):
- context._EnterOutsideCompilationScope() # pylint: disable=protected-access
- context = context.outer_context
-
- retval = computation(*args)
-
- # If we are in a TPUReplicateContext, signal that we are no longer
- # outside_compilation
- final_context = graph._get_control_flow_context() # pylint: disable=protected-access
- if initial_context is not final_context:
- raise NotImplementedError(
- "Control-flow context cannot be different at start and end of an "
- "outside_compilation scope")
- context = initial_context
- while context:
- if isinstance(context, TPUReplicateContext):
- context._ExitOutsideCompilationScope() # pylint: disable=protected-access
- context = context.outer_context
-
- return retval
-
-
def replicate(computation,
inputs=None,
infeed_queue=None,
computation_inputs.append(
tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i)))
- context = TPUReplicateContext(
- name=graph.unique_name("cluster"), num_replicas=num_replicas)
+ context = TPUReplicateContext(name=graph.unique_name("cluster"))
try:
context.Enter()
finally:
context.report_unsupported_operations()
context.Exit()
- host_compute_core = context.HostComputeCore()
-
- if host_compute_core:
- attr_value = attr_value_pb2.AttrValue()
- attr_value.list.s.extend([compat.as_bytes(x) for x in host_compute_core])
- metadata._set_attr("host_compute_core", attr_value) # pylint: disable=protected-access
# Fan-out: Builds a TPUReplicatedOutput node for each output.
outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas,
return x if isinstance(x, (list, tuple)) else [x]
-def _DefaultGradYs(grad_ys,
- ys,
- colocate_gradients_with_ops,
- gradient_uid="__unsupported__"):
+def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
"""Fill in default values for grad_ys.
Args:
ys: List of tensors.
colocate_gradients_with_ops: If True, try colocating gradients with
the corresponding op.
- gradient_uid: A unique identifier within the graph indicating
- which invocation of gradients is being executed. Used to cluster
- ops for compilation.
Returns:
A list of gradients to use, without None.
for i in xrange(len(grad_ys)):
grad_y = grad_ys[i]
y = ys[i]
- with _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops):
+ with _maybe_colocate_with(y.op, colocate_gradients_with_ops):
if grad_y is None:
if y.dtype.is_complex:
raise TypeError(
@contextlib.contextmanager
-def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): # pylint: disable=invalid-name
+def _maybe_colocate_with(op, colocate_gradients_with_ops):
"""Context to colocate with `op` if `colocate_gradients_with_ops`."""
if colocate_gradients_with_ops:
- with ops._colocate_with_for_gradient(op, gradient_uid): # pylint: disable=protected-access
+ with ops.colocate_with(op):
yield
else:
yield
with ops.name_scope(
name, "gradients",
list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
- # Get a uid for this call to gradients that can be used to help
- # cluster ops for compilation.
- gradient_uid = ops.get_default_graph().unique_name("uid")
ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
xs = [
x.handle if resource_variable_ops.is_resource_variable(x) else x
]
xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
xs, name="x", as_ref=True)
- grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
- gradient_uid)
+ grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops)
# The approach we take here is as follows: Create a list of all ops in the
# subgraph between the ys and xs. Visit these ops in reverse order of ids
while queue:
# generate gradient subgraph for op.
op = queue.popleft()
- with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):
+ with _maybe_colocate_with(op, colocate_gradients_with_ops):
if loop_state:
loop_state.EnterGradWhileContext(op, before=True)
- out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state,
- aggregation_method)
+ out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method)
if loop_state:
loop_state.ExitGradWhileContext(op, before=True)
if gate_gradients and len([x for x in in_grads
if x is not None]) > 1:
with ops.device(None):
- with ops._colocate_with_for_gradient( # pylint: disable=protected-access
- None,
- gradient_uid,
- ignore_existing=True):
+ with ops.colocate_with(None, ignore_existing=True):
in_grads = control_flow_ops.tuple(in_grads)
_LogOpGradients(op, out_grads, in_grads)
else:
", ".join([x.name for x in in_grads if _FilterGrad(x)]))
-def _MultiDeviceAddN(tensor_list, gradient_uid):
+def _MultiDeviceAddN(tensor_list):
"""Adds tensors from potentially multiple devices."""
# Basic function structure comes from control_flow_ops.group().
# Sort tensors according to their devices.
for dev in sorted(six.iterkeys(tensors_on_device), key=DeviceKey):
tensors = tensors_on_device[dev]
- with ops._colocate_with_for_gradient( # pylint: disable=protected-access
- tensors[0].op,
- gradient_uid,
- ignore_existing=True):
+ with ops.colocate_with(tensors[0].op, ignore_existing=True):
summands.append(math_ops.add_n(tensors))
return math_ops.add_n(summands)
EXPERIMENTAL_ACCUMULATE_N = 2
-def _AggregatedGrads(grads,
- op,
- gradient_uid,
- loop_state,
- aggregation_method=None):
+def _AggregatedGrads(grads, op, loop_state, aggregation_method=None):
"""Get the aggregated gradients for op.
Args:
grads: The map of memoized gradients.
op: The op to get gradients for.
- gradient_uid: A unique identifier within the graph indicating
- which invocation of gradients is being executed. Used to cluster
- ops for compilation.
loop_state: An object for maintaining the state of the while loops in the
graph. It is of type ControlFlowState. None if the graph
contains no while loops.
out_grads[i] = running_sum
else:
used = "add_n"
- out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid)
+ out_grads[i] = _MultiDeviceAddN(out_grad)
logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad),
tensor_shape, used)
else: