Fixes for accessing variables with a MirroredStrategy in a
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 7 May 2018 20:24:12 +0000 (13:24 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 8 May 2018 00:05:34 +0000 (17:05 -0700)
cross-tower context:

* only provide read-only access to variables via get()

* don't fail if use the variable isn't copied to the current device in
  get()

* make _as_graph_element() return the aggregate value for tower-local
  variables (instead of the incorrect previous behavior of returning
  the primary)

PiperOrigin-RevId: 195711474

tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
tensorflow/contrib/distribute/python/values.py

index 6c5c055..3635bd2 100644 (file)
@@ -370,22 +370,27 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
       expected_sum = 0.0
       expected_mean = 0.0
       for i, d in enumerate(dist.worker_devices):
-        # Test access within a device scope, should see different values.
-        with ops.device(d):
-          v_sum_value = self.evaluate(ret_v_sum.read_value())
-          v_mean_value = self.evaluate(ret_v_mean.read_value())
-          expected = i + 3.0
-          self.assertEqual(expected, v_sum_value)
-          expected_sum += expected
-          expected = i * 6.0
-          self.assertEqual(expected, v_mean_value)
-          expected_mean += expected
-
-      # fetch() should return the value you get by applying the
-      # reduction across all towers.
-      self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum)))
+        # Should see different values on different devices.
+        v_sum_value = self.evaluate(ret_v_sum.get(d).read_value())
+        v_mean_value = self.evaluate(ret_v_mean.get(d).read_value())
+        expected = i + 3.0
+        self.assertEqual(expected, v_sum_value)
+        expected_sum += expected
+        expected = i * 6.0
+        self.assertEqual(expected, v_mean_value)
+        expected_mean += expected
       expected_mean /= len(dist.worker_devices)
+
+      # Without get(device), should return the value you get by
+      # applying the reduction across all towers (whether you use
+      # fetch(), get(), or nothing).
+      self.assertEqual(expected_sum, self.evaluate(dist.fetch(ret_v_sum)))
       self.assertEqual(expected_mean, self.evaluate(dist.fetch(ret_v_mean)))
+      self.assertEqual(expected_sum, self.evaluate(ret_v_sum.get()))
+      self.assertEqual(expected_mean, self.evaluate(ret_v_mean.get()))
+      if not context.executing_eagerly():
+        self.assertEqual(expected_sum, self.evaluate(ret_v_sum))
+        self.assertEqual(expected_mean, self.evaluate(ret_v_mean))
 
   # NOTE(priyag): Names and name scopes are ignored in eager, hence we are not
   # testing this in eager mode.
index b04734f..759f3c3 100644 (file)
@@ -34,6 +34,7 @@ from tensorflow.python.framework import device as tf_device
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
 from tensorflow.python.training import checkpointable
 from tensorflow.python.training import device_util
 from tensorflow.python.training import distribute as distribute_lib
@@ -60,7 +61,7 @@ class DistributedValues(object):
       else:
         device = distribute_lib.get_update_device()
         if device is None:
-          device = device_util.current()
+          return self._get_cross_tower()
     device = device_util.canonicalize(device)
     try:
       return self._index[device]
@@ -231,12 +232,6 @@ class DistributedVariable(DistributedDelegate):
                               self._primary_var.op.type)
     return self.get().op
 
-  def _as_graph_element(self):
-    # pylint: disable=protected-access
-    if distribute_lib.get_cross_tower_context():
-      return self._primary_var._as_graph_element()
-    return self.get()._as_graph_element()
-
   def _should_act_as_resource_variable(self):
     """Pass resource_variable_ops.is_resource_variable check."""
     pass
@@ -320,6 +315,18 @@ class MirroredVariable(DistributedVariable, Mirrored,
   def assign(self, *args, **kwargs):
     return self.get(device=_get_update_device()).assign(*args, **kwargs)
 
+  def _get_cross_tower(self):
+    device = device_util.canonicalize(device_util.current())
+    if device in self._index:
+      return array_ops.identity(self._index[device])
+    return array_ops.identity(self._primary_var)
+
+  def _as_graph_element(self):
+    # pylint: disable=protected-access
+    if distribute_lib.get_cross_tower_context():
+      return self._primary_var._as_graph_element()
+    return self.get()._as_graph_element()
+
   def _gather_saveables_for_checkpoint(self):
     """Overrides CheckpointableBase method.
 
@@ -364,6 +371,12 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
         for d, v in six.iteritems(self._tower_local_variable._index)])  # pylint: disable=protected-access
 
 
+def _assert_tower_context():
+  if not distribute_lib.get_tower_context():
+    raise RuntimeError(
+        "Tower-local variables may only be assigned in a tower context.")
+
+
 class TowerLocalVariable(DistributedVariable, PerDevice,
                          checkpointable.CheckpointableBase):
   """Holds a map from device to variables whose values are reduced on save."""
@@ -374,18 +387,35 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
     super(TowerLocalVariable, self).__init__(index)
 
   def assign_sub(self, *args, **kwargs):
+    _assert_tower_context()
     return self.get().assign_sub(*args, **kwargs)
 
   def assign_add(self, *args, **kwargs):
+    _assert_tower_context()
     return self.get().assign_add(*args, **kwargs)
 
   def assign(self, *args, **kwargs):
+    _assert_tower_context()
     return self.get().assign(*args, **kwargs)
 
   @property
   def reduce_method(self):
     return self._reduce_method
 
+  def _get_cross_tower(self):
+    all_components = tuple(self._index.values())
+    # TODO(josh11b): Use a strategy-specific method.
+    total = math_ops.add_n(all_components)
+    if self._reduce_method == "mean":
+      return total * (1./ len(all_components))
+    return total
+
+  def _as_graph_element(self):
+    # pylint: disable=protected-access
+    if distribute_lib.get_cross_tower_context():
+      return self._get_cross_tower()
+    return self.get()._as_graph_element()
+
   def _gather_saveables_for_checkpoint(self):
     """Overrides CheckpointableBase method.