licenses(["notice"]) # Apache 2.0
-exports_files(["LICENSE"])
+package(default_visibility = [
+ "//learning/brain/google/xla/tests:__subpackages__",
+ "//tensorflow:__subpackages__",
+])
-package(default_visibility = ["//tensorflow:__subpackages__"])
+exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
"//tensorflow/python:check_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:control_flow_util",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:functional_ops",
# Mostly a smoke test
time_steps = max_out
+ expected_length = sequence_length
if maximum_iterations is not None:
time_steps = min(max_out, maximum_iterations)
+ expected_length = [min(x, maximum_iterations) for x in expected_length]
self.assertEqual(
_t((batch_size, time_steps, cell_depth)),
sess_results["final_outputs"].rnn_output.shape)
self.assertEqual(
_t((batch_size, time_steps)),
sess_results["final_outputs"].sample_id.shape)
+ self.assertItemsEqual(expected_length,
+ sess_results["final_sequence_length"])
def testDynamicDecodeRNNBatchMajor(self):
self._testDynamicDecodeRNN(time_major=False)
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
type(decoder))
+ def _is_xla_tensor(tensor):
+ try:
+ op = tensor.op
+ except AttributeError:
+ return False
+ if control_flow_util.IsInXLAContext(op):
+ return True
+ return False
+
with variable_scope.variable_scope(scope, "decoder") as varscope:
# Properly cache variable values inside the while_loop
if varscope.caching_device is None:
decoder.output_dtype,
decoder.batch_size)
+ is_xla = False
+ if any([_is_xla_tensor(i) for i in nest.flatten(initial_inputs)]):
+ is_xla = True
+ if is_xla and maximum_iterations is None:
+ raise ValueError("maximum_iterations is required for XLA compilation.")
if maximum_iterations is not None:
initial_finished = math_ops.logical_or(
initial_finished, 0 >= maximum_iterations)
batch_size, name="batch_size"))
return tensor_shape.TensorShape([batch_size]).concatenate(from_shape)
+ dynamic_size = maximum_iterations is None or not is_xla
+
def _create_ta(s, d):
return tensor_array_ops.TensorArray(
dtype=d,
- size=0,
- dynamic_size=True,
+ size=0 if dynamic_size else maximum_iterations,
+ dynamic_size=dynamic_size,
element_shape=_shape(decoder.batch_size, s))
initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size,
next_finished = decoder_finished
else:
next_finished = math_ops.logical_or(decoder_finished, finished)
- if maximum_iterations is not None:
- next_finished = math_ops.logical_or(
- next_finished, time + 1 >= maximum_iterations)
next_sequence_lengths = array_ops.where(
- math_ops.logical_and(math_ops.logical_not(finished), next_finished),
+ math_ops.logical_not(finished),
array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
sequence_lengths)
res = control_flow_ops.while_loop(
condition,
body,
- loop_vars=[
- initial_time, initial_outputs_ta, initial_state, initial_inputs,
- initial_finished, initial_sequence_lengths,
- ],
+ loop_vars=(
+ initial_time,
+ initial_outputs_ta,
+ initial_state,
+ initial_inputs,
+ initial_finished,
+ initial_sequence_lengths,
+ ),
parallel_iterations=parallel_iterations,
+ maximum_iterations=maximum_iterations,
swap_memory=swap_memory)
final_outputs_ta = res[1]