Make _USE_C_API = True and_USE_C_SHAPES = False work with import_graph_def.
authorSkye Wanderman-Milne <skyewm@google.com>
Tue, 20 Mar 2018 00:34:47 +0000 (17:34 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 20 Mar 2018 00:39:00 +0000 (17:39 -0700)
Without this change, shapes wouldn't be correctly computed for
operations created via import_graph_def.

PiperOrigin-RevId: 189670312

tensorflow/python/client/session_test.py
tensorflow/python/framework/importer.py
tensorflow/python/framework/importer_test.py
tensorflow/python/framework/meta_graph_test.py
tensorflow/python/framework/ops.py
tensorflow/python/training/saver_test.py

index 44ff440..6e2640e 100644 (file)
@@ -62,8 +62,7 @@ from tensorflow.python.util import compat
 ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
 
 
-# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
-# @test_util.with_c_api
+@test_util.with_c_api
 class SessionTest(test_util.TensorFlowTestCase):
 
   def setUp(self):
index 783e925..a9e399f 100644 (file)
@@ -489,23 +489,25 @@ def import_graph_def(graph_def,
           # Convert to ValueError for backwards compatibility.
           raise ValueError(str(e))
 
-      _ProcessNewOps(graph)
+      # Create _DefinedFunctions for any imported functions.
+      #
+      # We do this by creating _DefinedFunctions directly from `graph_def`, and
+      # adding them to `graph`. Adding an existing function to a TF_Graph is a
+      # no-op, so this only has the effect of updating the Python state (usually
+      # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
+      #
+      # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
+      # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
+      # TODO(b/74620627): move this after _ProcessNewOps outside the lock once
+      # _USE_C_SHAPES is removed.
+      if graph_def.library and graph_def.library.function:
+        # pylint: disable=protected-access
+        functions = function._from_library(graph_def.library)
+        for f in functions:
+          f.add_to_graph(graph)
+        # pylint: enable=protected-access
 
-    # Create _DefinedFunctions for any imported functions.
-    #
-    # We do this by creating _DefinedFunctions directly from `graph_def`, and
-    # adding them to `graph`. Adding an existing function to a TF_Graph is a
-    # no-op, so this only has the effect of updating the Python state (usually
-    # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
-    #
-    # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
-    # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
-    if graph_def.library and graph_def.library.function:
-      # pylint: disable=protected-access
-      functions = function._from_library(graph_def.library)
-      for f in functions:
-        f.add_to_graph(graph)
-      # pylint: enable=protected-access
+      _ProcessNewOps(graph)
 
     # Treat input mappings that don't appear in the graph as an error, because
     # they are likely to be due to a typo.
index c39191e..bf5d9fe 100644 (file)
@@ -31,6 +31,7 @@ from tensorflow.python.framework import function
 from tensorflow.python.framework import importer
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_ops  # pylint: disable=unused-import
+from tensorflow.python.framework import test_util
 from tensorflow.python.framework import versions
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
@@ -43,8 +44,7 @@ import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
 from tensorflow.python.platform import test
 
 
-# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
-# @test_util.with_c_api
+@test_util.with_c_api
 class ImportGraphDefTest(test.TestCase):
 
   def _MakeGraphDef(self,
index 06cec50..21963d0 100644 (file)
@@ -285,8 +285,7 @@ class SimpleMetaGraphTest(test.TestCase):
       self.assertIs(global_vars[0], trainable_vars[0])
 
 
-# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
-# @test_util.with_c_api
+@test_util.with_c_api
 class ScopedMetaGraphTest(test.TestCase):
 
   def _testScopedExport(self, test_dir, exported_filenames):
index f1cd341..4be2e2c 100644 (file)
@@ -3303,6 +3303,20 @@ class Graph(object):
           input_types=input_types,
           original_op=self._default_original_op,
           op_def=op_def)
+
+      # TODO(vrv): Instead of eagerly filling in shape property for every op,
+      # only populate the shape when requested.
+      #
+      # TODO(skyewm): unlike in the original Python implementation, the C API
+      # always computes shape information (even for function calls, which the
+      # original Python shape inference code doesn't handle). Deprecate the
+      # compute_shapes argument.
+      #
+      # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES
+      # is removed
+      if (ret._c_op and _USE_C_SHAPES) or compute_shapes:  # pylint: disable=protected-access
+        set_shapes_for_outputs(ret)
+
       self._create_op_helper(ret, compute_shapes=compute_shapes,
                              compute_device=compute_device)
     return ret
@@ -3336,15 +3350,6 @@ class Graph(object):
 
   def _create_op_helper(self, op, compute_shapes=True, compute_device=True):
     """Common logic for creating an op in this graph."""
-    # TODO(vrv): Instead of eagerly filling in shape property for every op, only
-    # populate the shape when requested.
-    #
-    # TODO(skyewm): unlike in the original Python implementation, the C API
-    # always computes shape information (even for function calls, which the
-    # original Python shape inference code doesn't handle). Deprecate the
-    # compute_shapes argument.
-    if (op._c_op and _USE_C_SHAPES) or compute_shapes:  # pylint: disable=protected-access
-      set_shapes_for_outputs(op)
     # TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed.
     self._add_op(op)
 
@@ -3449,6 +3454,12 @@ class Graph(object):
     ]
 
     for op in new_ops:
+      # The Python shape inference code does not support imported functions. It
+      # also needs access to op.inputs, which is why we call it here.
+      # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES
+      # is removed.
+      if not self._is_function(op.type) or _USE_C_SHAPES:
+        set_shapes_for_outputs(op)
       new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
       # pylint: disable=protected-access
       op._add_control_inputs(new_control_inputs)
index 787582a..7de778f 100644 (file)
@@ -1739,8 +1739,7 @@ class CheckpointStateTest(test.TestCase):
                      os.path.join(save_dir, "./model.ckpt-687529"))
 
 
-# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
-# @test_util.with_c_api
+@test_util.with_c_api
 class MetaGraphTest(test.TestCase):
 
   def _get_test_dir(self, dirname):