Swap in the new implementation of while and for loops.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 6 Apr 2018 03:14:42 +0000 (20:14 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 03:16:56 +0000 (20:16 -0700)
PiperOrigin-RevId: 191838806

13 files changed:
tensorflow/contrib/autograph/converters/BUILD
tensorflow/contrib/autograph/converters/break_statements.py
tensorflow/contrib/autograph/converters/control_flow.py
tensorflow/contrib/autograph/converters/control_flow_test.py
tensorflow/contrib/autograph/converters/converter_test_base.py
tensorflow/contrib/autograph/converters/for_loops.py [deleted file]
tensorflow/contrib/autograph/converters/for_loops_test.py [deleted file]
tensorflow/contrib/autograph/impl/api_test.py
tensorflow/contrib/autograph/impl/conversion.py
tensorflow/contrib/autograph/utils/__init__.py
tensorflow/contrib/autograph/utils/builtins.py
tensorflow/contrib/autograph/utils/multiple_dispatch.py
tensorflow/contrib/autograph/utils/multiple_dispatch_test.py

index 92cca30..8f9bffa 100644 (file)
@@ -24,7 +24,6 @@ py_library(
         "continue_statements.py",
         "control_flow.py",
         "decorators.py",
-        "for_loops.py",
         "ifexp.py",
         "list_comprehension.py",
         "lists.py",
@@ -134,16 +133,6 @@ py_test(
 )
 
 py_test(
-    name = "for_loops_test",
-    srcs = ["for_loops_test.py"],
-    srcs_version = "PY2AND3",
-    deps = [
-        ":test_lib",
-        "//tensorflow/python:client_testlib",
-    ],
-)
-
-py_test(
     name = "name_scopes_test",
     srcs = ["name_scopes_test.py"],
     deps = [
index 48026bc..62115d4 100644 (file)
@@ -32,6 +32,7 @@ class BreakCanonicalizationTransformer(transformer.Base):
   def __init__(self, context):
     super(BreakCanonicalizationTransformer, self).__init__(context)
     # This is a stack structure, to correctly process nested loops.
+    # Each item is a list [break_used, break_variable_name]
     self.break_uses = []
 
   def _create_break_check(self):
@@ -99,9 +100,9 @@ class BreakCanonicalizationTransformer(transformer.Base):
     self.break_uses.append([False, break_var])
     node.body = self._manual_visit_list(node.body)
     if self.break_uses[-1][0]:
-      anno.setanno(node, 'extra_cond',
-                   gast.UnaryOp(gast.Not(),
-                                gast.Name(break_var, gast.Load(), None)))
+      extra_cond = templates.replace_as_expression(
+          'not var_name', var_name=break_var)
+      anno.setanno(node, 'extra_cond', extra_cond)
       final_nodes = [self._create_break_init(), node]
     else:
       final_nodes = node
index 49d9320..55a28e8 100644 (file)
@@ -22,6 +22,7 @@ import gast
 
 from tensorflow.contrib.autograph.pyct import anno
 from tensorflow.contrib.autograph.pyct import ast_util
+from tensorflow.contrib.autograph.pyct import parser
 from tensorflow.contrib.autograph.pyct import templates
 from tensorflow.contrib.autograph.pyct import transformer
 from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
@@ -49,11 +50,6 @@ class ControlFlowTransformer(transformer.Base):
   def __init__(self, context):
     super(ControlFlowTransformer, self).__init__(context)
 
-  # pylint:disable=invalid-name
-
-  def visit_For(self, node):
-    assert False, 'for statement should have been canonicalized at this point'
-
   def _create_cond_branch(self, body_name, aliased_orig_names,
                           aliased_new_names, body, returns):
     if aliased_orig_names:
@@ -170,6 +166,13 @@ class ControlFlowTransformer(transformer.Base):
     body_closure = body_scope.modified - body_scope.created
     all_referenced = body_scope.referenced
 
+    cond_scope = anno.getanno(node, NodeAnno.COND_SCOPE)
+    cond_closure = set()
+    for s in cond_scope.referenced:
+      for root in s.support_set:
+        if root not in body_scope.created:
+          cond_closure.add(root)
+
     state = list(body_closure)
     if not state:
       # TODO(mdan): Implement this properly.
@@ -204,7 +207,8 @@ class ControlFlowTransformer(transformer.Base):
       def body_name(state_ssf):
         body
         return state_ssf,
-      state_ast_tuple = autograph_utils.run_while(test_name, body_name, [state])
+      state_ast_tuple = __ops.while_loop(
+          test_name, body_name, (state,), (extra_deps,))
     """
     node = templates.replace(
         template,
@@ -216,11 +220,67 @@ class ControlFlowTransformer(transformer.Base):
         test=test,
         body_name=self.context.namer.new_symbol('loop_body',
                                                 body_scope.referenced),
-        body=node_body)
+        body=node_body,
+        extra_deps=tuple(s.ast() for s in cond_closure),
+    )
 
     return node
 
-  # pylint:enable=invalid-name
+  def visit_For(self, node):
+    self.generic_visit(node)
+
+    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
+    body_closure = body_scope.modified - body_scope.created
+    all_referenced = body_scope.referenced
+
+    state = list(body_closure)
+
+    state_ssf = [
+        self.context.namer.new_symbol(s.ssf(), all_referenced) for s in state
+    ]
+    ssf_map = {
+        name: ssf
+        for name, ssf in zip(state, state_ssf)
+        if str(name) != ssf
+    }
+
+    if len(state) == 1:
+      state = state[0]
+      state_ssf = state_ssf[0]
+      state_ast_tuple = state
+    else:
+      state_ast_tuple = gast.Tuple([n.ast() for n in state], None)
+
+    node_body = ast_util.rename_symbols(node.body, ssf_map)
+    if anno.hasanno(node, 'extra_cond'):
+      extra_cond = anno.getanno(node, 'extra_cond')
+      extra_cond = ast_util.rename_symbols(extra_cond, ssf_map)
+    else:
+      extra_cond = parser.parse_expression('True')
+
+    template = """
+      def extra_cond_name(state_ssf):
+        return extra_cond_expr
+      def body_name(iterate, state_ssf):
+        body
+        return state_ssf,
+      state_ast_tuple = __ops.for_loop(
+          iterated, extra_cond_name, body_name, (state,))
+    """
+    node = templates.replace(
+        template,
+        state=state,
+        state_ssf=state_ssf,
+        state_ast_tuple=state_ast_tuple,
+        iterated=node.iter,
+        iterate=node.target,
+        extra_cond_name=self.context.namer.new_symbol('extra_cond',
+                                                      all_referenced),
+        extra_cond_expr=extra_cond,
+        body_name=self.context.namer.new_symbol('loop_body', all_referenced),
+        body=node_body)
+
+    return node
 
 
 def transform(node, context):
index 86fed51..c5610b1 100644 (file)
@@ -21,6 +21,7 @@ from __future__ import print_function
 from tensorflow.contrib.autograph.converters import control_flow
 from tensorflow.contrib.autograph.converters import converter_test_base
 from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.platform import test
 
@@ -94,6 +95,77 @@ class ControlFlowTest(converter_test_base.TestCase):
       with self.test_session() as sess:
         self.assertEqual(-1, sess.run(result.test_fn(constant_op.constant(1))))
 
+  def test_simple_for(self):
+
+    def test_fn(l):
+      s1 = 0
+      s2 = 0
+      for e in l:
+        s1 += e
+        s2 += e * e
+      return s1, s2
+
+    node = self.parse_and_analyze(test_fn, {})
+    node = control_flow.transform(node, self.ctx)
+
+    with self.compiled(node) as result:
+      with self.test_session() as sess:
+        l = [1, 2, 3]
+        self.assertEqual(
+            test_fn(l), sess.run(result.test_fn(constant_op.constant(l))))
+        l = []
+        self.assertEqual(
+            test_fn(l),
+            sess.run(
+                result.test_fn(
+                    constant_op.constant(l, shape=(0,), dtype=dtypes.int32))))
+
+  def test_for_single_var(self):
+
+    def test_fn(l):
+      s = 0
+      for e in l:
+        s += e
+      return s
+
+    node = self.parse_and_analyze(test_fn, {})
+    node = control_flow.transform(node, self.ctx)
+
+    with self.compiled(node) as result:
+      with self.test_session() as sess:
+        l = [1, 2, 3]
+        self.assertEqual(
+            test_fn(l), sess.run(result.test_fn(constant_op.constant(l))))
+        l = []
+        self.assertEqual(
+            test_fn(l),
+            sess.run(
+                result.test_fn(
+                    constant_op.constant(l, shape=(0,), dtype=dtypes.int32))))
+
+  def test_for_with_iterated_expression(self):
+
+    eval_count = [0]
+
+    def count_evals(x):
+      eval_count[0] += 1
+      return x
+
+    def test_fn(n):
+      s = 0
+      for e in count_evals(range(n)):
+        s += e
+      return s
+
+    node = self.parse_and_analyze(test_fn, {'count_evals': count_evals})
+    node = control_flow.transform(node, self.ctx)
+
+    with self.compiled(node) as result:
+      result.count_evals = count_evals
+      self.assertEqual(test_fn(5), result.test_fn(5))
+      # count_evals ran twice, once for test_fn and another for result.test_fn
+      self.assertEqual(eval_count[0], 2)
+
 
 if __name__ == '__main__':
   test.main()
index 984e72c..6f75e9a 100644 (file)
@@ -78,7 +78,7 @@ class TestCase(test.TestCase):
       result.tf = self.make_fake_mod('fake_tf', *symbols)
       result.autograph_utils = utils
       result.autograph_api = self.make_fake_mod('fake_api', converted_call)
-      result.__ops = operators  # pylint:disable=protected-access
+      result.__dict__['__ops'] = operators
       yield result
     except Exception:  # pylint:disable=broad-except
       if source is None:
diff --git a/tensorflow/contrib/autograph/converters/for_loops.py b/tensorflow/contrib/autograph/converters/for_loops.py
deleted file mode 100644 (file)
index 4999c47..0000000
+++ /dev/null
@@ -1,92 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Canonicalizes for loops into while loops.
-
-This canonicalizer uses the len function on its argument. That should be
-converted to a tf.shape separately.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.autograph.pyct import anno
-from tensorflow.contrib.autograph.pyct import templates
-from tensorflow.contrib.autograph.pyct import transformer
-from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
-
-
-class ForLoopCanonicalizationTransformer(transformer.Base):
-  """Canonicalizes for loops (e.g. into while loops)."""
-
-  def __init__(self, context):
-    super(ForLoopCanonicalizationTransformer, self).__init__(context)
-
-  def visit_For(self, node):
-    self.generic_visit(node)
-    body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
-    i_var = self.context.namer.new_symbol('i', body_scope.referenced)
-    smart_loop_iter_var = self.context.namer.new_symbol('smart_loop_iter',
-                                                        body_scope.referenced)
-    cont_var = self.context.namer.new_symbol('cont', body_scope.referenced)
-    # TODO(mdan): Use TensorListFromTensor(loop_iter) here.
-    if anno.hasanno(node, 'extra_cond'):
-      template = """
-        i = 0
-        smart_loop_iter = autograph_utils.dynamic_dataset(loop_iter)
-        cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter)
-        while cont and extra_cond:
-          body
-          i += 1
-          cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter)
-      """
-      return templates.replace(
-          template,
-          loop_iter=node.iter,
-          target=node.target,
-          body=node.body,
-          i=i_var,
-          smart_loop_iter=smart_loop_iter_var,
-          cont=cont_var,
-          extra_cond=anno.getanno(node, 'extra_cond'))
-    else:
-      template = """
-        i = 0
-        smart_loop_iter = autograph_utils.dynamic_dataset(loop_iter)
-        cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter)
-        while cont:
-          body
-          i += 1
-          cont, target = autograph_utils.dynamic_for_cond(i, smart_loop_iter)
-      """
-      repl = templates.replace(
-          template,
-          loop_iter=node.iter,
-          target=node.target,
-          body=node.body,
-          i=i_var,
-          smart_loop_iter=smart_loop_iter_var,
-          cont=cont_var)
-      return repl
-
-  def visit_Continue(self, node):
-    assert False, 'continue statement should be desugared at this point'
-
-  def visit_Break(self, node):
-    assert False, 'break statement should be desugared at this point'
-
-
-def transform(node, context):
-  return ForLoopCanonicalizationTransformer(context).visit(node)
diff --git a/tensorflow/contrib/autograph/converters/for_loops_test.py b/tensorflow/contrib/autograph/converters/for_loops_test.py
deleted file mode 100644 (file)
index 943f52d..0000000
+++ /dev/null
@@ -1,70 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for for_loops module."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.autograph.converters import converter_test_base
-from tensorflow.contrib.autograph.converters import for_loops
-from tensorflow.python.platform import test
-
-
-class ControlFlowTest(converter_test_base.TestCase):
-
-  def test_basic_for(self):
-
-    def test_fn(l):
-      s = 0
-      for e in l:
-        s += e
-      return s
-
-    node = self.parse_and_analyze(test_fn, {})
-    node = for_loops.transform(node, self.ctx)
-
-    with self.compiled(node) as result:
-      l = [1, 2, 3]
-      self.assertEqual(test_fn(l), result.test_fn(l))
-      l = []
-      self.assertEqual(test_fn(l), result.test_fn(l))
-
-  def test_for_with_iterated_expression(self):
-
-    eval_count = [0]
-
-    def count_evals(x):
-      eval_count[0] += 1
-      return x
-
-    def test_fn(n):
-      s = 0
-      for e in count_evals(range(n)):
-        s += e
-      return s
-
-    node = self.parse_and_analyze(test_fn, {'count_evals': count_evals})
-    node = for_loops.transform(node, self.ctx)
-
-    with self.compiled(node) as result:
-      result.count_evals = count_evals
-      self.assertEqual(test_fn(5), result.test_fn(5))
-      # count_evals ran twice, once for test_fn and another for result.test_fn
-      self.assertEqual(eval_count[0], 2)
-
-
-if __name__ == '__main__':
-  test.main()
index f156a87..f9db077 100644 (file)
@@ -201,8 +201,7 @@ class ApiTest(test.TestCase):
 
     compiled_code = api.to_code(test_fn)
 
-    # Just check for some key words and that it is parseable Python code.
-    self.assertRegexpMatches(compiled_code, 'autograph_utils\\.run_while')
+    # Just check that it is parseable Python code.
     self.assertIsNotNone(parser.parse_str(compiled_code))
 
 
index 62a49cd..3bacc94 100644 (file)
@@ -28,7 +28,6 @@ from tensorflow.contrib.autograph.converters import call_trees
 from tensorflow.contrib.autograph.converters import continue_statements
 from tensorflow.contrib.autograph.converters import control_flow
 from tensorflow.contrib.autograph.converters import decorators
-from tensorflow.contrib.autograph.converters import for_loops
 from tensorflow.contrib.autograph.converters import ifexp
 from tensorflow.contrib.autograph.converters import lists
 from tensorflow.contrib.autograph.converters import logical_expressions
@@ -324,8 +323,6 @@ def node_to_graph(node, ctx, nocompile_decorators):
 
   node = _static_analysis_pass(node, ctx)
   node = lists.transform(node, ctx)
-  node = for_loops.transform(node, ctx)
-  # for_loops may insert new global references.
   node = builtin_functions.transform(node, ctx)
 
   node = _static_analysis_pass(node, ctx)
index 22898b1..817d412 100644 (file)
@@ -19,8 +19,6 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.contrib.autograph.utils.builtins import dynamic_builtin
-from tensorflow.contrib.autograph.utils.builtins import dynamic_dataset
-from tensorflow.contrib.autograph.utils.builtins import dynamic_for_cond
 from tensorflow.contrib.autograph.utils.builtins import dynamic_print
 from tensorflow.contrib.autograph.utils.builtins import dynamic_range
 from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns
@@ -28,7 +26,6 @@ from tensorflow.contrib.autograph.utils.misc import alias_tensors
 from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is
 from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is_not
 from tensorflow.contrib.autograph.utils.multiple_dispatch import run_cond
-from tensorflow.contrib.autograph.utils.multiple_dispatch import run_while
 from tensorflow.contrib.autograph.utils.py_func import wrap_py_func
 from tensorflow.contrib.autograph.utils.tensor_list import dynamic_list_append
 from tensorflow.contrib.autograph.utils.testing import fake_tf
index c6af0e4..7fbb7c0 100644 (file)
@@ -24,10 +24,8 @@ import six
 
 from tensorflow.contrib.autograph.utils import py_func
 from tensorflow.contrib.autograph.utils import type_check
-from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import logging_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.util import tf_inspect
@@ -106,69 +104,3 @@ def dynamic_print(*values):
 
   return py_func.wrap_py_func(
       flushed_print, None, values, use_dummy_return=True)
-
-
-def dynamic_dataset(iterated):
-  """Implementartion of smart tf.data.Dataset epoch wrapping.
-
-  The function checks if the input is a tf.data.Dataset and if so then wraps it
-  so that for each element it returns it also returns the current epoch the
-  dataset iteration is in, for two epochs.  If the input is not a
-  tf.data.Dataset then it just returns the input.
-
-  Args:
-    iterated: The iterable or tf.data.Dataset that is being iterated over.
-  Returns:
-    Either just the untouched input, or in the case of input being a
-    tf.data.Dataset then it returns a wrapped  tf.data.Dataset where for each
-    element it returns it also returns the current epoch the dataset iteration
-    is in.
-  """
-  if not isinstance(iterated, dataset_ops.Dataset):
-    return iterated
-
-  def epoch_dataset_number_helper(i):
-    return dataset_ops.Dataset.zip(
-        (dataset_ops.Dataset.from_tensors(i).repeat(), iterated))
-
-  epoch_numbers = dataset_ops.Dataset.range(2)
-  return epoch_numbers.flat_map(epoch_dataset_number_helper)
-
-
-def dynamic_for_cond(iteration, iterated):
-  """Implementartion of smart while-loop condition using dynamic dispatch.
-
-  The function checks if it is iterating over a tf.data.Dataset or not, and in
-  the case it is not then it simply returns if we are still in range of the
-  iterated and the next element.  If it is iterating over a dataset then it only
-  iterates for a single epoch.
-
-  Args:
-    iteration: The current iteration of the loop.
-    iterated: The iterable or tf.data.Dataset that is being iterated over.
-  Returns:
-    A tuple of a bool that indicates whether the loop should continue, and the
-    next element in iterated.
-  """
-  # TODO(znado): Clean up.
-  # TODO(znado): This won't work for unpacked iterates. Fix.
-  if isinstance(iterated, dataset_ops.Dataset):
-    curr_epoch, next_elem = iterated.make_one_shot_iterator().get_next()
-    return math_ops.less(curr_epoch, 1), next_elem
-  elif tensor_util.is_tensor(iterated):
-    if iterated.shape.ndims > 1:
-      elem_shape = array_ops.shape(iterated)[1:]
-    else:
-      elem_shape = ()
-    if iterated.shape.ndims == 0 or iterated.shape[0] == 0:
-      return False, array_ops.zeros(elem_shape, iterated.dtype)
-    return control_flow_ops.cond(
-        math_ops.less(iteration, dynamic_len(iterated)),
-        lambda: (True, iterated[iteration]),
-        lambda: (False, array_ops.zeros(elem_shape, iterated.dtype)))
-  elif hasattr(iterated, '__len__'):
-    if iteration < len(iterated):
-      return True, iterated[iteration]
-    return False, None
-  else:
-    raise NotImplementedError('Python iterators not yet supported.')
index 4704925..70eef56 100644 (file)
@@ -18,8 +18,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import six
-
 from tensorflow.contrib.autograph.utils.type_check import is_tensor
 from tensorflow.python.ops import control_flow_ops
 
@@ -66,42 +64,3 @@ def py_cond(condition, true_fn, false_fn):
   if len(results) == 1:
     return results[0]
   return results
-
-
-def run_while(cond_fn, body_fn, init_args):
-  """Type-dependent functional while loop.
-
-  Args:
-    cond_fn: A Python callable implementing the stop conditions of the loop.
-    body_fn: A Python callable implementing the body of the loop.
-    init_args: The initial values of the arguments that will be passed to both
-      cond_fn and body_fn.
-
-  Returns:
-    result: A list of values with the same shape and type as init_args. If any
-    of the init_args, or any variables closed-over in cond_fn are Tensors,
-    tf.while_loop will be used, otherwise a Python while loop will be ran.
-
-  Raises:
-    ValueError: if init_args is not a tuple or list with one or more elements.
-  """
-  if not isinstance(init_args, (tuple, list)) or not init_args:
-    raise ValueError(
-        'init_args must be a non-empty list or tuple, found %s' % init_args)
-
-  # TODO(alexbw): statically determine all active variables in cond_fn,
-  # and pass them directly
-  closure_vars = tuple(
-      [c.cell_contents for c in six.get_function_closure(cond_fn) or []])
-  possibly_tensors = tuple(init_args) + closure_vars
-  if is_tensor(*possibly_tensors):
-    return control_flow_ops.while_loop(cond_fn, body_fn, init_args)
-  else:
-    return py_while_loop(cond_fn, body_fn, init_args)
-
-
-def py_while_loop(cond_fn, body_fn, init_args):
-  state = init_args
-  while cond_fn(*state):
-    state = body_fn(*state)
-  return state
index e6a41bb..f72f8e9 100644 (file)
@@ -70,29 +70,6 @@ class MultipleDispatchTest(test.TestCase):
       out = multiple_dispatch.run_cond(constant(False), true_fn, false_fn)
       self.assertEqual(sess.run(out), 3)
 
-  def test_run_while_python(self):
-    cond_fn = lambda x, t, s: x > t
-    body_fn = lambda x, t, s: (x * s, t, s)
-
-    x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 1.0, 0.5])
-    self.assertEqual(x, 0.75)
-
-    x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn, [3.0, 4.0, 0.5])
-    self.assertEqual(x, 3.0)
-
-  def test_run_while_tf(self):
-    cond_fn = lambda x, t, s: x > t
-    body_fn = lambda x, t, s: (x * s, t, s)
-
-    with Session() as sess:
-      x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn,
-                                            [constant(3.0), 1.0, 0.5])
-      self.assertEqual(sess.run(x), 0.75)
-
-      x, _, _ = multiple_dispatch.run_while(cond_fn, body_fn,
-                                            [constant(3.0), 4.0, 0.5])
-      self.assertEqual(sess.run(x), 3.0)
-
 
 if __name__ == '__main__':
   test.main()