Fix bug in importing MetaGraphDefs containing nested conds.
authorSkye Wanderman-Milne <skyewm@google.com>
Tue, 6 Mar 2018 22:43:10 +0000 (14:43 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Mar 2018 22:48:47 +0000 (14:48 -0800)
This change makes CondContext._external_values more consistently store
Tensors external this context. These values are then not added to the
context when it's imported. This also removes the workaround I added
earlier to manually remove the predicate and pivot Tensors from the
context, instead adding them to _external_values were they're
automatically excluded.

PiperOrigin-RevId: 188083780

tensorflow/contrib/testing/python/framework/fake_summary_writer.py
tensorflow/python/ops/control_flow_ops.py
tensorflow/python/training/saver_test.py

index f2065c6..15a415d 100644 (file)
@@ -18,6 +18,7 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.core.framework import summary_pb2
+from tensorflow.python.framework import test_util
 from tensorflow.python.summary.writer import writer
 from tensorflow.python.summary.writer import writer_cache
 
@@ -85,7 +86,11 @@ class FakeSummaryWriter(object):
     if expected_added_graphs is not None:
       test_case.assertEqual(expected_added_graphs, self._added_graphs)
     if expected_added_meta_graphs is not None:
-      test_case.assertEqual(expected_added_meta_graphs, self._added_meta_graphs)
+      test_case.assertEqual(len(expected_added_meta_graphs),
+                            len(self._added_meta_graphs))
+      for expected, actual in zip(expected_added_meta_graphs,
+                                  self._added_meta_graphs):
+        test_util.assert_meta_graph_protos_equal(test_case, expected, actual)
     if expected_session_logs is not None:
       test_case.assertEqual(expected_session_logs, self._added_session_logs)
 
index 689f7cd..1fa25a0 100644 (file)
@@ -1499,9 +1499,11 @@ class ControlFlowContext(object):
     if values_def:
       self._init_values_from_proto(values_def, import_scope=import_scope)
     else:
-      # Values that have been already seen in this context.
+      # The names of tensors that have been already seen in this context.
       self._values = set()
-      # Values referenced by but external to this context.
+      # The keys are the names of tensors referenced by but external to this
+      # context. Each value is the Tensor that should be used by this context to
+      # access the key value (e.g. a switch output guarding a cond input value).
       self._external_values = {}
 
   def _init_values_from_proto(self, values_def, import_scope=None):
@@ -1688,9 +1690,12 @@ class CondContext(ControlFlowContext):
       self._pivot = pivot  # The predicate tensor in this branch
       self._branch = branch  # 0 or 1 representing this branch
 
-      # Values considered to have been already seen in this context.
+      # Values considered to have been already seen in this context. They are
+      # not included in this context.
       self._values.add(pred.name)
+      self._external_values[pred.name] = pred
       self._values.add(pivot.name)
+      self._external_values[pivot.name] = pivot
 
   def _init_from_proto(self, context_def, import_scope=None):
     """Creates a new `CondContext` from protocol buffer.
@@ -1710,13 +1715,6 @@ class CondContext(ControlFlowContext):
     self._branch = context_def.branch
     super(CondContext, self).__init__(values_def=context_def.values_def,
                                       import_scope=import_scope)
-    # The predicate and pivot ops appear in self._values, but don't have self
-    # set as their control context. The __init__ call above will set self for
-    # all values, so manually override the predicate and pivot contexts here.
-    # pylint: disable=protected-access
-    self._pred.op._set_control_flow_context(self.outer_context)
-    self._pivot.op._set_control_flow_context(self.outer_context)
-    # pylint: enable=protected-access
 
   @property
   def pred(self):
@@ -1800,6 +1798,7 @@ class CondContext(ControlFlowContext):
       if self._outer_context:
         result = self._outer_context.AddValue(val)
         self._values.add(result.name)
+        self._external_values[result.name] = result
       with ops.control_dependencies(None):
         result = _SwitchRefOrTensor(result, self._pred)[self._branch]
         if self._outer_context:
@@ -1864,6 +1863,7 @@ class CondContext(ControlFlowContext):
       if self._outer_context:
         real_val = self._outer_context.AddValue(val)
         self._values.add(real_val.name)
+        self._external_values[real_val.name] = real_val
       real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch]
       self._external_values[val.name] = real_val
     else:
index 7947765..4fd3b58 100644 (file)
@@ -2059,20 +2059,25 @@ class MetaGraphTest(test.TestCase):
     self._testGraphExtensionRestore(test_dir)
     self._testRestoreFromTrainGraphWithControlContext(test_dir)
 
-  def _testWhileLoopAndGradientSerDes(self, outer_body_fn):
-    # Build a while loop with `outer_body_fn`, export it, and verify that it can
-    # be imported and the gradient can be built and run correctly.
+  def _testGradientSerDes(self, graph_fn):
+    """Tests that gradients can be computed after exporting and importing.
+
+    Builds a graph, exports it, and verifies that it can be imported and the
+    gradient can be built and run correctly.
 
+    Args:
+      graph_fn: takes a single float Tensor argument as input, outputs a single
+        Tensor
+    """
     test_dir = self._get_test_dir("nested_control_flow")
     filename = os.path.join(test_dir, "metafile")
     saver_ckpt = os.path.join(test_dir, "saver.ckpt")
 
     # Create while loop using `outer_body_fn`.
     with ops_lib.Graph().as_default():
-      var = variables.Variable(0)
+      var = variables.Variable(0.0)
       var_name = var.name
-      _, output = control_flow_ops.while_loop(lambda i, x: i < 5, outer_body_fn,
-                                              [0, var])
+      output = graph_fn(var)
       output_name = output.name
       init_op = variables.global_variables_initializer()
 
@@ -2109,12 +2114,21 @@ class MetaGraphTest(test.TestCase):
         actual_grad_value = sess.run(grad)
         self.assertEqual(expected_grad_value, actual_grad_value)
 
+  def _testWhileLoopAndGradientSerDes(self, outer_body_fn):
+    # Build a while loop with `outer_body_fn`, export it, and verify that it can
+    # be imported and the gradient can be built and run correctly.
+    # pylint: disable=g-long-lambda
+    return self._testGradientSerDes(
+        lambda x: control_flow_ops.while_loop(
+            lambda i, y: i < 5, outer_body_fn, [0, x])[1])
+    # pylint: enable=g-long-lambda
+
   def testNestedWhileLoopsSerDes(self):
     # Test two simple nested while loops.
     def body(i, x):
       _, r = control_flow_ops.while_loop(lambda j, y: j < 3,
                                          lambda j, y: (j + 1, y + x),
-                                         [0, 0])
+                                         [0, 0.0])
       return i + 1, x + r
     self._testWhileLoopAndGradientSerDes(body)
 
@@ -2127,12 +2141,25 @@ class MetaGraphTest(test.TestCase):
           lambda: control_flow_ops.while_loop(
               lambda j, y: j < 3,
               lambda j, y: (j + 1, y + x),
-              [0, 0])[1],
+              [0, 0.0])[1],
           lambda: x)
       return i + 1, cond_result
     # pylint: enable=g-long-lambda
     self._testWhileLoopAndGradientSerDes(body)
 
+  def testNestedCondsSerDes(self):
+    # Test conds in a cond.
+    # pylint: disable=g-long-lambda
+    self._testGradientSerDes(lambda x: control_flow_ops.cond(
+        x > 0,
+        lambda: control_flow_ops.cond(x > 3,
+                                      lambda: array_ops.identity(x),
+                                      lambda: math_ops.multiply(x, 2.0)),
+        lambda: control_flow_ops.cond(x < -3,
+                                      lambda: constant_op.constant(1.0),
+                                      lambda: math_ops.multiply(x, -1.0))))
+    # pylint: enable=g-long-lambda
+
   def testStrippedOpListDef(self):
     with self.test_session():
       # Creates a graph.