Use fixed sized tensor arrays and max loop iterations in dynamic_decode if the user...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 16 Apr 2018 21:13:52 +0000 (14:13 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 16 Apr 2018 21:16:53 +0000 (14:16 -0700)
PiperOrigin-RevId: 193097293

tensorflow/contrib/seq2seq/BUILD
tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py
tensorflow/contrib/seq2seq/python/ops/decoder.py

index a62069a..1a1591d 100644 (file)
@@ -3,9 +3,12 @@
 
 licenses(["notice"])  # Apache 2.0
 
-exports_files(["LICENSE"])
+package(default_visibility = [
+    "//learning/brain/google/xla/tests:__subpackages__",
+    "//tensorflow:__subpackages__",
+])
 
-package(default_visibility = ["//tensorflow:__subpackages__"])
+exports_files(["LICENSE"])
 
 load("//tensorflow:tensorflow.bzl", "cuda_py_test")
 load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
@@ -38,6 +41,7 @@ tf_custom_op_py_library(
         "//tensorflow/python:check_ops",
         "//tensorflow/python:clip_ops",
         "//tensorflow/python:control_flow_ops",
+        "//tensorflow/python:control_flow_util",
         "//tensorflow/python:embedding_ops",
         "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:functional_ops",
index ac830ae..b549cbf 100644 (file)
@@ -92,14 +92,18 @@ class DynamicDecodeRNNTest(test.TestCase):
 
       # Mostly a smoke test
       time_steps = max_out
+      expected_length = sequence_length
       if maximum_iterations is not None:
         time_steps = min(max_out, maximum_iterations)
+        expected_length = [min(x, maximum_iterations) for x in expected_length]
       self.assertEqual(
           _t((batch_size, time_steps, cell_depth)),
           sess_results["final_outputs"].rnn_output.shape)
       self.assertEqual(
           _t((batch_size, time_steps)),
           sess_results["final_outputs"].sample_id.shape)
+      self.assertItemsEqual(expected_length,
+                            sess_results["final_sequence_length"])
 
   def testDynamicDecodeRNNBatchMajor(self):
     self._testDynamicDecodeRNN(time_major=False)
index 8984936..e69725f 100644 (file)
@@ -28,6 +28,7 @@ from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import rnn
 from tensorflow.python.ops import rnn_cell_impl
@@ -181,6 +182,15 @@ def dynamic_decode(decoder,
     raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
                     type(decoder))
 
+  def _is_xla_tensor(tensor):
+    try:
+      op = tensor.op
+    except AttributeError:
+      return False
+    if control_flow_util.IsInXLAContext(op):
+      return True
+    return False
+
   with variable_scope.variable_scope(scope, "decoder") as varscope:
     # Properly cache variable values inside the while_loop
     if varscope.caching_device is None:
@@ -198,6 +208,11 @@ def dynamic_decode(decoder,
                                         decoder.output_dtype,
                                         decoder.batch_size)
 
+    is_xla = False
+    if any([_is_xla_tensor(i) for i in nest.flatten(initial_inputs)]):
+      is_xla = True
+    if is_xla and maximum_iterations is None:
+      raise ValueError("maximum_iterations is required for XLA compilation.")
     if maximum_iterations is not None:
       initial_finished = math_ops.logical_or(
           initial_finished, 0 >= maximum_iterations)
@@ -215,11 +230,13 @@ def dynamic_decode(decoder,
                 batch_size, name="batch_size"))
         return tensor_shape.TensorShape([batch_size]).concatenate(from_shape)
 
+    dynamic_size = maximum_iterations is None or not is_xla
+
     def _create_ta(s, d):
       return tensor_array_ops.TensorArray(
           dtype=d,
-          size=0,
-          dynamic_size=True,
+          size=0 if dynamic_size else maximum_iterations,
+          dynamic_size=dynamic_size,
           element_shape=_shape(decoder.batch_size, s))
 
     initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size,
@@ -251,11 +268,8 @@ def dynamic_decode(decoder,
         next_finished = decoder_finished
       else:
         next_finished = math_ops.logical_or(decoder_finished, finished)
-      if maximum_iterations is not None:
-        next_finished = math_ops.logical_or(
-            next_finished, time + 1 >= maximum_iterations)
       next_sequence_lengths = array_ops.where(
-          math_ops.logical_and(math_ops.logical_not(finished), next_finished),
+          math_ops.logical_not(finished),
           array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
           sequence_lengths)
 
@@ -296,11 +310,16 @@ def dynamic_decode(decoder,
     res = control_flow_ops.while_loop(
         condition,
         body,
-        loop_vars=[
-            initial_time, initial_outputs_ta, initial_state, initial_inputs,
-            initial_finished, initial_sequence_lengths,
-        ],
+        loop_vars=(
+            initial_time,
+            initial_outputs_ta,
+            initial_state,
+            initial_inputs,
+            initial_finished,
+            initial_sequence_lengths,
+        ),
         parallel_iterations=parallel_iterations,
+        maximum_iterations=maximum_iterations,
         swap_memory=swap_memory)
 
     final_outputs_ta = res[1]