Test use of tensor_list inside loops, as well as changing the contained type of a...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 14 Mar 2018 13:51:58 +0000 (06:51 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Mar 2018 13:56:19 +0000 (06:56 -0700)
PiperOrigin-RevId: 189021593

tensorflow/python/kernel_tests/list_ops_test.py

index 8040ea3..8865e16 100644 (file)
@@ -30,7 +30,9 @@ from tensorflow.python.framework import errors
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import math_ops
 from tensorflow.python.platform import test
 from tensorflow.python.training import server_lib
 
@@ -133,6 +135,46 @@ class ListOpsTest(test_util.TensorFlowTestCase):
           list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32).eval(),
           [[1]])
 
+  def testGraphStackInLoop(self):
+    with context.graph_mode(), self.test_session():
+      t1 = list_ops.empty_tensor_list(
+          element_shape=constant_op.constant([], dtype=dtypes.int32),
+          element_dtype=dtypes.int32)
+      i = constant_op.constant(0, dtype=dtypes.int32)
+
+      def body(i, t1):
+        t1 = list_ops.tensor_list_push_back(t1, i)
+        i += 1
+        return i, t1
+
+      i, t1 = control_flow_ops.while_loop(lambda i, t1: math_ops.less(i, 4),
+                                          body, [i, t1])
+      s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32).eval()
+      self.assertAllEqual(s1, [0, 1, 2, 3])
+
+  def testGraphStackInLoopSwitchDtype(self):
+    with context.graph_mode(), self.test_session():
+      t1 = list_ops.empty_tensor_list(
+          element_shape=constant_op.constant([], dtype=dtypes.int32),
+          element_dtype=dtypes.int32)
+      i = constant_op.constant(0, dtype=dtypes.float32)
+      m = constant_op.constant([1, 2, 3], dtype=dtypes.float32)
+
+      def body(i, m, t1):
+        t1 = control_flow_ops.cond(
+            math_ops.equal(list_ops.tensor_list_length(t1), 0),
+            lambda: list_ops.empty_tensor_list(m.shape, m.dtype), lambda: t1)
+
+        t1 = list_ops.tensor_list_push_back(t1, m * i)
+        i += 1.0
+        return i, m, t1
+
+      i, m, t1 = control_flow_ops.while_loop(
+          lambda i, m, t1: math_ops.less(i, 4), body, [i, m, t1])
+      s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.float32).eval()
+      np_s1 = np.vstack([np.arange(1, 4) * i for i in range(4)])
+      self.assertAllEqual(s1, np_s1)
+
   def testSerialize(self):
     # pylint: disable=g-import-not-at-top
     try: