self._primary_var.op.type)
return self.get().op
+ def _as_graph_element(self):
+ # pylint: disable=protected-access
+ if distribute_lib.get_cross_tower_context():
+ return self._primary_var._as_graph_element()
+ return self.get()._as_graph_element()
+
def _should_act_as_resource_variable(self):
"""Pass resource_variable_ops.is_resource_variable check."""
pass
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import device_util
from tensorflow.python.training import saver as saver_lib
save_path = self._save_normal()
self._restore_mirrored(save_path)
+ @test_util.run_in_graph_and_eager_modes(config=config)
+ def testFetchAMirroredVariable(self):
+ if context.num_gpus() < 1 or context.executing_eagerly():
+ self.skipTest("A GPU is not available for this test or it's eager mode.")
+
+ with self.test_session(
+ graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy(
+ ["/device:GPU:0"]).scope():
+ with ops.device("/device:GPU:0"):
+ v = variable_scope.get_variable(
+ name="v", initializer=1., use_resource=True)
+ mirrored = values.MirroredVariable({"/device:GPU:0": v}, v)
+ sess.run(variables_lib.global_variables_initializer())
+ sess.run({"complicated": mirrored})
+
_devices = ["/device:GPU:0", "/device:CPU:0"]