From 41982886efaa2ab9cc75d0d5ab6c27368468d061 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 26 Mar 2018 19:30:26 -0700 Subject: [PATCH] Fix inconsistency in run_cond. PiperOrigin-RevId: 190563114 --- tensorflow/contrib/autograph/converters/ifexp.py | 2 +- tensorflow/contrib/autograph/utils/multiple_dispatch.py | 11 +++++++++-- .../contrib/autograph/utils/multiple_dispatch_test.py | 17 ++++++++--------- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/tensorflow/contrib/autograph/converters/ifexp.py b/tensorflow/contrib/autograph/converters/ifexp.py index aff94d2..bb0c0a3 100644 --- a/tensorflow/contrib/autograph/converters/ifexp.py +++ b/tensorflow/contrib/autograph/converters/ifexp.py @@ -27,7 +27,7 @@ class IfExp(transformer.Base): def visit_IfExp(self, node): template = """ - autograph_utils.run_cond(test, lambda: body, lambda: orelse) + autograph_utils.run_cond(test, lambda: (body,), lambda: (orelse,)) """ desugared_ifexp = templates.replace_as_expression( template, test=node.test, body=node.body, orelse=node.orelse) diff --git a/tensorflow/contrib/autograph/utils/multiple_dispatch.py b/tensorflow/contrib/autograph/utils/multiple_dispatch.py index b756ccf..4704925 100644 --- a/tensorflow/contrib/autograph/utils/multiple_dispatch.py +++ b/tensorflow/contrib/autograph/utils/multiple_dispatch.py @@ -55,10 +55,17 @@ def run_cond(condition, true_fn, false_fn): def py_cond(condition, true_fn, false_fn): + """Functional version of Python's conditional.""" if condition: - return true_fn() + results = true_fn() else: - return false_fn() + results = false_fn() + + # The contract for the branch functions is to return tuples, but they should + # be collapsed to a single element when there is only one output. + if len(results) == 1: + return results[0] + return results def run_while(cond_fn, body_fn, init_args): diff --git a/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py b/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py index 8c7daa6..e6a41bb 100644 --- a/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py +++ b/tensorflow/contrib/autograph/utils/multiple_dispatch_test.py @@ -56,20 +56,19 @@ class MultipleDispatchTest(test.TestCase): self.assertFalse(should_be_false2) def test_run_cond_python(self): - true_fn = lambda: 2.0 - false_fn = lambda: 3.0 - self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2.0) - self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3.0) + true_fn = lambda: (2,) + false_fn = lambda: (3,) + self.assertEqual(multiple_dispatch.run_cond(True, true_fn, false_fn), 2) + self.assertEqual(multiple_dispatch.run_cond(False, true_fn, false_fn), 3) def test_run_cond_tf(self): - - true_fn = lambda: constant([2.0]) - false_fn = lambda: constant([3.0]) + true_fn = lambda: (constant(2),) + false_fn = lambda: (constant(3),) with Session() as sess: out = multiple_dispatch.run_cond(constant(True), true_fn, false_fn) - self.assertEqual(sess.run(out), 2.0) + self.assertEqual(sess.run(out), 2) out = multiple_dispatch.run_cond(constant(False), true_fn, false_fn) - self.assertEqual(sess.run(out), 3.0) + self.assertEqual(sess.run(out), 3) def test_run_while_python(self): cond_fn = lambda x, t, s: x > t -- 2.7.4