Fixes #17694
Prior to this change, when tf.map_fn was provided with scalars, the error would
be something like:
Traceback (most recent call last):
File "/tensorflow/python/kernel_tests/functional_ops_test.py", line 165, in testMapOverScalarErrors
functional_ops.map_fn(lambda x: x, [1, 2])
File "/tensorflow/python/ops/functional_ops.py", line 368, in map_fn
n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0]
File "/tensorflow/python/framework/tensor_shape.py", line 609, in __getitem__
return self._dims[key]
IndexError: list index out of range
PiperOrigin-RevId:
191465183
values=constant_op.constant([0, 1, 2]),
dense_shape=[2, 2]))
+ @test_util.run_in_graph_and_eager_modes()
+ def testMapOverScalarErrors(self):
+ with self.assertRaisesRegexp(ValueError, "not scalars"):
+ functional_ops.map_fn(lambda x: x, [1, 2])
+ with self.assertRaisesRegexp(ValueError, "not a scalar"):
+ functional_ops.map_fn(lambda x: x, 1)
+
def testMap_Scoped(self):
with self.test_session() as sess:
dtype_flat = output_flatten(dtype)
# Convert elems to tensor array. n may be known statically.
- n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0]
+ static_shape = elems_flat[0].shape
+ if static_shape.ndims is not None and static_shape.ndims < 1:
+ if len(elems_flat) == 1:
+ raise ValueError("elems must be a 1+ dimensional Tensor, not a scalar")
+ else:
+ raise ValueError(
+ "elements in elems must be 1+ dimensional Tensors, not scalars"
+ )
+ n = static_shape[0].value or array_ops.shape(elems_flat[0])[0]
# TensorArrays are always flat
elems_ta = [