dtype = dtype or input_pack([elem.dtype for elem in elems_flat])
dtype_flat = output_flatten(dtype)
- # Convert elems to tensor array.
- n = array_ops.shape(elems_flat[0])[0]
+ # Convert elems to tensor array. n may be known statically.
+ n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0]
# TensorArrays are always flat
elems_ta = [
elems_flat = [
ops.convert_to_tensor(elem, name="elem") for elem in elems_flat]
- n = array_ops.shape(elems_flat[0])[0]
+ # Convert elems to tensor array. n may be known statically.
+ n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0]
# TensorArrays are always flat
elems_ta = [
_, _, r_a = control_flow_ops.while_loop(
lambda i, _1, _2: i < n, compute, (i, a_flat, accs_ta),
parallel_iterations=parallel_iterations,
- back_prop=back_prop, swap_memory=swap_memory)
+ back_prop=back_prop, swap_memory=swap_memory,
+ maximum_iterations=n)
results_flat = [r.stack() for r in r_a]