body = lambda i, lv: (i + 1, orig_body(*lv))
if context.executing_eagerly():
+ try_to_pack = len(loop_vars) == 1
+ packed = False # whether the body result was packed into a 1-item tuple
+
while cond(*loop_vars):
loop_vars = body(*loop_vars)
+ if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
+ packed = True
+ loop_vars = (loop_vars,)
if maximum_iterations is not None:
return loop_vars[1]
else:
- return loop_vars
+ return loop_vars[0] if packed else loop_vars
if shape_invariants is not None:
if maximum_iterations is not None:
sess.run(output, feed_dict={x: 4})
+@test_util.with_c_api
+class WhileLoopTestCase(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testWhileLoopWithSingleVariable(self):
+ i = constant_op.constant(0)
+ c = lambda i: math_ops.less(i, 10)
+ b = lambda i: math_ops.add(i, 1)
+ r = control_flow_ops.while_loop(c, b, [i])
+
+ self.assertEqual(self.evaluate(r), 10)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testEagerWhileLoopWithSingleVariable_bodyReturnsTuple(self):
+ i = constant_op.constant(0)
+ c = lambda i: math_ops.less(i, 10)
+ b = lambda i: (math_ops.add(i, 1),)
+ r = control_flow_ops.while_loop(c, b, [i])
+
+ # Expect a tuple since that is what the body returns.
+ self.assertEqual(self.evaluate(r), (10,))
+
+
if __name__ == "__main__":
googletest.main()