Extend tensor_list with basic support for appending to TensorArrays. This allows...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 6 Mar 2018 23:50:13 +0000 (15:50 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Mar 2018 23:54:15 +0000 (15:54 -0800)
PiperOrigin-RevId: 188094077

tensorflow/contrib/py2tf/utils/tensor_list.py
tensorflow/contrib/py2tf/utils/tensor_list_test.py

index b6ff49e..2556f41 100644 (file)
@@ -18,7 +18,26 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.python.framework import ops
 from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
+
+
+def dynamic_list_append(target, element):
+  """Converts a list append call inline."""
+  if isinstance(target, tensor_array_ops.TensorArray):
+    return target.write(target.size(), element)
+  # TODO(mdan): What's the right way to check this?
+  # TODO(mdan): We may not need this branch.
+  # It may be possible to use TensorList alone if the loop body will not
+  # require wrapping it, although we'd have to think about an autoboxing
+  # mechanism for lists received as parameter.
+  if isinstance(target, ops.Tensor):
+    return list_ops.tensor_list_push_back(target, element)
+
+  # Python targets (including TensorList): fallback to their original append.
+  target.append(element)
+  return target
 
 
 class TensorList(object):
index b5e554a..110e4d1 100644 (file)
@@ -21,13 +21,41 @@ from __future__ import print_function
 from tensorflow.contrib.py2tf.utils import tensor_list as tl
 from tensorflow.python.client.session import Session
 from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework.constant_op import constant
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
 from tensorflow.python.platform import test
 
 
 class TensorListTest(test.TestCase):
 
+  def _shape(self, shape_tuple):
+    return constant(shape_tuple, dtypes.int32)
+
+  def test_dynamic_list_append(self):
+    l = []
+    l = tl.dynamic_list_append(l, 1)
+    self.assertListEqual(l, [1])
+
+    l = list_ops.empty_tensor_list(self._shape(()), dtypes.int32)
+    l = tl.dynamic_list_append(l, 1)
+    s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+    with self.test_session() as sess:
+      self.assertAllEqual(sess.run(s), [1])
+
+    l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
+    l = tl.dynamic_list_append(l, 1)
+    s = l.stack()
+    with self.test_session() as sess:
+      self.assertAllEqual(sess.run(s), [1])
+
+    l = tl.TensorList(self._shape(()), dtypes.int32)
+    l = tl.dynamic_list_append(l, 1)
+    with self.test_session() as sess:
+      self.assertAllEqual(sess.run(l[0]), 1)
+
   def test_list_append_python(self):
     with context.eager_mode():
       a = constant(3.0)