From: Igor Saprykin Date: Mon, 30 Apr 2018 19:41:12 +0000 (-0700) Subject: When a mirrored variable is fetched in cross-tower mode, fetch its primary variable. X-Git-Tag: upstream/v1.9.0_rc1~179^2^2~71 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=8609ef4db1a2af0da0c2c20b26756031637de3ff;p=platform%2Fupstream%2Ftensorflow.git When a mirrored variable is fetched in cross-tower mode, fetch its primary variable. This prevents errors like ValueError: Fetch argument MirroredVariable({'/job:localhost/replica:0/task:0/device:GPU:0': , '/job:localhost/replica:0/task:0/device:GPU:1': }) cannot be interpreted as a Tensor. (Device /job:localhost/replica:0/task:0/device:CPU:0 not found in ['/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1'] (current device )) I ran distribute/examples/resnet with and without the change and it fixed the problem. PiperOrigin-RevId: 194828672 --- diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 8cb5276..466678e 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -229,6 +229,12 @@ class DistributedVariable(DistributedDelegate): self._primary_var.op.type) return self.get().op + def _as_graph_element(self): + # pylint: disable=protected-access + if distribute_lib.get_cross_tower_context(): + return self._primary_var._as_graph_element() + return self.get()._as_graph_element() + def _should_act_as_resource_variable(self): """Pass resource_variable_ops.is_resource_variable check.""" pass diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index e96ce54..1d4e801 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -34,6 +34,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training import device_util from tensorflow.python.training import saver as saver_lib @@ -582,6 +583,21 @@ class MirroredVariableTest(test.TestCase): save_path = self._save_normal() self._restore_mirrored(save_path) + @test_util.run_in_graph_and_eager_modes(config=config) + def testFetchAMirroredVariable(self): + if context.num_gpus() < 1 or context.executing_eagerly(): + self.skipTest("A GPU is not available for this test or it's eager mode.") + + with self.test_session( + graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy( + ["/device:GPU:0"]).scope(): + with ops.device("/device:GPU:0"): + v = variable_scope.get_variable( + name="v", initializer=1., use_resource=True) + mirrored = values.MirroredVariable({"/device:GPU:0": v}, v) + sess.run(variables_lib.global_variables_initializer()) + sess.run({"complicated": mirrored}) + _devices = ["/device:GPU:0", "/device:CPU:0"]