[tf.data] Changing signature of `MakeIterator` to enable propagating error status.
authorJiri Simsa <jsimsa@google.com>
Thu, 31 May 2018 20:43:43 +0000 (13:43 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 31 May 2018 20:45:55 +0000 (13:45 -0700)
PiperOrigin-RevId: 198772254

45 files changed:
tensorflow/contrib/data/kernels/csv_dataset_op.cc
tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
tensorflow/contrib/data/kernels/unique_dataset_op.cc
tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc
tensorflow/core/framework/dataset.h
tensorflow/core/kernels/data/batch_dataset_op.cc
tensorflow/core/kernels/data/cache_dataset_ops.cc
tensorflow/core/kernels/data/concatenate_dataset_op.cc
tensorflow/core/kernels/data/dataset_utils.cc
tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
tensorflow/core/kernels/data/filter_dataset_op.cc
tensorflow/core/kernels/data/flat_map_dataset_op.cc
tensorflow/core/kernels/data/generator_dataset_op.cc
tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
tensorflow/core/kernels/data/group_by_window_dataset_op.cc
tensorflow/core/kernels/data/interleave_dataset_op.cc
tensorflow/core/kernels/data/iterator_ops.cc
tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
tensorflow/core/kernels/data/map_dataset_op.cc
tensorflow/core/kernels/data/padded_batch_dataset_op.cc
tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
tensorflow/core/kernels/data/parallel_map_dataset_op.cc
tensorflow/core/kernels/data/prefetch_dataset_op.cc
tensorflow/core/kernels/data/random_dataset_op.cc
tensorflow/core/kernels/data/range_dataset_op.cc
tensorflow/core/kernels/data/reader_dataset_ops.cc
tensorflow/core/kernels/data/repeat_dataset_op.cc
tensorflow/core/kernels/data/scan_dataset_op.cc
tensorflow/core/kernels/data/shuffle_dataset_op.cc
tensorflow/core/kernels/data/skip_dataset_op.cc
tensorflow/core/kernels/data/slide_dataset_op.cc
tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
tensorflow/core/kernels/data/sql_dataset_ops.cc
tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
tensorflow/core/kernels/data/stats_dataset_ops.cc
tensorflow/core/kernels/data/take_dataset_op.cc
tensorflow/core/kernels/data/tensor_dataset_op.cc
tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
tensorflow/core/kernels/data/unbatch_dataset_op.cc
tensorflow/core/kernels/data/window_dataset.cc
tensorflow/core/kernels/data/writer_ops.cc
tensorflow/core/kernels/data/zip_dataset_op.cc

index 76e54a2..b16e662 100644 (file)
@@ -133,7 +133,7 @@ class CSVDatasetOp : public DatasetOpKernel {
           delim_(delim),
           na_value_(std::move(na_value)) {}
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::CSV")}));
index 48d3734..bdff379 100644 (file)
@@ -91,7 +91,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
       }
     }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(new Iterator(
           {this, strings::StrCat(prefix, "::DirectedInterleave")}));
@@ -130,15 +130,21 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
      public:
       explicit Iterator(const Params& params)
           : DatasetIterator<Dataset>(params),
-            selector_input_impl_(params.dataset->selector_input_->MakeIterator(
-                params.prefix + ".selector")),
-            num_active_inputs_(params.dataset->data_inputs_.size()) {
-        data_input_impls_.reserve(params.dataset->data_inputs_.size());
-        for (size_t i = 0; i < params.dataset->data_inputs_.size(); ++i) {
-          const DatasetBase* data_input = params.dataset->data_inputs_[i];
-          data_input_impls_.push_back(data_input->MakeIterator(
-              strings::StrCat(params.prefix, "[", i, "]")));
+            num_active_inputs_(params.dataset->data_inputs_.size()) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        mutex_lock l(mu_);
+        TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
+            ctx, strings::StrCat(prefix(), ".selector"),
+            &selector_input_impl_));
+        data_input_impls_.resize(dataset()->data_inputs_.size());
+        for (size_t i = 0; i < data_input_impls_.size(); ++i) {
+          const DatasetBase* data_input = dataset()->data_inputs_[i];
+          TF_RETURN_IF_ERROR(data_input->MakeIterator(
+              ctx, strings::StrCat(prefix(), "[", i, "]"),
+              &data_input_impls_[i]));
         }
+        return Status::OK();
       }
 
       Status GetNextInternal(IteratorContext* ctx,
index bb29df6..c3759b6 100644 (file)
@@ -44,7 +44,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")}));
@@ -72,8 +72,11 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
     class Iterator : public DatasetIterator<Dataset> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<Dataset>(params) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
index 63e19ae..7cf01f6 100644 (file)
@@ -127,7 +127,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
       threadpool_->Unref();
     }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::ThreadPool")}));
@@ -154,8 +154,11 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
     class Iterator : public DatasetIterator<Dataset> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<Dataset>(params) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
index 69fbb0f..652913d 100644 (file)
@@ -56,7 +56,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::Unique")}));
@@ -87,8 +87,11 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
     class Iterator : public DatasetIterator<Dataset> {
      public:
       explicit Iterator(const typename Iterator::Params& params)
-          : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<Dataset>(params) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
index a4cd4a2..7b08cfa 100644 (file)
@@ -64,7 +64,7 @@ class KafkaDatasetOp : public DatasetOpKernel {
           eof_(eof),
           timeout_(timeout) {}
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::Kafka")}));
index 8624af9..0f352ea 100644 (file)
@@ -351,6 +351,10 @@ class IteratorBase {
   // in the outputs of this iterator.
   virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
 
+  // Performs initialization that needs to happen outside of a constructor to
+  // properly propagate errors.
+  virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
+
   // Saves the state of this iterator.
   virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) {
     return SaveInternal(writer);
@@ -402,12 +406,13 @@ class DatasetBase : public core::RefCounted {
   // iterator will traverse all elements in this dataset from the
   // start.
   //
-  // Ownership of the created iterator will be transferred to the caller.
-  //
   // The prefix identifies the sequence of iterators leading up to the newly
   // created iterator.
-  virtual std::unique_ptr<IteratorBase> MakeIterator(
-      const string& prefix) const = 0;
+  Status MakeIterator(IteratorContext* ctx, const string& prefix,
+                      std::unique_ptr<IteratorBase>* iterator) const {
+    *iterator = MakeIteratorInternal(prefix);
+    return (*iterator)->Initialize(ctx);
+  }
 
   // Returns a vector of DataType values, representing the respective
   // element types of each tuple component in the outputs of this
@@ -451,6 +456,9 @@ class DatasetBase : public core::RefCounted {
                                     Node** node) const {
     return errors::Unimplemented("AsGraphDefInternal");
   }
+
+  virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
+      const string& prefix) const = 0;
 };
 
 // Base-class for datasets that are built by ops.
index 3618c75..9c0a6b0 100644 (file)
@@ -61,7 +61,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(new Iterator(
           Iterator::Params{this, strings::StrCat(prefix, "::Batch")}));
@@ -95,8 +95,11 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
     class Iterator : public DatasetIterator<Dataset> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<Dataset>(params) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
index 4b4728d..5f7db9e 100644 (file)
@@ -64,7 +64,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
 
     ~FileDataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       if (env_->FileExists(strings::StrCat(filename_, ".index")).ok()) {
         return std::unique_ptr<IteratorBase>(new FileReaderIterator(
@@ -106,12 +106,15 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
       explicit FileWriterIterator(const Params& params)
           : DatasetIterator<FileDataset>(params),
             cur_index_(0),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
             writer_(params.dataset->env_, params.dataset->filename_),
             lockfile_(strings::StrCat(params.dataset->filename_, ".lockfile")),
             lockfile_created_(false),
             iteration_completed_(false) {}
 
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
+
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
@@ -268,7 +271,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
 
     ~MemoryDataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       mutex_lock l(mu_);
       if (cache_) {
@@ -305,7 +308,6 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
      public:
       explicit MemoryWriterIterator(const Params& params)
           : DatasetIterator<MemoryDataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
             cache_(new std::vector<std::vector<Tensor>>) {}
 
       ~MemoryWriterIterator() override {
@@ -323,6 +325,10 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
         }
       }
 
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
+
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
index f11abc6..7c9dd12 100644 (file)
@@ -61,7 +61,7 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
       to_concatenate_->Unref();
     }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::Concatenate")}));
@@ -94,10 +94,12 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
     class Iterator : public DatasetIterator<Dataset> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<Dataset>(params),
-            i_(0),
-            input_impl_(params.dataset->input_->MakeIterator(
-                strings::StrCat(params.prefix, "[0]"))) {}
+          : DatasetIterator<Dataset>(params), i_(0) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(
+            ctx, strings::StrCat(prefix(), "[0]"), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
@@ -114,8 +116,8 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
             return Status::OK();
           }
           if (++i_ < 2) {
-            input_impl_ = dataset()->to_concatenate_->MakeIterator(
-                strings::StrCat(prefix(), "[1]"));
+            TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
+                ctx, strings::StrCat(prefix(), "[1]"), &input_impl_));
           }
         }
         *end_of_sequence = true;
@@ -147,8 +149,8 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
         if (!TF_PREDICT_TRUE(i_ >= 0 && i_ <= 2))
           return errors::InvalidArgument("i_ must be in range [0, 2].");
         if (i_ == 1) {
-          input_impl_ = dataset()->to_concatenate_->MakeIterator(
-              strings::StrCat(prefix(), "[1]"));
+          TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
+              ctx, strings::StrCat(prefix(), "[1]"), &input_impl_));
         } else if (i_ == 2) {
           input_impl_.reset();
         }
index c608f9e..d85ef1c 100644 (file)
@@ -41,9 +41,8 @@ Status MakeIteratorFromInputElement(
       GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
 
   // Create an iterator for the dataset that was returned by `f`.
-  *out_iterator = returned_dataset->MakeIterator(
-      strings::StrCat(prefix, "[", thread_index, "]"));
-  return Status::OK();
+  return returned_dataset->MakeIterator(
+      ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator);
 }
 
 }  // namespace dataset
index 132808a..28fa77c 100644 (file)
@@ -94,7 +94,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(new Iterator(
           {this, strings::StrCat(prefix, "::DenseToSparseBatch")}));
@@ -137,8 +137,12 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
     class Iterator : public DatasetIterator<Dataset<T>> {
      public:
       explicit Iterator(const typename Iterator::Params& params)
-          : DatasetIterator<Dataset<T>>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<Dataset<T>>(params) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return DatasetIterator<Dataset<T>>::dataset()->input_->MakeIterator(
+            ctx, DatasetIterator<Dataset<T>>::prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
index 186b1e1..5760e55 100644 (file)
@@ -93,7 +93,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
 
     ~FilterDatasetBase() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::Filter")}));
@@ -145,8 +145,11 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
     class Iterator : public DatasetIterator<FilterDatasetBase> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<FilterDatasetBase>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<FilterDatasetBase>(params) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
index 77a48a2..e2edda0 100644 (file)
@@ -74,7 +74,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::FlatMap")}));
@@ -125,8 +125,11 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
     class Iterator : public DatasetIterator<Dataset> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<Dataset>(params) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
@@ -202,7 +205,8 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
         current_element_iterator_.reset();
         captured_func_inputs_.clear();
         if (!reader->Contains(full_name("exhausted"))) {
-          input_impl_ = dataset()->input_->MakeIterator(prefix());
+          TF_RETURN_IF_ERROR(
+              dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
           TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
           {
             int64 temp;
index 3f1e441..d298389 100644 (file)
@@ -99,7 +99,7 @@ class GeneratorDatasetOp : public DatasetOpKernel {
           output_types_(output_types),
           output_shapes_(output_shapes) {}
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::Generator")}));
index c8aeaab..7bbadff 100644 (file)
@@ -88,7 +88,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::GroupByReducer")}));
@@ -183,8 +183,11 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
     class Iterator : public DatasetIterator<Dataset> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<Dataset>(params) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
index 03f847c..f9cc5d2 100644 (file)
@@ -118,7 +118,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::GroupByWindow")}));
@@ -198,8 +198,11 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
     class Iterator : public DatasetIterator<Dataset> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<Dataset>(params) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
@@ -484,8 +487,8 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
             GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
 
         // Create an iterator for the dataset that was returned by `f`.
-        current_group_iterator_ = returned_dataset->MakeIterator(prefix());
-        return Status::OK();
+        return returned_dataset->MakeIterator(ctx, prefix(),
+                                              &current_group_iterator_);
       }
 
       mutex mu_;
index bce3f28..723648b 100644 (file)
@@ -96,7 +96,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::Interleave")}));
@@ -149,10 +149,13 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
      public:
       explicit Iterator(const Params& params)
           : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
             current_elements_(params.dataset->cycle_length_),
             args_list_(params.dataset->cycle_length_) {}
 
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
+
       void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
         block_index_ = 0;
         cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
@@ -294,7 +297,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
       }
 
       mutex mu_;
-      const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+      std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
       std::vector<std::unique_ptr<IteratorBase>> current_elements_
           GUARDED_BY(mu_);
       std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_);
index 87bc8eb..9d9e74a 100644 (file)
@@ -158,7 +158,10 @@ class IteratorResource : public ResourceBase {
         graph_runner.Run(&graph, lib, {}, {output_node}, &outputs));
     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
 
-    TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator")));
+    IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+    std::unique_ptr<IteratorBase> iterator;
+    TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iterator));
+    TF_RETURN_IF_ERROR(set_iterator(std::move(iterator)));
     std::shared_ptr<IteratorBase> captured_iterator(iterator_);
 
     if (captured_iterator) {
@@ -657,8 +660,12 @@ class MakeIteratorOp : public OpKernel {
     OP_REQUIRES_OK(
         ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
     core::ScopedUnref unref(iterator_resource);
-    OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(
-                            dataset->MakeIterator("Iterator")));
+
+    IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+    std::unique_ptr<IteratorBase> iterator;
+    OP_REQUIRES_OK(ctx,
+                   dataset->MakeIterator(&iter_ctx, "Iterator", &iterator));
+    OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator)));
   }
 };
 
@@ -680,9 +687,12 @@ class ToSingleElementOp : public AsyncOpKernel {
       DatasetBase* dataset;
       OP_REQUIRES_OK_ASYNC(
           ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
-      auto iterator = dataset->MakeIterator("SingleElementIterator");
-
       IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+      std::unique_ptr<IteratorBase> iterator;
+      OP_REQUIRES_OK_ASYNC(
+          ctx,
+          dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator),
+          done);
       std::vector<Tensor> components;
       components.reserve(dataset->output_dtypes().size());
       bool end_of_sequence;
@@ -866,8 +876,10 @@ class OneShotIteratorOp : public AsyncOpKernel {
     // factory function.
     DatasetBase* dataset;
     TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
-    TF_RETURN_IF_ERROR(
-        (*iterator)->set_iterator(dataset->MakeIterator("Iterator")));
+    IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+    std::unique_ptr<IteratorBase> iter;
+    TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iter));
+    TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter)));
 
     (*iterator)->Ref();
     return Status::OK();
index f41a810..f55a665 100644 (file)
@@ -125,7 +125,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::MapAndBatch")}));
@@ -188,7 +188,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
      public:
       explicit Iterator(const Params& params)
           : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
             batch_results_((params.dataset->num_parallel_calls_ +
                             params.dataset->batch_size_ - 1) /
                            params.dataset->batch_size_) {
@@ -208,6 +207,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
         }
       }
 
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
+
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
@@ -647,7 +650,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
       int64 num_calls_ GUARDED_BY(mu_) = 0;
       // Counts the total number of calls.
       int64 call_counter_ GUARDED_BY(mu_) = 0;
-      const std::unique_ptr<IteratorBase> input_impl_;
+      std::unique_ptr<IteratorBase> input_impl_;
       // Identifies the next batch to be read by the caller.
       int64 input_batch_ GUARDED_BY(mu_) = 0;
       // Identifies the next batch to create.
index 89360d1..40063c8 100644 (file)
@@ -73,7 +73,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::Map")}));
@@ -123,8 +123,11 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
     class Iterator : public DatasetIterator<Dataset> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<Dataset>(params) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
@@ -167,7 +170,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
       }
 
      private:
-      const std::unique_ptr<IteratorBase> input_impl_;
+      std::unique_ptr<IteratorBase> input_impl_;
     };
 
     const DatasetBase* const input_;
index e41800a..f60b547 100644 (file)
@@ -119,7 +119,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::PaddedBatch")}));
@@ -186,8 +186,11 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
     class Iterator : public DatasetIterator<Dataset> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<Dataset>(params) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
@@ -325,7 +328,8 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
         if (reader->Contains(full_name("exhausted"))) {
           input_impl_.reset();
         } else {
-          input_impl_ = dataset()->input_->MakeIterator(prefix());
+          TF_RETURN_IF_ERROR(
+              dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
           TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
         }
         return Status::OK();
index fa33867..8da6b33 100644 (file)
@@ -116,7 +116,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(new Iterator(
           {this, strings::StrCat(prefix, "::ParallelInterleave")}));
@@ -236,7 +236,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
      public:
       explicit Iterator(const Params& params)
           : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
             workers_(dataset()->num_threads()),
             worker_thread_states_(dataset()->num_threads()) {}
 
@@ -249,6 +248,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
         }
       }
 
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
+
       // It is implemented so that it matches the deterministic interleave
       // unless getting the next element would block and we are allowed to be
       // sloppy.
index 7e373f2..cf55067 100644 (file)
@@ -85,7 +85,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::ParallelMap")}));
@@ -150,7 +150,6 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
      public:
       explicit Iterator(const Params& params)
           : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
             invocation_results_(params.dataset->num_parallel_calls_) {}
 
       ~Iterator() override {
@@ -169,6 +168,10 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
         }
       }
 
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
+
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
index 536de81..1409838 100644 (file)
@@ -55,7 +55,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::Prefetch")}));
@@ -87,7 +87,6 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
      public:
       explicit Iterator(const Params& params)
           : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
             auto_tuner_(params.dataset->buffer_size_) {}
 
       ~Iterator() override {
@@ -106,6 +105,10 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
         }
       }
 
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
+
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
@@ -327,7 +330,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
       // accessing the parent iterator. We keep this separate from `mu_` to
       // allow prefetching to run in parallel with GetNext calls.
       mutex parent_mu_ ACQUIRED_BEFORE(mu_);
-      const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_);
+      std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_);
       condition_variable cond_var_;
       PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_);
       std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
index 210b9ad..40bd95e 100644 (file)
@@ -54,7 +54,7 @@ class RandomDatasetOp : public DatasetOpKernel {
     Dataset(OpKernelContext* ctx, int64 seed, int64 seed2)
         : GraphDatasetBase(ctx), seed_(seed), seed2_(seed2) {}
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::Random")}));
index b57518e..b18263b 100644 (file)
@@ -48,7 +48,7 @@ class RangeDatasetOp : public DatasetOpKernel {
     Dataset(OpKernelContext* ctx, int64 start, int64 stop, int64 step)
         : GraphDatasetBase(ctx), start_(start), stop_(stop), step_(step) {}
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::Range")}));
index 34d7d9f..28d38d4 100644 (file)
@@ -89,7 +89,7 @@ class TextLineDatasetOp : public DatasetOpKernel {
           use_compression_(!compression_type.empty()),
           options_(options) {}
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::TextLine")}));
@@ -323,7 +323,7 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
           footer_bytes_(footer_bytes),
           buffer_size_(buffer_size) {}
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::FixedLengthRecord")}));
@@ -543,7 +543,7 @@ class TFRecordDatasetOp : public DatasetOpKernel {
       }
     }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::TFRecord")}));
index d370865..fcd9820 100644 (file)
@@ -48,7 +48,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       if (count_ < 0) {
         return std::unique_ptr<IteratorBase>(new ForeverIterator(
@@ -108,9 +108,11 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
     class FiniteIterator : public DatasetIterator<Dataset> {
      public:
       explicit FiniteIterator(const Params& params)
-          : DatasetIterator<Dataset>(params),
-            i_(0),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<Dataset>(params), i_(0) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
@@ -127,7 +129,8 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
             return Status::OK();
           }
           ++i_;
-          input_impl_ = dataset()->input_->MakeIterator(prefix());
+          TF_RETURN_IF_ERROR(
+              dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
         }
         *end_of_sequence = true;
         input_impl_.reset();
@@ -178,7 +181,8 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
           bool first_call = false;
           if (!input_impl_) {
             first_call = true;
-            input_impl_ = dataset()->input_->MakeIterator(prefix());
+            TF_RETURN_IF_ERROR(
+                dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
           }
           TF_RETURN_IF_ERROR(
               input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
@@ -214,7 +218,8 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
         if (reader->Contains(full_name("uninitialized"))) {
           input_impl_.reset();
         } else {
-          input_impl_ = dataset()->input_->MakeIterator(prefix());
+          TF_RETURN_IF_ERROR(
+              dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
           TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
         }
         return Status::OK();
index 5dd6ff8..972ed8f 100644 (file)
@@ -90,7 +90,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::Scan")}));
@@ -149,9 +149,12 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
      public:
       explicit Iterator(const Params& params)
           : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
             state_(params.dataset->initial_state_) {}
 
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
+
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
@@ -250,7 +253,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
 
      private:
       mutex mu_;
-      const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+      std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
       std::vector<Tensor> state_ GUARDED_BY(mu_);
     };
 
index 2f6bf83..dad58ef 100644 (file)
@@ -85,7 +85,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
         bool first_call = false;
         if (!input_impl_ && epoch_ == 0) {
           first_call = true;
-          input_impl_ = dataset()->input_->MakeIterator(prefix());
+          TF_RETURN_IF_ERROR(
+              dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
         }
         while (input_impl_ && num_elements_ < dataset()->buffer_size_) {
           if (ctx->env()->NowMicros() >
@@ -114,7 +115,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
             epoch_++;
             int64 n = slices_.back()->end;
             slices_.emplace_back(new Slice{n, n});
-            input_impl_ = dataset()->input_->MakeIterator(prefix());
+            TF_RETURN_IF_ERROR(
+                dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
           }
           if (!end_of_input_sequence) {
             buffer_[slices_.back()->end % dataset()->buffer_size_] =
@@ -211,7 +213,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
 
         // Restore the input iterator if it wasn't already exhausted.
         if (!reader->Contains(full_name("end_of_input_sequence"))) {
-          input_impl_ = dataset()->input_->MakeIterator(prefix());
+          TF_RETURN_IF_ERROR(
+              dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
           TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
         } else {
           input_impl_.reset();
@@ -361,7 +364,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
                              ", ", seed2_, ")::ReshufflingDataset");
     }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       int64 iterator_seed;
       int64 iterator_seed2;
@@ -399,7 +402,7 @@ class ShuffleDatasetOp : public ShuffleDatasetOpBase {
                              ", ", seed2_, ")::FixedSeedDataset");
     }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
           {this, strings::StrCat(prefix, "::Shuffle")}, seed_, seed2_));
@@ -482,7 +485,7 @@ class ShuffleAndRepeatDatasetOp : public ShuffleDatasetOpBase {
                              seed_, ", ", seed2_, ", ", count_, ")::Dataset");
     }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
           {this, strings::StrCat(prefix, "::ShuffleAndRepeat")}, seed_,
index d636c37..0177839 100644 (file)
@@ -47,14 +47,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       if (count_ < 0) {
         return std::unique_ptr<IteratorBase>(
             new EmptyIterator({this, strings::StrCat(prefix, "::EmptySkip")}));
-      } else if (count_ == 0) {
-        // Pass through.
-        return input_->MakeIterator(prefix);
       } else {
         return std::unique_ptr<IteratorBase>(new FiniteIterator(
             {this, strings::StrCat(prefix, "::FiniteSkip")}));
@@ -108,9 +105,11 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
     class FiniteIterator : public DatasetIterator<Dataset> {
      public:
       explicit FiniteIterator(const Params& params)
-          : DatasetIterator<Dataset>(params),
-            i_(0),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<Dataset>(params), i_(0) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
index 78c8363..e4b2820 100644 (file)
@@ -33,10 +33,9 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
                    DatasetBase** output) override {
     int64 window_size = 0;
     int64 stride = 1;
-    OP_REQUIRES_OK(ctx,
-                   ParseScalarArgument<int64>(ctx, "window_size", &window_size));
-    OP_REQUIRES_OK(ctx,
-                   ParseScalarArgument<int64>(ctx, "stride", &stride));
+    OP_REQUIRES_OK(
+        ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
+    OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "stride", &stride));
     OP_REQUIRES(
         ctx, window_size > 0,
         errors::InvalidArgument("Window size must be greater than zero."));
@@ -50,8 +49,12 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
  private:
   class Dataset : public GraphDatasetBase {
    public:
-    Dataset(OpKernelContext* ctx, int64 window_size, int64 stride, const DatasetBase* input)
-        : GraphDatasetBase(ctx), window_size_(window_size), stride_(stride), input_(input) {
+    Dataset(OpKernelContext* ctx, int64 window_size, int64 stride,
+            const DatasetBase* input)
+        : GraphDatasetBase(ctx),
+          window_size_(window_size),
+          stride_(stride),
+          input_(input) {
       input_->Ref();
 
       const auto& input_shapes = input_->output_shapes();
@@ -64,7 +67,7 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(new Iterator(
           Iterator::Params{this, strings::StrCat(prefix, "::Slide")}));
@@ -79,7 +82,8 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
     }
 
     string DebugString() override {
-      return strings::StrCat("SlideDatasetOp(", window_size_, ", ", stride_, ")::Dataset");
+      return strings::StrCat("SlideDatasetOp(", window_size_, ", ", stride_,
+                             ")::Dataset");
     }
 
    protected:
@@ -101,8 +105,11 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
     class Iterator : public DatasetIterator<Dataset> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<Dataset>(params) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
index fcf17ad..4cc638b 100644 (file)
@@ -39,7 +39,7 @@ class Dataset : public GraphDatasetBase {
                  {-1},
                  {sparse_tensor.dims() - 1}}) {}
 
-  std::unique_ptr<IteratorBase> MakeIterator(
+  std::unique_ptr<IteratorBase> MakeIteratorInternal(
       const string& prefix) const override {
     return std::unique_ptr<IteratorBase>(
         new Iterator({this, strings::StrCat(prefix, "::SparseTensorSlice")}));
index 634b3c2..4742ed3 100644 (file)
@@ -88,7 +88,7 @@ class SqlDatasetOp : public DatasetOpKernel {
           output_types_(output_types),
           output_shapes_(output_shapes) {}
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::Sql")}));
index eb96b8a..fd490c7 100644 (file)
@@ -53,7 +53,7 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
       stats_aggregator_resource_->Unref();
     }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(new Iterator(
           {this, strings::StrCat(prefix, "::SetStatsAggregator")}));
@@ -82,8 +82,11 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
     class Iterator : public DatasetIterator<Dataset> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<Dataset>(params) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
index 633cd85..8dc7618 100644 (file)
@@ -56,7 +56,7 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::LatencyStats")}));
@@ -86,8 +86,11 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
     class Iterator : public DatasetIterator<Dataset> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<Dataset>(params) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
@@ -150,7 +153,7 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(new Iterator(
           {this, strings::StrCat(prefix, "::BytesProducedStats")}));
@@ -182,8 +185,11 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
     class Iterator : public DatasetIterator<Dataset> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<Dataset>(params),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<Dataset>(params) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
index 3bea46a..209207d 100644 (file)
@@ -47,12 +47,9 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
 
     ~Dataset() override { input_->Unref(); }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
-      if (count_ < 0) {
-        // Pass through
-        return input_->MakeIterator(prefix);
-      } else if (count_ == 0) {
+      if (count_ == 0) {
         return std::unique_ptr<IteratorBase>(
             new EmptyIterator({this, strings::StrCat(prefix, "::EmptyTake")}));
       } else {
@@ -109,9 +106,11 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
     class FiniteIterator : public DatasetIterator<Dataset> {
      public:
       explicit FiniteIterator(const Params& params)
-          : DatasetIterator<Dataset>(params),
-            i_(0),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+          : DatasetIterator<Dataset>(params), i_(0) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
 
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
@@ -121,7 +120,7 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
           *end_of_sequence = true;
           return Status::OK();
         }
-        while (i_ < dataset()->count_) {
+        while (dataset()->count_ < 0 || i_ < dataset()->count_) {
           TF_RETURN_IF_ERROR(
               input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
           if (!*end_of_sequence) {
index 8c8994b..8f4586b 100644 (file)
@@ -53,7 +53,7 @@ class TensorDatasetOp : public DatasetOpKernel {
       }
     }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::FromTensor")}));
index e271a42..e9f486d 100644 (file)
@@ -81,7 +81,7 @@ class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase {
 
   ~PrependFromQueueAndPaddedBatchDataset() override { input_->Unref(); }
 
-  std::unique_ptr<IteratorBase> MakeIterator(
+  std::unique_ptr<IteratorBase> MakeIteratorInternal(
       const string& prefix) const override {
     return std::unique_ptr<IteratorBase>(new Iterator(
         {this, strings::StrCat(prefix, "::PrependFromQueueAndPaddedBatch")}));
@@ -152,15 +152,19 @@ class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase {
       : public DatasetIterator<PrependFromQueueAndPaddedBatchDataset> {
    public:
     explicit Iterator(const Params& params)
-        : DatasetIterator<PrependFromQueueAndPaddedBatchDataset>(params),
-          queue_(new TensorQueue(/*input_impl*/
-                                 params.dataset->input_->MakeIterator(
-                                     params.prefix),
-                                 params.dataset->dtypes_,
-                                 params.dataset->shapes_)) {}
+        : DatasetIterator<PrependFromQueueAndPaddedBatchDataset>(params) {}
 
     ~Iterator() override { queue_->Unref(); }
 
+    Status Initialize(IteratorContext* ctx) override {
+      std::unique_ptr<IteratorBase> iterator;
+      TF_RETURN_IF_ERROR(
+          dataset()->input_->MakeIterator(ctx, prefix(), &iterator));
+      queue_ = new TensorQueue(std::move(iterator), dataset()->dtypes_,
+                               dataset()->shapes_);
+      return Status::OK();
+    }
+
     Status GetNextInternal(IteratorContext* ctx,
                            std::vector<Tensor>* out_tensors,
                            bool* end_of_sequence) override {
@@ -372,7 +376,8 @@ class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase {
         if (reader->Contains(iter->full_name("input_exhausted"))) {
           input_impl_.reset();
         } else {
-          input_impl_ = iter->dataset_input()->MakeIterator(iter->prefix());
+          TF_RETURN_IF_ERROR(iter->dataset_input()->MakeIterator(
+              ctx, iter->prefix(), &input_impl_));
           TF_RETURN_IF_ERROR(iter->RestoreParent(ctx, reader, input_impl_));
         }
         entries_.clear();
@@ -469,7 +474,7 @@ class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase {
     };
 
    private:
-    TensorQueue* const queue_;
+    TensorQueue* queue_;
   };
 
  private:
index 95708cc..fd87803 100644 (file)
@@ -70,7 +70,7 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
       }
     }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::TensorSlice")}));
index 2b383e5..28f2350 100644 (file)
@@ -49,7 +49,7 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel {
       }
     }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::Unbatch")}));
@@ -80,9 +80,12 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel {
           : DatasetIterator<Dataset>(params),
             current_index_(0),
             current_batch_size_(0),
-            input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
             shapes_(params.dataset->output_shapes().size()) {}
 
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
+
       Status GetNextInternal(IteratorContext* ctx,
                              std::vector<Tensor>* out_tensors,
                              bool* end_of_sequence) override {
index e24bdea..e7470f8 100644 (file)
@@ -26,7 +26,7 @@ class WindowDataset : public DatasetBase {
         output_types_(std::move(output_types)),
         output_shapes_(std::move(output_shapes)) {}
 
-  std::unique_ptr<IteratorBase> MakeIterator(
+  std::unique_ptr<IteratorBase> MakeIteratorInternal(
       const string& prefix) const override {
     return std::unique_ptr<IteratorBase>(
         new Iterator({this, strings::StrCat(prefix, "::Window")}));
index 656fee1..80d9a5b 100644 (file)
@@ -70,9 +70,13 @@ class ToTFRecordOp : public AsyncOpKernel {
       DatasetBase* dataset;
       OP_REQUIRES_OK_ASYNC(
           ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
-      auto iterator = dataset->MakeIterator("ToTFRecordOpIterator");
-
       IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+      std::unique_ptr<IteratorBase> iterator;
+      OP_REQUIRES_OK_ASYNC(
+          ctx,
+          dataset->MakeIterator(&iter_ctx, "ToTFRecordOpIterator", &iterator),
+          done);
+
       std::vector<Tensor> components;
       components.reserve(dataset->output_dtypes().size());
       bool end_of_sequence;
index 0f79eac..d5343cd 100644 (file)
@@ -60,7 +60,7 @@ class ZipDatasetOp : public DatasetOpKernel {
       }
     }
 
-    std::unique_ptr<IteratorBase> MakeIterator(
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
         const string& prefix) const override {
       return std::unique_ptr<IteratorBase>(
           new Iterator({this, strings::StrCat(prefix, "::Zip")}));
@@ -95,13 +95,16 @@ class ZipDatasetOp : public DatasetOpKernel {
     class Iterator : public DatasetIterator<Dataset> {
      public:
       explicit Iterator(const Params& params)
-          : DatasetIterator<Dataset>(params) {
-        input_impls_.reserve(params.dataset->inputs_.size());
-        size_t idx = 0;
-        for (const auto& input : params.dataset->inputs_) {
-          input_impls_.emplace_back(input->MakeIterator(
-              strings::StrCat(params.prefix, "[", idx++, "]")));
+          : DatasetIterator<Dataset>(params) {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        mutex_lock l(mu_);
+        input_impls_.resize(dataset()->inputs_.size());
+        for (size_t i = 0; i < input_impls_.size(); ++i) {
+          TF_RETURN_IF_ERROR(dataset()->inputs_[i]->MakeIterator(
+              ctx, strings::StrCat(prefix(), "[", i, "]"), &input_impls_[i]));
         }
+        return Status::OK();
       }
 
       Status GetNextInternal(IteratorContext* ctx,