From bc410d9c0133673e7b93a49487d7e14758cba280 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 16 Apr 2018 14:13:52 -0700 Subject: [PATCH] Use fixed sized tensor arrays and max loop iterations in dynamic_decode if the user supplies it and if the inputs were created in an XLA context. PiperOrigin-RevId: 193097293 --- tensorflow/contrib/seq2seq/BUILD | 8 +++-- .../seq2seq/python/kernel_tests/decoder_test.py | 4 +++ tensorflow/contrib/seq2seq/python/ops/decoder.py | 39 ++++++++++++++++------ 3 files changed, 39 insertions(+), 12 deletions(-) diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index a62069a..1a1591d 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -3,9 +3,12 @@ licenses(["notice"]) # Apache 2.0 -exports_files(["LICENSE"]) +package(default_visibility = [ + "//learning/brain/google/xla/tests:__subpackages__", + "//tensorflow:__subpackages__", +]) -package(default_visibility = ["//tensorflow:__subpackages__"]) +exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") @@ -38,6 +41,7 @@ tf_custom_op_py_library( "//tensorflow/python:check_ops", "//tensorflow/python:clip_ops", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:control_flow_util", "//tensorflow/python:embedding_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:functional_ops", diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py index ac830ae..b549cbf 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/decoder_test.py @@ -92,14 +92,18 @@ class DynamicDecodeRNNTest(test.TestCase): # Mostly a smoke test time_steps = max_out + expected_length = sequence_length if maximum_iterations is not None: time_steps = min(max_out, maximum_iterations) + expected_length = [min(x, maximum_iterations) for x in expected_length] self.assertEqual( _t((batch_size, time_steps, cell_depth)), sess_results["final_outputs"].rnn_output.shape) self.assertEqual( _t((batch_size, time_steps)), sess_results["final_outputs"].sample_id.shape) + self.assertItemsEqual(expected_length, + sess_results["final_sequence_length"]) def testDynamicDecodeRNNBatchMajor(self): self._testDynamicDecodeRNN(time_major=False) diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index 8984936..e69725f 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell_impl @@ -181,6 +182,15 @@ def dynamic_decode(decoder, raise TypeError("Expected decoder to be type Decoder, but saw: %s" % type(decoder)) + def _is_xla_tensor(tensor): + try: + op = tensor.op + except AttributeError: + return False + if control_flow_util.IsInXLAContext(op): + return True + return False + with variable_scope.variable_scope(scope, "decoder") as varscope: # Properly cache variable values inside the while_loop if varscope.caching_device is None: @@ -198,6 +208,11 @@ def dynamic_decode(decoder, decoder.output_dtype, decoder.batch_size) + is_xla = False + if any([_is_xla_tensor(i) for i in nest.flatten(initial_inputs)]): + is_xla = True + if is_xla and maximum_iterations is None: + raise ValueError("maximum_iterations is required for XLA compilation.") if maximum_iterations is not None: initial_finished = math_ops.logical_or( initial_finished, 0 >= maximum_iterations) @@ -215,11 +230,13 @@ def dynamic_decode(decoder, batch_size, name="batch_size")) return tensor_shape.TensorShape([batch_size]).concatenate(from_shape) + dynamic_size = maximum_iterations is None or not is_xla + def _create_ta(s, d): return tensor_array_ops.TensorArray( dtype=d, - size=0, - dynamic_size=True, + size=0 if dynamic_size else maximum_iterations, + dynamic_size=dynamic_size, element_shape=_shape(decoder.batch_size, s)) initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size, @@ -251,11 +268,8 @@ def dynamic_decode(decoder, next_finished = decoder_finished else: next_finished = math_ops.logical_or(decoder_finished, finished) - if maximum_iterations is not None: - next_finished = math_ops.logical_or( - next_finished, time + 1 >= maximum_iterations) next_sequence_lengths = array_ops.where( - math_ops.logical_and(math_ops.logical_not(finished), next_finished), + math_ops.logical_not(finished), array_ops.fill(array_ops.shape(sequence_lengths), time + 1), sequence_lengths) @@ -296,11 +310,16 @@ def dynamic_decode(decoder, res = control_flow_ops.while_loop( condition, body, - loop_vars=[ - initial_time, initial_outputs_ta, initial_state, initial_inputs, - initial_finished, initial_sequence_lengths, - ], + loop_vars=( + initial_time, + initial_outputs_ta, + initial_state, + initial_inputs, + initial_finished, + initial_sequence_lengths, + ), parallel_iterations=parallel_iterations, + maximum_iterations=maximum_iterations, swap_memory=swap_memory) final_outputs_ta = res[1] -- 2.7.4