Tower-local variable support for DistributionStrategy. Each tower has
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 28 Mar 2018 23:16:48 +0000 (16:16 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 28 Mar 2018 23:18:50 +0000 (16:18 -0700)
its own variable, but fetch() and checkpoint apply a reduction to get
a single value.

PiperOrigin-RevId: 190853123

tensorflow/python/training/distribute.py

index 757ba71..f988727 100644 (file)
@@ -126,16 +126,18 @@ class UpdateContext(object):
 
 
 def get_tower_context():
-  """Returns the current TowerContext or None.
+  """Returns the current TowerContext or None if in a cross-tower context.
 
   Note that execution:
-  1. starts in the default (single-tower) tower context;
-  2. switches to cross-tower context when entering a
-     `with DistributionStrategy.scope():` block;
+  1. starts in the default (single-tower) tower context (this function
+     will return the default TowerContext object);
+  2. switches to cross-tower context (in which case this will return
+     None) when entering a `with DistributionStrategy.scope():` block;
   3. switches to a (non-default) tower context inside
      `call_for_each_tower(fn, ...)`;
   4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then
-     inside `merge_fn` you are back in the cross-tower context.
+     inside `merge_fn` you are back in the cross-tower context (and again
+     this function will return None).
 
   Note that you can also go directly from step 1 to 4 to switch to a
   cross-tower context for the default `DistributionStrategy`. You may
@@ -188,6 +190,9 @@ def get_cross_tower_context():
 def get_distribution_strategy():
   """Returns the current `DistributionStrategy` object.
 
+  Prefer to use `get_tower_context()` or `get_cross_tower_context()`
+  instead when possible.
+
   Returns:
     A `DistributionStrategy` object. Inside a
     `with distribution_strategy.scope()` block, it returns
@@ -526,7 +531,6 @@ class DistributionStrategy(object):
   # TODO(josh11b): ClusterSpec/ClusterResolver
   # TODO(josh11b): Partitioned computations, state; sharding
   # TODO(josh11b): Model parallelism: "towers" with multiple devices; shuffling
-  # TODO(josh11b): Tower-local variables
   # TODO(josh11b): List of towers with their worker and parameter devices
   #   (where the parameter devices may overlap in the ps case).
 
@@ -556,6 +560,43 @@ class DistributionStrategy(object):
     # Note: should support "colocate_with" argument.
     raise NotImplementedError("must be implemented in descendants")
 
+  def tower_local_var_scope(self, reduce_method):
+    """Inside this scope, new variables will not be mirrored.
+
+    There will still be one component variable per tower, but there is
+    no requirement that they stay in sync. Instead, when saving them
+    or calling `fetch()`, we use the value that results when calling
+    `reduce()` on all the towers' variables.
+
+    Note: tower-local implies not trainable. Instead, it is expected
+    that each tower will directly update (using `assign_add()` or
+    whatever) its local variable instance but only the aggregated
+    value (accessible using `fetch()`) will be exported from the
+    model. When it is acceptable to only aggregate on export, we
+    greatly reduce communication overhead by using tower-local
+    variables.
+
+    Note: All component variables will be initialized to the same
+    value, using the initialization expression from the first tower.
+    The values will match even if the initialization expression uses
+    random numbers.
+
+    Args:
+      reduce_method: String used as a `method_string` to `reduce()`
+        to get the value to save when checkpointing.
+
+    Returns:
+      A context manager.
+    """
+    def create_tower_local_variable(next_creator, *args, **kwargs):
+      _require_distribution_strategy_scope(self)
+      kwargs["use_resource"] = True
+      kwargs["tower_local_reduce_method"] = reduce_method
+      return next_creator(*args, **kwargs)
+
+    _require_distribution_strategy_scope(self)
+    return variable_scope.variable_creator_scope(create_tower_local_variable)
+
   def colocate_vars_with(self, colocate_with_variable):
     """Scope that controls which devices variables will be created on.
 
@@ -984,6 +1025,10 @@ class TowerContext(object):
     finally:
       _pop_per_thread_mode()
 
+  def tower_local_var_scope(self, reduce_method):
+    """Alias for distribution_strategy.tower_local_var_scope()."""
+    return self._distribution_strategy.tower_local_var_scope(reduce_method)
+
   @property
   def is_single_tower(self):
     """Returns whether there is a single tower or multiple."""
@@ -1030,6 +1075,8 @@ class _DefaultDistributionStrategy(DistributionStrategy):
 
     def creator(next_creator, *args, **kwargs):
       _require_distribution_strategy_scope(self)
+      if kwargs.pop("tower_local_reduce_method", None) is not None:
+        kwargs["trainable"] = False
       return next_creator(*args, **kwargs)
 
     return _CurrentDistributionContext(