raise ValueError("No gradients provided for any variable: %s." %
([str(v) for _, v in grads_and_vars],))
return distribute_lib.get_tower_context().merge_call(
- self.distributed_apply, filtered, global_step=global_step, name=name)
+ self._distributed_apply, filtered, global_step=global_step, name=name)
def _get_or_create_state(self, var_list=None):
"""Either looks up or creates `_OptimizerV2State`.
self._per_graph_state[graph_key] = per_graph_state
return per_graph_state
- def distributed_apply(self, distribution, grads_and_vars, global_step, name):
+ def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
"""`apply_gradients` for use with a `DistributionStrategy`."""
reduced_grads = distribution.batch_reduce("sum", grads_and_vars)
var_list = [v for _, v in grads_and_vars]