from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32).eval(),
[[1]])
+ def testGraphStackInLoop(self):
+ with context.graph_mode(), self.test_session():
+ t1 = list_ops.empty_tensor_list(
+ element_shape=constant_op.constant([], dtype=dtypes.int32),
+ element_dtype=dtypes.int32)
+ i = constant_op.constant(0, dtype=dtypes.int32)
+
+ def body(i, t1):
+ t1 = list_ops.tensor_list_push_back(t1, i)
+ i += 1
+ return i, t1
+
+ i, t1 = control_flow_ops.while_loop(lambda i, t1: math_ops.less(i, 4),
+ body, [i, t1])
+ s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32).eval()
+ self.assertAllEqual(s1, [0, 1, 2, 3])
+
+ def testGraphStackInLoopSwitchDtype(self):
+ with context.graph_mode(), self.test_session():
+ t1 = list_ops.empty_tensor_list(
+ element_shape=constant_op.constant([], dtype=dtypes.int32),
+ element_dtype=dtypes.int32)
+ i = constant_op.constant(0, dtype=dtypes.float32)
+ m = constant_op.constant([1, 2, 3], dtype=dtypes.float32)
+
+ def body(i, m, t1):
+ t1 = control_flow_ops.cond(
+ math_ops.equal(list_ops.tensor_list_length(t1), 0),
+ lambda: list_ops.empty_tensor_list(m.shape, m.dtype), lambda: t1)
+
+ t1 = list_ops.tensor_list_push_back(t1, m * i)
+ i += 1.0
+ return i, m, t1
+
+ i, m, t1 = control_flow_ops.while_loop(
+ lambda i, m, t1: math_ops.less(i, 4), body, [i, m, t1])
+ s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.float32).eval()
+ np_s1 = np.vstack([np.arange(1, 4) * i for i in range(4)])
+ self.assertAllEqual(s1, np_s1)
+
def testSerialize(self):
# pylint: disable=g-import-not-at-top
try: