From: Skye Wanderman-Milne Date: Fri, 6 Apr 2018 01:21:54 +0000 (-0700) Subject: Lazily evaluate shapes with the C API enabled. X-Git-Tag: tflite-v0.1.7~16^2^2~126 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9fc9f19428e497f3a297538059804f69996a612e;p=platform%2Fupstream%2Ftensorflow.git Lazily evaluate shapes with the C API enabled. This change makes it so shapes are computed only when requested with _USE_C_API = True. Note that the C API will still raise a shape error if necessary when the op is created. In addition, it cleans up the logic for _USE_C_SHAPES = True. In this case, we lazily fetch and cache shapes directly from the C API. We no longer need set_shapes_for_outputs at all in this case. PiperOrigin-RevId: 191830565 --- diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index b48d758..b6481e7 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -629,15 +629,6 @@ void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output, TF_GraphSetTensorShape(graph, output, dims.data(), dims.size(), status); } -std::vector TF_GraphGetTensorShape_wrapper(TF_Graph* graph, - TF_Output output, - int num_dims, - TF_Status* status) { - std::vector dims(num_dims); - TF_GraphGetTensorShape(graph, output, dims.data(), num_dims, status); - return dims; -} - std::vector TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( TF_ImportGraphDefResults* results) { int num_missing_unused_input_mappings; diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index d2b4abc..cfd27c2 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -229,13 +229,6 @@ void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output, const std::vector& dims, bool unknown_shape, TF_Status* status); -// Return the shape of output. `num_dims` should be the output of -// TF_GraphGetTensorNumDims. If `num_dims = -1`, this should not be called. -std::vector TF_GraphGetTensorShape_wrapper(TF_Graph* graph, - TF_Output output, - int num_dims, - TF_Status* status); - // Returns the string representations of the missing unused input mappings. std::vector TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( TF_ImportGraphDefResults* results); diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 8beb74d..3f8a8c4 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -685,11 +685,10 @@ def import_graph_def(graph_def, ', '.join(x.name for x in op._input_types)))) # pylint: enable=protected-access - if not g._is_function(op.type): # pylint: disable=protected-access - # Execute shape inference for this op. - # NOTE(mrry): If the graph contains a cycle, the full shape - # information may not be available for this op's inputs. - ops.set_shapes_for_outputs(op) + # Execute shape inference for this op. + # NOTE(mrry): If the graph contains a cycle, the full shape + # information may not be available for this op's inputs. + ops.set_shape_and_handle_data_for_outputs(op) # For nodes with _output_shapes set, set the output shapes. if '_output_shapes' in op.node_def.attr: for i, output in enumerate(op.outputs): diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 84366e2..2574fa5 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -289,15 +289,26 @@ class Tensor(_TensorLike): self._op = op self._value_index = value_index self._dtype = dtypes.as_dtype(dtype) - self._shape_val = tensor_shape.unknown_shape() + + if _USE_C_API: + # This will be set by set_shape_and_handle_data_for_outputs. + self._shape_val = None + else: + # The Python code requires all tensors start with a shape to support shape + # inference on imported while loops. This isn't necessary with the C API + # enabled because the C API provides the shapes for imported nodes. + # TODO(skyewm): remove when _USE_C_API is removed. + self._shape_val = tensor_shape.unknown_shape() + # List of operations that use this Tensor as input. We maintain this list # to easily navigate a computation graph. self._consumers = [] - # Attributes used for C++ shape inference. Not inspected, only forwarded. - # If set, will be a HandleData object from cpp_shape_inference.proto. - # TODO(b/74620627): remove when _USE_C_SHAPES is removed - self._handle_data = None + if not _USE_C_SHAPES: + # Attributes used for C++ shape inference. Not inspected, only forwarded. + # If set, will be a HandleData object from cpp_shape_inference.proto. + self._handle_data = None + self._id = uid() @property @@ -371,18 +382,45 @@ class Tensor(_TensorLike): A `TensorShape` representing the shape of this tensor. """ - graph = self._op._graph._c_graph # pylint: disable=protected-access - if graph and _USE_C_SHAPES: - num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output()) - if num_dims == -1: - dim_list = None + if self._shape_val is None: + if _USE_C_SHAPES: + self._shape_val = self._c_api_shape() else: - dim_list = c_api.TF_GraphGetTensorShape_wrapper( - graph, self._as_tf_output(), num_dims) - dim_list = [None if i == -1 else i for i in dim_list] - return tensor_shape.TensorShape(dim_list) + assert _USE_C_API + # Call set_shape_and_handle_data_for_outputs in topological order on all + # ops that are needed to compute self.op's shape. We do this instead of + # having set_shape_and_handle_data_for_outputs recursively call + # Operation.shape on self.op.inputs to overflowing the call stack. + need_shapes = self._get_input_ops_without_shapes(self.op) + need_shapes.sort(key=lambda op: op._id) + for op in need_shapes: + set_shape_and_handle_data_for_outputs(op) return self._shape_val + def _get_input_ops_without_shapes(self, target_op): + """Returns ops needing shape inference to compute target_op's shape.""" + result = [] + stack = [self._op] + visited = set() + while stack: + op = stack.pop() + if op in visited: continue + result.append(op) + stack.extend(t.op for t in op.inputs if t._shape_val is None) + visited.add(op) + return result + + def _c_api_shape(self): + """Returns the TensorShape of this tensor according to the C API.""" + c_graph = self._op._graph._c_graph # pylint: disable=protected-access + shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper( + c_graph, self._as_tf_output()) + if unknown_shape: + return tensor_shape.unknown_shape() + else: + shape_vector = [None if d == -1 else d for d in shape_vector] + return tensor_shape.TensorShape(shape_vector) + @property def _shape(self): logging.warning("Tensor._shape is private, use Tensor.shape " @@ -466,8 +504,11 @@ class Tensor(_TensorLike): ValueError: If `shape` is not compatible with the current shape of this tensor. """ - if not _USE_C_SHAPES: # pylint: disable=protected-access - self._shape_val = self._shape_val.merge_with(shape) + if _USE_C_SHAPES: # pylint: disable=protected-access + # Reset cached shape. + self._shape_val = None + else: + self._shape_val = self.shape.merge_with(shape) if not self._op._graph._c_graph: return @@ -579,6 +620,16 @@ class Tensor(_TensorLike): # Necessary to support Python's collection membership operators return id(self) == id(other) + def __copy__(self): + # Make sure _shape_val is computed before we copy. + # TODO(b/77597810): get rid of Tensor copies. + if self._shape_val is None: + set_shape_and_handle_data_for_outputs(self.op) + cls = self.__class__ + result = cls.__new__(cls) + result.__dict__.update(self.__dict__) + return result + # NOTE(mrry): This enables the Tensor's overloaded "right" binary # operators to run when the left operand is an ndarray, because it # accords the Tensor class higher priority than an ndarray, or a @@ -1932,6 +1983,13 @@ class Operation(object): if not isinstance(tensor, Tensor): raise TypeError("tensor must be a Tensor: %s" % tensor) _assert_same_graph(self, tensor) + + # Make sure output shapes are already computed for this op in case we create + # a cycle (we cannot compute shapes for cycles). Usually shapes are computed + # lazily upon request. + if not _USE_C_SHAPES: + set_shape_and_handle_data_for_outputs(self) + if self._c_op: # Reset cached inputs. self._inputs_val = None @@ -2474,35 +2532,41 @@ class RegisterShape(object): return f -def _set_shapes_for_outputs_c_api(op): - """set_shapes_for_outputs implementation when C API is enabled.""" - # The C API computes the shapes when the TF_Operation is created. Fetch the - # output shapes from the C object. +# TODO(b/74620627): remove when _USE_C_SHAPES is removed +def _set_shape_and_handle_data_for_outputs_c_api(op): + """Set shapes and resource handle data using info from the C API.""" + assert not _USE_C_SHAPES for output in op.outputs: - # pylint: disable=protected-access - shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper( + output._shape_val = output._c_api_shape() + # Set the resource handle data for compatibility with the Python shape + # inference code. + serialized = c_api.ResourceHandleShapeAndType( op._graph._c_graph, output._as_tf_output()) - # pylint: enable=protected-access - if unknown_shape: - output.set_shape(tensor_shape.unknown_shape()) - elif not shape_vector: - output.set_shape(tensor_shape.scalar()) - else: - shape_vector = [None if d == -1 else d for d in shape_vector] - output.set_shape(tensor_shape.TensorShape(shape_vector)) - - serialized = c_api.ResourceHandleShapeAndType(op._graph._c_graph, - output._as_tf_output()) if serialized: output._handle_data = ( - cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString( - compat.as_bytes(serialized))) + cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData + .FromString(compat.as_bytes(serialized))) else: output._handle_data = None -# TODO(skyewm): remove this when _USE_C_API flag is removed. -def _set_shapes_for_outputs(op): - """set_shapes_for_outputs implementation when C API is disabled.""" + +# TODO(b/74620627): remove when _USE_C_SHAPES is removed +def set_shape_and_handle_data_for_outputs(op): + """Set the shapes and resource handle data for op's outputs. + + When _USE_C_API = True, this is lazily called when a tensor's shape is first + requested. Usually this should work automatically, but some edge cases may + require manaully calling this first to make sure Tensor._shape_val and + Tensor._handle_data are set (e.g. manually overriding _handle_data, copying a + Tensor). + """ + if _USE_C_SHAPES: return + + if op.graph._is_function(op.type): + for output in op.outputs: + output._shape_val = tensor_shape.unknown_shape() + return + try: shape_func = _shape_registry.lookup(op.type) except LookupError: @@ -2521,8 +2585,10 @@ def _set_shapes_for_outputs(op): shapes = shapes_dict["shapes"] handle_datas = shapes_dict["handle_data"] for output, handle_data in zip(op.outputs, handle_datas): + # Don't override any existing handle data that may have been manually set. # pylint: disable=protected-access - output._handle_data = handle_data + if output._handle_data is None: + output._handle_data = handle_data # pylint: enable=protected-access if len(op.outputs) != len(shapes): @@ -2530,15 +2596,8 @@ def _set_shapes_for_outputs(op): "Shape function for op %s returned %d shapes but expected %d %s %s" % (op, len(shapes), len(op.outputs), shape_func.__name__, str(shapes))) for output, s in zip(op.outputs, shapes): - output.set_shape(s) - - -def set_shapes_for_outputs(op): - """Set the shapes for op's outputs.""" - if op._c_op and _USE_C_SHAPES: # pylint: disable=protected-access - return _set_shapes_for_outputs_c_api(op) - else: - return _set_shapes_for_outputs(op) + output._shape_val = tensor_shape.unknown_shape() + output._shape_val = output._shape_val.merge_with(s) class OpStats(object): @@ -3331,18 +3390,14 @@ class Graph(object): original_op=self._default_original_op, op_def=op_def) - # TODO(vrv): Instead of eagerly filling in shape property for every op, - # only populate the shape when requested. + # Note: shapes are lazily computed with the C API enabled. # # TODO(skyewm): unlike in the original Python implementation, the C API # always computes shape information (even for function calls, which the # original Python shape inference code doesn't handle). Deprecate the # compute_shapes argument. - # - # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES - # is removed - if (ret._c_op and _USE_C_SHAPES) or compute_shapes: # pylint: disable=protected-access - set_shapes_for_outputs(ret) + if not _USE_C_API and compute_shapes: + set_shape_and_handle_data_for_outputs(ret) self._create_op_helper(ret, compute_shapes=compute_shapes, compute_device=compute_device) @@ -3484,18 +3539,17 @@ class Graph(object): for c_op in c_api_util.new_tf_operations(self) ] + # pylint: disable=protected-access for op in new_ops: # Operations created by the C API always retrieve shapes from the C API so # we preserve the shapes of ops created in import_graph_def (from the # "_output_shapes" attr of the imported NodeDef). - # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES - # is removed. - _set_shapes_for_outputs_c_api(op) + if not _USE_C_SHAPES: + _set_shape_and_handle_data_for_outputs_c_api(op) new_control_inputs = self._control_dependencies_for_inputs(op.inputs) - # pylint: disable=protected-access op._add_control_inputs(new_control_inputs) op._control_flow_post_processing() - # pylint: enable=protected-access + # pylint: enable=protected-access return new_ops diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 984bcec..64b0fa6 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -22,7 +22,6 @@ import six from tensorflow.core.framework import tensor_pb2 from tensorflow.core.framework import tensor_shape_pb2 -from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.util import compat @@ -828,7 +827,7 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name Returns: A `TensorShape` based on the constant value of the given `tensor`. """ - if context.executing_eagerly(): + if isinstance(tensor, ops.EagerTensor): return tensor_shape.as_shape( [dim if dim != -1 else None for dim in tensor.numpy()]) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 07e25e5..508ba9b 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -72,7 +72,12 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): # know the shape and dtype of the variable pointed to by a handle. Since # shape inference doesn't run in eager mode we copy this data here for when # the handle is captured by an eager mode function. - handle._handle_data = h._handle_data # pylint: disable=protected-access + # pylint: disable=protected-access + if h._handle_data is None: + ops.set_shape_and_handle_data_for_outputs(h.op) + handle._handle_data = h._handle_data + # pylint: enable=protected-access + # Clean up our reference cycles to avoid making the garbage collector run. # pylint: disable=protected-access # OrderedDict, constructed on Graph creation, makes a simple reference loop