Add for and while loops to the list of operators. Do not use them yet.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 5 Apr 2018 22:25:32 +0000 (15:25 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 5 Apr 2018 22:30:04 +0000 (15:30 -0700)
PiperOrigin-RevId: 191807973

tensorflow/contrib/autograph/operators/BUILD
tensorflow/contrib/autograph/operators/__init__.py
tensorflow/contrib/autograph/operators/control_flow.py [new file with mode: 0644]
tensorflow/contrib/autograph/operators/control_flow_test.py [new file with mode: 0644]

index 7856c25..4c62468 100644 (file)
@@ -2,6 +2,8 @@ licenses(["notice"])  # Apache 2.0
 
 exports_files(["LICENSE"])
 
+load("//tensorflow:tensorflow.bzl", "py_test")
+
 filegroup(
     name = "all_files",
     srcs = glob(
@@ -18,8 +20,21 @@ py_library(
     name = "operators",
     srcs = [
         "__init__.py",
+        "control_flow.py",
     ],
     srcs_version = "PY2AND3",
     visibility = ["//tensorflow:__subpackages__"],
-    deps = [],
+    deps = [
+        "//tensorflow/contrib/autograph/utils",
+    ],
+)
+
+py_test(
+    name = "control_flow_test",
+    srcs = ["control_flow_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":operators",
+        "//tensorflow/python:client_testlib",
+    ],
 )
index c3f4cab..04b4734 100644 (file)
@@ -22,3 +22,8 @@ closures for the body.
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
+
+# TODO(mdan): Add a container for implementation-specific toggles (throughout).
+
+from tensorflow.contrib.autograph.operators.control_flow import for_loop
+from tensorflow.contrib.autograph.operators.control_flow import while_loop
diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py
new file mode 100644 (file)
index 0000000..5b8cb2d
--- /dev/null
@@ -0,0 +1,179 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Control flow statements: loops, conditionals, etc."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.utils import builtins
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_math_ops
+
+
+def for_loop(iterated, extra_cond, loop_body, init_state):
+  """Functional form of a for statement.
+
+  The loop operates on a so-called state, which includes all symbols that are
+  variant across loop iterations, excluding the iterate. In what follows we
+  refer to state as either a tuple of entities that represent an actual state,
+  or a list of arguments of the corresponding types.
+
+  Args:
+    iterated: The entity being iterated over.
+    extra_cond: Callable with the state as arguments, and boolean return type.
+        An additionnal loop condition.
+    loop_body: Callable with the iterate and the state as arguments, and
+        state as return type. The actual loop body.
+    init_state: Tuple containing the initial state.
+
+  Returns:
+    Tuple containing the final state.
+  """
+  if tensor_util.is_tensor(iterated):
+    return _known_len_for_loop(iterated, extra_cond, loop_body, init_state)
+  elif isinstance(iterated, dataset_ops.Dataset):
+    return _dataset_for_loop(iterated, extra_cond, loop_body, init_state)
+  else:
+    return _py_for_loop(iterated, extra_cond, loop_body, init_state)
+
+
+def _py_for_loop(iterated, extra_cond, loop_body, init_state):
+  """Overload of for_loop that executes a Python for loop."""
+  state = init_state
+  for iterate in iterated:
+    if not extra_cond(*state):
+      break
+    state = loop_body(iterate, *state)
+
+  # TODO(mdan): Remove this special case.
+  if len(state) == 1:
+    return state[0]
+  return state
+
+
+def _known_len_for_loop(iterated, extra_cond, loop_body, init_state):
+  """Overload of for_loop that iterates over objects that define a length."""
+  n = builtins.dynamic_len(iterated)
+
+  def while_body(iterate_index, *state):
+    iterate = iterated[iterate_index]
+    new_state = loop_body(iterate, *state)
+    return (iterate_index + 1,) + new_state
+
+  def while_cond(iterate_index, *state):
+    return gen_math_ops.logical_and(iterate_index < n, extra_cond(*state))
+
+  results = while_loop(
+      while_cond,
+      while_body,
+      init_state=(0,) + init_state,
+      extra_deps=(iterated,))
+  # Dropping the iteration index because it's not syntactically visible.
+  results = results[1:]
+
+  # TODO(mdan): Remove this special case.
+  if len(results) == 1:
+    return results[0]
+  return results
+
+
+def _dataset_for_loop(ds, extra_cond, loop_body, init_state):
+  """Overload of for_loop that iterates over TF Datasets."""
+  # Because Datsets only expose get_next, in the style of Python iterators,
+  # we are forced to unpack the loop as:
+  #
+  # epoch_number, iterate = ds.get_next()
+  # while epoch_number < 2:
+  #   <body>
+  #   epoch_number, iterate = ds.get_next()
+  epoch_numbers = dataset_ops.Dataset.range(2)
+  def tag_with(ds, tag):
+    return dataset_ops.Dataset.zip(
+        (dataset_ops.Dataset.from_tensors(tag).repeat(), ds))
+  ds_with_epoch = epoch_numbers.flat_map(lambda i: tag_with(ds, i))
+
+  iterator = ds_with_epoch.make_initializable_iterator()
+  with ops.control_dependencies((iterator.initializer,)):
+    epoch_number, iterate = iterator.get_next()
+
+    def while_body(epoch_number, iterate, *state):
+      new_state = loop_body(iterate, *state)
+      epoch_number, iterate = iterator.get_next()
+      return (epoch_number, iterate) + new_state
+
+    def while_cond(epoch_number, iterate, *state):
+      del iterate
+      return gen_math_ops.logical_and(epoch_number < 1, extra_cond(*state))
+
+    results = while_loop(
+        while_cond,
+        while_body,
+        init_state=(epoch_number, iterate) + init_state,
+        extra_deps=())
+  # Dropping the epoch number and iterate because they are not not syntactically
+  # visible.
+  results = results[2:]
+
+  # TODO(mdan): Remove this special case.
+  if len(results) == 1:
+    return results[0]
+  return results
+
+
+def while_loop(loop_cond, loop_body, init_state, extra_deps):
+  """Functional form of a while statement.
+
+  The loop operates on a so-called state, which includes all symbols that are
+  variant across loop iterations. In what follows we refer to state as either
+  a tuple of entities that represent an actual state, or a list of arguments
+  of the corresponding types.
+
+  Args:
+    loop_cond: Callable with the state as arguments, and boolean return type.
+        The loop condition.
+    loop_body: Callable with the state as arguments, and state as return type.
+        The actual loop body.
+    init_state: Tuple containing the initial state.
+    extra_deps: Tuple containing additional entities on which the loop may
+        depend, such as loop invariants referenced by loop_cond. Used
+        exclusively for dispatch control.
+
+  Returns:
+    Tuple containing the final state.
+  """
+  # TODO(mdan): Consider adding a generic mechanism for dynamic dispatch.
+  # That could be somethins as simple as a collection of dispatch rules, with
+  # some prioritization.
+  if any(tensor_util.is_tensor(v) for v in init_state + extra_deps):
+    return _tf_while_loop(loop_cond, loop_body, init_state)
+  else:
+    return _py_while_loop(loop_cond, loop_body, init_state)
+
+
+def _tf_while_loop(loop_cond, loop_body, init_state):
+  """Overload of while_loop that stages a TF while_loop."""
+  return control_flow_ops.while_loop(loop_cond, loop_body, init_state)
+
+
+def _py_while_loop(loop_cond, loop_body, init_state):
+  """Overload of while_loop that executes a Python while loop."""
+  state = init_state
+  while loop_cond(*state):
+    state = loop_body(*state)
+  return state
diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/contrib/autograph/operators/control_flow_test.py
new file mode 100644 (file)
index 0000000..9112b16
--- /dev/null
@@ -0,0 +1,82 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for control_flow module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph import operators
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ForLoopTest(test.TestCase):
+
+  def test_tensor(self):
+    s = operators.for_loop(
+        constant_op.constant([1, 2, 3, 4]),
+        extra_cond=lambda s: True,
+        loop_body=lambda i, s: (s + i,),
+        init_state=(0,))
+    with self.test_session() as sess:
+      self.assertEqual((10,), sess.run(s))
+
+  def test_python(self):
+    s = operators.for_loop(
+        range(5),
+        extra_cond=lambda s: True,
+        loop_body=lambda i, s: (s + i,),
+        init_state=(0,))
+    self.assertEqual(10, s)
+
+  def test_dataset(self):
+    to_int32 = lambda i: math_ops.cast(i, dtypes.int32)
+    s = operators.for_loop(
+        dataset_ops.Dataset.range(5).map(to_int32),
+        extra_cond=lambda s: True,
+        loop_body=lambda i, s: (s + i,),
+        init_state=(0,))
+    with self.test_session() as sess:
+      self.assertEqual((10,), sess.run(s))
+
+
+class WhileLoopTest(test.TestCase):
+
+  def test_tensor(self):
+    n = constant_op.constant(5)
+    results = operators.while_loop(
+        loop_cond=lambda i, s: i < n,
+        loop_body=lambda i, s: (i + 1, s + i,),
+        init_state=(0, 0),
+        extra_deps=(n,))
+    with self.test_session() as sess:
+      self.assertEqual((5, 10), sess.run(results))
+
+  def test_python(self):
+    n = 5
+    results = operators.while_loop(
+        loop_cond=lambda i, s: i < n,
+        loop_body=lambda i, s: (i + 1, s + i),
+        init_state=(0, 0),
+        extra_deps=(n,))
+    self.assertEqual((5, 10), results)
+
+
+if __name__ == '__main__':
+  test.main()