Fetch C shapes for ops created by import_graph_def with C API enabled.
authorSkye Wanderman-Milne <skyewm@google.com>
Thu, 22 Mar 2018 21:09:59 +0000 (14:09 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Mar 2018 21:12:36 +0000 (14:12 -0700)
If _USE_C_API = True, this change makes us always fetch shapes using
the C API after calling TF_ImportGraphDef, even if _USE_C_SHAPES =
False. This is necessary to preserve the shapes specified by the
"_output_shapes" attr on imported NodeDefs (note that this attr isn't
present on the NodeDefs of the imported nodes, so there's no other way
to recover this information after calling TF_ImportGraphDef).

PiperOrigin-RevId: 190122991

tensorflow/python/framework/ops.py

index de222e1..93edaa0 100644 (file)
@@ -3455,12 +3455,12 @@ class Graph(object):
     ]
 
     for op in new_ops:
-      # The Python shape inference code does not support imported functions. It
-      # also needs access to op.inputs, which is why we call it here.
+      # Operations created by the C API always retrieve shapes from the C API so
+      # we preserve the shapes of ops created in import_graph_def (from the
+      # "_output_shapes" attr of the imported NodeDef).
       # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES
       # is removed.
-      if not self._is_function(op.type) or _USE_C_SHAPES:
-        set_shapes_for_outputs(op)
+      _set_shapes_for_outputs_c_api(op)
       new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
       # pylint: disable=protected-access
       op._add_control_inputs(new_control_inputs)