from __future__ import print_function
from tensorflow.core.framework import summary_pb2
+from tensorflow.python.framework import test_util
from tensorflow.python.summary.writer import writer
from tensorflow.python.summary.writer import writer_cache
if expected_added_graphs is not None:
test_case.assertEqual(expected_added_graphs, self._added_graphs)
if expected_added_meta_graphs is not None:
- test_case.assertEqual(expected_added_meta_graphs, self._added_meta_graphs)
+ test_case.assertEqual(len(expected_added_meta_graphs),
+ len(self._added_meta_graphs))
+ for expected, actual in zip(expected_added_meta_graphs,
+ self._added_meta_graphs):
+ test_util.assert_meta_graph_protos_equal(test_case, expected, actual)
if expected_session_logs is not None:
test_case.assertEqual(expected_session_logs, self._added_session_logs)
if values_def:
self._init_values_from_proto(values_def, import_scope=import_scope)
else:
- # Values that have been already seen in this context.
+ # The names of tensors that have been already seen in this context.
self._values = set()
- # Values referenced by but external to this context.
+ # The keys are the names of tensors referenced by but external to this
+ # context. Each value is the Tensor that should be used by this context to
+ # access the key value (e.g. a switch output guarding a cond input value).
self._external_values = {}
def _init_values_from_proto(self, values_def, import_scope=None):
self._pivot = pivot # The predicate tensor in this branch
self._branch = branch # 0 or 1 representing this branch
- # Values considered to have been already seen in this context.
+ # Values considered to have been already seen in this context. They are
+ # not included in this context.
self._values.add(pred.name)
+ self._external_values[pred.name] = pred
self._values.add(pivot.name)
+ self._external_values[pivot.name] = pivot
def _init_from_proto(self, context_def, import_scope=None):
"""Creates a new `CondContext` from protocol buffer.
self._branch = context_def.branch
super(CondContext, self).__init__(values_def=context_def.values_def,
import_scope=import_scope)
- # The predicate and pivot ops appear in self._values, but don't have self
- # set as their control context. The __init__ call above will set self for
- # all values, so manually override the predicate and pivot contexts here.
- # pylint: disable=protected-access
- self._pred.op._set_control_flow_context(self.outer_context)
- self._pivot.op._set_control_flow_context(self.outer_context)
- # pylint: enable=protected-access
@property
def pred(self):
if self._outer_context:
result = self._outer_context.AddValue(val)
self._values.add(result.name)
+ self._external_values[result.name] = result
with ops.control_dependencies(None):
result = _SwitchRefOrTensor(result, self._pred)[self._branch]
if self._outer_context:
if self._outer_context:
real_val = self._outer_context.AddValue(val)
self._values.add(real_val.name)
+ self._external_values[real_val.name] = real_val
real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch]
self._external_values[val.name] = real_val
else:
self._testGraphExtensionRestore(test_dir)
self._testRestoreFromTrainGraphWithControlContext(test_dir)
- def _testWhileLoopAndGradientSerDes(self, outer_body_fn):
- # Build a while loop with `outer_body_fn`, export it, and verify that it can
- # be imported and the gradient can be built and run correctly.
+ def _testGradientSerDes(self, graph_fn):
+ """Tests that gradients can be computed after exporting and importing.
+
+ Builds a graph, exports it, and verifies that it can be imported and the
+ gradient can be built and run correctly.
+ Args:
+ graph_fn: takes a single float Tensor argument as input, outputs a single
+ Tensor
+ """
test_dir = self._get_test_dir("nested_control_flow")
filename = os.path.join(test_dir, "metafile")
saver_ckpt = os.path.join(test_dir, "saver.ckpt")
# Create while loop using `outer_body_fn`.
with ops_lib.Graph().as_default():
- var = variables.Variable(0)
+ var = variables.Variable(0.0)
var_name = var.name
- _, output = control_flow_ops.while_loop(lambda i, x: i < 5, outer_body_fn,
- [0, var])
+ output = graph_fn(var)
output_name = output.name
init_op = variables.global_variables_initializer()
actual_grad_value = sess.run(grad)
self.assertEqual(expected_grad_value, actual_grad_value)
+ def _testWhileLoopAndGradientSerDes(self, outer_body_fn):
+ # Build a while loop with `outer_body_fn`, export it, and verify that it can
+ # be imported and the gradient can be built and run correctly.
+ # pylint: disable=g-long-lambda
+ return self._testGradientSerDes(
+ lambda x: control_flow_ops.while_loop(
+ lambda i, y: i < 5, outer_body_fn, [0, x])[1])
+ # pylint: enable=g-long-lambda
+
def testNestedWhileLoopsSerDes(self):
# Test two simple nested while loops.
def body(i, x):
_, r = control_flow_ops.while_loop(lambda j, y: j < 3,
lambda j, y: (j + 1, y + x),
- [0, 0])
+ [0, 0.0])
return i + 1, x + r
self._testWhileLoopAndGradientSerDes(body)
lambda: control_flow_ops.while_loop(
lambda j, y: j < 3,
lambda j, y: (j + 1, y + x),
- [0, 0])[1],
+ [0, 0.0])[1],
lambda: x)
return i + 1, cond_result
# pylint: enable=g-long-lambda
self._testWhileLoopAndGradientSerDes(body)
+ def testNestedCondsSerDes(self):
+ # Test conds in a cond.
+ # pylint: disable=g-long-lambda
+ self._testGradientSerDes(lambda x: control_flow_ops.cond(
+ x > 0,
+ lambda: control_flow_ops.cond(x > 3,
+ lambda: array_ops.identity(x),
+ lambda: math_ops.multiply(x, 2.0)),
+ lambda: control_flow_ops.cond(x < -3,
+ lambda: constant_op.constant(1.0),
+ lambda: math_ops.multiply(x, -1.0))))
+ # pylint: enable=g-long-lambda
+
def testStrippedOpListDef(self):
with self.test_session():
# Creates a graph.