When a mirrored variable is fetched in cross-tower mode, fetch its primary variable.
authorIgor Saprykin <isaprykin@google.com>
Mon, 30 Apr 2018 19:41:12 +0000 (12:41 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 30 Apr 2018 19:43:36 +0000 (12:43 -0700)
This prevents errors like
ValueError: Fetch argument MirroredVariable({'/job:localhost/replica:0/task:0/device:GPU:0': <tf.Variable 'global_step:0' shape=() dtype=int64>, '/job:localhost/replica:0/task:0/device:GPU:1': <tf.Variable 'global_step/replica_1:0' shape=() dtype=int64>}) cannot be interpreted as a Tensor. (Device /job:localhost/replica:0/task:0/device:CPU:0 not found in ['/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1'] (current device ))

I ran distribute/examples/resnet with and without the change and it fixed the problem.

PiperOrigin-RevId: 194828672

tensorflow/contrib/distribute/python/values.py
tensorflow/contrib/distribute/python/values_test.py

index 8cb5276..466678e 100644 (file)
@@ -229,6 +229,12 @@ class DistributedVariable(DistributedDelegate):
                               self._primary_var.op.type)
     return self.get().op
 
+  def _as_graph_element(self):
+    # pylint: disable=protected-access
+    if distribute_lib.get_cross_tower_context():
+      return self._primary_var._as_graph_element()
+    return self.get()._as_graph_element()
+
   def _should_act_as_resource_variable(self):
     """Pass resource_variable_ops.is_resource_variable check."""
     pass
index e96ce54..1d4e801 100644 (file)
@@ -34,6 +34,7 @@ from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variables_lib
 from tensorflow.python.training import device_util
 from tensorflow.python.training import saver as saver_lib
 
@@ -582,6 +583,21 @@ class MirroredVariableTest(test.TestCase):
     save_path = self._save_normal()
     self._restore_mirrored(save_path)
 
+  @test_util.run_in_graph_and_eager_modes(config=config)
+  def testFetchAMirroredVariable(self):
+    if context.num_gpus() < 1 or context.executing_eagerly():
+      self.skipTest("A GPU is not available for this test or it's eager mode.")
+
+    with self.test_session(
+        graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy(
+            ["/device:GPU:0"]).scope():
+      with ops.device("/device:GPU:0"):
+        v = variable_scope.get_variable(
+            name="v", initializer=1., use_resource=True)
+      mirrored = values.MirroredVariable({"/device:GPU:0": v}, v)
+      sess.run(variables_lib.global_variables_initializer())
+      sess.run({"complicated": mirrored})
+
 
 _devices = ["/device:GPU:0", "/device:CPU:0"]