From 390e19ab990f5656e09d98624c92b3c80e52937d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 28 Mar 2018 16:16:48 -0700 Subject: [PATCH] Tower-local variable support for DistributionStrategy. Each tower has its own variable, but fetch() and checkpoint apply a reduction to get a single value. PiperOrigin-RevId: 190853123 --- tensorflow/python/training/distribute.py | 59 ++++++++++++++++++++++++++++---- 1 file changed, 53 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index 757ba71..f988727 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -126,16 +126,18 @@ class UpdateContext(object): def get_tower_context(): - """Returns the current TowerContext or None. + """Returns the current TowerContext or None if in a cross-tower context. Note that execution: - 1. starts in the default (single-tower) tower context; - 2. switches to cross-tower context when entering a - `with DistributionStrategy.scope():` block; + 1. starts in the default (single-tower) tower context (this function + will return the default TowerContext object); + 2. switches to cross-tower context (in which case this will return + None) when entering a `with DistributionStrategy.scope():` block; 3. switches to a (non-default) tower context inside `call_for_each_tower(fn, ...)`; 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then - inside `merge_fn` you are back in the cross-tower context. + inside `merge_fn` you are back in the cross-tower context (and again + this function will return None). Note that you can also go directly from step 1 to 4 to switch to a cross-tower context for the default `DistributionStrategy`. You may @@ -188,6 +190,9 @@ def get_cross_tower_context(): def get_distribution_strategy(): """Returns the current `DistributionStrategy` object. + Prefer to use `get_tower_context()` or `get_cross_tower_context()` + instead when possible. + Returns: A `DistributionStrategy` object. Inside a `with distribution_strategy.scope()` block, it returns @@ -526,7 +531,6 @@ class DistributionStrategy(object): # TODO(josh11b): ClusterSpec/ClusterResolver # TODO(josh11b): Partitioned computations, state; sharding # TODO(josh11b): Model parallelism: "towers" with multiple devices; shuffling - # TODO(josh11b): Tower-local variables # TODO(josh11b): List of towers with their worker and parameter devices # (where the parameter devices may overlap in the ps case). @@ -556,6 +560,43 @@ class DistributionStrategy(object): # Note: should support "colocate_with" argument. raise NotImplementedError("must be implemented in descendants") + def tower_local_var_scope(self, reduce_method): + """Inside this scope, new variables will not be mirrored. + + There will still be one component variable per tower, but there is + no requirement that they stay in sync. Instead, when saving them + or calling `fetch()`, we use the value that results when calling + `reduce()` on all the towers' variables. + + Note: tower-local implies not trainable. Instead, it is expected + that each tower will directly update (using `assign_add()` or + whatever) its local variable instance but only the aggregated + value (accessible using `fetch()`) will be exported from the + model. When it is acceptable to only aggregate on export, we + greatly reduce communication overhead by using tower-local + variables. + + Note: All component variables will be initialized to the same + value, using the initialization expression from the first tower. + The values will match even if the initialization expression uses + random numbers. + + Args: + reduce_method: String used as a `method_string` to `reduce()` + to get the value to save when checkpointing. + + Returns: + A context manager. + """ + def create_tower_local_variable(next_creator, *args, **kwargs): + _require_distribution_strategy_scope(self) + kwargs["use_resource"] = True + kwargs["tower_local_reduce_method"] = reduce_method + return next_creator(*args, **kwargs) + + _require_distribution_strategy_scope(self) + return variable_scope.variable_creator_scope(create_tower_local_variable) + def colocate_vars_with(self, colocate_with_variable): """Scope that controls which devices variables will be created on. @@ -984,6 +1025,10 @@ class TowerContext(object): finally: _pop_per_thread_mode() + def tower_local_var_scope(self, reduce_method): + """Alias for distribution_strategy.tower_local_var_scope().""" + return self._distribution_strategy.tower_local_var_scope(reduce_method) + @property def is_single_tower(self): """Returns whether there is a single tower or multiple.""" @@ -1030,6 +1075,8 @@ class _DefaultDistributionStrategy(DistributionStrategy): def creator(next_creator, *args, **kwargs): _require_distribution_strategy_scope(self) + if kwargs.pop("tower_local_reduce_method", None) is not None: + kwargs["trainable"] = False return next_creator(*args, **kwargs) return _CurrentDistributionContext( -- 2.7.4