Makes ResourceVariables return correct initialized_value and initial_value for object...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 6 Feb 2018 11:11:59 +0000 (03:11 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Feb 2018 11:16:08 +0000 (03:16 -0800)
Previously self._initial_value wasn't set in such cases which causes accessing var.initial_value to fail for variables in the imported meta graphs.

PiperOrigin-RevId: 184657191

tensorflow/python/kernel_tests/resource_variable_ops_test.py
tensorflow/python/ops/resource_variable_ops.py

index cd94579..dc6e73b 100644 (file)
@@ -276,6 +276,32 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
     v.load(2.0)
     self.assertEqual(2.0, self.evaluate(v.value()))
 
+  def testVariableDefInitializedInstances(self):
+    with ops.Graph().as_default(), self.test_session() as sess:
+      v_def = resource_variable_ops.ResourceVariable(
+          initial_value=constant_op.constant(3.0)).to_proto()
+
+    with ops.Graph().as_default(), self.test_session() as sess:
+      # v describes a VariableDef-based variable without an initial value.
+      v = resource_variable_ops.ResourceVariable(variable_def=v_def)
+      self.assertEqual(3.0, sess.run(v.initialized_value()))
+
+      # initialized_value should not rerun the initializer_op if the variable
+      # has already been initialized elsewhere.
+      sess.run(v.assign(1.0))
+      self.assertEqual(1.0, v.initialized_value().eval())
+
+    v_def.ClearField("initial_value_name")
+    with ops.Graph().as_default(), self.test_session() as sess:
+      # Restoring a legacy VariableDef proto that does not have
+      # initial_value_name set should still work.
+      v = resource_variable_ops.ResourceVariable(variable_def=v_def)
+      # We should also be able to re-export the variable to a new meta graph.
+      self.assertProtoEquals(v_def, v.to_proto())
+      # But attempts to use initialized_value will result in errors.
+      with self.assertRaises(ValueError):
+        sess.run(v.initialized_value())
+
   @test_util.run_in_graph_and_eager_modes()
   def testSparseRead(self):
     with self.test_session():
index cc9f798..75cb57f 100644 (file)
@@ -503,6 +503,14 @@ class ResourceVariable(variables.Variable):
     self._initializer_op = g.as_graph_element(
         ops.prepend_name_scope(
             variable_def.initializer_name, import_scope=import_scope))
+    # Check whether initial_value_name exists for backwards compatibility.
+    if (hasattr(variable_def, "initial_value_name") and
+        variable_def.initial_value_name):
+      self._initial_value = g.as_graph_element(
+          ops.prepend_name_scope(variable_def.initial_value_name,
+                                 import_scope=import_scope))
+    else:
+      self._initial_value = None
     if variable_def.snapshot_name:
       self._cached_value = g.as_graph_element(
           ops.prepend_name_scope(
@@ -699,6 +707,12 @@ class ResourceVariable(variables.Variable):
       var_def = variable_pb2.VariableDef()
       var_def.variable_name = ops.strip_name_scope(self.handle.name,
                                                    export_scope)
+      if self._initial_value is not None:
+        # This is inside an if-statement for backwards compatibility, since
+        # self._initial_value might be None for variables constructed from old
+        # protos.
+        var_def.initial_value_name = ops.strip_name_scope(
+            self._initial_value.name, export_scope)
       var_def.initializer_name = ops.strip_name_scope(self.initializer.name,
                                                       export_scope)
       if self._cached_value is not None: