Add DistributionStrategy support to Optimizer.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 28 Mar 2018 21:52:25 +0000 (14:52 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 28 Mar 2018 21:55:18 +0000 (14:55 -0700)
PiperOrigin-RevId: 190838314

tensorflow/python/training/optimizer.py

index bf79714..75665fc 100644 (file)
@@ -35,11 +35,28 @@ from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.training import checkpointable
+from tensorflow.python.training import distribute as distribute_lib
 from tensorflow.python.training import slot_creator
 from tensorflow.python.util import nest
 from tensorflow.python.util.tf_export import tf_export
 
 
+def get_filtered_grad_fn(grad_fn):
+  # `distributed_context.join()` requires that its arguments are parallel
+  # across threads, and in particular that `grads_and_vars` has the same
+  # variables in the same order.
+
+  # When computing gradients in eager mode with multiple threads, you
+  # can get extra variables with a gradient of `None`. This happens when
+  # those variables are accessed in another thread during the gradient
+  # computation. To get a consistent set of variables, we filter out
+  # those with `None` gradients.
+  def filtered_grad_fn(x=None):
+    return [(g, v) for g, v in grad_fn(x) if g is not None]
+
+  return filtered_grad_fn
+
+
 def _deduplicate_indexed_slices(values, indices):
   """Sums `values` associated with any non-unique `indices`.
 
@@ -335,6 +352,13 @@ class Optimizer(
     #   ... }
     self._deferred_slot_restorations = {}
 
+    # TODO(isaprykin): When using a DistributionStrategy, and when an
+    # optimizer is created in each tower, it might be dangerous to
+    # rely on some Optimer methods.  When such methods are called on a
+    # per-tower optimizer, an exception needs to be thrown.  We do
+    # allow creation per-tower optimizers however, because the
+    # compute_gradients()->apply_gradients() sequence is safe.
+
   def get_name(self):
     return self._name
 
@@ -447,14 +471,33 @@ class Optimizer(
         if var_list is not None:
           tape.watch(var_list)
         loss_value = loss()
+
+        # Scale loss if using a "mean" loss reduction and multiple towers.
+        # Have to be careful to call distribute_lib.get_loss_reduction()
+        # *after* loss() is evaluated, so we know what loss reduction it uses.
+        # TODO(josh11b): Test that we handle weight decay in a reasonable way.
+        if distribute_lib.get_loss_reduction() == "mean":
+          num_towers = distribute_lib.get_distribution_strategy().num_towers
+          if num_towers > 1:
+            loss_value *= (1. / num_towers)
+
       if var_list is None:
         var_list = tape.watched_variables()
       grads = tape.gradient(loss_value, var_list, grad_loss)
       return list(zip(grads, var_list))
+
+    # Non-callable/Tensor loss case
     if context.executing_eagerly():
       raise RuntimeError(
           "`loss` passed to Optimizer.compute_gradients should "
           "be a function when eager execution is enabled.")
+
+    # Scale loss if using a "mean" loss reduction and multiple towers.
+    if distribute_lib.get_loss_reduction() == "mean":
+      num_towers = distribute_lib.get_distribution_strategy().num_towers
+      if num_towers > 1:
+        loss *= (1. / num_towers)
+
     if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
                               Optimizer.GATE_GRAPH]:
       raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
@@ -510,11 +553,25 @@ class Optimizer(
     Raises:
       TypeError: If `grads_and_vars` is malformed.
       ValueError: If none of the variables have gradients.
+      RuntimeError: If you should use `_distributed_apply()` instead.
     """
     # This is a default implementation of apply_gradients() that can be shared
     # by most optimizers.  It relies on the subclass implementing the following
     # methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
 
+    # Handle DistributionStrategy case.
+    if distribute_lib.get_cross_tower_context():
+      raise RuntimeError("Use `_distributed_apply()` instead of "
+                         "`apply_gradients()` in a cross-tower context.")
+    # TODO(isaprykin): Get rid of `has_distribution_strategy()` check by
+    # always calling _distributed_apply(), using the default distribution
+    # as needed.
+    if distribute_lib.has_distribution_strategy():
+      grads_and_vars = get_filtered_grad_fn(lambda _: grads_and_vars)()
+      return distribute_lib.get_tower_context().merge_call(
+          self._distributed_apply, grads_and_vars, global_step, name)
+
+    # No DistributionStrategy case.
     grads_and_vars = tuple(grads_and_vars)  # Make sure repeat iteration works.
     if not grads_and_vars:
       raise ValueError("No variables provided.")
@@ -582,6 +639,95 @@ class Optimizer(
 
       return apply_updates
 
+  def _distributed_apply(self,
+                         distribution,
+                         grads_and_vars,
+                         global_step=None,
+                         name=None):
+    """A version of `apply_gradients` for cross-tower context.
+
+    This is a version of `apply_gradients()` for when you are using a
+    `DistributionStrategy` and are in a cross-tower context. If in a
+    tower context, use `apply_gradients()` as normal.
+
+    Args:
+      distribution: A `DistributionStrategy` object.
+      grads_and_vars: List of (gradient, variable) pairs as returned by
+        `compute_gradients()`, and then aggregated across towers.
+      global_step: Optional (mirrored) `Variable` to increment by one
+        after the variables have been updated.
+      name: Optional name for the returned operation.  Default to the
+        name passed to the `Optimizer` constructor.
+
+    Returns:
+      An `Operation` that applies the specified gradients across all
+      towers. If `global_step` was not None, that operation also
+      increments `global_step`.
+    """
+    reduced_grads = distribution.batch_reduce("sum", grads_and_vars)
+    var_list = [v for _, v in grads_and_vars]
+    grads_and_vars = zip(reduced_grads, var_list)
+    # Note that this is called in a cross-tower context.
+    self._create_slots(var_list)
+
+    def update(v, g):
+      """Apply gradients to a replica variable."""
+      assert v is not None
+
+      try:
+        # Convert the grad to Tensor or IndexedSlices if necessary.
+        g = ops.convert_to_tensor_or_indexed_slices(g)
+      except TypeError:
+        raise TypeError("Gradient must be convertible to a Tensor"
+                        " or IndexedSlices, or None: %s" % g)
+      if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
+        raise TypeError(
+            "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
+      p = _get_processor(v)
+
+      scope_name = "" if context.executing_eagerly() else v.op.name
+      # device_policy is set because non-mirrored tensors will be read in
+      # `update_op`. `_resource_apply_dense`, `lr_t`, `beta1_t` and `beta2_t`
+      # is an example.
+      with ops.name_scope(
+          "update_" + scope_name), context.context().device_policy(
+              context.DEVICE_PLACEMENT_SILENT):
+        return p.update_op(self, g)
+
+    with ops.name_scope(name, self._name) as name:
+      self._prepare()
+
+      update_ops = [
+          op
+          for grad, var in grads_and_vars
+          for op in distribution.unwrap(distribution.update(var, update, grad))
+      ]
+
+      def finish(self, update_ops):
+        return self._finish(update_ops, "update")
+
+      non_slot_devices = distribution.non_slot_devices(var_list)
+      # Device policy is needed because hyperparameter tensors (such as
+      # AdamOptimizer's beta1_t) need to be copied across devices in Eager.
+      with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+        finish_updates = distribution.update_non_slot(
+            non_slot_devices, finish, self, update_ops)
+      if global_step is None:
+        apply_updates = distribution.group(finish_updates, name=name)
+      else:
+        with ops.control_dependencies(distribution.unwrap(finish_updates)):
+          apply_updates = distribution.group(distribution.update(
+              global_step, state_ops.assign_add, 1, name=name))
+
+      if not context.executing_eagerly():
+        if isinstance(apply_updates, ops.Tensor):
+          apply_updates = apply_updates.op
+        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+        if apply_updates not in train_op:
+          train_op.append(apply_updates)
+
+      return apply_updates
+
   def get_slot(self, var, name):
     """Return a slot named `name` created for `var` by the Optimizer.
 
@@ -599,9 +745,25 @@ class Optimizer(
     Returns:
       The `Variable` for the slot if it was created, `None` otherwise.
     """
+    # pylint: disable=protected-access
     named_slots = self._slots.get(name, None)
     if not named_slots:
       return None
+
+    if hasattr(var, "_mirrored_container"):
+      # NOTE: If this isn't patched, then there is no `handle` in
+      # `_resource_apply_dense`.
+      mirrored_container = var._mirrored_container()
+      assert mirrored_container is not None
+      if context.executing_eagerly():
+        key = mirrored_container._unique_id
+      else:
+        key = (mirrored_container.graph, mirrored_container._shared_name)
+      # pylint: enable=protected-access
+      mirrored_slot = named_slots.get(key, None)
+      if mirrored_slot is None: return None
+      return mirrored_slot.get(device=var.device)
+
     return named_slots.get(_var_key(var), None)
 
   def get_slot_names(self):
@@ -645,6 +807,7 @@ class Optimizer(
 
   def _create_non_slot_variable(self, initial_value, name, colocate_with):
     """Add an extra variable, not associated with a slot."""
+    # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
     eager = context.executing_eagerly()
     graph = None if eager else colocate_with.graph
 
@@ -652,7 +815,8 @@ class Optimizer(
     v = self._non_slot_dict.get(key, None)
     if v is None:
       self._maybe_initialize_checkpointable()
-      with ops.colocate_with(colocate_with):
+      distribution_strategy = distribute_lib.get_distribution_strategy()
+      with distribution_strategy.colocate_vars_with(colocate_with):
         if eager:
           restored_initial_value = self._preload_simple_restoration(
               name=name, shape=None)
@@ -694,7 +858,13 @@ class Optimizer(
     return self._get_non_slot_variable(name, graph=graph)
 
   def _get_non_slot_variable(self, name, graph=None):
-    return self._non_slot_dict.get((name, graph), None)
+    non_slot = self._non_slot_dict.get((name, graph), None)
+    if hasattr(non_slot, "_mirrored_container"):
+      # This is a mirrored non-slot.  In order to enable code like `_finish`
+      # to assign to a non-slot, return the current context replica.
+      return non_slot.get()
+    else:
+      return non_slot
 
   def _non_slot_variables(self):
     """Additional variables created by the `Optimizer`.