Initial implementation of a few of the list-specific operators.
authorDan Moldovan <mdan@google.com>
Thu, 31 May 2018 17:33:53 +0000 (10:33 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 31 May 2018 17:36:34 +0000 (10:36 -0700)
This introduces an abstraction for a dispatch context, which allows passing local information through to the specialized operators.

PiperOrigin-RevId: 198742074

tensorflow/contrib/autograph/operators/BUILD
tensorflow/contrib/autograph/operators/__init__.py
tensorflow/contrib/autograph/operators/data_structures.py
tensorflow/contrib/autograph/operators/data_structures_test.py
tensorflow/contrib/autograph/operators/slices.py [new file with mode: 0644]
tensorflow/contrib/autograph/operators/slices_test.py [new file with mode: 0644]

index 18bfec5..0c6ab65 100644 (file)
@@ -22,7 +22,7 @@ py_library(
         "__init__.py",
         "control_flow.py",
         "data_structures.py",
-        "dispatch_context.py",
+        "slices.py",
     ],
     srcs_version = "PY2AND3",
     visibility = ["//tensorflow:__subpackages__"],
@@ -52,3 +52,13 @@ py_test(
         "//tensorflow/python:client_testlib",
     ],
 )
+
+py_test(
+    name = "slices_test",
+    srcs = ["slices_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":operators",
+        "//tensorflow/python:client_testlib",
+    ],
+)
index 38b761d..c900fd6 100644 (file)
@@ -28,6 +28,10 @@ closures for the body.
 #    - the names used in the Python docs, if the operator is a function (e.g.
 #      list_ and x for append, see
 #      https://docs.python.org/3.7/tutorial/datastructures.html)
+#
+# All operators may accept a final argument named "opts", of a type that
+# subclasses namedtuple and contains any arguments that are only required
+# for some specializations of the operator.
 
 from __future__ import absolute_import
 from __future__ import division
@@ -35,3 +39,12 @@ from __future__ import print_function
 
 from tensorflow.contrib.autograph.operators.control_flow import for_stmt
 from tensorflow.contrib.autograph.operators.control_flow import while_stmt
+from tensorflow.contrib.autograph.operators.data_structures import list_append
+from tensorflow.contrib.autograph.operators.data_structures import list_pop
+from tensorflow.contrib.autograph.operators.data_structures import list_stack
+from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts
+from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts
+from tensorflow.contrib.autograph.operators.data_structures import new_list
+from tensorflow.contrib.autograph.operators.slices import get_item
+from tensorflow.contrib.autograph.operators.slices import GetItemOpts
+from tensorflow.contrib.autograph.operators.slices import set_item
index c862306..06d8727 100644 (file)
@@ -18,39 +18,250 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import collections
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import list_ops
 from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.ops import variables
+
+
+# TODO(mdan): Once control flow supports objects, repackage as a class.
+
+
+def new_list(iterable=None):
+  """The list constructor.
+
+  Args:
+    iterable: Optional elements to fill the list with.
+
+  Returns:
+    A list-like object. The exact return value depends on the initial elements.
+  """
+  if iterable:
+    elements = tuple(iterable)
+  else:
+    elements = ()
+
+  # TODO(mdan): Extend these criteria.
+  if any(isinstance(el, variables.Variable) for el in elements):
+    return _py_list_new(elements)
+  return _tf_tensor_list_new(elements)
 
-# TODO(mdan): Add support for TensorList once functional.
-# TODO(mdan): Add primitives for empty list, list with elements.
 
+def _tf_tensor_list_new(elements):
+  """Overload of new_list that stages a Tensor list creation."""
+  elements = tuple(ops.convert_to_tensor(el) for el in elements)
+  all_dtypes = set(el.dtype for el in elements)
+  if len(all_dtypes) == 1:
+    element_dtype = tuple(all_dtypes)[0]
+  else:
+    # Heterogeneous lists are ok.
+    element_dtype = dtypes.variant
+
+  # TODO(mdan): This may fail for elements of variable shapes.
+  all_shapes = set(tuple(el.shape.as_list()) for el in elements)
+  if len(all_shapes) == 1:
+    element_shape = array_ops.shape(elements[0])
+  else:
+    # Heterogeneous lists are ok.
+    element_shape = constant_op.constant(-1)  # unknown shape, by convention
+
+  l = list_ops.empty_tensor_list(
+      element_shape=element_shape, element_dtype=element_dtype)
+  for el in elements:
+    l = list_ops.tensor_list_push_back(l, el)
+  return l
 
-def append(target, element):
+
+def _py_list_new(elements):
+  """Overload of new_list that creates a Python list."""
+  return list(elements)
+
+
+def list_append(list_, x):
   """The list append function.
 
-  Note: it is unspecified where target will be mutated or not. If target is
-  a TensorFlow entity, it will not be typically mutated. If target is a plain
-  list, it will be. In general, if the target is mutated then the return value
+  Note: it is unspecified where list_ will be mutated or not. If list_ is
+  a TensorFlow entity, it will not be typically mutated. If list_ is a plain
+  list, it will be. In general, if the list is mutated then the return value
   should point to the original entity.
 
   Args:
-    target: An entity that supports append semantics.
-    element: The element to append.
+    list_: An entity that supports append semantics.
+    x: The element to append.
 
   Returns:
-    Same as target, after the append was performed.
+    Same as list_, after the append was performed.
+
+  Raises:
+    ValueError: if list_ is not of a known list-like type.
   """
-  if isinstance(target, tensor_array_ops.TensorArray):
-    return _tf_tensorarray_append(target, element)
+  if isinstance(list_, tensor_array_ops.TensorArray):
+    return _tf_tensorarray_append(list_, x)
+  elif tensor_util.is_tensor(list_):
+    if list_.dtype == dtypes.variant:
+      return _tf_tensor_list_append(list_, x)
+    else:
+      raise ValueError(
+          'tensor lists are expected to be Tensors with dtype=tf.variant,'
+          ' instead found %s' % list_)
   else:
-    return _py_append(target, element)
+    return _py_list_append(list_, x)
+
+
+def _tf_tensor_list_append(list_, x):
+  """Overload of list_append that stages a Tensor list write."""
+  def empty_list_of_elements_like_x():
+    tensor_x = ops.convert_to_tensor(x)
+    return list_ops.empty_tensor_list(
+        element_shape=array_ops.shape(tensor_x),
+        element_dtype=tensor_x.dtype)
+
+  list_ = control_flow_ops.cond(
+      list_ops.tensor_list_length(list_) > 0,
+      lambda: list_,
+      empty_list_of_elements_like_x,
+  )
+  return list_ops.tensor_list_push_back(list_, x)
+
+
+def _tf_tensorarray_append(list_, x):
+  """Overload of list_append that stages a TensorArray write."""
+  return list_.write(list_.size(), x)
+
+
+def _py_list_append(list_, x):
+  """Overload of list_append that executes a Python list append."""
+  # Revert to the original call.
+  list_.append(x)
+  return list_
+
+
+class ListPopOpts(
+    collections.namedtuple('ListPopOpts', ('element_dtype', 'element_shape'))):
+  pass
+
+
+def list_pop(list_, i, opts):
+  """The list pop function.
+
+  Note: it is unspecified where list_ will be mutated or not. If list_ is
+  a TensorFlow entity, it will not be typically mutated. If list_ is a plain
+  list, it will be. In general, if the list is mutated then the return value
+  should point to the original entity.
+
+  Args:
+    list_: An entity that supports pop semantics.
+    i: Optional index to pop from. May be None.
+    opts: A ListPopOpts.
+
+  Returns:
+    Tuple (x, out_list_):
+      out_list_: same as list_, after the removal was performed.
+      x: the removed element value.
+
+  Raises:
+    ValueError: if list_ is not of a known list-like type or the operation is
+    not supported for that type.
+  """
+  assert isinstance(opts, ListPopOpts)
+
+  if isinstance(list_, tensor_array_ops.TensorArray):
+    raise ValueError('TensorArray does not support item removal')
+  elif tensor_util.is_tensor(list_):
+    if list_.dtype == dtypes.variant:
+      return _tf_tensor_list_pop(list_, i, opts)
+    else:
+      raise ValueError(
+          'tensor lists are expected to be Tensors with dtype=tf.variant,'
+          ' instead found %s' % list_)
+  else:
+    return _py_list_pop(list_, i)
+
+
+def _tf_tensor_list_pop(list_, i, opts):
+  """Overload of list_pop that stages a Tensor list pop."""
+  if i is not None:
+    raise NotImplementedError('tensor lists only support removing from the end')
+
+  if opts.element_dtype is None:
+    raise ValueError('cannot pop from a list without knowing its element '
+                     'type; use set_element_type to annotate it')
+  if opts.element_shape is None:
+    raise ValueError('cannot pop from a list without knowing its element '
+                     'shape; use set_element_type to annotate it')
+  list_out, x = list_ops.tensor_list_pop_back(
+      list_, element_dtype=opts.element_dtype)
+  x.set_shape(opts.element_shape)
+  return list_out, x
+
+
+def _py_list_pop(list_, i):
+  """Overload of list_pop that executes a Python list append."""
+  if i is None:
+    x = list_.pop()
+  else:
+    x = list_.pop(i)
+  return list_, x
+
+
+# TODO(mdan): Look into reducing duplication between all these containers.
+class ListStackOpts(
+    collections.namedtuple('ListStackOpts',
+                           ('element_dtype', 'original_call'))):
+  pass
+
+
+def list_stack(list_, opts):
+  """The list stack function.
+
+  This does not have a direct correspondent in Python. The closest idiom to
+  this is tf.append or np.stack. It's different from those in the sense that it
+  accepts a Tensor list, rather than a list of tensors. It can also accept
+  TensorArray. When the target is anything else, the dispatcher will rely on
+  ctx.original_call for fallback.
+
+  Args:
+    list_: An entity that supports append semantics.
+    opts: A ListStackOpts object.
+
+  Returns:
+    The output of the stack operation, typically a Tensor.
+  """
+  assert isinstance(opts, ListStackOpts)
+
+  if isinstance(list_, tensor_array_ops.TensorArray):
+    return _tf_tensorarray_stack(list_)
+  elif tensor_util.is_tensor(list_):
+    if list_.dtype == dtypes.variant:
+      return _tf_tensor_list_stack(list_, opts)
+    else:
+      # No-op for primitive Tensor arguments.
+      return list_
+  else:
+    return _py_list_stack(list_, opts)
+
+
+def _tf_tensorarray_stack(list_):
+  """Overload of list_stack that stages a TensorArray stack."""
+  return list_.stack()
 
 
-def _tf_tensorarray_append(target, element):
-  """Overload of append that stages a TensorArray write at the last position."""
-  return target.write(target.size(), element)
+def _tf_tensor_list_stack(list_, opts):
+  """Overload of list_stack that stages a Tensor list write."""
+  if opts.element_dtype is None:
+    raise ValueError('cannot stack a list without knowing its element type;'
+                     ' use set_element_type to annotate it')
+  return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype)
 
 
-def _py_append(target, element):
-  """Overload of append that executes a Python list append."""
-  target.append(element)
-  return target
+def _py_list_stack(list_, opts):
+  """Overload of list_stack that executes a Python list append."""
+  # Revert to the original call.
+  return opts.original_call(list_)
index 577d28c..8bbb52d 100644 (file)
@@ -19,25 +19,98 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.contrib.autograph.operators import data_structures
+from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import list_ops
 from tensorflow.python.ops import tensor_array_ops
 from tensorflow.python.platform import test
 
 
-class AppendTest(test.TestCase):
+class ListTest(test.TestCase):
 
-  def test_tf_tensorarray(self):
+  def test_new_list_empty(self):
+    l = data_structures.new_list()
+    # Can't evaluate an empty list.
+    # TODO(mdan): sess.run should allow tf.variant maybe?
+    self.assertTrue(isinstance(l, ops.Tensor))
+
+  def test_new_list_tensor(self):
+    l = data_structures.new_list([3, 4, 5])
+    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+    with self.test_session() as sess:
+      self.assertAllEqual(sess.run(t), [3, 4, 5])
+
+  def test_append_tensor_list(self):
+    l = data_structures.new_list()
+    x = constant_op.constant([1, 2, 3])
+    l = data_structures.list_append(l, x)
+
+    t = list_ops.tensor_list_stack(l, element_dtype=x.dtype)
+    with self.test_session() as sess:
+      self.assertAllEqual(sess.run(t), [[1, 2, 3]])
+
+  def test_append_tensorarray(self):
     l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
-    l1 = data_structures.append(l, 1)
-    l2 = data_structures.append(l1, 2)
+    l1 = data_structures.list_append(l, 1)
+    l2 = data_structures.list_append(l1, 2)
     with self.test_session() as sess:
       self.assertAllEqual(sess.run(l1.stack()), [1])
       self.assertAllEqual(sess.run(l2.stack()), [1, 2])
 
-  def test_python(self):
+  def test_append_python(self):
     l = []
-    self.assertAllEqual(data_structures.append(l, 1), [1])
-    self.assertAllEqual(data_structures.append(l, 2), [1, 2])
+    self.assertAllEqual(data_structures.list_append(l, 1), [1])
+    self.assertAllEqual(data_structures.list_append(l, 2), [1, 2])
+
+  def test_pop_tensor_list(self):
+    initial_list = constant_op.constant([[1, 2], [3, 4]])
+    elem_shape = constant_op.constant([2])
+    l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
+
+    opts = data_structures.ListPopOpts(
+        element_dtype=initial_list.dtype,
+        element_shape=(2,))
+
+    with self.assertRaises(NotImplementedError):
+      data_structures.list_pop(l, 0, opts)
+
+    with self.test_session() as sess:
+      l, x = data_structures.list_pop(l, None, opts)
+      self.assertAllEqual(sess.run(x), [3, 4])
+
+      t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
+      self.assertAllEqual(sess.run(t), [[1, 2]])
+
+  def test_pop_python(self):
+    l = [1, 2, 3]
+    opts = data_structures.ListPopOpts(element_dtype=None, element_shape=())
+    self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1, 2], 3))
+    self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1], 2))
+
+  def test_stack_tensor_list(self):
+    initial_list = constant_op.constant([[1, 2], [3, 4]])
+    elem_shape = constant_op.constant([2])
+    l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
+
+    opts = data_structures.ListStackOpts(
+        element_dtype=initial_list.dtype, original_call=None)
+
+    with self.test_session() as sess:
+      t = data_structures.list_stack(l, opts)
+      self.assertAllEqual(sess.run(t), sess.run(initial_list))
+
+  def test_stack_fallback(self):
+
+    def dummy_function(l):
+      # Lazy person's mock: just transform the argument in a way in which we
+      # can check that this function was indeed called.
+      return [x * 2 for x in l]
+
+    opts = data_structures.ListStackOpts(
+        element_dtype=None, original_call=dummy_function)
+
+    self.assertAllEqual(data_structures.list_stack([1, 2], opts), [2, 4])
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py
new file mode 100644 (file)
index 0000000..04fbeb2
--- /dev/null
@@ -0,0 +1,133 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operators specific to slicing operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
+
+
+# TODO(mdan): Support extended slices.
+
+
+class GetItemOpts(collections.namedtuple('GetItemOpts', ('element_dtype',))):
+  pass
+
+
+def get_item(target, i, opts):
+  """The slice read operator (i.e. __getitem__).
+
+  Note: it is unspecified whether target will be mutated or not. In general,
+  if target is mutable (like Python lists), it will be mutated.
+
+  Args:
+    target: An entity that supports getitem semantics.
+    i: Index to read from.
+    opts: A GetItemOpts object.
+
+  Returns:
+    The read element.
+
+  Raises:
+    ValueError: if target is not of a supported type.
+  """
+  assert isinstance(opts, GetItemOpts)
+
+  if isinstance(target, tensor_array_ops.TensorArray):
+    return _tf_tensorarray_get_item(target, i)
+  elif tensor_util.is_tensor(target):
+    if target.dtype == dtypes.variant:
+      return _tf_tensor_list_get_item(target, i, opts)
+    else:
+      return _tf_tensor_get_item(target, i)
+  else:
+    return _py_get_item(target, i)
+
+
+def _tf_tensorarray_get_item(target, i):
+  """Overload of get_item that stages a TensorArray read."""
+  return target.read(i)
+
+
+def _tf_tensor_list_get_item(target, i, opts):
+  """Overload of get_item that stages a Tensor list read."""
+  if opts.element_dtype is None:
+    raise ValueError('cannot retrieve from a list without knowing its '
+                     'element type; use set_element_type to annotate it')
+  x = list_ops.tensor_list_get_item(target, i, element_dtype=opts.element_dtype)
+  return x
+
+
+def _tf_tensor_get_item(target, i):
+  """Overload of get_item that stages a Tensor (not Tensor list) read."""
+  return target[i]
+
+
+def _py_get_item(target, i):
+  """Overload of get_item that executes a Python list modification."""
+  return target[i]
+
+
+def set_item(target, i, x):
+  """The slice write operator (i.e. __setitem__).
+
+  Note: it is unspecified whether target will be mutated or not. In general,
+  if target is mutable (like Python lists), it will be mutated.
+
+  Args:
+    target: An entity that supports setitem semantics.
+    i: Index to modify.
+    x: The new element value.
+
+  Returns:
+    Same as target, after the update was performed.
+
+  Raises:
+    ValueError: if target is not of a supported type.
+  """
+  if isinstance(target, tensor_array_ops.TensorArray):
+    return _tf_tensorarray_set_item(target, i, x)
+  elif tensor_util.is_tensor(target):
+    if target.dtype == dtypes.variant:
+      return _tf_tensor_list_set_item(target, i, x)
+    else:
+      raise ValueError(
+          'tensor lists are expected to be Tensors with dtype=tf.variant,'
+          ' instead found %s' % target)
+  else:
+    return _py_set_item(target, i, x)
+
+
+def _tf_tensorarray_set_item(target, i, x):
+  """Overload of set_item that stages a TensorArray write."""
+  return target.write(i, x)
+
+
+def _tf_tensor_list_set_item(target, i, x):
+  """Overload of set_item that stages a Tensor list update."""
+  return list_ops.tensor_list_set_item(target, i, x)
+
+
+def _py_set_item(target, i, x):
+  """Overload of set_item that executes a Python list modification."""
+  target[i] = x
+  return target
diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py
new file mode 100644 (file)
index 0000000..d4aacb9
--- /dev/null
@@ -0,0 +1,51 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slices module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.operators import slices
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import list_ops
+from tensorflow.python.platform import test
+
+
+class SlicesTest(test.TestCase):
+
+  def test_set_item_tensor_list(self):
+    initial_list = constant_op.constant([[1, 2], [3, 4]])
+    elem_shape = constant_op.constant([2])
+    l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
+    l = slices.set_item(l, 0, [5, 6])
+
+    with self.test_session() as sess:
+      t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
+      self.assertAllEqual(sess.run(t), [[5, 6], [3, 4]])
+
+  def test_get_item_tensor_list(self):
+    initial_list = constant_op.constant([[1, 2], [3, 4]])
+    elem_shape = constant_op.constant([2])
+    l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
+    t = slices.get_item(
+        l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype))
+
+    with self.test_session() as sess:
+      self.assertAllEqual(sess.run(t), [3, 4])
+
+
+if __name__ == '__main__':
+  test.main()