Refactoring and bug-fixes for _build_initializer_expr.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 5 Jan 2018 02:53:50 +0000 (18:53 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 5 Jan 2018 02:57:24 +0000 (18:57 -0800)
- Rename _build_initializer_expr to _try_guard_against_uninitialized_dependencies so as to clarify what it does.

- Avoid invoking the logic in _try_guard_against_uninitialized_dependencies for cyclic graphs. This currently results in infinite recursion which blows the stack.

- Use memoization to reduce the number of redundant operations created by _try_guard_against_uninitialized_dependencies when it encounters initial values with diamond-shaped dependencies.

- Refactoring: Remove unnecessary logic in _try_guard_against_uninitialized_dependencies for dealing with types other than Tensor or Operation. The dependency graph of a Variable's _initial_value should only ever comprise these two types.

- Refactoring: Added some filtering logic to _try_guard_against_uninitialized_dependencies to avoid initial_values with cyclic dependencies

- Refactoring: Moved the recursive traversal of initial_value`s dependencies into _safe_initial_value_from_tensor and _safe_initial_value_from_op.

- Refactoring: Made it so _find_initialized_value_for_variable will return None when it can't find the initialized_value. Currently it returns a Tensor when it finds the initialized_value and an Operation when it can't. This makes the logic in the caller a bit more consistent and explicit.

Future changes will address more of the shortcomings of _build_initializer_expr.

PiperOrigin-RevId: 180876754

tensorflow/python/BUILD
tensorflow/python/ops/resource_variable_ops.py
tensorflow/python/ops/variables.py
tensorflow/python/training/session_manager_test.py

index 54af3071bc5d79a8fce5ee5326dc3bbb84f35449..c62ff10828bf77204e4734e9819d6d1e04131e77 100644 (file)
@@ -3717,6 +3717,7 @@ cuda_py_test(
     srcs = ["training/session_manager_test.py"],
     additional_deps = [
         ":array_ops",
+        ":control_flow_ops",
         ":client",
         ":client_testlib",
         ":errors",
index 58ede027477667a9d5f821dbf42d8a3fdab50b1a..60a32b1dbc9900f140cbc896eecca8b2d64bb254 100644 (file)
@@ -276,10 +276,6 @@ class ResourceVariable(variables.Variable):
           dtype=dtype,
           constraint=constraint)
 
-  # LINT.IfChange
-  # _VariableFromResource inherits from ResourceVariable but
-  # doesn't call the constructor, so changes here might need to be reflected
-  # there.
   # pylint: disable=unused-argument
   def _init_from_args(self,
                       initial_value=None,
@@ -438,7 +434,8 @@ class ResourceVariable(variables.Variable):
               self._initializer_op = (
                   gen_resource_variable_ops.assign_variable_op(
                       self._handle,
-                      self._build_initializer_expr(initial_value),
+                      self._try_guard_against_uninitialized_dependencies(
+                          initial_value),
                       name=n))
           with ops.name_scope("Read"), ops.colocate_with(self._handle):
             # Manually assign reads to the handle's device to avoid log
@@ -522,7 +519,6 @@ class ResourceVariable(variables.Variable):
     self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
     self._graph_element = self.value()
     self._constraint = None
-  # LINT.ThenChange(//tensorflow/python/eager/graph_callable.py)
 
   def __nonzero__(self):
     return self.__bool__()
index e0748d87e2d6ef2c2f8565669357f881334fa737..b25855633ed4ce485090fb47b09e1b5ce0ff2228 100644 (file)
@@ -362,7 +362,8 @@ class Variable(object):
         # using their initialized_value() method.
         self._initializer_op = state_ops.assign(
             self._variable,
-            self._build_initializer_expr(self._initial_value),
+            self._try_guard_against_uninitialized_dependencies(
+                self._initial_value),
             validate_shape=validate_shape).op
 
         # TODO(vrv): Change this class to not take caching_device, but
@@ -781,88 +782,142 @@ class Variable(object):
 
     setattr(Variable, operator, _run_op)
 
-  def _build_initializer_expr(self, initial_value):
-    """Build an expression suitable to initialize a variable.
+  def _try_guard_against_uninitialized_dependencies(self, initial_value):
+    """Attempt to guard against dependencies on uninitialized variables.
 
-    Replace references to variables in initial_value with references to the
-    variable initial values instead.
+    Replace references to variables in `initial_value` with references to the
+    variable's initialized values. The initialized values are essentially
+    conditional TensorFlow graphs that return a variable's value if it is
+    initialized or its `initial_value` if it hasn't been initialized. This
+    replacement is done on a best effort basis:
+
+    - If the `initial_value` graph contains cycles, we don't do any
+      replacements for that graph.
+    - If the variables that `initial_value` depends on are not present in the
+      `GLOBAL_VARIABLES` or `LOCAL_VARIABLES` we don't replace them.
+
+    In these cases, it is up to the caller to ensure that the `initial_value`
+    graph uses initialized variables or that they guard access to variables
+    using their `initialized_value` method.
 
     Args:
-      initial_value: original expression
+      initial_value: `Tensor`. The initial value.
     Returns:
-      A tensorflow expression suitable to initialize a variable.
+      A `Tensor` suitable to initialize a variable.
+    Raises:
+      TypeError: If `initial_value` is not a `Tensor`.
     """
-    if isinstance(initial_value, Variable):
-      return initial_value.initialized_value()
-    elif isinstance(initial_value, ops.Tensor):
-      new_op = self._build_initializer_expr(initial_value.op)
-      if new_op != initial_value.op:
-        if isinstance(new_op, ops.Tensor):
-          return new_op
-        else:
-          return ops.Tensor(new_op, initial_value.value_index,
-                            initial_value.dtype)
-      else:
-        return initial_value
-    elif isinstance(initial_value, ops.Operation):
-      if initial_value.node_def.op in [
-          "IsVariableInitialized", "VarIsInitializedOp", "ReadVariableOp"
-      ]:
-        return initial_value
-      if initial_value.node_def.op in ["Variable", "VariableV2", "VarHandleOp"]:
-        return self._find_initialized_value_for_variable(initial_value)
-      modified = False
-      new_inputs = []
-      for tensor in initial_value.inputs:
-        new_tensor = self._build_initializer_expr(tensor)
-        new_inputs.append(new_tensor)
-        if new_tensor != tensor:
-          modified = True
-
-      if modified:
-        new_name = initial_value.node_def.name + "_" + self.name
-        new_name = new_name.replace(":", "_")
-        new_op = initial_value.node_def.op
-        new_op = new_op.replace("RefSwitch", "Switch")
-        new_value = self.graph.create_op(
-            new_op,
-            new_inputs,
-            # pylint: disable=protected-access
-            initial_value._output_types,
-            # pylint: enable=protected-access
-            name=new_name,
-            attrs=initial_value.node_def.attr)
-        return new_value
-      else:
-        return initial_value
-    else:
+    if not isinstance(initial_value, ops.Tensor):
+      raise TypeError("initial_value needs to be a Tensor: %s" % initial_value)
+
+    # Don't modify initial_value if it contains any cyclic dependencies.
+    def has_cycle(op, path):
+      """Detect cycles in the dependencies of `initial_value`."""
+      if op.name in path:
+        return True
+      path.add(op.name)
+      for op_input in op.inputs:
+        if has_cycle(op_input.op, path):
+          return True
+      for op_control_input in op.control_inputs:
+        if has_cycle(op_control_input, path):
+          return True
+      path.remove(op.name)
+      return False
+    if has_cycle(initial_value.op, path=set()):
       return initial_value
 
+    return self._safe_initial_value_from_tensor(initial_value, op_cache={})
+
+  def _safe_initial_value_from_tensor(self, tensor, op_cache):
+    """Replace dependencies on variables with their initialized values.
+
+    Args:
+      tensor: A `Tensor`. The tensor to replace.
+      op_cache: A dict mapping operation names to `Operation`s. Used to memoize
+        the results so as to avoid creating redundant operations.
+    Returns:
+      A `Tensor` compatible with `tensor`. Any inputs that lead to variable
+      values will be replaced with a corresponding graph that uses the
+      variable's initialized values. This is done on a best-effort basis. If no
+      modifications need to be made then `tensor` will be returned unchanged.
+    """
+    op = tensor.op
+    new_op = op_cache.get(op.name)
+    if new_op is None:
+      new_op = self._safe_initial_value_from_op(op, op_cache)
+      op_cache[op.name] = new_op
+    return new_op.outputs[tensor.value_index]
+
+  def _safe_initial_value_from_op(self, op, op_cache):
+    """Replace dependencies on variables with their initialized values.
+
+    Args:
+      op: An `Operation`. The operation to replace.
+      op_cache: A dict mapping operation names to `Operation`s. Used to memoize
+        the results so as to avoid creating redundant operations.
+    Returns:
+      An `Operation` compatible with `op`. Any inputs that lead to variable
+      values will be replaced with a corresponding graph that uses the
+      variable's initialized values. This is done on a best-effort basis. If no
+      modifications need to be made then `op` will be returned unchanged.
+    """
+    op_type = op.node_def.op
+    if op_type in ("IsVariableInitialized", "VarIsInitializedOp",
+                   "ReadVariableOp"):
+      return op
+
+    # Attempt to find the initialized_value of any variable reference / handles.
+    # TODO(b/70206927): Fix handling of ResourceVariables.
+    if op_type in ("Variable", "VariableV2", "VarHandleOp"):
+      initialized_value = self._find_initialized_value_for_variable(op)
+      return op if initialized_value is None else initialized_value.op
+
+    # Recursively build initializer expressions for inputs.
+    modified = False
+    new_op_inputs = []
+    for op_input in op.inputs:
+      new_op_input = self._safe_initial_value_from_tensor(op_input, op_cache)
+      new_op_inputs.append(new_op_input)
+      modified = modified or (new_op_input != op_input)
+
+    # If at least one input was modified, replace the op.
+    if modified:
+      new_op_type = op_type
+      if new_op_type == "RefSwitch":
+        new_op_type = "Switch"
+      new_op_name = op.node_def.name + "_" + self.name
+      new_op_name = new_op_name.replace(":", "_")
+      return self.graph.create_op(
+          new_op_type, new_op_inputs,
+          op._output_types,  # pylint: disable=protected-access
+          name=new_op_name, attrs=op.node_def.attr)
+
+    return op
+
   def _find_initialized_value_for_variable(self, variable_op):
-    """Find the initial value for a variable op.
+    """Find the initialized value for a variable op.
 
     To do so, lookup the variable op in the variables collection.
 
     Args:
-      variable_op: a TensorFlow variable Operation
+      variable_op: A variable `Operation`.
     Returns:
-      The initial value for the variable.
+      A `Tensor` representing the initialized value for the variable or `None`
+      if the initialized value could not be found.
     """
     try:
       var_names = [variable_op.node_def.name, variable_op.node_def.name + ":0"]
-      global_vars = self.graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
-      for var in global_vars:
-        if var.name in var_names:
-          return var.initialized_value()
-      local_vars = self.graph.get_collection(ops.GraphKeys.LOCAL_VARIABLES)
-      for var in local_vars:
-        if var.name == var_names:
-          return var.initialized_value()
+      for collection_name in (ops.GraphKeys.GLOBAL_VARIABLES,
+                              ops.GraphKeys.LOCAL_VARIABLES):
+        for var in self.graph.get_collection(collection_name):
+          if var.name in var_names:
+            return var.initialized_value()
     except AttributeError:
-      # Return the variable itself when an incomplete user defined variable type
-      # was put in the collection.
-      return variable_op
-    return variable_op
+      # Return None when an incomplete user-defined variable type was put in
+      # the collection.
+      return None
+    return None
 
   # NOTE(mrry): This enables the Variable's overloaded "right" binary
   # operators to run when the left operand is an ndarray, because it
index 5879fd330adec58dde45f3da8ae16c9a297f3b24..6670d9365f2994a70b7228170179f97d314041c9 100644 (file)
@@ -26,6 +26,7 @@ from tensorflow.python.framework import errors
 from tensorflow.python.framework import errors_impl
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import variables
 from tensorflow.python.platform import gfile
 from tensorflow.python.platform import test
@@ -504,6 +505,7 @@ class SessionManagerTest(test.TestCase):
           trainable=False,
           collections=[ops.GraphKeys.LOCAL_VARIABLES],
           name="x")
+      # TODO(b/70206927): Use ResourceVariables once they are handled properly.
       v_res = variables.Variable(1, name="v_res")
       w_res = variables.Variable(
           v_res,
@@ -556,6 +558,24 @@ class SessionManagerTest(test.TestCase):
       self.assertEquals(1, sess.run(w_res))
       self.assertEquals(3, sess.run(x_res))
 
+  def testPrepareSessionWithCyclicInitializer(self):
+    # Regression test. Previously Variable._build_initializer_expr would enter
+    # into an infinite recursion when the variable's initial_value involved
+    # cyclic dependencies.
+    with ops.Graph().as_default():
+      i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0])
+      v = variables.Variable(array_ops.identity(i), name="v")
+      with self.test_session():
+        self.assertEqual(False, variables.is_variable_initialized(v).eval())
+      sm = session_manager.SessionManager(
+          ready_op=variables.report_uninitialized_variables())
+      sess = sm.prepare_session("", init_op=v.initializer)
+      self.assertEqual(1, sess.run(v))
+      self.assertEqual(
+          True,
+          variables.is_variable_initialized(
+              sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
+
   def testPrepareSessionDidNotInitLocalVariable(self):
     with ops.Graph().as_default():
       v = variables.Variable(1, name="v")