From d3fb437da12fc326d8229bdb955580c63eaccb5f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 16 Apr 2018 12:09:33 -0700 Subject: [PATCH] Copy the if statement handlers over to the operators module. They will enabled in a follow-up CL. PiperOrigin-RevId: 193078348 --- .../contrib/autograph/operators/control_flow.py | 32 ++++++++++++++++++++++ .../autograph/operators/control_flow_test.py | 29 ++++++++++++++++---- 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py index 81ae64f..d9d8b0d 100644 --- a/tensorflow/contrib/autograph/operators/control_flow.py +++ b/tensorflow/contrib/autograph/operators/control_flow.py @@ -25,6 +25,9 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_math_ops +# TODO(mdan): Rename _loop to _stmt to follow Python nomenclature. +# TODO(mdan): Rename arguments to match the AST names. + def for_loop(iterated, extra_cond, loop_body, init_state): """Functional form of a for statement. @@ -182,3 +185,32 @@ def _py_while_loop(loop_cond, loop_body, init_state, opts): while loop_cond(*state): state = loop_body(*state) return state + + +def if_stmt(cond, body, orelse): + """Functional form of an if statement. + + Args: + cond: Boolean. + body: Callable with no arguments, and outputs of the positive (if) branch + as return type. + orelse: Callable with no arguments, and outputs of the negative (else) + branch as return type. + + Returns: + Tuple containing the statement outputs. + """ + if tensor_util.is_tensor(cond): + return _tf_if_stmt(cond, body, orelse) + else: + return _py_if_stmt(cond, body, orelse) + + +def _tf_if_stmt(cond, body, orelse): + """Overload of if_stmt that stages a TF cond.""" + return control_flow_ops.cond(cond, body, orelse) + + +def _py_if_stmt(cond, body, orelse): + """Overload of if_stmt that executes a Python if statement.""" + return body() if cond else orelse() diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/contrib/autograph/operators/control_flow_test.py index 9112b16..a0cd0bf 100644 --- a/tensorflow/contrib/autograph/operators/control_flow_test.py +++ b/tensorflow/contrib/autograph/operators/control_flow_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.autograph import operators +from tensorflow.contrib.autograph.operators import control_flow from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -29,7 +29,7 @@ from tensorflow.python.platform import test class ForLoopTest(test.TestCase): def test_tensor(self): - s = operators.for_loop( + s = control_flow.for_loop( constant_op.constant([1, 2, 3, 4]), extra_cond=lambda s: True, loop_body=lambda i, s: (s + i,), @@ -38,7 +38,7 @@ class ForLoopTest(test.TestCase): self.assertEqual((10,), sess.run(s)) def test_python(self): - s = operators.for_loop( + s = control_flow.for_loop( range(5), extra_cond=lambda s: True, loop_body=lambda i, s: (s + i,), @@ -47,7 +47,7 @@ class ForLoopTest(test.TestCase): def test_dataset(self): to_int32 = lambda i: math_ops.cast(i, dtypes.int32) - s = operators.for_loop( + s = control_flow.for_loop( dataset_ops.Dataset.range(5).map(to_int32), extra_cond=lambda s: True, loop_body=lambda i, s: (s + i,), @@ -60,7 +60,7 @@ class WhileLoopTest(test.TestCase): def test_tensor(self): n = constant_op.constant(5) - results = operators.while_loop( + results = control_flow.while_loop( loop_cond=lambda i, s: i < n, loop_body=lambda i, s: (i + 1, s + i,), init_state=(0, 0), @@ -70,7 +70,7 @@ class WhileLoopTest(test.TestCase): def test_python(self): n = 5 - results = operators.while_loop( + results = control_flow.while_loop( loop_cond=lambda i, s: i < n, loop_body=lambda i, s: (i + 1, s + i), init_state=(0, 0), @@ -78,5 +78,22 @@ class WhileLoopTest(test.TestCase): self.assertEqual((5, 10), results) +class IfStmtTest(test.TestCase): + + def test_tensor(self): + def test_if_stmt(cond): + return control_flow.if_stmt( + cond=cond, + body=lambda: 1, + orelse=lambda: -1) + with self.test_session() as sess: + self.assertEqual(1, sess.run(test_if_stmt(constant_op.constant(True)))) + self.assertEqual(-1, sess.run(test_if_stmt(constant_op.constant(False)))) + + def test_python(self): + self.assertEqual(1, control_flow.if_stmt(True, lambda: 1, lambda: -1)) + self.assertEqual(-1, control_flow.if_stmt(False, lambda: 1, lambda: -1)) + + if __name__ == '__main__': test.main() -- 2.7.4