Always enter the handle graph before calling convert_to_tensor in resource variables.
authorAlexandre Passos <apassos@google.com>
Mon, 21 May 2018 20:00:20 +0000 (13:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 21 May 2018 20:03:11 +0000 (13:03 -0700)
This mimics the behavior of ref variable's assign which converts to tensor in the
right graph inside op_def_lib.apply_op.

PiperOrigin-RevId: 197441989

tensorflow/python/framework/ops.py
tensorflow/python/kernel_tests/resource_variable_ops_test.py
tensorflow/python/ops/resource_variable_ops.py

index 9cd51a3..7b3acc4 100644 (file)
@@ -5562,6 +5562,10 @@ def get_default_graph():
   """
   return _default_graph_stack.get_default()
 
+def has_default_graph():
+  """Returns True if there is a default graph."""
+  return len(_default_graph_stack.stack) >= 1
+
 
 def get_name_scope():
   """Returns the current name scope in the default_graph.
index 073799c..846231f 100644 (file)
@@ -106,6 +106,13 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
       v = resource_variable_ops.ResourceVariable(False, name="bool_test")
       self.assertAllEqual(bool(v), False)
 
+  def testDifferentAssignGraph(self):
+    with ops.Graph().as_default():
+      v = resource_variable_ops.ResourceVariable(1.0)
+    ops.reset_default_graph()
+    v.assign(2.0)  # Note: this fails if we run convert_to_tensor on not the
+                   # variable graph.
+
   def testFetchHandle(self):
     with self.test_session():
       handle = resource_variable_ops.var_handle_op(
index 93f3912..e5b8020 100644 (file)
@@ -19,6 +19,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import contextlib
+
 from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.framework import variable_pb2
 from tensorflow.python import pywrap_tensorflow
@@ -115,6 +117,18 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
   return handle
 
 
+@contextlib.contextmanager
+def _handle_graph(handle):
+  # Note: might have an eager tensor but not be executing eagerly when building
+  # functions.
+  if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor)
+      or ops.has_default_graph()):
+    yield
+  else:
+    with handle.graph.as_default():
+      yield
+
+
 class EagerResourceDeleter(object):
   """An object which cleans up a resource handle.
 
@@ -159,7 +173,8 @@ class EagerResourceDeleter(object):
 
 def shape_safe_assign_variable_handle(handle, shape, value, name=None):
   """Helper that checks shape compatibility and assigns variable."""
-  value_tensor = ops.convert_to_tensor(value)
+  with _handle_graph(handle):
+    value_tensor = ops.convert_to_tensor(value)
   shape.assert_is_compatible_with(value_tensor.shape)
   return gen_resource_variable_ops.assign_variable_op(handle,
                                                       value_tensor,
@@ -850,8 +865,10 @@ class ResourceVariable(variables.Variable):
     # TODO(apassos): this here and below is not atomic. Consider making it
     # atomic if there's a way to do so without a performance cost for those who
     # don't need it.
-    assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(
-        self.handle, ops.convert_to_tensor(delta, dtype=self.dtype), name=name)
+    with _handle_graph(self.handle):
+      assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(
+          self.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
+          name=name)
     if read_value:
       return self._lazy_read(assign_sub_op)
     return assign_sub_op
@@ -872,8 +889,10 @@ class ResourceVariable(variables.Variable):
       it will return the `Operation` that does the assignment, and when in eager
       mode it will return `None`.
     """
-    assign_add_op = gen_resource_variable_ops.assign_add_variable_op(
-        self.handle, ops.convert_to_tensor(delta, dtype=self.dtype), name=name)
+    with _handle_graph(self.handle):
+      assign_add_op = gen_resource_variable_ops.assign_add_variable_op(
+          self.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
+          name=name)
     if read_value:
       return self._lazy_read(assign_add_op)
     return assign_add_op
@@ -902,30 +921,32 @@ class ResourceVariable(variables.Variable):
       it will return the `Operation` that does the assignment, and when in eager
       mode it will return `None`.
     """
-    value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
-    self._shape.assert_is_compatible_with(value_tensor.shape)
-    assign_op = gen_resource_variable_ops.assign_variable_op(
-        self.handle, value_tensor, name=name)
-    if read_value:
-      return self._lazy_read(assign_op)
+    with _handle_graph(self.handle):
+      value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
+      self._shape.assert_is_compatible_with(value_tensor.shape)
+      assign_op = gen_resource_variable_ops.assign_variable_op(
+          self.handle, value_tensor, name=name)
+      if read_value:
+        return self._lazy_read(assign_op)
     return assign_op
 
   def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask,
                             end_mask, ellipsis_mask, new_axis_mask,
                             shrink_axis_mask):
-    return self._lazy_read(
-        gen_array_ops.resource_strided_slice_assign(
-            ref=self.handle,
-            begin=begin,
-            end=end,
-            strides=strides,
-            value=ops.convert_to_tensor(value, dtype=self.dtype),
-            name=name,
-            begin_mask=begin_mask,
-            end_mask=end_mask,
-            ellipsis_mask=ellipsis_mask,
-            new_axis_mask=new_axis_mask,
-            shrink_axis_mask=shrink_axis_mask))
+    with _handle_graph(self.handle):
+      return self._lazy_read(
+          gen_array_ops.resource_strided_slice_assign(
+              ref=self.handle,
+              begin=begin,
+              end=end,
+              strides=strides,
+              value=ops.convert_to_tensor(value, dtype=self.dtype),
+              name=name,
+              begin_mask=begin_mask,
+              end_mask=end_mask,
+              ellipsis_mask=ellipsis_mask,
+              new_axis_mask=new_axis_mask,
+              shrink_axis_mask=shrink_axis_mask))
 
   def __int__(self):
     if self.dtype != dtypes.int32 and self.dtype != dtypes.int64: