[tf.data] Replace the Reader-oriented documentation for supporting new datasets with...
authorDerek Murray <mrry@google.com>
Fri, 6 Apr 2018 22:38:05 +0000 (15:38 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 22:40:36 +0000 (15:40 -0700)
PiperOrigin-RevId: 191950831

tensorflow/docs_src/extend/new_data_formats.md

index 10e717c..2c33a6b 100644 (file)
@@ -1,4 +1,4 @@
-# Custom Data Readers
+# Reading custom file and record formats
 
 PREREQUISITES:
 
@@ -9,187 +9,273 @@ PREREQUISITES:
 
 We divide the task of supporting a file format into two pieces:
 
-*   File formats: We use a *Reader* Op to read a *record* (which can be any
-    string) from a file.
-*   Record formats: We use decoder or parsing Ops to turn a string record
+*   File formats: We use a reader `tf.data.Dataset` to read raw *records* (which
+    are typically represented by scalar string tensors, but can have more
+    structure) from a file.
+*   Record formats: We use decoder or parsing ops to turn a string record
     into tensors usable by TensorFlow.
 
 For example, to read a
 [CSV file](https://en.wikipedia.org/wiki/Comma-separated_values), we use
-@{tf.TextLineReader$a Reader for text files}
-followed by
-@{tf.decode_csv$an Op that parses CSV data from a line of text}.
+@{tf.data.TextLineDataset$a dataset for reading text files line-by-line}
+and then @{tf.data.Dataset.map$map} an
+@{tf.decode_csv$op} that parses CSV data from each line of text in the dataset.
 
 [TOC]
 
-## Writing a Reader for a file format
+## Writing a `Dataset` for a file format
 
-A `Reader` is something that reads records from a file.  There are some examples
-of Reader Ops already built into TensorFlow:
+A @{tf.data.Dataset} represents a sequence of *elements*, which can be the
+individual records in a file. There are several examples of "reader" datasets
+that are already built into TensorFlow:
 
-*   @{tf.TFRecordReader}
-    ([source in `kernels/tf_record_reader_op.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/tf_record_reader_op.cc))
-*   @{tf.FixedLengthRecordReader}
-    ([source in `kernels/fixed_length_record_reader_op.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/fixed_length_record_reader_op.cc))
-*   @{tf.TextLineReader}
-    ([source in `kernels/text_line_reader_op.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/text_line_reader_op.cc))
+*   @{tf.data.TFRecordDataset}
+    ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc))
+*   @{tf.data.FixedLengthRecordDataset}
+    ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc))
+*   @{tf.data.TextLineDataset}
+    ([source in `kernels/data/reader_dataset_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/kernels/data/reader_dataset_ops.cc))
 
-You can see these all expose the same interface, the only differences
-are in their constructors.  The most important method is `read`.
-It takes a queue argument, which is where it gets filenames to
-read from whenever it needs one (e.g. when the `read` op first runs, or
-the previous `read` reads the last record from a file).  It produces
-two scalar tensors: a string key and a string value.
+Each of these implementations comprises three related classes:
 
-To create a new reader called `SomeReader`, you will need to:
+* A `tensorflow::DatasetOpKernel` subclass (e.g. `TextLineDatasetOp`), which
+  tells TensorFlow how to construct a dataset object from the inputs to and
+  attrs of an op, in its `MakeDataset()` method.
 
-1.  In C++, define a subclass of
-    [`tensorflow::ReaderBase`](https://www.tensorflow.org/code/tensorflow/core/framework/reader_base.h)
-    called `SomeReader`.
-2.  In C++, register a new reader op and kernel with the name `"SomeReader"`.
-3.  In Python, define a subclass of @{tf.ReaderBase} called `SomeReader`.
+* A `tensorflow::GraphDatasetBase` subclass (e.g. `TextLineDatasetOp::Dataset`),
+  which represents the *immutable* definition of the dataset itself, and tells
+  TensorFlow how to construct an iterator object over that dataset, in its
+  `MakeIterator()` method.
 
-You can put all the C++ code in a file in
-`tensorflow/core/user_ops/some_reader_op.cc`. The code to read a file will live
-in a descendant of the C++ `ReaderBase` class, which is defined in
-[`tensorflow/core/kernels/reader_base.h`](https://www.tensorflow.org/code/tensorflow/core/framework/reader_base.h).
-You will need to implement the following methods:
+* A `tensorflow::DatasetIterator<Dataset>` subclass (e.g.
+  `TextLineDatasetOp::Dataset::Iterator`), which represents the *mutable* state
+  of an iterator over a particular dataset, and tells TensorFlow how to get the
+  next element from the iterator, in its `GetNextInternal()` method.
 
-*   `OnWorkStartedLocked`: open the next file
-*   `ReadLocked`: read a record or report EOF/error
-*   `OnWorkFinishedLocked`: close the current file, and
-*   `ResetLocked`: get a clean slate after, e.g., an error
+The most important method is the `GetNextInternal()` method, since it defines
+how to actually read records from the file and represent them as one or more
+`Tensor` objects.
 
-These methods have names ending in "Locked" since `ReaderBase` makes sure
-to acquire a mutex before calling any of these methods, so you generally don't
-have to worry about thread safety (though that only protects the members of the
-class, not global state).
+To create a new reader dataset called (for example) `MyReaderDataset`, you will
+need to:
 
-For `OnWorkStartedLocked`, the name of the file to open is the value returned by
-the `current_work()` method.  `ReadLocked` has this signature:
+1. In C++, define subclasses of `tensorflow::DatasetOpKernel`,
+   `tensorflow::GraphDatasetBase`, and `tensorflow::DatasetIterator<Dataset>`
+   that implement the reading logic.
+2. In C++, register a new reader op and kernel with the name
+   `"MyReaderDataset"`.
+3. In Python, define a subclass of @{tf.data.Dataset} called `MyReaderDataset`.
 
-```c++
-Status ReadLocked(string* key, string* value, bool* produced, bool* at_end)
-```
-
-If `ReadLocked` successfully reads a record from the file, it should fill in:
-
-*   `*key`: with an identifier for the record, that a human could use to find
-    this record again.  You can include the filename from `current_work()`,
-    and append a record number or whatever.
-*   `*value`: with the contents of the record.
-*   `*produced`: set to `true`.
-
-If you hit the end of a file (EOF), set `*at_end` to `true`.  In either case,
-return `Status::OK()`.  If there is an error, simply return it using one of the
-helper functions from
-[`tensorflow/core/lib/core/errors.h`](https://www.tensorflow.org/code/tensorflow/core/lib/core/errors.h)
-without modifying any arguments.
-
-Next you will create the actual Reader op.  It will help if you are familiar
-with @{$adding_an_op$the adding an op how-to}.  The main steps
-are:
-
-*   Registering the op.
-*   Define and register an `OpKernel`.
-
-To register the op, you will use a `REGISTER_OP` call defined in
-[`tensorflow/core/framework/op.h`](https://www.tensorflow.org/code/tensorflow/core/framework/op.h).
-Reader ops never take any input and always have a single output with type
-`resource`.  They should have string `container` and `shared_name` attrs.
-You may optionally define additional attrs
-for configuration or include documentation in a `Doc`.  For examples, see
-[`tensorflow/core/ops/io_ops.cc`](https://www.tensorflow.org/code/tensorflow/core/ops/io_ops.cc),
-e.g.:
+You can put all the C++ code in a single file, such as
+`my_reader_dataset_op.cc`. It will help if you are
+familiar with @{$adding_an_op$the adding an op how-to}. The following skeleton
+can be used as a starting point for your implementation:
 
 ```c++
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/dataset.h"
 #include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
 
-REGISTER_OP("TextLineReader")
-    .Output("reader_handle: resource")
-    .Attr("skip_header_lines: int = 0")
-    .Attr("container: string = ''")
-    .Attr("shared_name: string = ''")
-    .SetIsStateful()
-    .SetShapeFn(shape_inference::ScalarShape)
-    .Doc(R"doc(
-A Reader that outputs the lines of a file delimited by '\n'.
-)doc");
-```
-
-To define an `OpKernel`, Readers can use the shortcut of descending from
-`ReaderOpKernel`, defined in
-[`tensorflow/core/framework/reader_op_kernel.h`](https://www.tensorflow.org/code/tensorflow/core/framework/reader_op_kernel.h),
-and implement a constructor that calls `SetReaderFactory`.  After defining
-your class, you will need to register it using `REGISTER_KERNEL_BUILDER(...)`.
-An example with no attrs:
+namespace tensorflow {
+namespace {
 
-```c++
-#include "tensorflow/core/framework/reader_op_kernel.h"
-
-class TFRecordReaderOp : public ReaderOpKernel {
+class MyReaderDatasetOp : public DatasetOpKernel {
  public:
-  explicit TFRecordReaderOp(OpKernelConstruction* context)
-      : ReaderOpKernel(context) {
-    Env* env = context->env();
-    SetReaderFactory([this, env]() { return new TFRecordReader(name(), env); });
-  }
-};
 
-REGISTER_KERNEL_BUILDER(Name("TFRecordReader").Device(DEVICE_CPU),
-                        TFRecordReaderOp);
-```
+  MyReaderDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
+    // Parse and validate any attrs that define the dataset using
+    // `ctx->GetAttr()`, and store them in member variables.
+  }
 
-An example with attrs:
+  void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+    // Parse and validate any input tensors 0that define the dataset using
+    // `ctx->input()` or the utility function
+    // `ParseScalarArgument<T>(ctx, &arg)`.
 
-```c++
-#include "tensorflow/core/framework/reader_op_kernel.h"
-
-class TextLineReaderOp : public ReaderOpKernel {
- public:
-  explicit TextLineReaderOp(OpKernelConstruction* context)
-      : ReaderOpKernel(context) {
-    int skip_header_lines = -1;
-    OP_REQUIRES_OK(context,
-                   context->GetAttr("skip_header_lines", &skip_header_lines));
-    OP_REQUIRES(context, skip_header_lines >= 0,
-                errors::InvalidArgument("skip_header_lines must be >= 0 not ",
-                                        skip_header_lines));
-    Env* env = context->env();
-    SetReaderFactory([this, skip_header_lines, env]() {
-      return new TextLineReader(name(), skip_header_lines, env);
-    });
+    // Create the dataset object, passing any (already-validated) arguments from
+    // attrs or input tensors.
+    *output = new Dataset(ctx);
   }
-};
 
-REGISTER_KERNEL_BUILDER(Name("TextLineReader").Device(DEVICE_CPU),
-                        TextLineReaderOp);
-```
-
-The last step is to add the Python wrapper.  You can either do this by
-@{$adding_an_op#build_the_op_library$compiling a dynamic library}
-or, if you are building TensorFlow from source, adding to `user_ops.py`.
-For the latter, you will import `tensorflow.python.ops.io_ops` in
-[`tensorflow/python/user_ops/user_ops.py`](https://www.tensorflow.org/code/tensorflow/python/user_ops/user_ops.py)
-and add a descendant of [`io_ops.ReaderBase`](https://www.tensorflow.org/code/tensorflow/python/ops/io_ops.py).
+ private:
+  class Dataset : public GraphDatasetBase {
+   public:
+    Dataset(OpKernelContext* ctx) : GraphDatasetBase(ctx) {}
+
+    std::unique_ptr<IteratorBase> MakeIterator(
+        const string& prefix) const override {
+      return std::unique_ptr<IteratorBase>(
+          new Iterator({this, strings::StrCat(prefix, "::MyReader")}));
+    }
+
+    // Record structure: Each record is represented by a scalar string tensor.
+    //
+    // Dataset elements can have a fixed number of components of different
+    // types and shapes; replace the following two methods to customize this
+    // aspect of the dataset.
+    const DataTypeVector& output_dtypes() const override {
+      static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
+      return *dtypes;
+    }
+    const std::vector<PartialTensorShape>& output_shapes() const override {
+      static std::vector<PartialTensorShape>* shapes =
+          new std::vector<PartialTensorShape>({{}});
+      return *shapes;
+    }
+
+    string DebugString() override { return "MyReaderDatasetOp::Dataset"; }
+
+   protected:
+    // Optional: Implementation of `GraphDef` serialization for this dataset.
+    //
+    // Implement this method if you want to be able to save and restore
+    // instances of this dataset (and any iterators over it).
+    Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+                              Node** output) const override {
+      // Construct nodes to represent any of the input tensors from this
+      // object's member variables using `b->AddScalar()` and `b->AddVector()`.
+      std::vector<Node*> input_tensors;
+      TF_RETURN_IF_ERROR(b->AddDataset(this, input_tensors, output));
+      return Status::OK();
+    }
+
+   private:
+    class Iterator : public DatasetIterator<Dataset> {
+     public:
+      explicit Iterator(const Params& params)
+          : DatasetIterator<Dataset>(params), i_(0) {}
+
+      // Implementation of the reading logic.
+      //
+      // The example implementation in this file yields the string "MyReader!"
+      // ten times. In general there are three cases:
+      //
+      // 1. If an element is successfully read, store it as one or more tensors
+      //    in `*out_tensors`, set `*end_of_sequence = false` and return
+      //    `Status::OK()`.
+      // 2. If the end of input is reached, set `*end_of_sequence = true` and
+      //    return `Status::OK()`.
+      // 3. If an error occurs, return an error status using one of the helper
+      //    functions from "tensorflow/core/lib/core/errors.h".
+      Status GetNextInternal(IteratorContext* ctx,
+                             std::vector<Tensor>* out_tensors,
+                             bool* end_of_sequence) override {
+        // NOTE: `GetNextInternal()` may be called concurrently, so it is
+        // recommended that you protect the iterator state with a mutex.
+        mutex_lock l(mu_);
+        if (i_ < 10) {
+          // Create a scalar string tensor and add it to the output.
+          Tensor record_tensor(ctx->allocator({}), DT_STRING, {});
+          record_tensor.scalar<string>()() = "MyReader!";
+          out_tensors->emplace_back(std::move(record_tensor));
+          ++i_;
+          *end_of_sequence = false;
+        } else {
+          *end_of_sequence = true;
+        }
+        return Status::OK();
+      }
+
+     protected:
+      // Optional: Implementation of iterator state serialization for this
+      // iterator.
+      //
+      // Implement these two methods if you want to be able to save and restore
+      // instances of this iterator.
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        mutex_lock l(mu_);
+        TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("i"), i_));
+        return Status::OK();
+      }
+      Status RestoreInternal(IteratorContext* ctx,
+                             IteratorStateReader* reader) override {
+        mutex_lock l(mu_);
+        TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("i"), &i_));
+        return Status::OK();
+      }
+
+     private:
+      mutex mu_;
+      int64 i_ GUARDED_BY(mu_);
+    };
+  };
+};
 
-```python
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import common_shapes
-from tensorflow.python.ops import io_ops
+// Register the op definition for MyReaderDataset.
+//
+// Dataset ops always have a single output, of type `variant`, which represents
+// the constructed `Dataset` object.
+//
+// Add any attrs and input tensors that define the dataset here.
+REGISTER_OP("MyReaderDataset")
+    .Output("handle: variant")
+    .SetIsStateful()
+    .SetShapeFn(shape_inference::ScalarShape);
 
-class SomeReader(io_ops.ReaderBase):
+// Register the kernel implementation for MyReaderDataset.
+REGISTER_KERNEL_BUILDER(Name("MyReaderDataset").Device(DEVICE_CPU),
+                        MyReaderDatasetOp);
 
-    def __init__(self, name=None):
-        rr = gen_user_ops.some_reader(name=name)
-        super(SomeReader, self).__init__(rr)
+}  // namespace
+}  // namespace tensorflow
+```
 
+The last step is to build the C++ code and add a Python wrapper. The easiest way
+to do this is by @{$adding_an_op#build_the_op_library$compiling a dynamic
+library} (e.g. called `"my_reader_dataset_op.so"`), and adding a Python class
+that subclasses @{tf.data.Dataset} to wrap it. An example Python program is
+given here:
 
-ops.NotDifferentiable("SomeReader")
+```python
+import tensorflow as tf
+
+# Assumes the file is in the current working directory.
+my_reader_dataset_module = tf.load_op_library("./my_reader_dataset_op.so")
+
+class MyReaderDataset(tf.data.Dataset):
+
+  def __init__(self):
+    super(MyReaderDataset, self).__init__()
+    # Create any input attrs or tensors as members of this class.
+
+  def _as_variant_tensor(self):
+    # Actually construct the graph node for the dataset op.
+    #
+    # This method will be invoked when you create an iterator on this dataset
+    # or a dataset derived from it.
+    return my_reader_dataset_module.my_reader_dataset()
+
+  # The following properties define the structure of each element: a scalar
+  # `tf.string` tensor. Change these properties to match the `output_dtypes()`
+  # and `output_shapes()` methods of `MyReaderDataset::Dataset` if you modify
+  # the structure of each element.
+  @property
+  def output_types(self):
+    return tf.string
+
+  @property
+  def output_shapes(self):
+    return tf.TensorShape([])
+
+  @property
+  def output_classes(self):
+    return tf.Tensor
+
+if __name__ == "__main__":
+  # Create a MyReaderDataset and print its elements.
+  with tf.Session() as sess:
+    iterator = MyReaderDataset().make_one_shot_iterator()
+    next_element = iterator.get_next()
+    try:
+      while True:
+        print(sess.run(next_element))  # Prints "MyReader!" ten times.
+    except tf.errors.OutOfRangeError:
+      pass
 ```
 
-You can see some examples in
-[`tensorflow/python/ops/io_ops.py`](https://www.tensorflow.org/code/tensorflow/python/ops/io_ops.py).
+You can see some examples of `Dataset` wrapper classes in
+[`tensorflow/python/data/ops/dataset_ops.py`](https://www.tensorflow.org/code/tensorflow/python/data/ops/dataset_ops.py).
 
 ## Writing an Op for a record format
 
@@ -201,9 +287,7 @@ track down where the bad data came from.
 
 Examples of Ops useful for decoding records:
 
-*   @{tf.parse_single_example}
-    (and
-    @{tf.parse_example})
+*   @{tf.parse_single_example} (and @{tf.parse_example})
 *   @{tf.decode_csv}
 *   @{tf.decode_raw}
 
@@ -211,11 +295,6 @@ Note that it can be useful to use multiple Ops to decode a particular record
 format.  For example, you may have an image saved as a string in
 [a `tf.train.Example` protocol buffer](https://www.tensorflow.org/code/tensorflow/core/example/example.proto).
 Depending on the format of that image, you might take the corresponding output
-from a
-@{tf.parse_single_example}
-op and call @{tf.image.decode_jpeg},
-@{tf.image.decode_png}, or
-@{tf.decode_raw}.  It is common to
-take the output of `tf.decode_raw` and use
-@{tf.slice} and
-@{tf.reshape} to extract pieces.
+from a @{tf.parse_single_example} op and call @{tf.image.decode_jpeg},
+@{tf.image.decode_png}, or @{tf.decode_raw}.  It is common to take the output
+of `tf.decode_raw` and use @{tf.slice} and @{tf.reshape} to extract pieces.