[tf.data] Add native implementation for `tf.contrib.data.unbatch()`.
authorDerek Murray <mrry@google.com>
Thu, 19 Apr 2018 05:59:01 +0000 (22:59 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 19 Apr 2018 06:01:53 +0000 (23:01 -0700)
The implementation has two main improvements:
1. Avoid relatively expensive (~15us) function invocation for each incoming batch.
2. Use std::move() where possible to avoid copying strings/variants into the unbatched
   elements.

PiperOrigin-RevId: 193467856

tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
tensorflow/contrib/data/python/ops/batching.py
tensorflow/contrib/tpu/python/tpu/datasets.py
tensorflow/core/api_def/base_api/api_def_UnbatchDataset.pbtxt [new file with mode: 0644]
tensorflow/core/api_def/python_api/api_def_UnbatchDataset.pbtxt [new file with mode: 0644]
tensorflow/core/framework/tensor.h
tensorflow/core/kernels/batch_util.cc
tensorflow/core/kernels/batch_util.h
tensorflow/core/kernels/data/BUILD
tensorflow/core/kernels/data/unbatch_dataset_op.cc [new file with mode: 0644]
tensorflow/core/ops/dataset_ops.cc

index 413d873..e1ec60d 100644 (file)
@@ -18,15 +18,18 @@ from __future__ import division
 from __future__ import print_function
 
 import math
+import time
 
 import numpy as np
 
 from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
 from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.client import session
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
 from tensorflow.python.framework import sparse_tensor
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.ops import array_ops
@@ -34,6 +37,7 @@ from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import script_ops
 from tensorflow.python.ops import string_ops
 from tensorflow.python.platform import test
+from tensorflow.python.util import compat
 
 
 class BatchDatasetTest(test.TestCase):
@@ -151,6 +155,69 @@ class BatchDatasetTest(test.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(op)
 
+  def testUnbatchDatasetWithStrings(self):
+    data = tuple([math_ops.range(10) for _ in range(3)])
+    data = dataset_ops.Dataset.from_tensor_slices(data)
+    data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z))
+    expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
+    data = data.batch(2)
+    self.assertEqual(expected_types, data.output_types)
+    data = data.apply(batching.unbatch())
+    self.assertEqual(expected_types, data.output_types)
+
+    iterator = data.make_one_shot_iterator()
+    op = iterator.get_next()
+
+    with self.test_session() as sess:
+      for i in range(10):
+        self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
+
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(op)
+
+  def testUnbatchDatasetWithSparseTensor(self):
+    st = sparse_tensor.SparseTensorValue(
+        indices=[[i, i] for i in range(10)],
+        values=list(range(10)),
+        dense_shape=[10, 10])
+    data = dataset_ops.Dataset.from_tensors(st)
+    data = data.apply(batching.unbatch())
+    data = data.batch(5)
+    data = data.apply(batching.unbatch())
+    iterator = data.make_one_shot_iterator()
+    next_element = iterator.get_next()
+
+    with self.test_session() as sess:
+      for i in range(10):
+        st_row = sess.run(next_element)
+        self.assertEqual([i], st_row.indices)
+        self.assertEqual([i], st_row.values)
+        self.assertEqual([10], st_row.dense_shape)
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(next_element)
+
+  def testUnbatchDatasetWithDenseAndSparseTensor(self):
+    st = sparse_tensor.SparseTensorValue(
+        indices=[[i, i] for i in range(10)],
+        values=list(range(10)),
+        dense_shape=[10, 10])
+    data = dataset_ops.Dataset.from_tensors((list(range(10)), st))
+    data = data.apply(batching.unbatch())
+    data = data.batch(5)
+    data = data.apply(batching.unbatch())
+    iterator = data.make_one_shot_iterator()
+    next_element = iterator.get_next()
+
+    with self.test_session() as sess:
+      for i in range(10):
+        dense_elem, st_row = sess.run(next_element)
+        self.assertEqual(i, dense_elem)
+        self.assertEqual([i], st_row.indices)
+        self.assertEqual([i], st_row.values)
+        self.assertEqual([10], st_row.dense_shape)
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(next_element)
+
   def testUnbatchSingleElementTupleDataset(self):
     data = tuple([(math_ops.range(10),) for _ in range(3)])
     data = dataset_ops.Dataset.from_tensor_slices(data)
@@ -191,6 +258,53 @@ class BatchDatasetTest(test.TestCase):
       with self.assertRaises(errors.OutOfRangeError):
         sess.run(op)
 
+  def testUnbatchEmpty(self):
+    data = dataset_ops.Dataset.from_tensors(
+        (constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
+         constant_op.constant([], shape=[0, 4, 0])))
+    data = data.apply(batching.unbatch())
+    iterator = data.make_one_shot_iterator()
+    next_element = iterator.get_next()
+
+    with self.test_session() as sess:
+      with self.assertRaises(errors.OutOfRangeError):
+        sess.run(next_element)
+
+  def testUnbatchStaticShapeMismatch(self):
+    data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
+                                             np.arange(9)))
+    with self.assertRaises(ValueError):
+      data.apply(batching.unbatch())
+
+  def testUnbatchDynamicShapeMismatch(self):
+    ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
+    ph2 = array_ops.placeholder(dtypes.int32, shape=None)
+    data = dataset_ops.Dataset.from_tensors((ph1, ph2))
+    data = data.apply(batching.unbatch())
+    iterator = data.make_initializable_iterator()
+    next_element = iterator.get_next()
+
+    with self.test_session() as sess:
+      # Mismatch in the 0th dimension.
+      sess.run(
+          iterator.initializer,
+          feed_dict={
+              ph1: np.arange(7).astype(np.int32),
+              ph2: np.arange(8).astype(np.int32)
+          })
+      with self.assertRaises(errors.InvalidArgumentError):
+        print(sess.run(next_element))
+
+      # No 0th dimension (i.e. scalar value) for one component.
+      sess.run(
+          iterator.initializer,
+          feed_dict={
+              ph1: np.arange(7).astype(np.int32),
+              ph2: 7
+          })
+      with self.assertRaises(errors.InvalidArgumentError):
+        print(sess.run(next_element))
+
   def testBatchAndDropRemainder(self):
     components = (np.arange(7),
                   np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
@@ -545,6 +659,28 @@ class BatchDatasetSerializationTest(
     self.run_core_tests(self._build_dataset_nested_sparse, None, 1)
 
 
+class UnbatchDatasetSerializationTest(
+    dataset_serialization_test_base.DatasetSerializationTestBase):
+
+  def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2):
+    components = (
+        np.arange(tensor_slice_len),
+        np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis],
+        np.array(multiplier) * np.arange(tensor_slice_len))
+
+    return dataset_ops.Dataset.from_tensor_slices(components).batch(
+        batch_size).apply(batching.unbatch())
+
+  def testCore(self):
+    tensor_slice_len = 8
+    batch_size = 2
+    num_outputs = tensor_slice_len
+    self.run_core_tests(
+        lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
+        lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
+        num_outputs)
+
+
 class PaddedBatchDatasetSerializationTest(
     dataset_serialization_test_base.DatasetSerializationTestBase):
 
@@ -586,10 +722,12 @@ class RestructuredDatasetTest(test.TestCase):
   def test_assert_element_shape(self):
 
     def create_unknown_shape_dataset(x):
-      return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32),
-                                           np.zeros((3, 4), dtype=np.int32)),
-                                [x],
-                                [dtypes.float32, dtypes.int32])
+      return script_ops.py_func(
+          lambda _: (  # pylint: disable=g-long-lambda
+              np.ones(2, dtype=np.float32),
+              np.zeros((3, 4), dtype=np.int32)),
+          [x],
+          [dtypes.float32, dtypes.int32])
 
     dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
     unknown_shapes = (tensor_shape.TensorShape(None),
@@ -626,10 +764,12 @@ class RestructuredDatasetTest(test.TestCase):
   def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
 
     def create_unknown_shape_dataset(x):
-      return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32),
-                                           np.zeros((3, 4), dtype=np.int32)),
-                                [x],
-                                [dtypes.float32, dtypes.int32])
+      return script_ops.py_func(
+          lambda _: (  # pylint: disable=g-long-lambda
+              np.ones(2, dtype=np.float32),
+              np.zeros((3, 4), dtype=np.int32)),
+          [x],
+          [dtypes.float32, dtypes.int32])
 
     dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
     unknown_shapes = (tensor_shape.TensorShape(None),
@@ -649,5 +789,77 @@ class RestructuredDatasetTest(test.TestCase):
         sess.run(get_next)
 
 
+class UnbatchDatasetBenchmark(test.Benchmark):
+
+  def benchmarkNativeUnbatch(self):
+    batch_sizes = [1, 2, 5, 10, 20, 50]
+    elems_per_trial = 10000
+    with ops.Graph().as_default():
+      dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
+      batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
+      dataset = dataset.batch(batch_size_placeholder)
+      dataset = dataset.apply(batching.unbatch())
+      dataset = dataset.skip(elems_per_trial)
+      iterator = dataset.make_initializable_iterator()
+      next_element = iterator.get_next()
+
+      with session.Session() as sess:
+        for batch_size in batch_sizes:
+          deltas = []
+          for _ in range(5):
+            sess.run(
+                iterator.initializer,
+                feed_dict={batch_size_placeholder: batch_size})
+            start = time.time()
+            sess.run(next_element.op)
+            end = time.time()
+            deltas.append((end - start) / elems_per_trial)
+
+          median_wall_time = np.median(deltas)
+          print("Unbatch (native) batch size: %d Median wall time per element:"
+                " %f microseconds" % (batch_size, median_wall_time * 1e6))
+          self.report_benchmark(
+              iters=10000,
+              wall_time=median_wall_time,
+              name="benchmark_unbatch_dataset_native_batch_size_%d" %
+              batch_size)
+
+  # Include a benchmark of the previous `unbatch()` implementation that uses
+  # a composition of more primitive ops. Eventually we'd hope to generate code
+  # that is as good in both cases.
+  def benchmarkOldUnbatchImplementation(self):
+    batch_sizes = [1, 2, 5, 10, 20, 50]
+    elems_per_trial = 10000
+    with ops.Graph().as_default():
+      dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
+      batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
+      dataset = dataset.batch(batch_size_placeholder)
+      dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
+      dataset = dataset.skip(elems_per_trial)
+      iterator = dataset.make_initializable_iterator()
+      next_element = iterator.get_next()
+
+      with session.Session() as sess:
+        for batch_size in batch_sizes:
+          deltas = []
+          for _ in range(5):
+            sess.run(
+                iterator.initializer,
+                feed_dict={batch_size_placeholder: batch_size})
+            start = time.time()
+            sess.run(next_element.op)
+            end = time.time()
+            deltas.append((end - start) / elems_per_trial)
+
+          median_wall_time = np.median(deltas)
+          print("Unbatch (unfused) batch size: %d Median wall time per element:"
+                " %f microseconds" % (batch_size, median_wall_time * 1e6))
+          self.report_benchmark(
+              iters=10000,
+              wall_time=median_wall_time,
+              name="benchmark_unbatch_dataset_unfused_batch_size_%d" %
+              batch_size)
+
+
 if __name__ == "__main__":
   test.main()
index 28db949..2152bcd 100644 (file)
@@ -80,28 +80,98 @@ def dense_to_sparse_batch(batch_size, row_shape):
   return _apply_fn
 
 
+class UnbatchDataset(dataset_ops.Dataset):
+  """A dataset that splits the elements of its input into multiple elements."""
+
+  def __init__(self, input_dataset):
+    """See `unbatch()` for more details."""
+    super(UnbatchDataset, self).__init__()
+    flat_shapes = nest.flatten(input_dataset.output_shapes)
+    if any(s.ndims == 0 for s in flat_shapes):
+      raise ValueError("Cannot unbatch an input with scalar components.")
+    known_batch_dim = tensor_shape.Dimension(None)
+    for s in flat_shapes:
+      try:
+        known_batch_dim = known_batch_dim.merge_with(s[0])
+      except ValueError:
+        raise ValueError("Cannot unbatch an input whose components have "
+                         "different batch sizes.")
+    self._input_dataset = input_dataset
+
+  def _as_variant_tensor(self):
+    return gen_dataset_ops.unbatch_dataset(
+        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
+        output_shapes=nest.flatten(
+            sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
+        output_types=nest.flatten(
+            sparse.as_dense_types(self.output_types, self.output_classes)))
+
+  @property
+  def output_classes(self):
+    return self._input_dataset.output_classes
+
+  @property
+  def output_shapes(self):
+    return nest.map_structure(lambda s: s[1:],
+                              self._input_dataset.output_shapes)
+
+  @property
+  def output_types(self):
+    return self._input_dataset.output_types
+
+
 def unbatch():
-  """A Transformation which splits the elements of a dataset.
+  """Splits elements of a dataset into multiple elements on the batch dimension.
 
   For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
-  where `B` may vary from element to element, then for each element in
-  the dataset, the unbatched dataset will contain `B` consecutive elements
+  where `B` may vary for each input element, then for each element in the
+  dataset, the unbatched dataset will contain `B` consecutive elements
   of shape `[a0, a1, ...]`.
 
+  ```python
+  # NOTE: The following example uses `{ ... }` to represent the contents
+  # of a dataset.
+  a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
+
+  a.apply(tf.contrib.data.unbatch()) == {
+      'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'}
+  ```
+
   Returns:
     A `Dataset` transformation function, which can be passed to
     @{tf.data.Dataset.apply}.
   """
 
   def _apply_fn(dataset):
-
-    def unbatch_map(arg, *rest):
+    """Function from `Dataset` to `Dataset` that applies the transformation."""
+    if not sparse.any_sparse(dataset.output_classes):
+      return UnbatchDataset(dataset)
+
+    # NOTE(mrry): We must ensure that any SparseTensors in `dataset`
+    # are normalized to the rank-1 dense representation, so that the
+    # sparse-oblivious unbatching logic will slice them
+    # appropriately. This leads to a somewhat inefficient re-encoding step
+    # for all SparseTensor components.
+    # TODO(mrry): Consider optimizing this in future
+    # if it turns out to be a bottleneck.
+    def normalize(arg, *rest):
       if rest:
-        return dataset_ops.Dataset.from_tensor_slices((arg,) + rest)
+        return sparse.serialize_many_sparse_tensors((arg,) + rest)
       else:
-        return dataset_ops.Dataset.from_tensor_slices(arg)
+        return sparse.serialize_many_sparse_tensors(arg)
+
+    normalized_dataset = dataset.map(normalize)
 
-    return dataset.flat_map(map_func=unbatch_map)
+    # NOTE(mrry): Our `map()` has lost information about the sparseness
+    # of any SparseTensor components, so re-apply the structure of the
+    # original dataset.
+    restructured_dataset = _RestructuredDataset(
+        normalized_dataset,
+        dataset.output_types,
+        dataset.output_shapes,
+        dataset.output_classes,
+        allow_unsafe_cast=True)
+    return UnbatchDataset(restructured_dataset)
 
   return _apply_fn
 
@@ -265,7 +335,8 @@ class _RestructuredDataset(dataset_ops.Dataset):
                dataset,
                output_types,
                output_shapes=None,
-               output_classes=None):
+               output_classes=None,
+               allow_unsafe_cast=False):
     """Creates a new dataset with the given output types and shapes.
 
     The given `dataset` must have a structure that is convertible:
@@ -283,6 +354,10 @@ class _RestructuredDataset(dataset_ops.Dataset):
         If omitted, the shapes will be inherited from `dataset`.
       output_classes: (Optional.) A nested structure of class types.
         If omitted, the class types will be inherited from `dataset`.
+      allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
+        reported output types and shapes of the restructured dataset, e.g. to
+        switch a sparse tensor represented as `tf.variant` to its user-visible
+        type and shape.
 
     Raises:
       ValueError: If either `output_types` or `output_shapes` is not compatible
@@ -291,14 +366,15 @@ class _RestructuredDataset(dataset_ops.Dataset):
     super(_RestructuredDataset, self).__init__()
     self._dataset = dataset
 
-    # Validate that the types are compatible.
-    output_types = nest.map_structure(dtypes.as_dtype, output_types)
-    flat_original_types = nest.flatten(dataset.output_types)
-    flat_new_types = nest.flatten(output_types)
-    if flat_original_types != flat_new_types:
-      raise ValueError(
-          "Dataset with output types %r cannot be restructured to have output "
-          "types %r" % (dataset.output_types, output_types))
+    if not allow_unsafe_cast:
+      # Validate that the types are compatible.
+      output_types = nest.map_structure(dtypes.as_dtype, output_types)
+      flat_original_types = nest.flatten(dataset.output_types)
+      flat_new_types = nest.flatten(output_types)
+      if flat_original_types != flat_new_types:
+        raise ValueError(
+            "Dataset with output types %r cannot be restructured to have "
+            "output types %r" % (dataset.output_types, output_types))
 
     self._output_types = output_types
 
@@ -308,18 +384,19 @@ class _RestructuredDataset(dataset_ops.Dataset):
                                                   nest.flatten(
                                                       dataset.output_shapes))
     else:
-      # Validate that the shapes are compatible.
-      nest.assert_same_structure(output_types, output_shapes)
-      flat_original_shapes = nest.flatten(dataset.output_shapes)
-      flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
-
-      for original_shape, new_shape in zip(flat_original_shapes,
-                                           flat_new_shapes):
-        if not original_shape.is_compatible_with(new_shape):
-          raise ValueError(
-              "Dataset with output shapes %r cannot be restructured to have "
-              "incompatible output shapes %r" % (dataset.output_shapes,
-                                                 output_shapes))
+      if not allow_unsafe_cast:
+        # Validate that the shapes are compatible.
+        nest.assert_same_structure(output_types, output_shapes)
+        flat_original_shapes = nest.flatten(dataset.output_shapes)
+        flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
+
+        for original_shape, new_shape in zip(flat_original_shapes,
+                                             flat_new_shapes):
+          if not original_shape.is_compatible_with(new_shape):
+            raise ValueError(
+                "Dataset with output shapes %r cannot be restructured to have "
+                "incompatible output shapes %r" % (dataset.output_shapes,
+                                                   output_shapes))
       self._output_shapes = nest.map_structure_up_to(
           output_types, tensor_shape.as_shape, output_shapes)
     if output_classes is None:
index 465c668..2e472a2 100644 (file)
@@ -170,7 +170,7 @@ def StreamingFilesDataset(files,
         args=[source_handle],
         Tout=[dtypes.string],
         f=LoadingFunc,
-        target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)
+        target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)[0]
 
   with ops.device('/job:%s' % worker_job):
     output_dataset = dataset_ops.Dataset.range(2).repeat().map(
diff --git a/tensorflow/core/api_def/base_api/api_def_UnbatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnbatchDataset.pbtxt
new file mode 100644 (file)
index 0000000..324fada
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "UnbatchDataset"
+  summary: "A dataset that splits the elements of its input into multiple elements."
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnbatchDataset.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnbatchDataset.pbtxt
new file mode 100644 (file)
index 0000000..1e54157
--- /dev/null
@@ -0,0 +1,4 @@
+op {
+  graph_op_name: "UnbatchDataset"
+  visibility: HIDDEN
+}
index 4d10f7e..58fbced 100644 (file)
@@ -44,6 +44,7 @@ class TensorProto;
 class VariantTensorData;
 namespace batch_util {
 Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index);
+Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index);
 }  // namespace batch_util
 
 /// @ingroup core
@@ -493,6 +494,10 @@ class Tensor {
   friend Status batch_util::CopyElementToSlice(
       Tensor element, Tensor* parent,
       int64 index);                // For access to RefCountIsOne().
+  friend Status batch_util::MaybeMoveSliceToElement(
+      Tensor* parent, Tensor* element,
+      int64 index);  // For access to RefCountIsOne().
+
   friend class NumpyTensorBuffer;  // For access to the private constructor
                                    // taking the buffer.
 
index 1a45212..52be1ab 100644 (file)
@@ -78,14 +78,44 @@ Status HandleElementToSlice<Variant>(Tensor element, Tensor* parent,
   return Status::OK();
 }
 
-// TODO(jsimsa): Add HandleElementToSlice<variant> specialization that moves
-// the data when possible.
-
+// TODO(b/78245576): Consider removing this overload.
 template <typename T>
-static Status HandleSliceToElement(const Tensor& parent, Tensor* element,
-                                   int64 index) {
+void HandleSliceToElement(const Tensor& parent, Tensor* element, int64 index) {
   element->flat<T>() = parent.flat_outer_dims<T>().chip(index, 0);
-  return Status::OK();
+}
+
+template <typename T>
+void HandleSliceToElement(Tensor* parent, Tensor* element, int64 index,
+                          bool can_move) {
+  element->flat<T>() = parent->flat_outer_dims<T>().chip(index, 0);
+}
+
+template <>
+void HandleSliceToElement<string>(Tensor* parent, Tensor* element, int64 index,
+                                  bool can_move) {
+  auto parent_as_matrix = parent->flat_outer_dims<string>();
+  auto element_flat = element->flat<string>();
+  if (can_move) {
+    for (int64 i = 0; i < element->NumElements(); ++i) {
+      element_flat(i) = std::move(parent_as_matrix(index, i));
+    }
+  } else {
+    element_flat = parent_as_matrix.chip(index, 0);
+  }
+}
+
+template <>
+void HandleSliceToElement<Variant>(Tensor* parent, Tensor* element, int64 index,
+                                   bool can_move) {
+  auto parent_as_matrix = parent->flat_outer_dims<Variant>();
+  auto element_flat = element->flat<Variant>();
+  if (can_move) {
+    for (int64 i = 0; i < element->NumElements(); ++i) {
+      element_flat(i) = std::move(parent_as_matrix(index, i));
+    }
+  } else {
+    element_flat = parent_as_matrix.chip(index, 0);
+  }
 }
 
 }  // namespace
@@ -115,9 +145,10 @@ Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index) {
 Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) {
   TF_RETURN_IF_ERROR(ValidateInput(parent, *element, index));
 
-#define HANDLE_TYPE(T)                                      \
-  case DataTypeToEnum<T>::value: {                          \
-    return HandleSliceToElement<T>(parent, element, index); \
+#define HANDLE_TYPE(T)                               \
+  case DataTypeToEnum<T>::value: {                   \
+    HandleSliceToElement<T>(parent, element, index); \
+    return Status::OK();                             \
   }
 
   switch (parent.dtype()) {
@@ -130,6 +161,30 @@ Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) {
   }
 }
 
+// Copies the index^th slice of parent (in the 0th dimension) into element.
+//
+// NOTE(mrry): The implementation may be able to optimize the copy to a move.
+// This is particularly important for DT_STRING tensors.
+Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index) {
+  TF_RETURN_IF_ERROR(ValidateInput(*parent, *element, index));
+  bool can_move = parent->RefCountIsOne();
+
+#define HANDLE_TYPE(T)                                         \
+  case DataTypeToEnum<T>::value: {                             \
+    HandleSliceToElement<T>(parent, element, index, can_move); \
+    return Status::OK();                                       \
+  }
+
+  switch (parent->dtype()) {
+    TF_CALL_ALL_TYPES(HANDLE_TYPE);
+    TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
+#undef HANDLE_TYPE
+    default:
+      return errors::Unimplemented(
+          "MaybeMoveSliceToElement Unhandled data type: ", element->dtype());
+  }
+}
+
 // The following five functions are copied from padding_fifo_queue.cc.
 // TODO(mrry): Reconcile these functions with the similar methods in the
 // queue implementation.
index a47bf19..69098fb 100644 (file)
@@ -32,6 +32,12 @@ Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index);
 // Copies the index^th slice of parent (in the 0th dimension) into element.
 Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index);
 
+// Copies the index^th slice of parent (in the 0th dimension) into element.
+//
+// NOTE(mrry): The implementation may be able to optimize the copy to a move.
+// This is particularly important for DT_STRING tensors.
+Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index);
+
 // Zero-initializes the tensor `element` using the scalar stored in `padding`.
 // Both `element` and `padding` must have matching `dtype`.
 Status SetElementZero(Tensor* element, const Tensor& padding);
index 221724e..1e96eb6 100644 (file)
@@ -447,6 +447,19 @@ tf_kernel_library(
 )
 
 tf_kernel_library(
+    name = "unbatch_dataset_op",
+    srcs = ["unbatch_dataset_op.cc"],
+    deps = [
+        ":dataset",
+        "//tensorflow/core:dataset_ops_op_lib",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
+        "//tensorflow/core/kernels:batch_util",
+    ],
+)
+
+tf_kernel_library(
     name = "zip_dataset_op",
     srcs = ["zip_dataset_op.cc"],
     deps = [
@@ -562,6 +575,7 @@ tf_kernel_library(
         ":tensor_dataset_op",
         ":tensor_queue_dataset_op",
         ":tensor_slice_dataset_op",
+        ":unbatch_dataset_op",
         ":zip_dataset_op",
     ],
 )
diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
new file mode 100644 (file)
index 0000000..241b615
--- /dev/null
@@ -0,0 +1,204 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/batch_util.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+
+namespace tensorflow {
+
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class UnbatchDatasetOp : public UnaryDatasetOpKernel {
+ public:
+  explicit UnbatchDatasetOp(OpKernelConstruction* ctx)
+      : UnaryDatasetOpKernel(ctx) {}
+
+  void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+                   DatasetBase** output) override {
+    *output = new Dataset(ctx, input);
+  }
+
+ private:
+  class Dataset : public GraphDatasetBase {
+   public:
+    explicit Dataset(OpKernelContext* ctx, DatasetBase* input)
+        : GraphDatasetBase(ctx), input_(input) {
+      input_->Ref();
+      for (const PartialTensorShape& shape : input->output_shapes()) {
+        gtl::InlinedVector<int64, 4> partial_dim_sizes;
+        for (int i = 1; i < shape.dims(); ++i) {
+          partial_dim_sizes.push_back(shape.dim_size(i));
+        }
+        shapes_.emplace_back(std::move(partial_dim_sizes));
+      }
+    }
+
+    std::unique_ptr<IteratorBase> MakeIterator(
+        const string& prefix) const override {
+      return std::unique_ptr<IteratorBase>(
+          new Iterator({this, strings::StrCat(prefix, "::Unbatch")}));
+    }
+
+    const DataTypeVector& output_dtypes() const override {
+      return input_->output_dtypes();
+    }
+    const std::vector<PartialTensorShape>& output_shapes() const override {
+      return shapes_;
+    }
+
+    string DebugString() override { return "UnbatchDatasetOp::Dataset"; }
+
+   protected:
+    Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+                              Node** output) const override {
+      Node* input_graph_node = nullptr;
+      TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
+      TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
+      return Status::OK();
+    }
+
+   private:
+    class Iterator : public DatasetIterator<Dataset> {
+     public:
+      explicit Iterator(const Params& params)
+          : DatasetIterator<Dataset>(params),
+            current_index_(0),
+            current_batch_size_(0),
+            input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
+            shapes_(params.dataset->output_shapes().size()) {}
+
+      Status GetNextInternal(IteratorContext* ctx,
+                             std::vector<Tensor>* out_tensors,
+                             bool* end_of_sequence) override {
+        mutex_lock l(mu_);
+        if (!input_impl_) {
+          *end_of_sequence = true;
+          return Status::OK();
+        }
+        *end_of_sequence = false;
+        while (!*end_of_sequence) {
+          if (current_index_ < current_batch_size_) {
+            out_tensors->clear();
+            out_tensors->reserve(tensors_.size());
+            for (int i = 0; i < tensors_.size(); ++i) {
+              out_tensors->emplace_back(ctx->allocator({}), tensors_[i].dtype(),
+                                        shapes_[i]);
+              TF_RETURN_IF_ERROR(batch_util::MaybeMoveSliceToElement(
+                  &tensors_[i], &out_tensors->back(), current_index_));
+            }
+            ++current_index_;
+            *end_of_sequence = false;
+            return Status::OK();
+          }
+          current_index_ = 0;
+          current_batch_size_ = 0;
+          tensors_.clear();
+          TF_RETURN_IF_ERROR(
+              input_impl_->GetNext(ctx, &tensors_, end_of_sequence));
+          if (!*end_of_sequence) {
+            for (size_t i = 0; i < tensors_.size(); ++i) {
+              if (tensors_[i].dims() == 0) {
+                return errors::InvalidArgument(
+                    "Input element must have a non-scalar value in each "
+                    "component.");
+              }
+              if (tensors_[i].dim_size(0) != tensors_[0].dim_size(0)) {
+                return errors::InvalidArgument(
+                    "Input element must have the same batch size in each "
+                    "component. Component 0 had size ",
+                    tensors_[0].dim_size(0), " but component ", i,
+                    " had size, ", tensors_[i].dim_size(0), ".");
+              }
+              shapes_[i] = tensors_[i].shape();
+              shapes_[i].RemoveDim(0);
+            }
+            current_batch_size_ = tensors_[0].dim_size(0);
+          }
+        }
+        input_impl_.reset();
+        return Status::OK();
+      }
+
+     protected:
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        mutex_lock l(mu_);
+        if (input_impl_) {
+          TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+        } else {
+          TF_RETURN_IF_ERROR(
+              writer->WriteScalar(full_name("input_impl_empty"), ""));
+        }
+        TF_RETURN_IF_ERROR(
+            writer->WriteScalar(full_name("current_index"), current_index_));
+        TF_RETURN_IF_ERROR(
+            writer->WriteScalar(full_name("n"), current_batch_size_));
+        if (current_index_ < current_batch_size_) {
+          for (size_t i = 0; i < tensors_.size(); ++i) {
+            TF_RETURN_IF_ERROR(writer->WriteTensor(
+                full_name(strings::StrCat("tensors[", i, "]")), tensors_[i]));
+          }
+        }
+        return Status::OK();
+      }
+
+      Status RestoreInternal(IteratorContext* ctx,
+                             IteratorStateReader* reader) override {
+        mutex_lock l(mu_);
+        if (!reader->Contains(full_name("input_impl_empty"))) {
+          TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+        } else {
+          input_impl_.reset();
+        }
+        TF_RETURN_IF_ERROR(
+            reader->ReadScalar(full_name("current_index"), &current_index_));
+        TF_RETURN_IF_ERROR(
+            reader->ReadScalar(full_name("n"), &current_batch_size_));
+        tensors_.clear();
+        tensors_.resize(dataset()->output_dtypes().size());
+        if (current_index_ < current_batch_size_) {
+          for (size_t i = 0; i < tensors_.size(); ++i) {
+            TF_RETURN_IF_ERROR(reader->ReadTensor(
+                full_name(strings::StrCat("tensors[", i, "]")), &tensors_[i]));
+            shapes_[i] = tensors_[i].shape();
+            shapes_[i].RemoveDim(0);
+          }
+        }
+        return Status::OK();
+      }
+
+     private:
+      mutex mu_;
+      int64 current_index_ GUARDED_BY(mu_);
+      int64 current_batch_size_ GUARDED_BY(mu_);
+      std::vector<Tensor> tensors_ GUARDED_BY(mu_);
+      std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+      std::vector<TensorShape> shapes_ GUARDED_BY(mu_);
+    };
+
+    const DatasetBase* const input_;
+    std::vector<PartialTensorShape> shapes_;
+  };
+};
+
+REGISTER_KERNEL_BUILDER(Name("UnbatchDataset").Device(DEVICE_CPU),
+                        UnbatchDatasetOp);
+
+}  // namespace
+
+}  // namespace tensorflow
index 57f871a..8be569b 100644 (file)
@@ -83,6 +83,13 @@ REGISTER_OP("GeneratorDataset")
                       // stateful to inhibit constant folding.
     .SetShapeFn(shape_inference::ScalarShape);
 
+REGISTER_OP("UnbatchDataset")
+    .Input("input_dataset: variant")
+    .Output("handle: variant")
+    .Attr("output_types: list(type) >= 1")
+    .Attr("output_shapes: list(shape) >= 1")
+    .SetShapeFn(shape_inference::ScalarShape);
+
 REGISTER_OP("ZipDataset")
     .Input("input_datasets: N * variant")
     .Output("handle: variant")