Create tuple if body doesn't return one.
authorAkshay Modi <nareshmodi@google.com>
Fri, 6 Apr 2018 22:06:05 +0000 (15:06 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 22:08:43 +0000 (15:08 -0700)
Fixes #18257.

PiperOrigin-RevId: 191946459

tensorflow/python/ops/control_flow_ops.py
tensorflow/python/ops/control_flow_ops_test.py

index 1278768..e56ab93 100644 (file)
@@ -3181,12 +3181,18 @@ def while_loop(cond,
         body = lambda i, lv: (i + 1, orig_body(*lv))
 
     if context.executing_eagerly():
+      try_to_pack = len(loop_vars) == 1
+      packed = False  # whether the body result was packed into a 1-item tuple
+
       while cond(*loop_vars):
         loop_vars = body(*loop_vars)
+        if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
+          packed = True
+          loop_vars = (loop_vars,)
       if maximum_iterations is not None:
         return loop_vars[1]
       else:
-        return loop_vars
+        return loop_vars[0] if packed else loop_vars
 
     if shape_invariants is not None:
       if maximum_iterations is not None:
index f22f305..289df6f 100644 (file)
@@ -947,5 +947,28 @@ class CaseTest(test_util.TensorFlowTestCase):
         sess.run(output, feed_dict={x: 4})
 
 
+@test_util.with_c_api
+class WhileLoopTestCase(test_util.TensorFlowTestCase):
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testWhileLoopWithSingleVariable(self):
+    i = constant_op.constant(0)
+    c = lambda i: math_ops.less(i, 10)
+    b = lambda i: math_ops.add(i, 1)
+    r = control_flow_ops.while_loop(c, b, [i])
+
+    self.assertEqual(self.evaluate(r), 10)
+
+  @test_util.run_in_graph_and_eager_modes()
+  def testEagerWhileLoopWithSingleVariable_bodyReturnsTuple(self):
+    i = constant_op.constant(0)
+    c = lambda i: math_ops.less(i, 10)
+    b = lambda i: (math_ops.add(i, 1),)
+    r = control_flow_ops.while_loop(c, b, [i])
+
+    # Expect a tuple since that is what the body returns.
+    self.assertEqual(self.evaluate(r), (10,))
+
+
 if __name__ == "__main__":
   googletest.main()