From: Asim Shankar Date: Tue, 3 Apr 2018 17:51:57 +0000 (-0700) Subject: tf.map_fn: Improve error messaging when elems consists of scalars. X-Git-Tag: tflite-v0.1.7~39^2^2~85 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=f654b0d15af364d6f43d22a179fa05d20650fe9a;p=platform%2Fupstream%2Ftensorflow.git tf.map_fn: Improve error messaging when elems consists of scalars. Fixes #17694 Prior to this change, when tf.map_fn was provided with scalars, the error would be something like: Traceback (most recent call last): File "/tensorflow/python/kernel_tests/functional_ops_test.py", line 165, in testMapOverScalarErrors functional_ops.map_fn(lambda x: x, [1, 2]) File "/tensorflow/python/ops/functional_ops.py", line 368, in map_fn n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0] File "/tensorflow/python/framework/tensor_shape.py", line 609, in __getitem__ return self._dims[key] IndexError: list index out of range PiperOrigin-RevId: 191465183 --- diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index 10aea89..34fb655 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -160,6 +160,13 @@ class FunctionalOpsTest(test.TestCase): values=constant_op.constant([0, 1, 2]), dense_shape=[2, 2])) + @test_util.run_in_graph_and_eager_modes() + def testMapOverScalarErrors(self): + with self.assertRaisesRegexp(ValueError, "not scalars"): + functional_ops.map_fn(lambda x: x, [1, 2]) + with self.assertRaisesRegexp(ValueError, "not a scalar"): + functional_ops.map_fn(lambda x: x, 1) + def testMap_Scoped(self): with self.test_session() as sess: diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index 4d95ca2..161f6f3 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -367,7 +367,15 @@ def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, dtype_flat = output_flatten(dtype) # Convert elems to tensor array. n may be known statically. - n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0] + static_shape = elems_flat[0].shape + if static_shape.ndims is not None and static_shape.ndims < 1: + if len(elems_flat) == 1: + raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar") + else: + raise ValueError( + "elements in elems must be 1+ dimensional Tensors, not scalars" + ) + n = static_shape[0].value or array_ops.shape(elems_flat[0])[0] # TensorArrays are always flat elems_ta = [