-# Custom Data Readers
+# Reading custom file and record formats
PREREQUISITES:
We divide the task of supporting a file format into two pieces:
-* File formats: We use a *Reader* Op to read a *record* (which can be any
- string) from a file.
-* Record formats: We use decoder or parsing Ops to turn a string record
+* File formats: We use a reader `tf.data.Dataset` to read raw *records* (which
+ are typically represented by scalar string tensors, but can have more
+ structure) from a file.
+* Record formats: We use decoder or parsing ops to turn a string record
into tensors usable by TensorFlow.
For example, to read a
[CSV file](https://en.wikipedia.org/wiki/Comma-separated_values), we use
-@{tf.TextLineReader$a Reader for text files}
-followed by
-@{tf.decode_csv$an Op that parses CSV data from a line of text}.
+@{tf.data.TextLineDataset$a dataset for reading text files line-by-line}
+and then @{tf.data.Dataset.map$map} an
+@{tf.decode_csv$op} that parses CSV data from each line of text in the dataset.
[TOC]
-## Writing a Reader for a file format
+## Writing a `Dataset` for a file format
-A `Reader` is something that reads records from a file. There are some examples
-of Reader Ops already built into TensorFlow:
+A @{tf.data.Dataset} represents a sequence of *elements*, which can be the
+individual records in a file. There are several examples of "reader" datasets
+that are already built into TensorFlow:
-* @{tf.TFRecordReader}
- ([source in `kernels/tf_record_reader_op.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/tf_record_reader_op.cc))
-* @{tf.FixedLengthRecordReader}
- ([source in `kernels/fixed_length_record_reader_op.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/fixed_length_record_reader_op.cc))
-* @{tf.TextLineReader}
- ([source in `kernels/text_line_reader_op.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/text_line_reader_op.cc))
+* @{tf.data.TFRecordDataset}
+ ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc))
+* @{tf.data.FixedLengthRecordDataset}
+ ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc))
+* @{tf.data.TextLineDataset}
+ ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc))
-You can see these all expose the same interface, the only differences
-are in their constructors. The most important method is `read`.
-It takes a queue argument, which is where it gets filenames to
-read from whenever it needs one (e.g. when the `read` op first runs, or
-the previous `read` reads the last record from a file). It produces
-two scalar tensors: a string key and a string value.
+Each of these implementations comprises three related classes:
-To create a new reader called `SomeReader`, you will need to:
+* A `tensorflow::DatasetOpKernel` subclass (e.g. `TextLineDatasetOp`), which
+ tells TensorFlow how to construct a dataset object from the inputs to and
+ attrs of an op, in its `MakeDataset()` method.
-1. In C++, define a subclass of
- [`tensorflow::ReaderBase`](https://www.tensorflow.org/code/tensorflow/core/framework/reader_base.h)
- called `SomeReader`.
-2. In C++, register a new reader op and kernel with the name `"SomeReader"`.
-3. In Python, define a subclass of @{tf.ReaderBase} called `SomeReader`.
+* A `tensorflow::GraphDatasetBase` subclass (e.g. `TextLineDatasetOp::Dataset`),
+ which represents the *immutable* definition of the dataset itself, and tells
+ TensorFlow how to construct an iterator object over that dataset, in its
+ `MakeIterator()` method.
-You can put all the C++ code in a file in
-`tensorflow/core/user_ops/some_reader_op.cc`. The code to read a file will live
-in a descendant of the C++ `ReaderBase` class, which is defined in
-[`tensorflow/core/kernels/reader_base.h`](https://www.tensorflow.org/code/tensorflow/core/framework/reader_base.h).
-You will need to implement the following methods:
+* A `tensorflow::DatasetIterator<Dataset>` subclass (e.g.
+ `TextLineDatasetOp::Dataset::Iterator`), which represents the *mutable* state
+ of an iterator over a particular dataset, and tells TensorFlow how to get the
+ next element from the iterator, in its `GetNextInternal()` method.
-* `OnWorkStartedLocked`: open the next file
-* `ReadLocked`: read a record or report EOF/error
-* `OnWorkFinishedLocked`: close the current file, and
-* `ResetLocked`: get a clean slate after, e.g., an error
+The most important method is the `GetNextInternal()` method, since it defines
+how to actually read records from the file and represent them as one or more
+`Tensor` objects.
-These methods have names ending in "Locked" since `ReaderBase` makes sure
-to acquire a mutex before calling any of these methods, so you generally don't
-have to worry about thread safety (though that only protects the members of the
-class, not global state).
+To create a new reader dataset called (for example) `MyReaderDataset`, you will
+need to:
-For `OnWorkStartedLocked`, the name of the file to open is the value returned by
-the `current_work()` method. `ReadLocked` has this signature:
+1. In C++, define subclasses of `tensorflow::DatasetOpKernel`,
+ `tensorflow::GraphDatasetBase`, and `tensorflow::DatasetIterator<Dataset>`
+ that implement the reading logic.
+2. In C++, register a new reader op and kernel with the name
+ `"MyReaderDataset"`.
+3. In Python, define a subclass of @{tf.data.Dataset} called `MyReaderDataset`.
-```c++
-Status ReadLocked(string* key, string* value, bool* produced, bool* at_end)
-```
-
-If `ReadLocked` successfully reads a record from the file, it should fill in:
-
-* `*key`: with an identifier for the record, that a human could use to find
- this record again. You can include the filename from `current_work()`,
- and append a record number or whatever.
-* `*value`: with the contents of the record.
-* `*produced`: set to `true`.
-
-If you hit the end of a file (EOF), set `*at_end` to `true`. In either case,
-return `Status::OK()`. If there is an error, simply return it using one of the
-helper functions from
-[`tensorflow/core/lib/core/errors.h`](https://www.tensorflow.org/code/tensorflow/core/lib/core/errors.h)
-without modifying any arguments.
-
-Next you will create the actual Reader op. It will help if you are familiar
-with @{$adding_an_op$the adding an op how-to}. The main steps
-are:
-
-* Registering the op.
-* Define and register an `OpKernel`.
-
-To register the op, you will use a `REGISTER_OP` call defined in
-[`tensorflow/core/framework/op.h`](https://www.tensorflow.org/code/tensorflow/core/framework/op.h).
-Reader ops never take any input and always have a single output with type
-`resource`. They should have string `container` and `shared_name` attrs.
-You may optionally define additional attrs
-for configuration or include documentation in a `Doc`. For examples, see
-[`tensorflow/core/ops/io_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/ops/io_ops.cc),
-e.g.:
+You can put all the C++ code in a single file, such as
+`my_reader_dataset_op.cc`. It will help if you are
+familiar with @{$adding_an_op$the adding an op how-to}. The following skeleton
+can be used as a starting point for your implementation:
```c++
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
-REGISTER_OP("TextLineReader")
- .Output("reader_handle: resource")
- .Attr("skip_header_lines: int = 0")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .SetIsStateful()
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-A Reader that outputs the lines of a file delimited by '\n'.
-)doc");
-```
-
-To define an `OpKernel`, Readers can use the shortcut of descending from
-`ReaderOpKernel`, defined in
-[`tensorflow/core/framework/reader_op_kernel.h`](https://www.tensorflow.org/code/tensorflow/core/framework/reader_op_kernel.h),
-and implement a constructor that calls `SetReaderFactory`. After defining
-your class, you will need to register it using `REGISTER_KERNEL_BUILDER(...)`.
-An example with no attrs:
+namespace tensorflow {
+namespace {
-```c++
-#include "tensorflow/core/framework/reader_op_kernel.h"
-
-class TFRecordReaderOp : public ReaderOpKernel {
+class MyReaderDatasetOp : public DatasetOpKernel {
public:
- explicit TFRecordReaderOp(OpKernelConstruction* context)
- : ReaderOpKernel(context) {
- Env* env = context->env();
- SetReaderFactory([this, env]() { return new TFRecordReader(name(), env); });
- }
-};
-REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU),
- TFRecordReaderOp);
-```
+ MyReaderDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
+ // Parse and validate any attrs that define the dataset using
+ // `ctx->GetAttr()`, and store them in member variables.
+ }
-An example with attrs:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ // Parse and validate any input tensors 0that define the dataset using
+ // `ctx->input()` or the utility function
+ // `ParseScalarArgument<T>(ctx, &arg)`.
-```c++
-#include "tensorflow/core/framework/reader_op_kernel.h"
-
-class TextLineReaderOp : public ReaderOpKernel {
- public:
- explicit TextLineReaderOp(OpKernelConstruction* context)
- : ReaderOpKernel(context) {
- int skip_header_lines = -1;
- OP_REQUIRES_OK(context,
- context->GetAttr("skip_header_lines", &skip_header_lines));
- OP_REQUIRES(context, skip_header_lines >= 0,
- errors::InvalidArgument("skip_header_lines must be >= 0 not ",
- skip_header_lines));
- Env* env = context->env();
- SetReaderFactory([this, skip_header_lines, env]() {
- return new TextLineReader(name(), skip_header_lines, env);
- });
+ // Create the dataset object, passing any (already-validated) arguments from
+ // attrs or input tensors.
+ *output = new Dataset(ctx);
}
-};
-REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
- TextLineReaderOp);
-```
-
-The last step is to add the Python wrapper. You can either do this by
-@{$adding_an_op#build_the_op_library$compiling a dynamic library}
-or, if you are building TensorFlow from source, adding to `user_ops.py`.
-For the latter, you will import `tensorflow.python.ops.io_ops` in
-[`tensorflow/python/user_ops/user_ops.py`](https://www.tensorflow.org/code/tensorflow/python/user_ops/user_ops.py)
-and add a descendant of [`io_ops.ReaderBase`](https://www.tensorflow.org/code/tensorflow/python/ops/io_ops.py).
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx) : GraphDatasetBase(ctx) {}
+
+ std::unique_ptr<IteratorBase> MakeIterator(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::MyReader")}));
+ }
+
+ // Record structure: Each record is represented by a scalar string tensor.
+ //
+ // Dataset elements can have a fixed number of components of different
+ // types and shapes; replace the following two methods to customize this
+ // aspect of the dataset.
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
+ return *dtypes;
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ string DebugString() override { return "MyReaderDatasetOp::Dataset"; }
+
+ protected:
+ // Optional: Implementation of `GraphDef` serialization for this dataset.
+ //
+ // Implement this method if you want to be able to save and restore
+ // instances of this dataset (and any iterators over it).
+ Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ // Construct nodes to represent any of the input tensors from this
+ // object's member variables using `b->AddScalar()` and `b->AddVector()`.
+ std::vector<Node*> input_tensors;
+ TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params), i_(0) {}
+
+ // Implementation of the reading logic.
+ //
+ // The example implementation in this file yields the string "MyReader!"
+ // ten times. In general there are three cases:
+ //
+ // 1. If an element is successfully read, store it as one or more tensors
+ // in `*out_tensors`, set `*end_of_sequence = false` and return
+ // `Status::OK()`.
+ // 2. If the end of input is reached, set `*end_of_sequence = true` and
+ // return `Status::OK()`.
+ // 3. If an error occurs, return an error status using one of the helper
+ // functions from "tensorflow/core/lib/core/errors.h".
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ // NOTE: `GetNextInternal()` may be called concurrently, so it is
+ // recommended that you protect the iterator state with a mutex.
+ mutex_lock l(mu_);
+ if (i_ < 10) {
+ // Create a scalar string tensor and add it to the output.
+ Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
+ record_tensor.scalar<string>()() = "MyReader!";
+ out_tensors->emplace_back(std::move(record_tensor));
+ ++i_;
+ *end_of_sequence = false;
+ } else {
+ *end_of_sequence = true;
+ }
+ return Status::OK();
+ }
+
+ protected:
+ // Optional: Implementation of iterator state serialization for this
+ // iterator.
+ //
+ // Implement these two methods if you want to be able to save and restore
+ // instances of this iterator.
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
+ return Status::OK();
+ }
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ int64 i_ GUARDED_BY(mu_);
+ };
+ };
+};
-```python
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import common_shapes
-from tensorflow.python.ops import io_ops
+// Register the op definition for MyReaderDataset.
+//
+// Dataset ops always have a single output, of type `variant`, which represents
+// the constructed `Dataset` object.
+//
+// Add any attrs and input tensors that define the dataset here.
+REGISTER_OP("MyReaderDataset")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ScalarShape);
-class SomeReader(io_ops.ReaderBase):
+// Register the kernel implementation for MyReaderDataset.
+REGISTER_KERNEL_BUILDER(Name("MyReaderDataset").Device(DEVICE_CPU),
+ MyReaderDatasetOp);
- def __init__(self, name=None):
- rr = gen_user_ops.some_reader(name=name)
- super(SomeReader, self).__init__(rr)
+} // namespace
+} // namespace tensorflow
+```
+The last step is to build the C++ code and add a Python wrapper. The easiest way
+to do this is by @{$adding_an_op#build_the_op_library$compiling a dynamic
+library} (e.g. called `"my_reader_dataset_op.so"`), and adding a Python class
+that subclasses @{tf.data.Dataset} to wrap it. An example Python program is
+given here:
-ops.NotDifferentiable("SomeReader")
+```python
+import tensorflow as tf
+
+# Assumes the file is in the current working directory.
+my_reader_dataset_module = tf.load_op_library("./my_reader_dataset_op.so")
+
+class MyReaderDataset(tf.data.Dataset):
+
+ def __init__(self):
+ super(MyReaderDataset, self).__init__()
+ # Create any input attrs or tensors as members of this class.
+
+ def _as_variant_tensor(self):
+ # Actually construct the graph node for the dataset op.
+ #
+ # This method will be invoked when you create an iterator on this dataset
+ # or a dataset derived from it.
+ return my_reader_dataset_module.my_reader_dataset()
+
+ # The following properties define the structure of each element: a scalar
+ # `tf.string` tensor. Change these properties to match the `output_dtypes()`
+ # and `output_shapes()` methods of `MyReaderDataset::Dataset` if you modify
+ # the structure of each element.
+ @property
+ def output_types(self):
+ return tf.string
+
+ @property
+ def output_shapes(self):
+ return tf.TensorShape([])
+
+ @property
+ def output_classes(self):
+ return tf.Tensor
+
+if __name__ == "__main__":
+ # Create a MyReaderDataset and print its elements.
+ with tf.Session() as sess:
+ iterator = MyReaderDataset().make_one_shot_iterator()
+ next_element = iterator.get_next()
+ try:
+ while True:
+ print(sess.run(next_element)) # Prints "MyReader!" ten times.
+ except tf.errors.OutOfRangeError:
+ pass
```
-You can see some examples in
-[`tensorflow/python/ops/io_ops.py`](https://www.tensorflow.org/code/tensorflow/python/ops/io_ops.py).
+You can see some examples of `Dataset` wrapper classes in
+[`tensorflow/python/data/ops/dataset_ops.py`](https://www.tensorflow.org/code/tensorflow/python/data/ops/dataset_ops.py).
## Writing an Op for a record format
Examples of Ops useful for decoding records:
-* @{tf.parse_single_example}
- (and
- @{tf.parse_example})
+* @{tf.parse_single_example} (and @{tf.parse_example})
* @{tf.decode_csv}
* @{tf.decode_raw}
format. For example, you may have an image saved as a string in
[a `tf.train.Example` protocol buffer](https://www.tensorflow.org/code/tensorflow/core/example/example.proto).
Depending on the format of that image, you might take the corresponding output
-from a
-@{tf.parse_single_example}
-op and call @{tf.image.decode_jpeg},
-@{tf.image.decode_png}, or
-@{tf.decode_raw}. It is common to
-take the output of `tf.decode_raw` and use
-@{tf.slice} and
-@{tf.reshape} to extract pieces.
+from a @{tf.parse_single_example} op and call @{tf.image.decode_jpeg},
+@{tf.image.decode_png}, or @{tf.decode_raw}. It is common to take the output
+of `tf.decode_raw` and use @{tf.slice} and @{tf.reshape} to extract pieces.