From: Skye Wanderman-Milne Date: Tue, 20 Mar 2018 00:34:47 +0000 (-0700) Subject: Make _USE_C_API = True and_USE_C_SHAPES = False work with import_graph_def. X-Git-Tag: tflite-v0.1.7~145^2^2~42 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=2714c07c93c2fd84480f816e0da44030a0a2bd45;p=platform%2Fupstream%2Ftensorflow.git Make _USE_C_API = True and_USE_C_SHAPES = False work with import_graph_def. Without this change, shapes wouldn't be correctly computed for operations created via import_graph_def. PiperOrigin-RevId: 189670312 --- diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 44ff440..6e2640e 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -62,8 +62,7 @@ from tensorflow.python.util import compat ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape) -# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False -# @test_util.with_c_api +@test_util.with_c_api class SessionTest(test_util.TensorFlowTestCase): def setUp(self): diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 783e925..a9e399f 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -489,23 +489,25 @@ def import_graph_def(graph_def, # Convert to ValueError for backwards compatibility. raise ValueError(str(e)) - _ProcessNewOps(graph) + # Create _DefinedFunctions for any imported functions. + # + # We do this by creating _DefinedFunctions directly from `graph_def`, and + # adding them to `graph`. Adding an existing function to a TF_Graph is a + # no-op, so this only has the effect of updating the Python state (usually + # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). + # + # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph + # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph + # TODO(b/74620627): move this after _ProcessNewOps outside the lock once + # _USE_C_SHAPES is removed. + if graph_def.library and graph_def.library.function: + # pylint: disable=protected-access + functions = function._from_library(graph_def.library) + for f in functions: + f.add_to_graph(graph) + # pylint: enable=protected-access - # Create _DefinedFunctions for any imported functions. - # - # We do this by creating _DefinedFunctions directly from `graph_def`, and - # adding them to `graph`. Adding an existing function to a TF_Graph is a - # no-op, so this only has the effect of updating the Python state (usually - # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). - # - # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph - # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph - if graph_def.library and graph_def.library.function: - # pylint: disable=protected-access - functions = function._from_library(graph_def.library) - for f in functions: - f.add_to_graph(graph) - # pylint: enable=protected-access + _ProcessNewOps(graph) # Treat input mappings that don't appear in the graph as an error, because # they are likely to be due to a typo. diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index c39191e..bf5d9fe 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.framework import test_ops # pylint: disable=unused-import +from tensorflow.python.framework import test_util from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -43,8 +44,7 @@ import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test -# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False -# @test_util.with_c_api +@test_util.with_c_api class ImportGraphDefTest(test.TestCase): def _MakeGraphDef(self, diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index 06cec50..21963d0 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -285,8 +285,7 @@ class SimpleMetaGraphTest(test.TestCase): self.assertIs(global_vars[0], trainable_vars[0]) -# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False -# @test_util.with_c_api +@test_util.with_c_api class ScopedMetaGraphTest(test.TestCase): def _testScopedExport(self, test_dir, exported_filenames): diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index f1cd341..4be2e2c 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -3303,6 +3303,20 @@ class Graph(object): input_types=input_types, original_op=self._default_original_op, op_def=op_def) + + # TODO(vrv): Instead of eagerly filling in shape property for every op, + # only populate the shape when requested. + # + # TODO(skyewm): unlike in the original Python implementation, the C API + # always computes shape information (even for function calls, which the + # original Python shape inference code doesn't handle). Deprecate the + # compute_shapes argument. + # + # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES + # is removed + if (ret._c_op and _USE_C_SHAPES) or compute_shapes: # pylint: disable=protected-access + set_shapes_for_outputs(ret) + self._create_op_helper(ret, compute_shapes=compute_shapes, compute_device=compute_device) return ret @@ -3336,15 +3350,6 @@ class Graph(object): def _create_op_helper(self, op, compute_shapes=True, compute_device=True): """Common logic for creating an op in this graph.""" - # TODO(vrv): Instead of eagerly filling in shape property for every op, only - # populate the shape when requested. - # - # TODO(skyewm): unlike in the original Python implementation, the C API - # always computes shape information (even for function calls, which the - # original Python shape inference code doesn't handle). Deprecate the - # compute_shapes argument. - if (op._c_op and _USE_C_SHAPES) or compute_shapes: # pylint: disable=protected-access - set_shapes_for_outputs(op) # TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed. self._add_op(op) @@ -3449,6 +3454,12 @@ class Graph(object): ] for op in new_ops: + # The Python shape inference code does not support imported functions. It + # also needs access to op.inputs, which is why we call it here. + # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES + # is removed. + if not self._is_function(op.type) or _USE_C_SHAPES: + set_shapes_for_outputs(op) new_control_inputs = self._control_dependencies_for_inputs(op.inputs) # pylint: disable=protected-access op._add_control_inputs(new_control_inputs) diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 787582a..7de778f 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -1739,8 +1739,7 @@ class CheckpointStateTest(test.TestCase): os.path.join(save_dir, "./model.ckpt-687529")) -# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False -# @test_util.with_c_api +@test_util.with_c_api class MetaGraphTest(test.TestCase): def _get_test_dir(self, dirname):