Lazily evaluate shapes with the C API enabled.
authorSkye Wanderman-Milne <skyewm@google.com>
Fri, 6 Apr 2018 01:21:54 +0000 (18:21 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 01:24:30 +0000 (18:24 -0700)
This change makes it so shapes are computed only when requested with
_USE_C_API = True. Note that the C API will still raise a shape error
if necessary when the op is created.

In addition, it cleans up the logic for _USE_C_SHAPES = True. In this
case, we lazily fetch and cache shapes directly from the C API. We no
longer need set_shapes_for_outputs at all in this case.

PiperOrigin-RevId: 191830565

tensorflow/python/client/tf_session_helper.cc
tensorflow/python/client/tf_session_helper.h
tensorflow/python/framework/importer.py
tensorflow/python/framework/ops.py
tensorflow/python/framework/tensor_util.py
tensorflow/python/ops/resource_variable_ops.py

index b48d758..b6481e7 100644 (file)
@@ -629,15 +629,6 @@ void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
   TF_GraphSetTensorShape(graph, output, dims.data(), dims.size(), status);
 }
 
-std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph,
-                                                    TF_Output output,
-                                                    int num_dims,
-                                                    TF_Status* status) {
-  std::vector<int64_t> dims(num_dims);
-  TF_GraphGetTensorShape(graph, output, dims.data(), num_dims, status);
-  return dims;
-}
-
 std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
     TF_ImportGraphDefResults* results) {
   int num_missing_unused_input_mappings;
index d2b4abc..cfd27c2 100644 (file)
@@ -229,13 +229,6 @@ void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
                                     const std::vector<int64_t>& dims,
                                     bool unknown_shape, TF_Status* status);
 
-// Return the shape of output. `num_dims` should be the output of
-// TF_GraphGetTensorNumDims. If `num_dims = -1`, this should not be called.
-std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph,
-                                                    TF_Output output,
-                                                    int num_dims,
-                                                    TF_Status* status);
-
 // Returns the string representations of the missing unused input mappings.
 std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
     TF_ImportGraphDefResults* results);
index 8beb74d..3f8a8c4 100644 (file)
@@ -685,11 +685,10 @@ def import_graph_def(graph_def,
                      ', '.join(x.name for x in op._input_types))))
         # pylint: enable=protected-access
 
-        if not g._is_function(op.type):  # pylint: disable=protected-access
-          # Execute shape inference for this op.
-          # NOTE(mrry): If the graph contains a cycle, the full shape
-          # information may not be available for this op's inputs.
-          ops.set_shapes_for_outputs(op)
+        # Execute shape inference for this op.
+        # NOTE(mrry): If the graph contains a cycle, the full shape
+        # information may not be available for this op's inputs.
+        ops.set_shape_and_handle_data_for_outputs(op)
         # For nodes with _output_shapes set, set the output shapes.
         if '_output_shapes' in op.node_def.attr:
           for i, output in enumerate(op.outputs):
index 84366e2..2574fa5 100644 (file)
@@ -289,15 +289,26 @@ class Tensor(_TensorLike):
     self._op = op
     self._value_index = value_index
     self._dtype = dtypes.as_dtype(dtype)
-    self._shape_val = tensor_shape.unknown_shape()
+
+    if _USE_C_API:
+      # This will be set by set_shape_and_handle_data_for_outputs.
+      self._shape_val = None
+    else:
+      # The Python code requires all tensors start with a shape to support shape
+      # inference on imported while loops. This isn't necessary with the C API
+      # enabled because the C API provides the shapes for imported nodes.
+      # TODO(skyewm): remove when _USE_C_API is removed.
+      self._shape_val = tensor_shape.unknown_shape()
+
     # List of operations that use this Tensor as input.  We maintain this list
     # to easily navigate a computation graph.
     self._consumers = []
 
-    # Attributes used for C++ shape inference. Not inspected, only forwarded.
-    # If set, will be a HandleData object from cpp_shape_inference.proto.
-    # TODO(b/74620627): remove when _USE_C_SHAPES is removed
-    self._handle_data = None
+    if not _USE_C_SHAPES:
+      # Attributes used for C++ shape inference. Not inspected, only forwarded.
+      # If set, will be a HandleData object from cpp_shape_inference.proto.
+      self._handle_data = None
+
     self._id = uid()
 
   @property
@@ -371,18 +382,45 @@ class Tensor(_TensorLike):
       A `TensorShape` representing the shape of this tensor.
 
     """
-    graph = self._op._graph._c_graph # pylint: disable=protected-access
-    if graph and _USE_C_SHAPES:
-      num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output())
-      if num_dims == -1:
-        dim_list = None
+    if self._shape_val is None:
+      if _USE_C_SHAPES:
+        self._shape_val = self._c_api_shape()
       else:
-        dim_list = c_api.TF_GraphGetTensorShape_wrapper(
-            graph, self._as_tf_output(), num_dims)
-        dim_list = [None if i == -1 else i for i in dim_list]
-      return tensor_shape.TensorShape(dim_list)
+        assert _USE_C_API
+        # Call set_shape_and_handle_data_for_outputs in topological order on all
+        # ops that are needed to compute self.op's shape. We do this instead of
+        # having set_shape_and_handle_data_for_outputs recursively call
+        # Operation.shape on self.op.inputs to overflowing the call stack.
+        need_shapes = self._get_input_ops_without_shapes(self.op)
+        need_shapes.sort(key=lambda op: op._id)
+        for op in need_shapes:
+          set_shape_and_handle_data_for_outputs(op)
     return self._shape_val
 
+  def _get_input_ops_without_shapes(self, target_op):
+    """Returns ops needing shape inference to compute target_op's shape."""
+    result = []
+    stack = [self._op]
+    visited = set()
+    while stack:
+      op = stack.pop()
+      if op in visited: continue
+      result.append(op)
+      stack.extend(t.op for t in op.inputs if t._shape_val is None)
+      visited.add(op)
+    return result
+
+  def _c_api_shape(self):
+    """Returns the TensorShape of this tensor according to the C API."""
+    c_graph = self._op._graph._c_graph  # pylint: disable=protected-access
+    shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
+        c_graph, self._as_tf_output())
+    if unknown_shape:
+      return tensor_shape.unknown_shape()
+    else:
+      shape_vector = [None if d == -1 else d for d in shape_vector]
+      return tensor_shape.TensorShape(shape_vector)
+
   @property
   def _shape(self):
     logging.warning("Tensor._shape is private, use Tensor.shape "
@@ -466,8 +504,11 @@ class Tensor(_TensorLike):
       ValueError: If `shape` is not compatible with the current shape of
         this tensor.
     """
-    if not _USE_C_SHAPES:  # pylint: disable=protected-access
-      self._shape_val = self._shape_val.merge_with(shape)
+    if _USE_C_SHAPES:  # pylint: disable=protected-access
+      # Reset cached shape.
+      self._shape_val = None
+    else:
+      self._shape_val = self.shape.merge_with(shape)
 
     if not self._op._graph._c_graph: return
 
@@ -579,6 +620,16 @@ class Tensor(_TensorLike):
     # Necessary to support Python's collection membership operators
     return id(self) == id(other)
 
+  def __copy__(self):
+    # Make sure _shape_val is computed before we copy.
+    # TODO(b/77597810): get rid of Tensor copies.
+    if self._shape_val is None:
+      set_shape_and_handle_data_for_outputs(self.op)
+    cls = self.__class__
+    result = cls.__new__(cls)
+    result.__dict__.update(self.__dict__)
+    return result
+
   # NOTE(mrry): This enables the Tensor's overloaded "right" binary
   # operators to run when the left operand is an ndarray, because it
   # accords the Tensor class higher priority than an ndarray, or a
@@ -1932,6 +1983,13 @@ class Operation(object):
     if not isinstance(tensor, Tensor):
       raise TypeError("tensor must be a Tensor: %s" % tensor)
     _assert_same_graph(self, tensor)
+
+    # Make sure output shapes are already computed for this op in case we create
+    # a cycle (we cannot compute shapes for cycles). Usually shapes are computed
+    # lazily upon request.
+    if not _USE_C_SHAPES:
+      set_shape_and_handle_data_for_outputs(self)
+
     if self._c_op:
       # Reset cached inputs.
       self._inputs_val = None
@@ -2474,35 +2532,41 @@ class RegisterShape(object):
     return f
 
 
-def _set_shapes_for_outputs_c_api(op):
-  """set_shapes_for_outputs implementation when C API is enabled."""
-  # The C API computes the shapes when the TF_Operation is created. Fetch the
-  # output shapes from the C object.
+# TODO(b/74620627): remove when _USE_C_SHAPES is removed
+def _set_shape_and_handle_data_for_outputs_c_api(op):
+  """Set shapes and resource handle data using info from the C API."""
+  assert not _USE_C_SHAPES
   for output in op.outputs:
-    # pylint: disable=protected-access
-    shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
+    output._shape_val = output._c_api_shape()
+    # Set the resource handle data for compatibility with the Python shape
+    # inference code.
+    serialized = c_api.ResourceHandleShapeAndType(
         op._graph._c_graph, output._as_tf_output())
-    # pylint: enable=protected-access
-    if unknown_shape:
-      output.set_shape(tensor_shape.unknown_shape())
-    elif not shape_vector:
-      output.set_shape(tensor_shape.scalar())
-    else:
-      shape_vector = [None if d == -1 else d for d in shape_vector]
-      output.set_shape(tensor_shape.TensorShape(shape_vector))
-
-    serialized = c_api.ResourceHandleShapeAndType(op._graph._c_graph,
-                                                  output._as_tf_output())
     if serialized:
       output._handle_data = (
-          cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
-              compat.as_bytes(serialized)))
+          cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData
+          .FromString(compat.as_bytes(serialized)))
     else:
       output._handle_data = None
 
-# TODO(skyewm): remove this when _USE_C_API flag is removed.
-def _set_shapes_for_outputs(op):
-  """set_shapes_for_outputs implementation when C API is disabled."""
+
+# TODO(b/74620627): remove when _USE_C_SHAPES is removed
+def set_shape_and_handle_data_for_outputs(op):
+  """Set the shapes and resource handle data for op's outputs.
+
+  When _USE_C_API = True, this is lazily called when a tensor's shape is first
+  requested. Usually this should work automatically, but some edge cases may
+  require manaully calling this first to make sure Tensor._shape_val and
+  Tensor._handle_data are set (e.g. manually overriding _handle_data, copying a
+  Tensor).
+  """
+  if _USE_C_SHAPES: return
+
+  if op.graph._is_function(op.type):
+    for output in op.outputs:
+      output._shape_val = tensor_shape.unknown_shape()
+    return
+
   try:
     shape_func = _shape_registry.lookup(op.type)
   except LookupError:
@@ -2521,8 +2585,10 @@ def _set_shapes_for_outputs(op):
     shapes = shapes_dict["shapes"]
     handle_datas = shapes_dict["handle_data"]
     for output, handle_data in zip(op.outputs, handle_datas):
+      # Don't override any existing handle data that may have been manually set.
       # pylint: disable=protected-access
-      output._handle_data = handle_data
+      if output._handle_data is None:
+        output._handle_data = handle_data
       # pylint: enable=protected-access
 
   if len(op.outputs) != len(shapes):
@@ -2530,15 +2596,8 @@ def _set_shapes_for_outputs(op):
         "Shape function for op %s returned %d shapes but expected %d %s %s" %
         (op, len(shapes), len(op.outputs), shape_func.__name__, str(shapes)))
   for output, s in zip(op.outputs, shapes):
-    output.set_shape(s)
-
-
-def set_shapes_for_outputs(op):
-  """Set the shapes for op's outputs."""
-  if op._c_op and _USE_C_SHAPES:  # pylint: disable=protected-access
-    return _set_shapes_for_outputs_c_api(op)
-  else:
-    return _set_shapes_for_outputs(op)
+    output._shape_val = tensor_shape.unknown_shape()
+    output._shape_val = output._shape_val.merge_with(s)
 
 
 class OpStats(object):
@@ -3331,18 +3390,14 @@ class Graph(object):
           original_op=self._default_original_op,
           op_def=op_def)
 
-      # TODO(vrv): Instead of eagerly filling in shape property for every op,
-      # only populate the shape when requested.
+      # Note: shapes are lazily computed with the C API enabled.
       #
       # TODO(skyewm): unlike in the original Python implementation, the C API
       # always computes shape information (even for function calls, which the
       # original Python shape inference code doesn't handle). Deprecate the
       # compute_shapes argument.
-      #
-      # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES
-      # is removed
-      if (ret._c_op and _USE_C_SHAPES) or compute_shapes:  # pylint: disable=protected-access
-        set_shapes_for_outputs(ret)
+      if not _USE_C_API and compute_shapes:
+        set_shape_and_handle_data_for_outputs(ret)
 
       self._create_op_helper(ret, compute_shapes=compute_shapes,
                              compute_device=compute_device)
@@ -3484,18 +3539,17 @@ class Graph(object):
         for c_op in c_api_util.new_tf_operations(self)
     ]
 
+    # pylint: disable=protected-access
     for op in new_ops:
       # Operations created by the C API always retrieve shapes from the C API so
       # we preserve the shapes of ops created in import_graph_def (from the
       # "_output_shapes" attr of the imported NodeDef).
-      # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES
-      # is removed.
-      _set_shapes_for_outputs_c_api(op)
+      if not _USE_C_SHAPES:
+        _set_shape_and_handle_data_for_outputs_c_api(op)
       new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
-      # pylint: disable=protected-access
       op._add_control_inputs(new_control_inputs)
       op._control_flow_post_processing()
-      # pylint: enable=protected-access
+    # pylint: enable=protected-access
 
     return new_ops
 
index 984bcec..64b0fa6 100644 (file)
@@ -22,7 +22,6 @@ import six
 
 from tensorflow.core.framework import tensor_pb2
 from tensorflow.core.framework import tensor_shape_pb2
-from tensorflow.python.eager import context
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.util import compat
@@ -828,7 +827,7 @@ def constant_value_as_shape(tensor):  # pylint: disable=invalid-name
   Returns:
     A `TensorShape` based on the constant value of the given `tensor`.
   """
-  if context.executing_eagerly():
+  if isinstance(tensor, ops.EagerTensor):
     return tensor_shape.as_shape(
         [dim if dim != -1 else None for dim in tensor.numpy()])
 
index 07e25e5..508ba9b 100644 (file)
@@ -72,7 +72,12 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
     # know the shape and dtype of the variable pointed to by a handle. Since
     # shape inference doesn't run in eager mode we copy this data here for when
     # the handle is captured by an eager mode function.
-    handle._handle_data = h._handle_data  # pylint: disable=protected-access
+    # pylint: disable=protected-access
+    if h._handle_data is None:
+      ops.set_shape_and_handle_data_for_outputs(h.op)
+    handle._handle_data = h._handle_data
+    # pylint: enable=protected-access
+
   # Clean up our reference cycles to avoid making the garbage collector run.
   # pylint: disable=protected-access
   # OrderedDict, constructed on Graph creation, makes a simple reference loop