From 8609ef4db1a2af0da0c2c20b26756031637de3ff Mon Sep 17 00:00:00 2001 From: Igor Saprykin Date: Mon, 30 Apr 2018 12:41:12 -0700 Subject: [PATCH] When a mirrored variable is fetched in cross-tower mode, fetch its primary variable. This prevents errors like ValueError: Fetch argument MirroredVariable({'/job:localhost/replica:0/task:0/device:GPU:0': , '/job:localhost/replica:0/task:0/device:GPU:1': }) cannot be interpreted as a Tensor. (Device /job:localhost/replica:0/task:0/device:CPU:0 not found in ['/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1'] (current device )) I ran distribute/examples/resnet with and without the change and it fixed the problem. PiperOrigin-RevId: 194828672 --- tensorflow/contrib/distribute/python/values.py | 6 ++++++ tensorflow/contrib/distribute/python/values_test.py | 16 ++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 8cb5276..466678e 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -229,6 +229,12 @@ class DistributedVariable(DistributedDelegate): self._primary_var.op.type) return self.get().op + def _as_graph_element(self): + # pylint: disable=protected-access + if distribute_lib.get_cross_tower_context(): + return self._primary_var._as_graph_element() + return self.get()._as_graph_element() + def _should_act_as_resource_variable(self): """Pass resource_variable_ops.is_resource_variable check.""" pass diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index e96ce54..1d4e801 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util from tensorflow.python.training import saver as saver_lib @@ -582,6 +583,21 @@ class MirroredVariableTest(test.TestCase): save_path = self._save_normal() self._restore_mirrored(save_path) + @test_util.run_in_graph_and_eager_modes(config=config) + def testFetchAMirroredVariable(self): + if context.num_gpus() < 1 or context.executing_eagerly(): + self.skipTest("A GPU is not available for this test or it's eager mode.") + + with self.test_session( + graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy( + ["/device:GPU:0"]).scope(): + with ops.device("/device:GPU:0"): + v = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + mirrored = values.MirroredVariable({"/device:GPU:0": v}, v) + sess.run(variables_lib.global_variables_initializer()) + sess.run({"complicated": mirrored}) + _devices = ["/device:GPU:0", "/device:CPU:0"] -- 2.7.4