From ea887d2d13a686990145b65e11701deae676b28b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 5 Apr 2018 15:25:32 -0700 Subject: [PATCH] Add for and while loops to the list of operators. Do not use them yet. PiperOrigin-RevId: 191807973 --- tensorflow/contrib/autograph/operators/BUILD | 17 +- tensorflow/contrib/autograph/operators/__init__.py | 5 + .../contrib/autograph/operators/control_flow.py | 179 +++++++++++++++++++++ .../autograph/operators/control_flow_test.py | 82 ++++++++++ 4 files changed, 282 insertions(+), 1 deletion(-) create mode 100644 tensorflow/contrib/autograph/operators/control_flow.py create mode 100644 tensorflow/contrib/autograph/operators/control_flow_test.py diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD index 7856c25..4c62468 100644 --- a/tensorflow/contrib/autograph/operators/BUILD +++ b/tensorflow/contrib/autograph/operators/BUILD @@ -2,6 +2,8 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) +load("//tensorflow:tensorflow.bzl", "py_test") + filegroup( name = "all_files", srcs = glob( @@ -18,8 +20,21 @@ py_library( name = "operators", srcs = [ "__init__.py", + "control_flow.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], - deps = [], + deps = [ + "//tensorflow/contrib/autograph/utils", + ], +) + +py_test( + name = "control_flow_test", + srcs = ["control_flow_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":operators", + "//tensorflow/python:client_testlib", + ], ) diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py index c3f4cab..04b4734 100644 --- a/tensorflow/contrib/autograph/operators/__init__.py +++ b/tensorflow/contrib/autograph/operators/__init__.py @@ -22,3 +22,8 @@ closures for the body. from __future__ import absolute_import from __future__ import division from __future__ import print_function + +# TODO(mdan): Add a container for implementation-specific toggles (throughout). + +from tensorflow.contrib.autograph.operators.control_flow import for_loop +from tensorflow.contrib.autograph.operators.control_flow import while_loop diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py new file mode 100644 index 0000000..5b8cb2d --- /dev/null +++ b/tensorflow/contrib/autograph/operators/control_flow.py @@ -0,0 +1,179 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Control flow statements: loops, conditionals, etc.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph.utils import builtins +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_math_ops + + +def for_loop(iterated, extra_cond, loop_body, init_state): + """Functional form of a for statement. + + The loop operates on a so-called state, which includes all symbols that are + variant across loop iterations, excluding the iterate. In what follows we + refer to state as either a tuple of entities that represent an actual state, + or a list of arguments of the corresponding types. + + Args: + iterated: The entity being iterated over. + extra_cond: Callable with the state as arguments, and boolean return type. + An additionnal loop condition. + loop_body: Callable with the iterate and the state as arguments, and + state as return type. The actual loop body. + init_state: Tuple containing the initial state. + + Returns: + Tuple containing the final state. + """ + if tensor_util.is_tensor(iterated): + return _known_len_for_loop(iterated, extra_cond, loop_body, init_state) + elif isinstance(iterated, dataset_ops.Dataset): + return _dataset_for_loop(iterated, extra_cond, loop_body, init_state) + else: + return _py_for_loop(iterated, extra_cond, loop_body, init_state) + + +def _py_for_loop(iterated, extra_cond, loop_body, init_state): + """Overload of for_loop that executes a Python for loop.""" + state = init_state + for iterate in iterated: + if not extra_cond(*state): + break + state = loop_body(iterate, *state) + + # TODO(mdan): Remove this special case. + if len(state) == 1: + return state[0] + return state + + +def _known_len_for_loop(iterated, extra_cond, loop_body, init_state): + """Overload of for_loop that iterates over objects that define a length.""" + n = builtins.dynamic_len(iterated) + + def while_body(iterate_index, *state): + iterate = iterated[iterate_index] + new_state = loop_body(iterate, *state) + return (iterate_index + 1,) + new_state + + def while_cond(iterate_index, *state): + return gen_math_ops.logical_and(iterate_index < n, extra_cond(*state)) + + results = while_loop( + while_cond, + while_body, + init_state=(0,) + init_state, + extra_deps=(iterated,)) + # Dropping the iteration index because it's not syntactically visible. + results = results[1:] + + # TODO(mdan): Remove this special case. + if len(results) == 1: + return results[0] + return results + + +def _dataset_for_loop(ds, extra_cond, loop_body, init_state): + """Overload of for_loop that iterates over TF Datasets.""" + # Because Datsets only expose get_next, in the style of Python iterators, + # we are forced to unpack the loop as: + # + # epoch_number, iterate = ds.get_next() + # while epoch_number < 2: + # + # epoch_number, iterate = ds.get_next() + epoch_numbers = dataset_ops.Dataset.range(2) + def tag_with(ds, tag): + return dataset_ops.Dataset.zip( + (dataset_ops.Dataset.from_tensors(tag).repeat(), ds)) + ds_with_epoch = epoch_numbers.flat_map(lambda i: tag_with(ds, i)) + + iterator = ds_with_epoch.make_initializable_iterator() + with ops.control_dependencies((iterator.initializer,)): + epoch_number, iterate = iterator.get_next() + + def while_body(epoch_number, iterate, *state): + new_state = loop_body(iterate, *state) + epoch_number, iterate = iterator.get_next() + return (epoch_number, iterate) + new_state + + def while_cond(epoch_number, iterate, *state): + del iterate + return gen_math_ops.logical_and(epoch_number < 1, extra_cond(*state)) + + results = while_loop( + while_cond, + while_body, + init_state=(epoch_number, iterate) + init_state, + extra_deps=()) + # Dropping the epoch number and iterate because they are not not syntactically + # visible. + results = results[2:] + + # TODO(mdan): Remove this special case. + if len(results) == 1: + return results[0] + return results + + +def while_loop(loop_cond, loop_body, init_state, extra_deps): + """Functional form of a while statement. + + The loop operates on a so-called state, which includes all symbols that are + variant across loop iterations. In what follows we refer to state as either + a tuple of entities that represent an actual state, or a list of arguments + of the corresponding types. + + Args: + loop_cond: Callable with the state as arguments, and boolean return type. + The loop condition. + loop_body: Callable with the state as arguments, and state as return type. + The actual loop body. + init_state: Tuple containing the initial state. + extra_deps: Tuple containing additional entities on which the loop may + depend, such as loop invariants referenced by loop_cond. Used + exclusively for dispatch control. + + Returns: + Tuple containing the final state. + """ + # TODO(mdan): Consider adding a generic mechanism for dynamic dispatch. + # That could be somethins as simple as a collection of dispatch rules, with + # some prioritization. + if any(tensor_util.is_tensor(v) for v in init_state + extra_deps): + return _tf_while_loop(loop_cond, loop_body, init_state) + else: + return _py_while_loop(loop_cond, loop_body, init_state) + + +def _tf_while_loop(loop_cond, loop_body, init_state): + """Overload of while_loop that stages a TF while_loop.""" + return control_flow_ops.while_loop(loop_cond, loop_body, init_state) + + +def _py_while_loop(loop_cond, loop_body, init_state): + """Overload of while_loop that executes a Python while loop.""" + state = init_state + while loop_cond(*state): + state = loop_body(*state) + return state diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/contrib/autograph/operators/control_flow_test.py new file mode 100644 index 0000000..9112b16 --- /dev/null +++ b/tensorflow/contrib/autograph/operators/control_flow_test.py @@ -0,0 +1,82 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for control_flow module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.autograph import operators +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class ForLoopTest(test.TestCase): + + def test_tensor(self): + s = operators.for_loop( + constant_op.constant([1, 2, 3, 4]), + extra_cond=lambda s: True, + loop_body=lambda i, s: (s + i,), + init_state=(0,)) + with self.test_session() as sess: + self.assertEqual((10,), sess.run(s)) + + def test_python(self): + s = operators.for_loop( + range(5), + extra_cond=lambda s: True, + loop_body=lambda i, s: (s + i,), + init_state=(0,)) + self.assertEqual(10, s) + + def test_dataset(self): + to_int32 = lambda i: math_ops.cast(i, dtypes.int32) + s = operators.for_loop( + dataset_ops.Dataset.range(5).map(to_int32), + extra_cond=lambda s: True, + loop_body=lambda i, s: (s + i,), + init_state=(0,)) + with self.test_session() as sess: + self.assertEqual((10,), sess.run(s)) + + +class WhileLoopTest(test.TestCase): + + def test_tensor(self): + n = constant_op.constant(5) + results = operators.while_loop( + loop_cond=lambda i, s: i < n, + loop_body=lambda i, s: (i + 1, s + i,), + init_state=(0, 0), + extra_deps=(n,)) + with self.test_session() as sess: + self.assertEqual((5, 10), sess.run(results)) + + def test_python(self): + n = 5 + results = operators.while_loop( + loop_cond=lambda i, s: i < n, + loop_body=lambda i, s: (i + 1, s + i), + init_state=(0, 0), + extra_deps=(n,)) + self.assertEqual((5, 10), results) + + +if __name__ == '__main__': + test.main() -- 2.7.4