self._op = op
self._value_index = value_index
self._dtype = dtypes.as_dtype(dtype)
- self._shape_val = tensor_shape.unknown_shape()
+
+ if _USE_C_API:
+ # This will be set by set_shape_and_handle_data_for_outputs.
+ self._shape_val = None
+ else:
+ # The Python code requires all tensors start with a shape to support shape
+ # inference on imported while loops. This isn't necessary with the C API
+ # enabled because the C API provides the shapes for imported nodes.
+ # TODO(skyewm): remove when _USE_C_API is removed.
+ self._shape_val = tensor_shape.unknown_shape()
+
# List of operations that use this Tensor as input. We maintain this list
# to easily navigate a computation graph.
self._consumers = []
- # Attributes used for C++ shape inference. Not inspected, only forwarded.
- # If set, will be a HandleData object from cpp_shape_inference.proto.
- # TODO(b/74620627): remove when _USE_C_SHAPES is removed
- self._handle_data = None
+ if not _USE_C_SHAPES:
+ # Attributes used for C++ shape inference. Not inspected, only forwarded.
+ # If set, will be a HandleData object from cpp_shape_inference.proto.
+ self._handle_data = None
+
self._id = uid()
@property
A `TensorShape` representing the shape of this tensor.
"""
- graph = self._op._graph._c_graph # pylint: disable=protected-access
- if graph and _USE_C_SHAPES:
- num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output())
- if num_dims == -1:
- dim_list = None
+ if self._shape_val is None:
+ if _USE_C_SHAPES:
+ self._shape_val = self._c_api_shape()
else:
- dim_list = c_api.TF_GraphGetTensorShape_wrapper(
- graph, self._as_tf_output(), num_dims)
- dim_list = [None if i == -1 else i for i in dim_list]
- return tensor_shape.TensorShape(dim_list)
+ assert _USE_C_API
+ # Call set_shape_and_handle_data_for_outputs in topological order on all
+ # ops that are needed to compute self.op's shape. We do this instead of
+ # having set_shape_and_handle_data_for_outputs recursively call
+ # Operation.shape on self.op.inputs to overflowing the call stack.
+ need_shapes = self._get_input_ops_without_shapes(self.op)
+ need_shapes.sort(key=lambda op: op._id)
+ for op in need_shapes:
+ set_shape_and_handle_data_for_outputs(op)
return self._shape_val
+ def _get_input_ops_without_shapes(self, target_op):
+ """Returns ops needing shape inference to compute target_op's shape."""
+ result = []
+ stack = [self._op]
+ visited = set()
+ while stack:
+ op = stack.pop()
+ if op in visited: continue
+ result.append(op)
+ stack.extend(t.op for t in op.inputs if t._shape_val is None)
+ visited.add(op)
+ return result
+
+ def _c_api_shape(self):
+ """Returns the TensorShape of this tensor according to the C API."""
+ c_graph = self._op._graph._c_graph # pylint: disable=protected-access
+ shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
+ c_graph, self._as_tf_output())
+ if unknown_shape:
+ return tensor_shape.unknown_shape()
+ else:
+ shape_vector = [None if d == -1 else d for d in shape_vector]
+ return tensor_shape.TensorShape(shape_vector)
+
@property
def _shape(self):
logging.warning("Tensor._shape is private, use Tensor.shape "
ValueError: If `shape` is not compatible with the current shape of
this tensor.
"""
- if not _USE_C_SHAPES: # pylint: disable=protected-access
- self._shape_val = self._shape_val.merge_with(shape)
+ if _USE_C_SHAPES: # pylint: disable=protected-access
+ # Reset cached shape.
+ self._shape_val = None
+ else:
+ self._shape_val = self.shape.merge_with(shape)
if not self._op._graph._c_graph: return
# Necessary to support Python's collection membership operators
return id(self) == id(other)
+ def __copy__(self):
+ # Make sure _shape_val is computed before we copy.
+ # TODO(b/77597810): get rid of Tensor copies.
+ if self._shape_val is None:
+ set_shape_and_handle_data_for_outputs(self.op)
+ cls = self.__class__
+ result = cls.__new__(cls)
+ result.__dict__.update(self.__dict__)
+ return result
+
# NOTE(mrry): This enables the Tensor's overloaded "right" binary
# operators to run when the left operand is an ndarray, because it
# accords the Tensor class higher priority than an ndarray, or a
if not isinstance(tensor, Tensor):
raise TypeError("tensor must be a Tensor: %s" % tensor)
_assert_same_graph(self, tensor)
+
+ # Make sure output shapes are already computed for this op in case we create
+ # a cycle (we cannot compute shapes for cycles). Usually shapes are computed
+ # lazily upon request.
+ if not _USE_C_SHAPES:
+ set_shape_and_handle_data_for_outputs(self)
+
if self._c_op:
# Reset cached inputs.
self._inputs_val = None
return f
-def _set_shapes_for_outputs_c_api(op):
- """set_shapes_for_outputs implementation when C API is enabled."""
- # The C API computes the shapes when the TF_Operation is created. Fetch the
- # output shapes from the C object.
+# TODO(b/74620627): remove when _USE_C_SHAPES is removed
+def _set_shape_and_handle_data_for_outputs_c_api(op):
+ """Set shapes and resource handle data using info from the C API."""
+ assert not _USE_C_SHAPES
for output in op.outputs:
- # pylint: disable=protected-access
- shape_vector, unknown_shape = c_api.TF_GraphGetTensorShapeHelper(
+ output._shape_val = output._c_api_shape()
+ # Set the resource handle data for compatibility with the Python shape
+ # inference code.
+ serialized = c_api.ResourceHandleShapeAndType(
op._graph._c_graph, output._as_tf_output())
- # pylint: enable=protected-access
- if unknown_shape:
- output.set_shape(tensor_shape.unknown_shape())
- elif not shape_vector:
- output.set_shape(tensor_shape.scalar())
- else:
- shape_vector = [None if d == -1 else d for d in shape_vector]
- output.set_shape(tensor_shape.TensorShape(shape_vector))
-
- serialized = c_api.ResourceHandleShapeAndType(op._graph._c_graph,
- output._as_tf_output())
if serialized:
output._handle_data = (
- cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
- compat.as_bytes(serialized)))
+ cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData
+ .FromString(compat.as_bytes(serialized)))
else:
output._handle_data = None
-# TODO(skyewm): remove this when _USE_C_API flag is removed.
-def _set_shapes_for_outputs(op):
- """set_shapes_for_outputs implementation when C API is disabled."""
+
+# TODO(b/74620627): remove when _USE_C_SHAPES is removed
+def set_shape_and_handle_data_for_outputs(op):
+ """Set the shapes and resource handle data for op's outputs.
+
+ When _USE_C_API = True, this is lazily called when a tensor's shape is first
+ requested. Usually this should work automatically, but some edge cases may
+ require manaully calling this first to make sure Tensor._shape_val and
+ Tensor._handle_data are set (e.g. manually overriding _handle_data, copying a
+ Tensor).
+ """
+ if _USE_C_SHAPES: return
+
+ if op.graph._is_function(op.type):
+ for output in op.outputs:
+ output._shape_val = tensor_shape.unknown_shape()
+ return
+
try:
shape_func = _shape_registry.lookup(op.type)
except LookupError:
shapes = shapes_dict["shapes"]
handle_datas = shapes_dict["handle_data"]
for output, handle_data in zip(op.outputs, handle_datas):
+ # Don't override any existing handle data that may have been manually set.
# pylint: disable=protected-access
- output._handle_data = handle_data
+ if output._handle_data is None:
+ output._handle_data = handle_data
# pylint: enable=protected-access
if len(op.outputs) != len(shapes):
"Shape function for op %s returned %d shapes but expected %d %s %s" %
(op, len(shapes), len(op.outputs), shape_func.__name__, str(shapes)))
for output, s in zip(op.outputs, shapes):
- output.set_shape(s)
-
-
-def set_shapes_for_outputs(op):
- """Set the shapes for op's outputs."""
- if op._c_op and _USE_C_SHAPES: # pylint: disable=protected-access
- return _set_shapes_for_outputs_c_api(op)
- else:
- return _set_shapes_for_outputs(op)
+ output._shape_val = tensor_shape.unknown_shape()
+ output._shape_val = output._shape_val.merge_with(s)
class OpStats(object):
original_op=self._default_original_op,
op_def=op_def)
- # TODO(vrv): Instead of eagerly filling in shape property for every op,
- # only populate the shape when requested.
+ # Note: shapes are lazily computed with the C API enabled.
#
# TODO(skyewm): unlike in the original Python implementation, the C API
# always computes shape information (even for function calls, which the
# original Python shape inference code doesn't handle). Deprecate the
# compute_shapes argument.
- #
- # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES
- # is removed
- if (ret._c_op and _USE_C_SHAPES) or compute_shapes: # pylint: disable=protected-access
- set_shapes_for_outputs(ret)
+ if not _USE_C_API and compute_shapes:
+ set_shape_and_handle_data_for_outputs(ret)
self._create_op_helper(ret, compute_shapes=compute_shapes,
compute_device=compute_device)
for c_op in c_api_util.new_tf_operations(self)
]
+ # pylint: disable=protected-access
for op in new_ops:
# Operations created by the C API always retrieve shapes from the C API so
# we preserve the shapes of ops created in import_graph_def (from the
# "_output_shapes" attr of the imported NodeDef).
- # TODO(b/74620627): move this back to _create_op_helper once _USE_C_SHAPES
- # is removed.
- _set_shapes_for_outputs_c_api(op)
+ if not _USE_C_SHAPES:
+ _set_shape_and_handle_data_for_outputs_c_api(op)
new_control_inputs = self._control_dependencies_for_inputs(op.inputs)
- # pylint: disable=protected-access
op._add_control_inputs(new_control_inputs)
op._control_flow_post_processing()
- # pylint: enable=protected-access
+ # pylint: enable=protected-access
return new_ops