[TF RNN] Small optimization to rnn: only calculate copy_cond once.
authorEugene Brevdo <ebrevdo@google.com>
Fri, 2 Feb 2018 22:30:50 +0000 (14:30 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 2 Feb 2018 22:35:08 +0000 (14:35 -0800)
PiperOrigin-RevId: 184335231

tensorflow/python/ops/rnn.py

index a10e196..24c6f64 100644 (file)
@@ -171,11 +171,11 @@ def _rnn_step(
   return (final_output, final_state)
 
   Args:
-    time: Python int, the current time step
-    sequence_length: int32 `Tensor` vector of size [batch_size]
-    min_sequence_length: int32 `Tensor` scalar, min of sequence_length
-    max_sequence_length: int32 `Tensor` scalar, max of sequence_length
-    zero_output: `Tensor` vector of shape [output_size]
+    time: int32 `Tensor` scalar.
+    sequence_length: int32 `Tensor` vector of size [batch_size].
+    min_sequence_length: int32 `Tensor` scalar, min of sequence_length.
+    max_sequence_length: int32 `Tensor` scalar, max of sequence_length.
+    zero_output: `Tensor` vector of shape [output_size].
     state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`,
       or a list/tuple of such tensors.
     call_cell: lambda returning tuple of (new_output, new_state) where
@@ -202,6 +202,9 @@ def _rnn_step(
   flat_state = nest.flatten(state)
   flat_zero_output = nest.flatten(zero_output)
 
+  # Vector describing which batch entries are finished.
+  copy_cond = time >= sequence_length
+
   def _copy_one_through(output, new_output):
     # TensorArray and scalar get passed through.
     if isinstance(output, tensor_array_ops.TensorArray):
@@ -209,7 +212,6 @@ def _rnn_step(
     if output.shape.ndims == 0:
       return new_output
     # Otherwise propagate the old or the new value.
-    copy_cond = (time >= sequence_length)
     with ops.colocate_with(new_output):
       return array_ops.where(copy_cond, output, new_output)