# using their initialized_value() method.
self._initializer_op = state_ops.assign(
self._variable,
- self._build_initializer_expr(self._initial_value),
+ self._try_guard_against_uninitialized_dependencies(
+ self._initial_value),
validate_shape=validate_shape).op
# TODO(vrv): Change this class to not take caching_device, but
setattr(Variable, operator, _run_op)
- def _build_initializer_expr(self, initial_value):
- """Build an expression suitable to initialize a variable.
+ def _try_guard_against_uninitialized_dependencies(self, initial_value):
+ """Attempt to guard against dependencies on uninitialized variables.
- Replace references to variables in initial_value with references to the
- variable initial values instead.
+ Replace references to variables in `initial_value` with references to the
+ variable's initialized values. The initialized values are essentially
+ conditional TensorFlow graphs that return a variable's value if it is
+ initialized or its `initial_value` if it hasn't been initialized. This
+ replacement is done on a best effort basis:
+
+ - If the `initial_value` graph contains cycles, we don't do any
+ replacements for that graph.
+ - If the variables that `initial_value` depends on are not present in the
+ `GLOBAL_VARIABLES` or `LOCAL_VARIABLES` we don't replace them.
+
+ In these cases, it is up to the caller to ensure that the `initial_value`
+ graph uses initialized variables or that they guard access to variables
+ using their `initialized_value` method.
Args:
- initial_value: original expression
+ initial_value: `Tensor`. The initial value.
Returns:
- A tensorflow expression suitable to initialize a variable.
+ A `Tensor` suitable to initialize a variable.
+ Raises:
+ TypeError: If `initial_value` is not a `Tensor`.
"""
- if isinstance(initial_value, Variable):
- return initial_value.initialized_value()
- elif isinstance(initial_value, ops.Tensor):
- new_op = self._build_initializer_expr(initial_value.op)
- if new_op != initial_value.op:
- if isinstance(new_op, ops.Tensor):
- return new_op
- else:
- return ops.Tensor(new_op, initial_value.value_index,
- initial_value.dtype)
- else:
- return initial_value
- elif isinstance(initial_value, ops.Operation):
- if initial_value.node_def.op in [
- "IsVariableInitialized", "VarIsInitializedOp", "ReadVariableOp"
- ]:
- return initial_value
- if initial_value.node_def.op in ["Variable", "VariableV2", "VarHandleOp"]:
- return self._find_initialized_value_for_variable(initial_value)
- modified = False
- new_inputs = []
- for tensor in initial_value.inputs:
- new_tensor = self._build_initializer_expr(tensor)
- new_inputs.append(new_tensor)
- if new_tensor != tensor:
- modified = True
-
- if modified:
- new_name = initial_value.node_def.name + "_" + self.name
- new_name = new_name.replace(":", "_")
- new_op = initial_value.node_def.op
- new_op = new_op.replace("RefSwitch", "Switch")
- new_value = self.graph.create_op(
- new_op,
- new_inputs,
- # pylint: disable=protected-access
- initial_value._output_types,
- # pylint: enable=protected-access
- name=new_name,
- attrs=initial_value.node_def.attr)
- return new_value
- else:
- return initial_value
- else:
+ if not isinstance(initial_value, ops.Tensor):
+ raise TypeError("initial_value needs to be a Tensor: %s" % initial_value)
+
+ # Don't modify initial_value if it contains any cyclic dependencies.
+ def has_cycle(op, path):
+ """Detect cycles in the dependencies of `initial_value`."""
+ if op.name in path:
+ return True
+ path.add(op.name)
+ for op_input in op.inputs:
+ if has_cycle(op_input.op, path):
+ return True
+ for op_control_input in op.control_inputs:
+ if has_cycle(op_control_input, path):
+ return True
+ path.remove(op.name)
+ return False
+ if has_cycle(initial_value.op, path=set()):
return initial_value
+ return self._safe_initial_value_from_tensor(initial_value, op_cache={})
+
+ def _safe_initial_value_from_tensor(self, tensor, op_cache):
+ """Replace dependencies on variables with their initialized values.
+
+ Args:
+ tensor: A `Tensor`. The tensor to replace.
+ op_cache: A dict mapping operation names to `Operation`s. Used to memoize
+ the results so as to avoid creating redundant operations.
+ Returns:
+ A `Tensor` compatible with `tensor`. Any inputs that lead to variable
+ values will be replaced with a corresponding graph that uses the
+ variable's initialized values. This is done on a best-effort basis. If no
+ modifications need to be made then `tensor` will be returned unchanged.
+ """
+ op = tensor.op
+ new_op = op_cache.get(op.name)
+ if new_op is None:
+ new_op = self._safe_initial_value_from_op(op, op_cache)
+ op_cache[op.name] = new_op
+ return new_op.outputs[tensor.value_index]
+
+ def _safe_initial_value_from_op(self, op, op_cache):
+ """Replace dependencies on variables with their initialized values.
+
+ Args:
+ op: An `Operation`. The operation to replace.
+ op_cache: A dict mapping operation names to `Operation`s. Used to memoize
+ the results so as to avoid creating redundant operations.
+ Returns:
+ An `Operation` compatible with `op`. Any inputs that lead to variable
+ values will be replaced with a corresponding graph that uses the
+ variable's initialized values. This is done on a best-effort basis. If no
+ modifications need to be made then `op` will be returned unchanged.
+ """
+ op_type = op.node_def.op
+ if op_type in ("IsVariableInitialized", "VarIsInitializedOp",
+ "ReadVariableOp"):
+ return op
+
+ # Attempt to find the initialized_value of any variable reference / handles.
+ # TODO(b/70206927): Fix handling of ResourceVariables.
+ if op_type in ("Variable", "VariableV2", "VarHandleOp"):
+ initialized_value = self._find_initialized_value_for_variable(op)
+ return op if initialized_value is None else initialized_value.op
+
+ # Recursively build initializer expressions for inputs.
+ modified = False
+ new_op_inputs = []
+ for op_input in op.inputs:
+ new_op_input = self._safe_initial_value_from_tensor(op_input, op_cache)
+ new_op_inputs.append(new_op_input)
+ modified = modified or (new_op_input != op_input)
+
+ # If at least one input was modified, replace the op.
+ if modified:
+ new_op_type = op_type
+ if new_op_type == "RefSwitch":
+ new_op_type = "Switch"
+ new_op_name = op.node_def.name + "_" + self.name
+ new_op_name = new_op_name.replace(":", "_")
+ return self.graph.create_op(
+ new_op_type, new_op_inputs,
+ op._output_types, # pylint: disable=protected-access
+ name=new_op_name, attrs=op.node_def.attr)
+
+ return op
+
def _find_initialized_value_for_variable(self, variable_op):
- """Find the initial value for a variable op.
+ """Find the initialized value for a variable op.
To do so, lookup the variable op in the variables collection.
Args:
- variable_op: a TensorFlow variable Operation
+ variable_op: A variable `Operation`.
Returns:
- The initial value for the variable.
+ A `Tensor` representing the initialized value for the variable or `None`
+ if the initialized value could not be found.
"""
try:
var_names = [variable_op.node_def.name, variable_op.node_def.name + ":0"]
- global_vars = self.graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- for var in global_vars:
- if var.name in var_names:
- return var.initialized_value()
- local_vars = self.graph.get_collection(ops.GraphKeys.LOCAL_VARIABLES)
- for var in local_vars:
- if var.name == var_names:
- return var.initialized_value()
+ for collection_name in (ops.GraphKeys.GLOBAL_VARIABLES,
+ ops.GraphKeys.LOCAL_VARIABLES):
+ for var in self.graph.get_collection(collection_name):
+ if var.name in var_names:
+ return var.initialized_value()
except AttributeError:
- # Return the variable itself when an incomplete user defined variable type
- # was put in the collection.
- return variable_op
- return variable_op
+ # Return None when an incomplete user-defined variable type was put in
+ # the collection.
+ return None
+ return None
# NOTE(mrry): This enables the Variable's overloaded "right" binary
# operators to run when the left operand is an ndarray, because it
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="x")
+ # TODO(b/70206927): Use ResourceVariables once they are handled properly.
v_res = variables.Variable(1, name="v_res")
w_res = variables.Variable(
v_res,
self.assertEquals(1, sess.run(w_res))
self.assertEquals(3, sess.run(x_res))
+ def testPrepareSessionWithCyclicInitializer(self):
+ # Regression test. Previously Variable._build_initializer_expr would enter
+ # into an infinite recursion when the variable's initial_value involved
+ # cyclic dependencies.
+ with ops.Graph().as_default():
+ i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0])
+ v = variables.Variable(array_ops.identity(i), name="v")
+ with self.test_session():
+ self.assertEqual(False, variables.is_variable_initialized(v).eval())
+ sm = session_manager.SessionManager(
+ ready_op=variables.report_uninitialized_variables())
+ sess = sm.prepare_session("", init_op=v.initializer)
+ self.assertEqual(1, sess.run(v))
+ self.assertEqual(
+ True,
+ variables.is_variable_initialized(
+ sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
+
def testPrepareSessionDidNotInitLocalVariable(self):
with ops.Graph().as_default():
v = variables.Variable(1, name="v")