from __future__ import print_function
import math
+import time
import numpy as np
from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
+from tensorflow.python.util import compat
class BatchDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
+ def testUnbatchDatasetWithStrings(self):
+ data = tuple([math_ops.range(10) for _ in range(3)])
+ data = dataset_ops.Dataset.from_tensor_slices(data)
+ data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z))
+ expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
+ data = data.batch(2)
+ self.assertEqual(expected_types, data.output_types)
+ data = data.apply(batching.unbatch())
+ self.assertEqual(expected_types, data.output_types)
+
+ iterator = data.make_one_shot_iterator()
+ op = iterator.get_next()
+
+ with self.test_session() as sess:
+ for i in range(10):
+ self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(op)
+
+ def testUnbatchDatasetWithSparseTensor(self):
+ st = sparse_tensor.SparseTensorValue(
+ indices=[[i, i] for i in range(10)],
+ values=list(range(10)),
+ dense_shape=[10, 10])
+ data = dataset_ops.Dataset.from_tensors(st)
+ data = data.apply(batching.unbatch())
+ data = data.batch(5)
+ data = data.apply(batching.unbatch())
+ iterator = data.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ for i in range(10):
+ st_row = sess.run(next_element)
+ self.assertEqual([i], st_row.indices)
+ self.assertEqual([i], st_row.values)
+ self.assertEqual([10], st_row.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testUnbatchDatasetWithDenseAndSparseTensor(self):
+ st = sparse_tensor.SparseTensorValue(
+ indices=[[i, i] for i in range(10)],
+ values=list(range(10)),
+ dense_shape=[10, 10])
+ data = dataset_ops.Dataset.from_tensors((list(range(10)), st))
+ data = data.apply(batching.unbatch())
+ data = data.batch(5)
+ data = data.apply(batching.unbatch())
+ iterator = data.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ for i in range(10):
+ dense_elem, st_row = sess.run(next_element)
+ self.assertEqual(i, dense_elem)
+ self.assertEqual([i], st_row.indices)
+ self.assertEqual([i], st_row.values)
+ self.assertEqual([10], st_row.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
def testUnbatchSingleElementTupleDataset(self):
data = tuple([(math_ops.range(10),) for _ in range(3)])
data = dataset_ops.Dataset.from_tensor_slices(data)
with self.assertRaises(errors.OutOfRangeError):
sess.run(op)
+ def testUnbatchEmpty(self):
+ data = dataset_ops.Dataset.from_tensors(
+ (constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
+ constant_op.constant([], shape=[0, 4, 0])))
+ data = data.apply(batching.unbatch())
+ iterator = data.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testUnbatchStaticShapeMismatch(self):
+ data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
+ np.arange(9)))
+ with self.assertRaises(ValueError):
+ data.apply(batching.unbatch())
+
+ def testUnbatchDynamicShapeMismatch(self):
+ ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
+ ph2 = array_ops.placeholder(dtypes.int32, shape=None)
+ data = dataset_ops.Dataset.from_tensors((ph1, ph2))
+ data = data.apply(batching.unbatch())
+ iterator = data.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.test_session() as sess:
+ # Mismatch in the 0th dimension.
+ sess.run(
+ iterator.initializer,
+ feed_dict={
+ ph1: np.arange(7).astype(np.int32),
+ ph2: np.arange(8).astype(np.int32)
+ })
+ with self.assertRaises(errors.InvalidArgumentError):
+ print(sess.run(next_element))
+
+ # No 0th dimension (i.e. scalar value) for one component.
+ sess.run(
+ iterator.initializer,
+ feed_dict={
+ ph1: np.arange(7).astype(np.int32),
+ ph2: 7
+ })
+ with self.assertRaises(errors.InvalidArgumentError):
+ print(sess.run(next_element))
+
def testBatchAndDropRemainder(self):
components = (np.arange(7),
np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
self.run_core_tests(self._build_dataset_nested_sparse, None, 1)
+class UnbatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2):
+ components = (
+ np.arange(tensor_slice_len),
+ np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis],
+ np.array(multiplier) * np.arange(tensor_slice_len))
+
+ return dataset_ops.Dataset.from_tensor_slices(components).batch(
+ batch_size).apply(batching.unbatch())
+
+ def testCore(self):
+ tensor_slice_len = 8
+ batch_size = 2
+ num_outputs = tensor_slice_len
+ self.run_core_tests(
+ lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
+ lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
+ num_outputs)
+
+
class PaddedBatchDatasetSerializationTest(
dataset_serialization_test_base.DatasetSerializationTestBase):
def test_assert_element_shape(self):
def create_unknown_shape_dataset(x):
- return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
unknown_shapes = (tensor_shape.TensorShape(None),
def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
def create_unknown_shape_dataset(x):
- return script_ops.py_func(lambda _: (np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
unknown_shapes = (tensor_shape.TensorShape(None),
sess.run(get_next)
+class UnbatchDatasetBenchmark(test.Benchmark):
+
+ def benchmarkNativeUnbatch(self):
+ batch_sizes = [1, 2, 5, 10, 20, 50]
+ elems_per_trial = 10000
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
+ batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = dataset.batch(batch_size_placeholder)
+ dataset = dataset.apply(batching.unbatch())
+ dataset = dataset.skip(elems_per_trial)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for batch_size in batch_sizes:
+ deltas = []
+ for _ in range(5):
+ sess.run(
+ iterator.initializer,
+ feed_dict={batch_size_placeholder: batch_size})
+ start = time.time()
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append((end - start) / elems_per_trial)
+
+ median_wall_time = np.median(deltas)
+ print("Unbatch (native) batch size: %d Median wall time per element:"
+ " %f microseconds" % (batch_size, median_wall_time * 1e6))
+ self.report_benchmark(
+ iters=10000,
+ wall_time=median_wall_time,
+ name="benchmark_unbatch_dataset_native_batch_size_%d" %
+ batch_size)
+
+ # Include a benchmark of the previous `unbatch()` implementation that uses
+ # a composition of more primitive ops. Eventually we'd hope to generate code
+ # that is as good in both cases.
+ def benchmarkOldUnbatchImplementation(self):
+ batch_sizes = [1, 2, 5, 10, 20, 50]
+ elems_per_trial = 10000
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
+ batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = dataset.batch(batch_size_placeholder)
+ dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
+ dataset = dataset.skip(elems_per_trial)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for batch_size in batch_sizes:
+ deltas = []
+ for _ in range(5):
+ sess.run(
+ iterator.initializer,
+ feed_dict={batch_size_placeholder: batch_size})
+ start = time.time()
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append((end - start) / elems_per_trial)
+
+ median_wall_time = np.median(deltas)
+ print("Unbatch (unfused) batch size: %d Median wall time per element:"
+ " %f microseconds" % (batch_size, median_wall_time * 1e6))
+ self.report_benchmark(
+ iters=10000,
+ wall_time=median_wall_time,
+ name="benchmark_unbatch_dataset_unfused_batch_size_%d" %
+ batch_size)
+
+
if __name__ == "__main__":
test.main()
return _apply_fn
+class UnbatchDataset(dataset_ops.Dataset):
+ """A dataset that splits the elements of its input into multiple elements."""
+
+ def __init__(self, input_dataset):
+ """See `unbatch()` for more details."""
+ super(UnbatchDataset, self).__init__()
+ flat_shapes = nest.flatten(input_dataset.output_shapes)
+ if any(s.ndims == 0 for s in flat_shapes):
+ raise ValueError("Cannot unbatch an input with scalar components.")
+ known_batch_dim = tensor_shape.Dimension(None)
+ for s in flat_shapes:
+ try:
+ known_batch_dim = known_batch_dim.merge_with(s[0])
+ except ValueError:
+ raise ValueError("Cannot unbatch an input whose components have "
+ "different batch sizes.")
+ self._input_dataset = input_dataset
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.unbatch_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return nest.map_structure(lambda s: s[1:],
+ self._input_dataset.output_shapes)
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
def unbatch():
- """A Transformation which splits the elements of a dataset.
+ """Splits elements of a dataset into multiple elements on the batch dimension.
For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
- where `B` may vary from element to element, then for each element in
- the dataset, the unbatched dataset will contain `B` consecutive elements
+ where `B` may vary for each input element, then for each element in the
+ dataset, the unbatched dataset will contain `B` consecutive elements
of shape `[a0, a1, ...]`.
+ ```python
+ # NOTE: The following example uses `{ ... }` to represent the contents
+ # of a dataset.
+ a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
+
+ a.apply(tf.contrib.data.unbatch()) == {
+ 'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'}
+ ```
+
Returns:
A `Dataset` transformation function, which can be passed to
@{tf.data.Dataset.apply}.
"""
def _apply_fn(dataset):
-
- def unbatch_map(arg, *rest):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ if not sparse.any_sparse(dataset.output_classes):
+ return UnbatchDataset(dataset)
+
+ # NOTE(mrry): We must ensure that any SparseTensors in `dataset`
+ # are normalized to the rank-1 dense representation, so that the
+ # sparse-oblivious unbatching logic will slice them
+ # appropriately. This leads to a somewhat inefficient re-encoding step
+ # for all SparseTensor components.
+ # TODO(mrry): Consider optimizing this in future
+ # if it turns out to be a bottleneck.
+ def normalize(arg, *rest):
if rest:
- return dataset_ops.Dataset.from_tensor_slices((arg,) + rest)
+ return sparse.serialize_many_sparse_tensors((arg,) + rest)
else:
- return dataset_ops.Dataset.from_tensor_slices(arg)
+ return sparse.serialize_many_sparse_tensors(arg)
+
+ normalized_dataset = dataset.map(normalize)
- return dataset.flat_map(map_func=unbatch_map)
+ # NOTE(mrry): Our `map()` has lost information about the sparseness
+ # of any SparseTensor components, so re-apply the structure of the
+ # original dataset.
+ restructured_dataset = _RestructuredDataset(
+ normalized_dataset,
+ dataset.output_types,
+ dataset.output_shapes,
+ dataset.output_classes,
+ allow_unsafe_cast=True)
+ return UnbatchDataset(restructured_dataset)
return _apply_fn
dataset,
output_types,
output_shapes=None,
- output_classes=None):
+ output_classes=None,
+ allow_unsafe_cast=False):
"""Creates a new dataset with the given output types and shapes.
The given `dataset` must have a structure that is convertible:
If omitted, the shapes will be inherited from `dataset`.
output_classes: (Optional.) A nested structure of class types.
If omitted, the class types will be inherited from `dataset`.
+ allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
+ reported output types and shapes of the restructured dataset, e.g. to
+ switch a sparse tensor represented as `tf.variant` to its user-visible
+ type and shape.
Raises:
ValueError: If either `output_types` or `output_shapes` is not compatible
super(_RestructuredDataset, self).__init__()
self._dataset = dataset
- # Validate that the types are compatible.
- output_types = nest.map_structure(dtypes.as_dtype, output_types)
- flat_original_types = nest.flatten(dataset.output_types)
- flat_new_types = nest.flatten(output_types)
- if flat_original_types != flat_new_types:
- raise ValueError(
- "Dataset with output types %r cannot be restructured to have output "
- "types %r" % (dataset.output_types, output_types))
+ if not allow_unsafe_cast:
+ # Validate that the types are compatible.
+ output_types = nest.map_structure(dtypes.as_dtype, output_types)
+ flat_original_types = nest.flatten(dataset.output_types)
+ flat_new_types = nest.flatten(output_types)
+ if flat_original_types != flat_new_types:
+ raise ValueError(
+ "Dataset with output types %r cannot be restructured to have "
+ "output types %r" % (dataset.output_types, output_types))
self._output_types = output_types
nest.flatten(
dataset.output_shapes))
else:
- # Validate that the shapes are compatible.
- nest.assert_same_structure(output_types, output_shapes)
- flat_original_shapes = nest.flatten(dataset.output_shapes)
- flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
-
- for original_shape, new_shape in zip(flat_original_shapes,
- flat_new_shapes):
- if not original_shape.is_compatible_with(new_shape):
- raise ValueError(
- "Dataset with output shapes %r cannot be restructured to have "
- "incompatible output shapes %r" % (dataset.output_shapes,
- output_shapes))
+ if not allow_unsafe_cast:
+ # Validate that the shapes are compatible.
+ nest.assert_same_structure(output_types, output_shapes)
+ flat_original_shapes = nest.flatten(dataset.output_shapes)
+ flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
+
+ for original_shape, new_shape in zip(flat_original_shapes,
+ flat_new_shapes):
+ if not original_shape.is_compatible_with(new_shape):
+ raise ValueError(
+ "Dataset with output shapes %r cannot be restructured to have "
+ "incompatible output shapes %r" % (dataset.output_shapes,
+ output_shapes))
self._output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
if output_classes is None:
args=[source_handle],
Tout=[dtypes.string],
f=LoadingFunc,
- target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)
+ target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job)[0]
with ops.device('/job:%s' % worker_job):
output_dataset = dataset_ops.Dataset.range(2).repeat().map(
--- /dev/null
+op {
+ graph_op_name: "UnbatchDataset"
+ summary: "A dataset that splits the elements of its input into multiple elements."
+}
--- /dev/null
+op {
+ graph_op_name: "UnbatchDataset"
+ visibility: HIDDEN
+}
class VariantTensorData;
namespace batch_util {
Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index);
+Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index);
} // namespace batch_util
/// @ingroup core
friend Status batch_util::CopyElementToSlice(
Tensor element, Tensor* parent,
int64 index); // For access to RefCountIsOne().
+ friend Status batch_util::MaybeMoveSliceToElement(
+ Tensor* parent, Tensor* element,
+ int64 index); // For access to RefCountIsOne().
+
friend class NumpyTensorBuffer; // For access to the private constructor
// taking the buffer.
return Status::OK();
}
-// TODO(jsimsa): Add HandleElementToSlice<variant> specialization that moves
-// the data when possible.
-
+// TODO(b/78245576): Consider removing this overload.
template <typename T>
-static Status HandleSliceToElement(const Tensor& parent, Tensor* element,
- int64 index) {
+void HandleSliceToElement(const Tensor& parent, Tensor* element, int64 index) {
element->flat<T>() = parent.flat_outer_dims<T>().chip(index, 0);
- return Status::OK();
+}
+
+template <typename T>
+void HandleSliceToElement(Tensor* parent, Tensor* element, int64 index,
+ bool can_move) {
+ element->flat<T>() = parent->flat_outer_dims<T>().chip(index, 0);
+}
+
+template <>
+void HandleSliceToElement<string>(Tensor* parent, Tensor* element, int64 index,
+ bool can_move) {
+ auto parent_as_matrix = parent->flat_outer_dims<string>();
+ auto element_flat = element->flat<string>();
+ if (can_move) {
+ for (int64 i = 0; i < element->NumElements(); ++i) {
+ element_flat(i) = std::move(parent_as_matrix(index, i));
+ }
+ } else {
+ element_flat = parent_as_matrix.chip(index, 0);
+ }
+}
+
+template <>
+void HandleSliceToElement<Variant>(Tensor* parent, Tensor* element, int64 index,
+ bool can_move) {
+ auto parent_as_matrix = parent->flat_outer_dims<Variant>();
+ auto element_flat = element->flat<Variant>();
+ if (can_move) {
+ for (int64 i = 0; i < element->NumElements(); ++i) {
+ element_flat(i) = std::move(parent_as_matrix(index, i));
+ }
+ } else {
+ element_flat = parent_as_matrix.chip(index, 0);
+ }
}
} // namespace
Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index) {
TF_RETURN_IF_ERROR(ValidateInput(parent, *element, index));
-#define HANDLE_TYPE(T) \
- case DataTypeToEnum<T>::value: { \
- return HandleSliceToElement<T>(parent, element, index); \
+#define HANDLE_TYPE(T) \
+ case DataTypeToEnum<T>::value: { \
+ HandleSliceToElement<T>(parent, element, index); \
+ return Status::OK(); \
}
switch (parent.dtype()) {
}
}
+// Copies the index^th slice of parent (in the 0th dimension) into element.
+//
+// NOTE(mrry): The implementation may be able to optimize the copy to a move.
+// This is particularly important for DT_STRING tensors.
+Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index) {
+ TF_RETURN_IF_ERROR(ValidateInput(*parent, *element, index));
+ bool can_move = parent->RefCountIsOne();
+
+#define HANDLE_TYPE(T) \
+ case DataTypeToEnum<T>::value: { \
+ HandleSliceToElement<T>(parent, element, index, can_move); \
+ return Status::OK(); \
+ }
+
+ switch (parent->dtype()) {
+ TF_CALL_ALL_TYPES(HANDLE_TYPE);
+ TF_CALL_QUANTIZED_TYPES(HANDLE_TYPE);
+#undef HANDLE_TYPE
+ default:
+ return errors::Unimplemented(
+ "MaybeMoveSliceToElement Unhandled data type: ", element->dtype());
+ }
+}
+
// The following five functions are copied from padding_fifo_queue.cc.
// TODO(mrry): Reconcile these functions with the similar methods in the
// queue implementation.
// Copies the index^th slice of parent (in the 0th dimension) into element.
Status CopySliceToElement(const Tensor& parent, Tensor* element, int64 index);
+// Copies the index^th slice of parent (in the 0th dimension) into element.
+//
+// NOTE(mrry): The implementation may be able to optimize the copy to a move.
+// This is particularly important for DT_STRING tensors.
+Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index);
+
// Zero-initializes the tensor `element` using the scalar stored in `padding`.
// Both `element` and `padding` must have matching `dtype`.
Status SetElementZero(Tensor* element, const Tensor& padding);
)
tf_kernel_library(
+ name = "unbatch_dataset_op",
+ srcs = ["unbatch_dataset_op.cc"],
+ deps = [
+ ":dataset",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/kernels:batch_util",
+ ],
+)
+
+tf_kernel_library(
name = "zip_dataset_op",
srcs = ["zip_dataset_op.cc"],
deps = [
":tensor_dataset_op",
":tensor_queue_dataset_op",
":tensor_slice_dataset_op",
+ ":unbatch_dataset_op",
":zip_dataset_op",
],
)
--- /dev/null
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/batch_util.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+
+namespace tensorflow {
+
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class UnbatchDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit UnbatchDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ *output = new Dataset(ctx, input);
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, DatasetBase* input)
+ : GraphDatasetBase(ctx), input_(input) {
+ input_->Ref();
+ for (const PartialTensorShape& shape : input->output_shapes()) {
+ gtl::InlinedVector<int64, 4> partial_dim_sizes;
+ for (int i = 1; i < shape.dims(); ++i) {
+ partial_dim_sizes.push_back(shape.dim_size(i));
+ }
+ shapes_.emplace_back(std::move(partial_dim_sizes));
+ }
+ }
+
+ std::unique_ptr<IteratorBase> MakeIterator(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Unbatch")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return shapes_;
+ }
+
+ string DebugString() override { return "UnbatchDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ current_index_(0),
+ current_batch_size_(0),
+ input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
+ shapes_(params.dataset->output_shapes().size()) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (!input_impl_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ *end_of_sequence = false;
+ while (!*end_of_sequence) {
+ if (current_index_ < current_batch_size_) {
+ out_tensors->clear();
+ out_tensors->reserve(tensors_.size());
+ for (int i = 0; i < tensors_.size(); ++i) {
+ out_tensors->emplace_back(ctx->allocator({}), tensors_[i].dtype(),
+ shapes_[i]);
+ TF_RETURN_IF_ERROR(batch_util::MaybeMoveSliceToElement(
+ &tensors_[i], &out_tensors->back(), current_index_));
+ }
+ ++current_index_;
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+ current_index_ = 0;
+ current_batch_size_ = 0;
+ tensors_.clear();
+ TF_RETURN_IF_ERROR(
+ input_impl_->GetNext(ctx, &tensors_, end_of_sequence));
+ if (!*end_of_sequence) {
+ for (size_t i = 0; i < tensors_.size(); ++i) {
+ if (tensors_[i].dims() == 0) {
+ return errors::InvalidArgument(
+ "Input element must have a non-scalar value in each "
+ "component.");
+ }
+ if (tensors_[i].dim_size(0) != tensors_[0].dim_size(0)) {
+ return errors::InvalidArgument(
+ "Input element must have the same batch size in each "
+ "component. Component 0 had size ",
+ tensors_[0].dim_size(0), " but component ", i,
+ " had size, ", tensors_[i].dim_size(0), ".");
+ }
+ shapes_[i] = tensors_[i].shape();
+ shapes_[i].RemoveDim(0);
+ }
+ current_batch_size_ = tensors_[0].dim_size(0);
+ }
+ }
+ input_impl_.reset();
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (input_impl_) {
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ } else {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("input_impl_empty"), ""));
+ }
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("current_index"), current_index_));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("n"), current_batch_size_));
+ if (current_index_ < current_batch_size_) {
+ for (size_t i = 0; i < tensors_.size(); ++i) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("tensors[", i, "]")), tensors_[i]));
+ }
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (!reader->Contains(full_name("input_impl_empty"))) {
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ } else {
+ input_impl_.reset();
+ }
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("current_index"), ¤t_index_));
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("n"), ¤t_batch_size_));
+ tensors_.clear();
+ tensors_.resize(dataset()->output_dtypes().size());
+ if (current_index_ < current_batch_size_) {
+ for (size_t i = 0; i < tensors_.size(); ++i) {
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("tensors[", i, "]")), &tensors_[i]));
+ shapes_[i] = tensors_[i].shape();
+ shapes_[i].RemoveDim(0);
+ }
+ }
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ int64 current_index_ GUARDED_BY(mu_);
+ int64 current_batch_size_ GUARDED_BY(mu_);
+ std::vector<Tensor> tensors_ GUARDED_BY(mu_);
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ std::vector<TensorShape> shapes_ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* const input_;
+ std::vector<PartialTensorShape> shapes_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("UnbatchDataset").Device(DEVICE_CPU),
+ UnbatchDatasetOp);
+
+} // namespace
+
+} // namespace tensorflow
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("UnbatchDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("ZipDataset")
.Input("input_datasets: N * variant")
.Output("handle: variant")