tf.map_fn: Improve error messaging when elems consists of scalars.
authorAsim Shankar <ashankar@google.com>
Tue, 3 Apr 2018 17:51:57 +0000 (10:51 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 3 Apr 2018 17:54:28 +0000 (10:54 -0700)
Fixes #17694
Prior to this change, when tf.map_fn was provided with scalars, the error would
be something like:

Traceback (most recent call last):
  File "/tensorflow/python/kernel_tests/functional_ops_test.py", line 165, in testMapOverScalarErrors
    functional_ops.map_fn(lambda x: x, [1, 2])
  File "/tensorflow/python/ops/functional_ops.py", line 368, in map_fn
    n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0]
  File "/tensorflow/python/framework/tensor_shape.py", line 609, in __getitem__
    return self._dims[key]
IndexError: list index out of range
PiperOrigin-RevId: 191465183

tensorflow/python/kernel_tests/functional_ops_test.py
tensorflow/python/ops/functional_ops.py

index 10aea89..34fb655 100644 (file)
@@ -160,6 +160,13 @@ class FunctionalOpsTest(test.TestCase):
                 values=constant_op.constant([0, 1, 2]),
                 dense_shape=[2, 2]))
 
+  @test_util.run_in_graph_and_eager_modes()
+  def testMapOverScalarErrors(self):
+    with self.assertRaisesRegexp(ValueError, "not scalars"):
+      functional_ops.map_fn(lambda x: x, [1, 2])
+    with self.assertRaisesRegexp(ValueError, "not a scalar"):
+      functional_ops.map_fn(lambda x: x, 1)
+
   def testMap_Scoped(self):
     with self.test_session() as sess:
 
index 4d95ca2..161f6f3 100644 (file)
@@ -367,7 +367,15 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
     dtype_flat = output_flatten(dtype)
 
     # Convert elems to tensor array. n may be known statically.
-    n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0]
+    static_shape = elems_flat[0].shape
+    if static_shape.ndims is not None and static_shape.ndims < 1:
+      if len(elems_flat) == 1:
+        raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar")
+      else:
+        raise ValueError(
+            "elements in elems must be 1+ dimensional Tensors, not scalars"
+        )
+    n = static_shape[0].value or array_ops.shape(elems_flat[0])[0]
 
     # TensorArrays are always flat
     elems_ta = [