Remove C API staging from importer.py.
authorSkye Wanderman-Milne <skyewm@google.com>
Thu, 17 May 2018 18:33:36 +0000 (11:33 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 17 May 2018 18:36:15 +0000 (11:36 -0700)
PiperOrigin-RevId: 197024708

tensorflow/python/framework/importer.py

index 5112bea..72eb7e0 100644 (file)
@@ -17,78 +17,21 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import collections
 import contextlib
-import copy
 
-from tensorflow.core.framework import attr_value_pb2
 from tensorflow.core.framework import graph_pb2
-from tensorflow.core.framework import types_pb2
 from tensorflow.python import pywrap_tensorflow as c_api
 from tensorflow.python.framework import c_api_util
 from tensorflow.python.framework import device as pydev
-from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
 from tensorflow.python.framework import function
 from tensorflow.python.framework import op_def_registry
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
 from tensorflow.python.util import compat
 from tensorflow.python.util.deprecation import deprecated_args
 from tensorflow.python.util.tf_export import tf_export
 
 
-# TODO(josh11b): SWIG the code from node_def_util instead of duplicating
-# the logic here.
-def _GetNodeAttr(node_def, attr_name):
-  if attr_name not in node_def.attr:
-    raise ValueError('Expected one attr with name %r in %s.' % (attr_name,
-                                                                str(node_def)))
-  return node_def.attr[attr_name]
-
-
-def _ArgToTypesNoRef(node_def, arg_def):
-  if arg_def.number_attr:
-    repeats = _GetNodeAttr(node_def, arg_def.number_attr).i
-    if arg_def.type_attr:
-      dtype = _GetNodeAttr(node_def, arg_def.type_attr).type
-    else:
-      assert arg_def.type != types_pb2.DT_INVALID
-      dtype = arg_def.type
-    return [dtype] * repeats
-  elif arg_def.type_attr:
-    return [_GetNodeAttr(node_def, arg_def.type_attr).type]
-  elif arg_def.type_list_attr:
-    return _GetNodeAttr(node_def, arg_def.type_list_attr).list.type
-  else:
-    assert arg_def.type != types_pb2.DT_INVALID
-    return [arg_def.type]
-
-
-def _SingleArgToTypes(node_def, arg_def):
-  types = _ArgToTypesNoRef(node_def, arg_def)
-  if arg_def.is_ref:
-    return [dtypes.as_dtype(dt)._as_ref.as_datatype_enum for dt in types]  # pylint: disable=protected-access
-  return types
-
-
-def _ArgsToTypes(node_def, arg_list):
-  types = []
-  for arg_def in arg_list:
-    types.extend(_SingleArgToTypes(node_def, arg_def))
-  return types
-
-
-def _InputTypes(node_def, op_dict):
-  op_def = op_dict[node_def.op]
-  return _ArgsToTypes(node_def, op_def.input_arg)
-
-
-def _OutputTypes(node_def, op_dict):
-  op_def = op_dict[node_def.op]
-  return _ArgsToTypes(node_def, op_def.output_arg)
-
-
 def _IsControlInput(input_name):
   # Expected format: '^operation_name' (control input).
   return input_name.startswith('^')
@@ -128,18 +71,6 @@ def _ParseTensorName(tensor_name):
     raise ValueError('Cannot convert %r to a tensor name.' % (tensor_name,))
 
 
-def _CanonicalInputName(input_name):
-  input_name = compat.as_str(input_name)
-  if _IsControlInput(input_name):
-    return input_name
-  input_op_name, output_index = _ParseTensorName(input_name)
-  return '%s:%d' % (input_op_name, output_index)
-
-
-def _InvalidNodeMessage(node, message):
-  return 'graph_def is invalid at node %r: %s.' % (node.name, message)
-
-
 @contextlib.contextmanager
 def _MaybeDevice(device):
   """Applies the given device only if device is not None or empty."""
@@ -460,351 +391,70 @@ def import_graph_def(graph_def,
     _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)
 
   graph = ops.get_default_graph()
-
-  if graph._c_graph:  # pylint: disable=protected-access
-    with ops.name_scope(name, 'import', input_map.values()) as scope:
-      # Save unique prefix generated by name_scope
-      if scope:
-        assert scope.endswith('/')
-        prefix = scope[:-1]
-      else:
-        prefix = ''
-
-      # Generate any input map tensors inside name scope
-      input_map = _ConvertInputMapValues(name, input_map)
-
-    scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
-    options = scoped_options.options
-    _PopulateTFImportGraphDefOptions(options, prefix, input_map,
-                                     return_elements)
-
-    # _ProcessNewOps mutates the new operations. _lock ensures a Session.run
-    # call cannot occur between creating the TF_Operations in the
-    # TF_GraphImportGraphDefWithResults call and mutating the them in
-    # _ProcessNewOps.
-    with graph._lock:  # pylint: disable=protected-access
-      with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
-        try:
-          results = c_api.TF_GraphImportGraphDefWithResults(
-              graph._c_graph, serialized, options)  # pylint: disable=protected-access
-          results = c_api_util.ScopedTFImportGraphDefResults(results)
-        except errors.InvalidArgumentError as e:
-          # Convert to ValueError for backwards compatibility.
-          raise ValueError(str(e))
-
-      # Create _DefinedFunctions for any imported functions.
-      #
-      # We do this by creating _DefinedFunctions directly from `graph_def`, and
-      # adding them to `graph`. Adding an existing function to a TF_Graph is a
-      # no-op, so this only has the effect of updating the Python state (usually
-      # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
-      #
-      # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
-      # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
-      # TODO(b/74620627): move this after _ProcessNewOps outside the lock once
-      # _USE_C_SHAPES is removed.
-      if graph_def.library and graph_def.library.function:
-        # pylint: disable=protected-access
-        functions = function._from_library(graph_def.library)
-        for f in functions:
-          f.add_to_graph(graph)
-        # pylint: enable=protected-access
-
-      _ProcessNewOps(graph)
-
-    # Treat input mappings that don't appear in the graph as an error, because
-    # they are likely to be due to a typo.
-    missing_unused_input_keys = (
-        c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
-            results.results))
-    if missing_unused_input_keys:
-      missing_unused_input_keys = [
-          compat.as_str(s) for s in missing_unused_input_keys
-      ]
-      raise ValueError(
-          'Attempted to map inputs that were not found in graph_def: [%s]' %
-          ', '.join(missing_unused_input_keys))
-
-    if return_elements is None:
-      return None
+  with ops.name_scope(name, 'import', input_map.values()) as scope:
+    # Save unique prefix generated by name_scope
+    if scope:
+      assert scope.endswith('/')
+      prefix = scope[:-1]
     else:
-      return _GatherReturnElements(return_elements, graph, results.results)
-
-  else:
-    g = graph
-
-    # Use a canonical representation for all tensor names.
-    input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
-    used_input_keys = set()
-    name_to_op = {}
-
-    # Add any functions defined in `graph_def` to `g`
+      prefix = ''
+
+    # Generate any input map tensors inside name scope
+    input_map = _ConvertInputMapValues(name, input_map)
+
+  scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
+  options = scoped_options.options
+  _PopulateTFImportGraphDefOptions(options, prefix, input_map,
+                                   return_elements)
+
+  # _ProcessNewOps mutates the new operations. _lock ensures a Session.run
+  # call cannot occur between creating the TF_Operations in the
+  # TF_GraphImportGraphDefWithResults call and mutating the them in
+  # _ProcessNewOps.
+  with graph._lock:  # pylint: disable=protected-access
+    with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
+      try:
+        results = c_api.TF_GraphImportGraphDefWithResults(
+            graph._c_graph, serialized, options)  # pylint: disable=protected-access
+        results = c_api_util.ScopedTFImportGraphDefResults(results)
+      except errors.InvalidArgumentError as e:
+        # Convert to ValueError for backwards compatibility.
+        raise ValueError(str(e))
+
+    # Create _DefinedFunctions for any imported functions.
+    #
+    # We do this by creating _DefinedFunctions directly from `graph_def`, and
+    # adding them to `graph`. Adding an existing function to a TF_Graph is a
+    # no-op, so this only has the effect of updating the Python state (usually
+    # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
+    #
+    # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
+    # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
+    # TODO(b/74620627): move this after _ProcessNewOps outside the lock once
+    # _USE_C_SHAPES is removed.
     if graph_def.library and graph_def.library.function:
-      # Copy op_dict so we don't clobber the original
-      op_dict = copy.copy(op_dict)
       # pylint: disable=protected-access
-      # Note that we do not prepend `name` to the function name. The reasoning
-      # is that function names are similar to op definition names, which
-      # currently do not have a scoped name or namespace scheme.
       functions = function._from_library(graph_def.library)
       for f in functions:
-        f.add_to_graph(g)
-        op_dict[f.name] = f.definition.signature
+        f.add_to_graph(graph)
       # pylint: enable=protected-access
 
-    # LINT.IfChange
-    with ops.name_scope(name, 'import', input_map.values()) as scope:
-      # TODO(ashankar): Should this just copy over or should it do some
-      # more nuanced merging? For example, the graph may already have some
-      # marked "bad versions" and we don't want to lose those because of
-      # what's in graph_def.versions? The C++ ImporGraphDef does something
-      # more nuanced.
-      g.graph_def_versions.CopyFrom(graph_def.versions)
-
-      input_map = _ConvertInputMapValues(name, input_map)
-
-      # NOTE(mrry): We do this in two passes, because there may be a cycle in
-      # `graph_def`.
-
-      # 1. Add operations without their inputs.
-      for node in graph_def.node:
-        # Check to see if this op's name matches a previously seen op
-        if node.name in name_to_op:
-          raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name)
-        if node.op not in op_dict:
-          raise ValueError(
-              'No op named %s in defined operations. If the Graph you are '
-              'importing uses custom ops or any parts of tf.contrib, you '
-              'should explicitly import the libraries defining those ops '
-              'before loading the Graph. Note that tf.contrib is lazily loaded '
-              'when accessed, so simply referencing (e.g.) '
-              '`tf.contrib.resampler` will cause those ops to be made '
-              'available.' % node.op)
-        op_def = op_dict[node.op]
-
-        output_types = _OutputTypes(node, op_dict)
-        name_to_op[node.name] = g.create_op(
-            node.op, [], output_types, name=node.name, attrs=node.attr,
-            compute_shapes=False, compute_device=False,
-            op_def=op_def)
-
-      # Maps from a node to the ops it is colocated with, if colocation
-      # is specified in the attributes.
-      colocation_pairs = collections.defaultdict(list)
-
-      # 2. Add inputs to the operations.
-      for node in graph_def.node:
-        op = name_to_op[node.name]
-        input_types = _InputTypes(node, op_dict)
-        apply_device_function = True
-
-        # Rewrite the colocation attributes in the graph, since the
-        # names of new ops may have changed.
-        for key, value in op.node_def.attr.items():
-          if key == '_class':
-            class_values = value.list
-            new_class_values = []
-            for class_value in class_values.s:
-              if class_value.startswith(b'loc:@'):
-                op_to_bind_to = class_value[5:].decode()
-                # Find the op by its original name.
-                if op_to_bind_to not in name_to_op:
-                  raise ValueError('Specified colocation to an op that '
-                                   'does not exist during import: %s in %s' % (
-                                       op_to_bind_to, node.name))
-                original_op = name_to_op[op_to_bind_to]
-                new_class_values.append(compat.as_bytes(
-                    'loc:@' + original_op.name))
-                if op_to_bind_to != node.name:
-                  # Keep track of this mapping for a later phase.
-                  colocation_pairs[op].append(original_op)
-                  # Don't apply this op's device function,
-                  # the colocation constraint will ensure
-                  # the proper device gets assigned at runtime.
-                  apply_device_function = False
-
-              else:
-                new_class_values.append(class_value)
-            value.list.CopyFrom(attr_value_pb2.AttrValue.ListValue(
-                s=new_class_values))
-
-        # NOTE(mrry): We cannot use zip here because control inputs do not
-        # appear in the list of input_types.
-        for i, input_name in enumerate(
-            [_CanonicalInputName(x) for x in node.input]):
-
-          if _IsControlInput(input_name):
-            # (a) Input is a control input that should be taken from an op
-            #     in "graph_def".
-            try:
-              source_op = name_to_op[input_name[1:]]
-            except KeyError:
-              raise ValueError(
-                  _InvalidNodeMessage(
-                      node,
-                      'Control input %r not found in graph_def.'
-                      % (input_name,)))
-            # pylint: disable=protected-access
-            op._add_control_input(source_op)
-            # pylint: enable=protected-access
-
-          else:
-            try:
-              input_type = input_types[i]
-            except IndexError:
-              raise ValueError(_InvalidNodeMessage(
-                  node, 'More inputs specified (%r) than the op expects.'
-                  % (input_name,)))
-
-            if input_name in input_map:
-              # (b) Input should be replaced by a tensor from the caller.
-              source_tensor = input_map[input_name]
-              used_input_keys.add(input_name)
-
-            else:
-              # (c) Input should be taken from an op in `graph_def`.
-              operation_name, output_index = _ParseTensorName(input_name)
-              try:
-                source_op = name_to_op[operation_name]
-                source_tensor = list(source_op.values())[output_index]
-              except (KeyError, IndexError):
-                raise ValueError(
-                    _InvalidNodeMessage(
-                        node,
-                        'Input tensor %r not found in graph_def.'
-                        % (input_name,)))
-
-            try:
-              # pylint: disable=protected-access
-              op._add_input(source_tensor, dtype=input_type)
-              # pylint: enable=protected-access
-            except TypeError as te:
-              raise ValueError(_InvalidNodeMessage(
-                  node, 'Input tensor %r %s' % (input_name, te)))
-
-        # pylint: disable=protected-access
-        if op._input_types != input_types:
-          raise ValueError(
-              _InvalidNodeMessage(
-                  node,
-                  'Input types mismatch (expected %r but got %r)'
-                  % (', '.join(dtypes.as_dtype(x).name for x in input_types),
-                     ', '.join(x.name for x in op._input_types))))
-        # pylint: enable=protected-access
-
-        # Execute shape inference for this op.
-        # NOTE(mrry): If the graph contains a cycle, the full shape
-        # information may not be available for this op's inputs.
-        ops.set_shape_and_handle_data_for_outputs(op)
-        # For nodes with _output_shapes set, set the output shapes.
-        if '_output_shapes' in op.node_def.attr:
-          for i, output in enumerate(op.outputs):
-            dims = op.node_def.attr['_output_shapes'].list.shape[i]
-            output_shape = tensor_shape.TensorShape(
-                None if dims.unknown_rank else
-                [dim.size if dim.size >= 0 else None for dim in dims.dim])
-
-            try:
-              output.set_shape(output_shape)
-            except ValueError as e:
-              # If the output shape is incompatible with what is inferred
-              # by the graph for a very specific whitelist of ops, then we
-              # ignore this output shape.  This can happen if there is a
-              # bug in the shape function for some operation, and the
-              # serialized graph def has the incorrect shape set when
-              # running on a newer binary with the fixed shape function.
-              # This is an escape hatch that allows us to correct shape
-              # functions that are not critical to correct execution but
-              # would cause graphs to fail if imported after correcting.
-              #
-              # This can be removed after 2017/03/08.
-              if op.type in ['RandomShuffleQueue', 'PaddingFIFOQueue',
-                             'FIFOQueue', 'PriorityQueue', 'QueueSize',
-                             'Stack', 'Barrier', 'BarrierReadySize',
-                             'BarrierIncompleteSize', 'HashTable',
-                             'MutableHashTable',
-                             'MutableHashTableOfTensors', 'Mutex',
-                             'CuckooTable', 'IndexTable',
-                             'WholeFileReader', 'TextLineReader',
-                             'FixedLengthRecordReader',
-                             'TFRecordReader', 'IdentityReader',
-                             'LMDBReader',
-                             'RefSwitch', 'RefEnter', 'RefNextIteration',
-                             'RefMerge', 'RefIdentity']:
-                pass
-              elif op.type in [
-                  'ConditionalAccumulator', 'SparseConditionalAccumulator',
-                  'Table'
-              ]:
-                # This can be removed after 2017/04/24.
-                pass
-              else:
-                raise e
-
-          del op.node_def.attr['_output_shapes']
-
-        # NOTE(mrry): We do this after configuring the inputs, because
-        # the result of the device functions may depend on the inputs.
-        if apply_device_function:
-          with _MaybeDevice(node.device):
-            g._apply_device_functions(op)  # pylint: disable=protected-access
-
-      # The following loop populates the device field of ops that are
-      # colocated with another op.  This is implied by the colocation
-      # attribute, but we propagate the device field for completeness.
-      for op, coloc_op_list in colocation_pairs.items():
-        coloc_device = None
-        # Find any device in the list of colocated ops that have a
-        # device, if it exists.  We assume that if multiple ops
-        # have devices, they refer to the same device.  Otherwise, a
-        # runtime error will occur since the colocation property
-        # cannot be guaranteed.
-        #
-        # One possible improvement is to try to check for compatibility
-        # of all devices in this list at import time here, which would
-        # require implementing a compatibility function for device specs
-        # in python.
-        for coloc_op in coloc_op_list:
-          if coloc_op.device:
-            coloc_device = pydev.DeviceSpec.from_string(coloc_op.device)
-            break
-        if coloc_device:
-          op._set_device(coloc_device)  # pylint: disable=protected-access
-
-      # Treat input mappings that don't appear in the graph as an error,
-      # because they are likely to be due to a typo.
-      def _IsImportedNodeOutput(tensor_name):
-        operation_name, output_index = _ParseTensorName(tensor_name)
-        try:
-          return output_index < len(name_to_op[operation_name].outputs)
-        except KeyError:
-          return False
-      absent_input_keys = [
-          k for k in frozenset(input_map.keys()).difference(used_input_keys)
-          if not _IsImportedNodeOutput(k)]
-      if absent_input_keys:
-        raise ValueError(
-            'Attempted to map inputs that were not found in graph_def: [%s]'
-            % ', '.join(absent_input_keys))
-
-      if return_elements is None:
-        return None
-      else:
-        ret = []
-        for name in return_elements:
-          name = compat.as_str(name)
-          if ':' in name:
-            try:
-              operation_name, output_index = _ParseTensorName(name)
-              ret.append(name_to_op[operation_name].outputs[output_index])
-            except (ValueError, KeyError, IndexError):
-              raise ValueError(
-                  'Requested return_element %r not found in graph_def.' % name)
-          else:
-            try:
-              ret.append(name_to_op[name])
-            except KeyError:
-              raise ValueError(
-                  'Requested return_element %r not found in graph_def.' % name)
-        return ret
-    # LINT.ThenChange(//tensorflow/core/graph/graph_constructor.cc)
+    _ProcessNewOps(graph)
+
+  # Treat input mappings that don't appear in the graph as an error, because
+  # they are likely to be due to a typo.
+  missing_unused_input_keys = (
+      c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
+          results.results))
+  if missing_unused_input_keys:
+    missing_unused_input_keys = [
+        compat.as_str(s) for s in missing_unused_input_keys
+    ]
+    raise ValueError(
+        'Attempted to map inputs that were not found in graph_def: [%s]' %
+        ', '.join(missing_unused_input_keys))
+
+  if return_elements is None:
+    return None
+  else:
+    return _GatherReturnElements(return_elements, graph, results.results)