return (final_output, final_state)
Args:
- time: Python int, the current time step
- sequence_length: int32 `Tensor` vector of size [batch_size]
- min_sequence_length: int32 `Tensor` scalar, min of sequence_length
- max_sequence_length: int32 `Tensor` scalar, max of sequence_length
- zero_output: `Tensor` vector of shape [output_size]
+ time: int32 `Tensor` scalar.
+ sequence_length: int32 `Tensor` vector of size [batch_size].
+ min_sequence_length: int32 `Tensor` scalar, min of sequence_length.
+ max_sequence_length: int32 `Tensor` scalar, max of sequence_length.
+ zero_output: `Tensor` vector of shape [output_size].
state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`,
or a list/tuple of such tensors.
call_cell: lambda returning tuple of (new_output, new_state) where
flat_state = nest.flatten(state)
flat_zero_output = nest.flatten(zero_output)
+ # Vector describing which batch entries are finished.
+ copy_cond = time >= sequence_length
+
def _copy_one_through(output, new_output):
# TensorArray and scalar get passed through.
if isinstance(output, tensor_array_ops.TensorArray):
if output.shape.ndims == 0:
return new_output
# Otherwise propagate the old or the new value.
- copy_cond = (time >= sequence_length)
with ops.colocate_with(new_output):
return array_ops.where(copy_cond, output, new_output)