from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import checkpointable
+from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import slot_creator
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
+def get_filtered_grad_fn(grad_fn):
+ # `distributed_context.join()` requires that its arguments are parallel
+ # across threads, and in particular that `grads_and_vars` has the same
+ # variables in the same order.
+
+ # When computing gradients in eager mode with multiple threads, you
+ # can get extra variables with a gradient of `None`. This happens when
+ # those variables are accessed in another thread during the gradient
+ # computation. To get a consistent set of variables, we filter out
+ # those with `None` gradients.
+ def filtered_grad_fn(x=None):
+ return [(g, v) for g, v in grad_fn(x) if g is not None]
+
+ return filtered_grad_fn
+
+
def _deduplicate_indexed_slices(values, indices):
"""Sums `values` associated with any non-unique `indices`.
# ... }
self._deferred_slot_restorations = {}
+ # TODO(isaprykin): When using a DistributionStrategy, and when an
+ # optimizer is created in each tower, it might be dangerous to
+ # rely on some Optimer methods. When such methods are called on a
+ # per-tower optimizer, an exception needs to be thrown. We do
+ # allow creation per-tower optimizers however, because the
+ # compute_gradients()->apply_gradients() sequence is safe.
+
def get_name(self):
return self._name
if var_list is not None:
tape.watch(var_list)
loss_value = loss()
+
+ # Scale loss if using a "mean" loss reduction and multiple towers.
+ # Have to be careful to call distribute_lib.get_loss_reduction()
+ # *after* loss() is evaluated, so we know what loss reduction it uses.
+ # TODO(josh11b): Test that we handle weight decay in a reasonable way.
+ if distribute_lib.get_loss_reduction() == "mean":
+ num_towers = distribute_lib.get_distribution_strategy().num_towers
+ if num_towers > 1:
+ loss_value *= (1. / num_towers)
+
if var_list is None:
var_list = tape.watched_variables()
grads = tape.gradient(loss_value, var_list, grad_loss)
return list(zip(grads, var_list))
+
+ # Non-callable/Tensor loss case
if context.executing_eagerly():
raise RuntimeError(
"`loss` passed to Optimizer.compute_gradients should "
"be a function when eager execution is enabled.")
+
+ # Scale loss if using a "mean" loss reduction and multiple towers.
+ if distribute_lib.get_loss_reduction() == "mean":
+ num_towers = distribute_lib.get_distribution_strategy().num_towers
+ if num_towers > 1:
+ loss *= (1. / num_towers)
+
if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
Optimizer.GATE_GRAPH]:
raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
Raises:
TypeError: If `grads_and_vars` is malformed.
ValueError: If none of the variables have gradients.
+ RuntimeError: If you should use `_distributed_apply()` instead.
"""
# This is a default implementation of apply_gradients() that can be shared
# by most optimizers. It relies on the subclass implementing the following
# methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
+ # Handle DistributionStrategy case.
+ if distribute_lib.get_cross_tower_context():
+ raise RuntimeError("Use `_distributed_apply()` instead of "
+ "`apply_gradients()` in a cross-tower context.")
+ # TODO(isaprykin): Get rid of `has_distribution_strategy()` check by
+ # always calling _distributed_apply(), using the default distribution
+ # as needed.
+ if distribute_lib.has_distribution_strategy():
+ grads_and_vars = get_filtered_grad_fn(lambda _: grads_and_vars)()
+ return distribute_lib.get_tower_context().merge_call(
+ self._distributed_apply, grads_and_vars, global_step, name)
+
+ # No DistributionStrategy case.
grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works.
if not grads_and_vars:
raise ValueError("No variables provided.")
return apply_updates
+ def _distributed_apply(self,
+ distribution,
+ grads_and_vars,
+ global_step=None,
+ name=None):
+ """A version of `apply_gradients` for cross-tower context.
+
+ This is a version of `apply_gradients()` for when you are using a
+ `DistributionStrategy` and are in a cross-tower context. If in a
+ tower context, use `apply_gradients()` as normal.
+
+ Args:
+ distribution: A `DistributionStrategy` object.
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ `compute_gradients()`, and then aggregated across towers.
+ global_step: Optional (mirrored) `Variable` to increment by one
+ after the variables have been updated.
+ name: Optional name for the returned operation. Default to the
+ name passed to the `Optimizer` constructor.
+
+ Returns:
+ An `Operation` that applies the specified gradients across all
+ towers. If `global_step` was not None, that operation also
+ increments `global_step`.
+ """
+ reduced_grads = distribution.batch_reduce("sum", grads_and_vars)
+ var_list = [v for _, v in grads_and_vars]
+ grads_and_vars = zip(reduced_grads, var_list)
+ # Note that this is called in a cross-tower context.
+ self._create_slots(var_list)
+
+ def update(v, g):
+ """Apply gradients to a replica variable."""
+ assert v is not None
+
+ try:
+ # Convert the grad to Tensor or IndexedSlices if necessary.
+ g = ops.convert_to_tensor_or_indexed_slices(g)
+ except TypeError:
+ raise TypeError("Gradient must be convertible to a Tensor"
+ " or IndexedSlices, or None: %s" % g)
+ if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
+ raise TypeError(
+ "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
+ p = _get_processor(v)
+
+ scope_name = "" if context.executing_eagerly() else v.op.name
+ # device_policy is set because non-mirrored tensors will be read in
+ # `update_op`. `_resource_apply_dense`, `lr_t`, `beta1_t` and `beta2_t`
+ # is an example.
+ with ops.name_scope(
+ "update_" + scope_name), context.context().device_policy(
+ context.DEVICE_PLACEMENT_SILENT):
+ return p.update_op(self, g)
+
+ with ops.name_scope(name, self._name) as name:
+ self._prepare()
+
+ update_ops = [
+ op
+ for grad, var in grads_and_vars
+ for op in distribution.unwrap(distribution.update(var, update, grad))
+ ]
+
+ def finish(self, update_ops):
+ return self._finish(update_ops, "update")
+
+ non_slot_devices = distribution.non_slot_devices(var_list)
+ # Device policy is needed because hyperparameter tensors (such as
+ # AdamOptimizer's beta1_t) need to be copied across devices in Eager.
+ with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ finish_updates = distribution.update_non_slot(
+ non_slot_devices, finish, self, update_ops)
+ if global_step is None:
+ apply_updates = distribution.group(finish_updates, name=name)
+ else:
+ with ops.control_dependencies(distribution.unwrap(finish_updates)):
+ apply_updates = distribution.group(distribution.update(
+ global_step, state_ops.assign_add, 1, name=name))
+
+ if not context.executing_eagerly():
+ if isinstance(apply_updates, ops.Tensor):
+ apply_updates = apply_updates.op
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ if apply_updates not in train_op:
+ train_op.append(apply_updates)
+
+ return apply_updates
+
def get_slot(self, var, name):
"""Return a slot named `name` created for `var` by the Optimizer.
Returns:
The `Variable` for the slot if it was created, `None` otherwise.
"""
+ # pylint: disable=protected-access
named_slots = self._slots.get(name, None)
if not named_slots:
return None
+
+ if hasattr(var, "_mirrored_container"):
+ # NOTE: If this isn't patched, then there is no `handle` in
+ # `_resource_apply_dense`.
+ mirrored_container = var._mirrored_container()
+ assert mirrored_container is not None
+ if context.executing_eagerly():
+ key = mirrored_container._unique_id
+ else:
+ key = (mirrored_container.graph, mirrored_container._shared_name)
+ # pylint: enable=protected-access
+ mirrored_slot = named_slots.get(key, None)
+ if mirrored_slot is None: return None
+ return mirrored_slot.get(device=var.device)
+
return named_slots.get(_var_key(var), None)
def get_slot_names(self):
def _create_non_slot_variable(self, initial_value, name, colocate_with):
"""Add an extra variable, not associated with a slot."""
+ # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
eager = context.executing_eagerly()
graph = None if eager else colocate_with.graph
v = self._non_slot_dict.get(key, None)
if v is None:
self._maybe_initialize_checkpointable()
- with ops.colocate_with(colocate_with):
+ distribution_strategy = distribute_lib.get_distribution_strategy()
+ with distribution_strategy.colocate_vars_with(colocate_with):
if eager:
restored_initial_value = self._preload_simple_restoration(
name=name, shape=None)
return self._get_non_slot_variable(name, graph=graph)
def _get_non_slot_variable(self, name, graph=None):
- return self._non_slot_dict.get((name, graph), None)
+ non_slot = self._non_slot_dict.get((name, graph), None)
+ if hasattr(non_slot, "_mirrored_container"):
+ # This is a mirrored non-slot. In order to enable code like `_finish`
+ # to assign to a non-slot, return the current context replica.
+ return non_slot.get()
+ else:
+ return non_slot
def _non_slot_variables(self):
"""Additional variables created by the `Optimizer`.