def get_tower_context():
- """Returns the current TowerContext or None.
+ """Returns the current TowerContext or None if in a cross-tower context.
Note that execution:
- 1. starts in the default (single-tower) tower context;
- 2. switches to cross-tower context when entering a
- `with DistributionStrategy.scope():` block;
+ 1. starts in the default (single-tower) tower context (this function
+ will return the default TowerContext object);
+ 2. switches to cross-tower context (in which case this will return
+ None) when entering a `with DistributionStrategy.scope():` block;
3. switches to a (non-default) tower context inside
`call_for_each_tower(fn, ...)`;
4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then
- inside `merge_fn` you are back in the cross-tower context.
+ inside `merge_fn` you are back in the cross-tower context (and again
+ this function will return None).
Note that you can also go directly from step 1 to 4 to switch to a
cross-tower context for the default `DistributionStrategy`. You may
def get_distribution_strategy():
"""Returns the current `DistributionStrategy` object.
+ Prefer to use `get_tower_context()` or `get_cross_tower_context()`
+ instead when possible.
+
Returns:
A `DistributionStrategy` object. Inside a
`with distribution_strategy.scope()` block, it returns
# TODO(josh11b): ClusterSpec/ClusterResolver
# TODO(josh11b): Partitioned computations, state; sharding
# TODO(josh11b): Model parallelism: "towers" with multiple devices; shuffling
- # TODO(josh11b): Tower-local variables
# TODO(josh11b): List of towers with their worker and parameter devices
# (where the parameter devices may overlap in the ps case).
# Note: should support "colocate_with" argument.
raise NotImplementedError("must be implemented in descendants")
+ def tower_local_var_scope(self, reduce_method):
+ """Inside this scope, new variables will not be mirrored.
+
+ There will still be one component variable per tower, but there is
+ no requirement that they stay in sync. Instead, when saving them
+ or calling `fetch()`, we use the value that results when calling
+ `reduce()` on all the towers' variables.
+
+ Note: tower-local implies not trainable. Instead, it is expected
+ that each tower will directly update (using `assign_add()` or
+ whatever) its local variable instance but only the aggregated
+ value (accessible using `fetch()`) will be exported from the
+ model. When it is acceptable to only aggregate on export, we
+ greatly reduce communication overhead by using tower-local
+ variables.
+
+ Note: All component variables will be initialized to the same
+ value, using the initialization expression from the first tower.
+ The values will match even if the initialization expression uses
+ random numbers.
+
+ Args:
+ reduce_method: String used as a `method_string` to `reduce()`
+ to get the value to save when checkpointing.
+
+ Returns:
+ A context manager.
+ """
+ def create_tower_local_variable(next_creator, *args, **kwargs):
+ _require_distribution_strategy_scope(self)
+ kwargs["use_resource"] = True
+ kwargs["tower_local_reduce_method"] = reduce_method
+ return next_creator(*args, **kwargs)
+
+ _require_distribution_strategy_scope(self)
+ return variable_scope.variable_creator_scope(create_tower_local_variable)
+
def colocate_vars_with(self, colocate_with_variable):
"""Scope that controls which devices variables will be created on.
finally:
_pop_per_thread_mode()
+ def tower_local_var_scope(self, reduce_method):
+ """Alias for distribution_strategy.tower_local_var_scope()."""
+ return self._distribution_strategy.tower_local_var_scope(reduce_method)
+
@property
def is_single_tower(self):
"""Returns whether there is a single tower or multiple."""
def creator(next_creator, *args, **kwargs):
_require_distribution_strategy_scope(self)
+ if kwargs.pop("tower_local_reduce_method", None) is not None:
+ kwargs["trainable"] = False
return next_creator(*args, **kwargs)
return _CurrentDistributionContext(