ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape)
-# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
-# @test_util.with_c_api
+@test_util.with_c_api
class SessionTest(test_util.TensorFlowTestCase):
def setUp(self):
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
- _ProcessNewOps(graph)
+ # Create _DefinedFunctions for any imported functions.
+ #
+ # We do this by creating _DefinedFunctions directly from `graph_def`, and
+ # adding them to `graph`. Adding an existing function to a TF_Graph is a
+ # no-op, so this only has the effect of updating the Python state (usually
+ # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
+ #
+ # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
+ # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
+ # TODO(b/74620627): move this after _ProcessNewOps outside the lock once
+ # _USE_C_SHAPES is removed.
+ if graph_def.library and graph_def.library.function:
+ # pylint: disable=protected-access
+ functions = function._from_library(graph_def.library)
+ for f in functions:
+ f.add_to_graph(graph)
+ # pylint: enable=protected-access
- # Create _DefinedFunctions for any imported functions.
- #
- # We do this by creating _DefinedFunctions directly from `graph_def`, and
- # adding them to `graph`. Adding an existing function to a TF_Graph is a
- # no-op, so this only has the effect of updating the Python state (usually
- # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
- #
- # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
- # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
- if graph_def.library and graph_def.library.function:
- # pylint: disable=protected-access
- functions = function._from_library(graph_def.library)
- for f in functions:
- f.add_to_graph(graph)
- # pylint: enable=protected-access
+ _ProcessNewOps(graph)
# Treat input mappings that don't appear in the graph as an error, because
# they are likely to be due to a typo.
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_ops # pylint: disable=unused-import
+from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
-# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
-# @test_util.with_c_api
+@test_util.with_c_api
class ImportGraphDefTest(test.TestCase):
def _MakeGraphDef(self,
self.assertIs(global_vars[0], trainable_vars[0])
-# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
-# @test_util.with_c_api
+@test_util.with_c_api
class ScopedMetaGraphTest(test.TestCase):
def _testScopedExport(self, test_dir, exported_filenames):
input_types=input_types,
original_op=self._default_original_op,
op_def=op_def)
+
+ # TODO(vrv): Instead of eagerly filling in shape property for every op,
+ # only populate the shape when requested.
+ #
+ # TODO(skyewm): unlike in the original Python implementation, the C API
+ # always computes shape information (even for function calls, which the
+ # original Python shape inference code doesn't handle). Deprecate the
+ # compute_shapes argument.
+ #
+ # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES
+ # is removed
+ if (ret._c_op and _USE_C_SHAPES) or compute_shapes: # pylint: disable=protected-access
+ set_shapes_for_outputs(ret)
+
self._create_op_helper(ret, compute_shapes=compute_shapes,
compute_device=compute_device)
return ret
def _create_op_helper(self, op, compute_shapes=True, compute_device=True):
"""Common logic for creating an op in this graph."""
- # TODO(vrv): Instead of eagerly filling in shape property for every op, only
- # populate the shape when requested.
- #
- # TODO(skyewm): unlike in the original Python implementation, the C API
- # always computes shape information (even for function calls, which the
- # original Python shape inference code doesn't handle). Deprecate the
- # compute_shapes argument.
- if (op._c_op and _USE_C_SHAPES) or compute_shapes: # pylint: disable=protected-access
- set_shapes_for_outputs(op)
# TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed.
self._add_op(op)
]
for op in new_ops:
+ # The Python shape inference code does not support imported functions. It
+ # also needs access to op.inputs, which is why we call it here.
+ # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES
+ # is removed.
+ if not self._is_function(op.type) or _USE_C_SHAPES:
+ set_shapes_for_outputs(op)
new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
# pylint: disable=protected-access
op._add_control_inputs(new_control_inputs)
os.path.join(save_dir, "./model.ckpt-687529"))
-# TODO(skyewm): reenable when this works with _USE_C_SHAPES=False
-# @test_util.with_c_api
+@test_util.with_c_api
class MetaGraphTest(test.TestCase):
def _get_test_dir(self, dirname):