Fixes issue where name scope collisions could lead to an invalid variable in the...
authorAlexandre Passos <apassos@google.com>
Wed, 11 Apr 2018 22:23:17 +0000 (15:23 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 11 Apr 2018 22:27:09 +0000 (15:27 -0700)
PiperOrigin-RevId: 192518307

tensorflow/python/kernel_tests/resource_variable_ops_test.py
tensorflow/python/ops/resource_variable_ops.py

index 6d33086..9841922 100644 (file)
@@ -36,6 +36,9 @@ from tensorflow.python.ops import state_ops
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import test
+from tensorflow.python.training import momentum
+from tensorflow.python.training import saver
+from tensorflow.python.training import training_util
 from tensorflow.python.util import compat
 
 
@@ -228,16 +231,40 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
 
   @test_util.run_in_graph_and_eager_modes()
   def testScatterMin(self):
-    handle = resource_variable_ops.var_handle_op(
-        dtype=dtypes.int32, shape=[1, 1])
-    self.evaluate(
-        resource_variable_ops.assign_variable_op(
-            handle, constant_op.constant([[6]], dtype=dtypes.int32)))
-    self.evaluate(
-        resource_variable_ops.resource_scatter_min(
-            handle, [0], constant_op.constant([[3]], dtype=dtypes.int32)))
-    read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
-    self.assertEqual(self.evaluate(read), [[3]])
+    with ops.device("cpu:0"):
+      handle = resource_variable_ops.var_handle_op(
+          dtype=dtypes.int32, shape=[1, 1])
+      self.evaluate(
+          resource_variable_ops.assign_variable_op(handle,
+                                                   constant_op.constant(
+                                                       [[6]],
+                                                       dtype=dtypes.int32)))
+      self.evaluate(
+          resource_variable_ops.resource_scatter_min(handle, [0],
+                                                     constant_op.constant(
+                                                         [[3]],
+                                                         dtype=dtypes.int32)))
+      read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
+      self.assertEqual(self.evaluate(read), [[3]])
+
+  def testMetagraph(self):
+    with ops.Graph().as_default():
+      with variable_scope.variable_scope("foo", use_resource=True):
+        a = variable_scope.get_variable("a", initializer=10.0)
+
+      momentum.MomentumOptimizer(
+          learning_rate=0.001, momentum=0.1).minimize(
+              a,
+              colocate_gradients_with_ops=True,
+              global_step=training_util.get_or_create_global_step())
+
+      graph = ops.get_default_graph()
+      meta_graph_def = saver.export_meta_graph(graph=graph)
+
+    with ops.Graph().as_default():
+      saver.import_meta_graph(meta_graph_def, import_scope="")
+      meta_graph_two = saver.export_meta_graph(graph=graph)
+    self.assertEqual(meta_graph_def, meta_graph_two)
 
   @test_util.run_in_graph_and_eager_modes()
   def testScatterMax(self):
index 508ba9b..c51d1e4 100644 (file)
@@ -525,8 +525,15 @@ class ResourceVariable(variables.Variable):
       self._cached_value = g.as_graph_element(
           ops.prepend_name_scope(
               variable_def.snapshot_name, import_scope=import_scope))
+      self._graph_element = g.as_graph_element(
+          ops.prepend_name_scope(variable_def.snapshot_name,
+                                 import_scope=import_scope))
     else:
       self._cached_value = None
+      # Legacy case for protos without the snapshot name; assume it's the
+      # following.
+      self._graph_element = g.get_tensor_by_name(
+          self._handle.op.name + "/Read/ReadVariableOp:0")
     if variable_def.HasField("save_slice_info_def"):
       self._save_slice_info = variables.Variable.SaveSliceInfo(
           save_slice_info_def=variable_def.save_slice_info_def,
@@ -535,8 +542,6 @@ class ResourceVariable(variables.Variable):
       self._save_slice_info = None
     self._caching_device = None
     self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
-    self._graph_element = g.get_tensor_by_name(
-        self._handle.op.name + "/Read/ReadVariableOp:0")
     self._constraint = None
     self._cached_shape_as_list = None
 
@@ -745,6 +750,10 @@ class ResourceVariable(variables.Variable):
       if self._cached_value is not None:
         var_def.snapshot_name = ops.strip_name_scope(self._cached_value.name,
                                                      export_scope)
+      else:
+        # Store the graph_element here
+        var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name,
+                                                     export_scope)
       var_def.is_resource = True
       if self._save_slice_info:
         var_def.save_slice_info_def.MergeFrom(
@@ -910,7 +919,6 @@ class ResourceVariable(variables.Variable):
   def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
     del name
     if dtype is not None and dtype != self.dtype:
-      print("trying to switch the dtype to ", dtype, " from ", self.dtype)
       return NotImplemented
     if as_ref:
       return self.read_value().op.inputs[0]