from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import ops
from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
+
+
+def dynamic_list_append(target, element):
+ """Converts a list append call inline."""
+ if isinstance(target, tensor_array_ops.TensorArray):
+ return target.write(target.size(), element)
+ # TODO(mdan): What's the right way to check this?
+ # TODO(mdan): We may not need this branch.
+ # It may be possible to use TensorList alone if the loop body will not
+ # require wrapping it, although we'd have to think about an autoboxing
+ # mechanism for lists received as parameter.
+ if isinstance(target, ops.Tensor):
+ return list_ops.tensor_list_push_back(target, element)
+
+ # Python targets (including TensorList): fallback to their original append.
+ target.append(element)
+ return target
class TensorList(object):
from tensorflow.contrib.py2tf.utils import tensor_list as tl
from tensorflow.python.client.session import Session
from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework.constant_op import constant
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import test
class TensorListTest(test.TestCase):
+ def _shape(self, shape_tuple):
+ return constant(shape_tuple, dtypes.int32)
+
+ def test_dynamic_list_append(self):
+ l = []
+ l = tl.dynamic_list_append(l, 1)
+ self.assertListEqual(l, [1])
+
+ l = list_ops.empty_tensor_list(self._shape(()), dtypes.int32)
+ l = tl.dynamic_list_append(l, 1)
+ s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(s), [1])
+
+ l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
+ l = tl.dynamic_list_append(l, 1)
+ s = l.stack()
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(s), [1])
+
+ l = tl.TensorList(self._shape(()), dtypes.int32)
+ l = tl.dynamic_list_append(l, 1)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(l[0]), 1)
+
def test_list_append_python(self):
with context.eager_mode():
a = constant(3.0)