From: A. Unique TensorFlower Date: Wed, 28 Mar 2018 21:52:25 +0000 (-0700) Subject: Add DistributionStrategy support to Optimizer. X-Git-Tag: tflite-v0.1.7~67^2^2~35 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=15908d912ed26f2517207e0a0bea6cd5768476ee;p=platform%2Fupstream%2Ftensorflow.git Add DistributionStrategy support to Optimizer. PiperOrigin-RevId: 190838314 --- diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index bf79714..75665fc 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -35,11 +35,28 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.training import checkpointable +from tensorflow.python.training import distribute as distribute_lib from tensorflow.python.training import slot_creator from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export +def get_filtered_grad_fn(grad_fn): + # `distributed_context.join()` requires that its arguments are parallel + # across threads, and in particular that `grads_and_vars` has the same + # variables in the same order. + + # When computing gradients in eager mode with multiple threads, you + # can get extra variables with a gradient of `None`. This happens when + # those variables are accessed in another thread during the gradient + # computation. To get a consistent set of variables, we filter out + # those with `None` gradients. + def filtered_grad_fn(x=None): + return [(g, v) for g, v in grad_fn(x) if g is not None] + + return filtered_grad_fn + + def _deduplicate_indexed_slices(values, indices): """Sums `values` associated with any non-unique `indices`. @@ -335,6 +352,13 @@ class Optimizer( # ... } self._deferred_slot_restorations = {} + # TODO(isaprykin): When using a DistributionStrategy, and when an + # optimizer is created in each tower, it might be dangerous to + # rely on some Optimer methods. When such methods are called on a + # per-tower optimizer, an exception needs to be thrown. We do + # allow creation per-tower optimizers however, because the + # compute_gradients()->apply_gradients() sequence is safe. + def get_name(self): return self._name @@ -447,14 +471,33 @@ class Optimizer( if var_list is not None: tape.watch(var_list) loss_value = loss() + + # Scale loss if using a "mean" loss reduction and multiple towers. + # Have to be careful to call distribute_lib.get_loss_reduction() + # *after* loss() is evaluated, so we know what loss reduction it uses. + # TODO(josh11b): Test that we handle weight decay in a reasonable way. + if distribute_lib.get_loss_reduction() == "mean": + num_towers = distribute_lib.get_distribution_strategy().num_towers + if num_towers > 1: + loss_value *= (1. / num_towers) + if var_list is None: var_list = tape.watched_variables() grads = tape.gradient(loss_value, var_list, grad_loss) return list(zip(grads, var_list)) + + # Non-callable/Tensor loss case if context.executing_eagerly(): raise RuntimeError( "`loss` passed to Optimizer.compute_gradients should " "be a function when eager execution is enabled.") + + # Scale loss if using a "mean" loss reduction and multiple towers. + if distribute_lib.get_loss_reduction() == "mean": + num_towers = distribute_lib.get_distribution_strategy().num_towers + if num_towers > 1: + loss *= (1. / num_towers) + if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP, Optimizer.GATE_GRAPH]: raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, " @@ -510,11 +553,25 @@ class Optimizer( Raises: TypeError: If `grads_and_vars` is malformed. ValueError: If none of the variables have gradients. + RuntimeError: If you should use `_distributed_apply()` instead. """ # This is a default implementation of apply_gradients() that can be shared # by most optimizers. It relies on the subclass implementing the following # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse(). + # Handle DistributionStrategy case. + if distribute_lib.get_cross_tower_context(): + raise RuntimeError("Use `_distributed_apply()` instead of " + "`apply_gradients()` in a cross-tower context.") + # TODO(isaprykin): Get rid of `has_distribution_strategy()` check by + # always calling _distributed_apply(), using the default distribution + # as needed. + if distribute_lib.has_distribution_strategy(): + grads_and_vars = get_filtered_grad_fn(lambda _: grads_and_vars)() + return distribute_lib.get_tower_context().merge_call( + self._distributed_apply, grads_and_vars, global_step, name) + + # No DistributionStrategy case. grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works. if not grads_and_vars: raise ValueError("No variables provided.") @@ -582,6 +639,95 @@ class Optimizer( return apply_updates + def _distributed_apply(self, + distribution, + grads_and_vars, + global_step=None, + name=None): + """A version of `apply_gradients` for cross-tower context. + + This is a version of `apply_gradients()` for when you are using a + `DistributionStrategy` and are in a cross-tower context. If in a + tower context, use `apply_gradients()` as normal. + + Args: + distribution: A `DistributionStrategy` object. + grads_and_vars: List of (gradient, variable) pairs as returned by + `compute_gradients()`, and then aggregated across towers. + global_step: Optional (mirrored) `Variable` to increment by one + after the variables have been updated. + name: Optional name for the returned operation. Default to the + name passed to the `Optimizer` constructor. + + Returns: + An `Operation` that applies the specified gradients across all + towers. If `global_step` was not None, that operation also + increments `global_step`. + """ + reduced_grads = distribution.batch_reduce("sum", grads_and_vars) + var_list = [v for _, v in grads_and_vars] + grads_and_vars = zip(reduced_grads, var_list) + # Note that this is called in a cross-tower context. + self._create_slots(var_list) + + def update(v, g): + """Apply gradients to a replica variable.""" + assert v is not None + + try: + # Convert the grad to Tensor or IndexedSlices if necessary. + g = ops.convert_to_tensor_or_indexed_slices(g) + except TypeError: + raise TypeError("Gradient must be convertible to a Tensor" + " or IndexedSlices, or None: %s" % g) + if not isinstance(g, (ops.Tensor, ops.IndexedSlices)): + raise TypeError( + "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) + p = _get_processor(v) + + scope_name = "" if context.executing_eagerly() else v.op.name + # device_policy is set because non-mirrored tensors will be read in + # `update_op`. `_resource_apply_dense`, `lr_t`, `beta1_t` and `beta2_t` + # is an example. + with ops.name_scope( + "update_" + scope_name), context.context().device_policy( + context.DEVICE_PLACEMENT_SILENT): + return p.update_op(self, g) + + with ops.name_scope(name, self._name) as name: + self._prepare() + + update_ops = [ + op + for grad, var in grads_and_vars + for op in distribution.unwrap(distribution.update(var, update, grad)) + ] + + def finish(self, update_ops): + return self._finish(update_ops, "update") + + non_slot_devices = distribution.non_slot_devices(var_list) + # Device policy is needed because hyperparameter tensors (such as + # AdamOptimizer's beta1_t) need to be copied across devices in Eager. + with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT): + finish_updates = distribution.update_non_slot( + non_slot_devices, finish, self, update_ops) + if global_step is None: + apply_updates = distribution.group(finish_updates, name=name) + else: + with ops.control_dependencies(distribution.unwrap(finish_updates)): + apply_updates = distribution.group(distribution.update( + global_step, state_ops.assign_add, 1, name=name)) + + if not context.executing_eagerly(): + if isinstance(apply_updates, ops.Tensor): + apply_updates = apply_updates.op + train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) + if apply_updates not in train_op: + train_op.append(apply_updates) + + return apply_updates + def get_slot(self, var, name): """Return a slot named `name` created for `var` by the Optimizer. @@ -599,9 +745,25 @@ class Optimizer( Returns: The `Variable` for the slot if it was created, `None` otherwise. """ + # pylint: disable=protected-access named_slots = self._slots.get(name, None) if not named_slots: return None + + if hasattr(var, "_mirrored_container"): + # NOTE: If this isn't patched, then there is no `handle` in + # `_resource_apply_dense`. + mirrored_container = var._mirrored_container() + assert mirrored_container is not None + if context.executing_eagerly(): + key = mirrored_container._unique_id + else: + key = (mirrored_container.graph, mirrored_container._shared_name) + # pylint: enable=protected-access + mirrored_slot = named_slots.get(key, None) + if mirrored_slot is None: return None + return mirrored_slot.get(device=var.device) + return named_slots.get(_var_key(var), None) def get_slot_names(self): @@ -645,6 +807,7 @@ class Optimizer( def _create_non_slot_variable(self, initial_value, name, colocate_with): """Add an extra variable, not associated with a slot.""" + # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables. eager = context.executing_eagerly() graph = None if eager else colocate_with.graph @@ -652,7 +815,8 @@ class Optimizer( v = self._non_slot_dict.get(key, None) if v is None: self._maybe_initialize_checkpointable() - with ops.colocate_with(colocate_with): + distribution_strategy = distribute_lib.get_distribution_strategy() + with distribution_strategy.colocate_vars_with(colocate_with): if eager: restored_initial_value = self._preload_simple_restoration( name=name, shape=None) @@ -694,7 +858,13 @@ class Optimizer( return self._get_non_slot_variable(name, graph=graph) def _get_non_slot_variable(self, name, graph=None): - return self._non_slot_dict.get((name, graph), None) + non_slot = self._non_slot_dict.get((name, graph), None) + if hasattr(non_slot, "_mirrored_container"): + # This is a mirrored non-slot. In order to enable code like `_finish` + # to assign to a non-slot, return the current context replica. + return non_slot.get() + else: + return non_slot def _non_slot_variables(self): """Additional variables created by the `Optimizer`.