expected_sum = 0.0
expected_mean = 0.0
for i, d in enumerate(dist.worker_devices):
- # Test access within a device scope, should see different values.
- with ops.device(d):
- v_sum_value = self.evaluate(ret_v_sum.read_value())
- v_mean_value = self.evaluate(ret_v_mean.read_value())
- expected = i + 3.0
- self.assertEqual(expected, v_sum_value)
- expected_sum += expected
- expected = i * 6.0
- self.assertEqual(expected, v_mean_value)
- expected_mean += expected
-
- # fetch() should return the value you get by applying the
- # reduction across all towers.
- self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum)))
+ # Should see different values on different devices.
+ v_sum_value = self.evaluate(ret_v_sum.get(d).read_value())
+ v_mean_value = self.evaluate(ret_v_mean.get(d).read_value())
+ expected = i + 3.0
+ self.assertEqual(expected, v_sum_value)
+ expected_sum += expected
+ expected = i * 6.0
+ self.assertEqual(expected, v_mean_value)
+ expected_mean += expected
expected_mean /= len(dist.worker_devices)
+
+ # Without get(device), should return the value you get by
+ # applying the reduction across all towers (whether you use
+ # fetch(), get(), or nothing).
+ self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum)))
self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean)))
+ self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get()))
+ self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get()))
+ if not context.executing_eagerly():
+ self.assertEqual(expected_sum, self.evaluate(ret_v_sum))
+ self.assertEqual(expected_mean, self.evaluate(ret_v_mean))
# NOTE(priyag): Names and name scopes are ignored in eager, hence we are not
# testing this in eager mode.
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.training import checkpointable
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
else:
device = distribute_lib.get_update_device()
if device is None:
- device = device_util.current()
+ return self._get_cross_tower()
device = device_util.canonicalize(device)
try:
return self._index[device]
self._primary_var.op.type)
return self.get().op
- def _as_graph_element(self):
- # pylint: disable=protected-access
- if distribute_lib.get_cross_tower_context():
- return self._primary_var._as_graph_element()
- return self.get()._as_graph_element()
-
def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check."""
pass
def assign(self, *args, **kwargs):
return self.get(device=_get_update_device()).assign(*args, **kwargs)
+ def _get_cross_tower(self):
+ device = device_util.canonicalize(device_util.current())
+ if device in self._index:
+ return array_ops.identity(self._index[device])
+ return array_ops.identity(self._primary_var)
+
+ def _as_graph_element(self):
+ # pylint: disable=protected-access
+ if distribute_lib.get_cross_tower_context():
+ return self._primary_var._as_graph_element()
+ return self.get()._as_graph_element()
+
def _gather_saveables_for_checkpoint(self):
"""Overrides CheckpointableBase method.
for d, v in six.iteritems(self._tower_local_variable._index)]) # pylint: disable=protected-access
+def _assert_tower_context():
+ if not distribute_lib.get_tower_context():
+ raise RuntimeError(
+ "Tower-local variables may only be assigned in a tower context.")
+
+
class TowerLocalVariable(DistributedVariable, PerDevice,
checkpointable.CheckpointableBase):
"""Holds a map from device to variables whose values are reduced on save."""
super(TowerLocalVariable, self).__init__(index)
def assign_sub(self, *args, **kwargs):
+ _assert_tower_context()
return self.get().assign_sub(*args, **kwargs)
def assign_add(self, *args, **kwargs):
+ _assert_tower_context()
return self.get().assign_add(*args, **kwargs)
def assign(self, *args, **kwargs):
+ _assert_tower_context()
return self.get().assign(*args, **kwargs)
@property
def reduce_method(self):
return self._reduce_method
+ def _get_cross_tower(self):
+ all_components = tuple(self._index.values())
+ # TODO(josh11b): Use a strategy-specific method.
+ total = math_ops.add_n(all_components)
+ if self._reduce_method == "mean":
+ return total * (1./ len(all_components))
+ return total
+
+ def _as_graph_element(self):
+ # pylint: disable=protected-access
+ if distribute_lib.get_cross_tower_context():
+ return self._get_cross_tower()
+ return self.get()._as_graph_element()
+
def _gather_saveables_for_checkpoint(self):
"""Overrides CheckpointableBase method.