if isinstance(ret, list):
ret = tuple(ret)
- # Convert any `SparseTensorValue`s to `SparseTensor`s.
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
ret = nest.pack_sequence_as(ret, [
sparse_tensor_lib.SparseTensor.from_value(t)
- if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret)
+ if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t)
+ for t in nest.flatten(ret)
])
self._state_classes = sparse.get_classes(ret)
self._state_types = nest.pack_sequence_as(
ret, [t.dtype for t in nest.flatten(ret)])
- # Serialize any sparse tensors and convert result to tensors.
- ret = nest.pack_sequence_as(ret, [
- ops.convert_to_tensor(t)
- for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
- ])
+ # Serialize any sparse tensors.
+ ret = nest.pack_sequence_as(
+ ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
return nest.flatten(ret)
self._init_func = tf_init_func
if isinstance(ret, list):
ret = tuple(ret)
- # Convert any `SparseTensorValue`s to `SparseTensor`s.
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
ret = nest.pack_sequence_as(ret, [
sparse_tensor_lib.SparseTensor.from_value(t)
- if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret)
+ if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t)
+ for t in nest.flatten(ret)
])
self._output_classes = sparse.get_classes(ret)
self._output_types = nest.pack_sequence_as(
ret, [t.dtype for t in nest.flatten(ret)])
- # Serialize any sparse tensors and convert result to tensors.
- ret = nest.pack_sequence_as(ret, [
- ops.convert_to_tensor(t)
- for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
- ])
+ # Serialize any sparse tensors.
+ ret = nest.pack_sequence_as(
+ ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
return nest.flatten(ret)
self._next_func = tf_next_func
if isinstance(ret, list):
ret = tuple(ret)
- # Convert any `SparseTensorValue`s to `SparseTensor`s.
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
ret = nest.pack_sequence_as(ret, [
sparse_tensor_lib.SparseTensor.from_value(t)
- if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret)
+ if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(t)
+ for t in nest.flatten(ret)
])
self._output_classes = sparse.get_classes(ret)
self._output_types = nest.pack_sequence_as(
ret, [t.dtype for t in nest.flatten(ret)])
- # Serialize any sparse tensors and convert result to tensors.
- ret = nest.pack_sequence_as(ret, [
- ops.convert_to_tensor(t)
- for t in nest.flatten(sparse.serialize_sparse_tensors(ret))
- ])
+ # Serialize any sparse tensors.
+ ret = nest.pack_sequence_as(
+ ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
return nest.flatten(ret)
self._map_func = tf_map_func