Specify the `maximum_iterations` to tf.while_loop in tf.scan to be compatible with...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Sun, 11 Mar 2018 17:44:02 +0000 (10:44 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 11 Mar 2018 17:47:47 +0000 (10:47 -0700)
PiperOrigin-RevId: 188652533

tensorflow/python/ops/functional_ops.py

index 8f56735..a840b1e 100644 (file)
@@ -364,8 +364,8 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
     dtype = dtype or input_pack([elem.dtype for elem in elems_flat])
     dtype_flat = output_flatten(dtype)
 
-    # Convert elems to tensor array.
-    n = array_ops.shape(elems_flat[0])[0]
+    # Convert elems to tensor array. n may be known statically.
+    n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0]
 
     # TensorArrays are always flat
     elems_ta = [
@@ -555,7 +555,8 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
     elems_flat = [
         ops.convert_to_tensor(elem, name="elem") for elem in elems_flat]
 
-    n = array_ops.shape(elems_flat[0])[0]
+    # Convert elems to tensor array. n may be known statically.
+    n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0]
 
     # TensorArrays are always flat
     elems_ta = [
@@ -615,7 +616,8 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
     _, _, r_a = control_flow_ops.while_loop(
         lambda i, _1, _2: i < n, compute, (i, a_flat, accs_ta),
         parallel_iterations=parallel_iterations,
-        back_prop=back_prop, swap_memory=swap_memory)
+        back_prop=back_prop, swap_memory=swap_memory,
+        maximum_iterations=n)
 
     results_flat = [r.stack() for r in r_a]