delim_(delim),
na_value_(std::move(na_value)) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::CSV")}));
}
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::DirectedInterleave")}));
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- selector_input_impl_(params.dataset->selector_input_->MakeIterator(
- params.prefix + ".selector")),
- num_active_inputs_(params.dataset->data_inputs_.size()) {
- data_input_impls_.reserve(params.dataset->data_inputs_.size());
- for (size_t i = 0; i < params.dataset->data_inputs_.size(); ++i) {
- const DatasetBase* data_input = params.dataset->data_inputs_[i];
- data_input_impls_.push_back(data_input->MakeIterator(
- strings::StrCat(params.prefix, "[", i, "]")));
+ num_active_inputs_(params.dataset->data_inputs_.size()) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
+ ctx, strings::StrCat(prefix(), ".selector"),
+ &selector_input_impl_));
+ data_input_impls_.resize(dataset()->data_inputs_.size());
+ for (size_t i = 0; i < data_input_impls_.size(); ++i) {
+ const DatasetBase* data_input = dataset()->data_inputs_[i];
+ TF_RETURN_IF_ERROR(data_input->MakeIterator(
+ ctx, strings::StrCat(prefix(), "[", i, "]"),
+ &data_input_impls_[i]));
}
+ return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")}));
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
threadpool_->Unref();
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::ThreadPool")}));
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Unique")}));
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const typename Iterator::Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
eof_(eof),
timeout_(timeout) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Kafka")}));
// in the outputs of this iterator.
virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
+ // Performs initialization that needs to happen outside of a constructor to
+ // properly propagate errors.
+ virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
+
// Saves the state of this iterator.
virtual Status Save(OpKernelContext* ctx, IteratorStateWriter* writer) {
return SaveInternal(writer);
// iterator will traverse all elements in this dataset from the
// start.
//
- // Ownership of the created iterator will be transferred to the caller.
- //
// The prefix identifies the sequence of iterators leading up to the newly
// created iterator.
- virtual std::unique_ptr<IteratorBase> MakeIterator(
- const string& prefix) const = 0;
+ Status MakeIterator(IteratorContext* ctx, const string& prefix,
+ std::unique_ptr<IteratorBase>* iterator) const {
+ *iterator = MakeIteratorInternal(prefix);
+ return (*iterator)->Initialize(ctx);
+ }
// Returns a vector of DataType values, representing the respective
// element types of each tuple component in the outputs of this
Node** node) const {
return errors::Unimplemented("AsGraphDefInternal");
}
+
+ virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const = 0;
};
// Base-class for datasets that are built by ops.
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
Iterator::Params{this, strings::StrCat(prefix, "::Batch")}));
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
~FileDataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
if (env_->FileExists(strings::StrCat(filename_, ".index")).ok()) {
return std::unique_ptr<IteratorBase>(new FileReaderIterator(
explicit FileWriterIterator(const Params& params)
: DatasetIterator<FileDataset>(params),
cur_index_(0),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
writer_(params.dataset->env_, params.dataset->filename_),
lockfile_(strings::StrCat(params.dataset->filename_, ".lockfile")),
lockfile_created_(false),
iteration_completed_(false) {}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
~MemoryDataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
mutex_lock l(mu_);
if (cache_) {
public:
explicit MemoryWriterIterator(const Params& params)
: DatasetIterator<MemoryDataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
cache_(new std::vector<std::vector<Tensor>>) {}
~MemoryWriterIterator() override {
}
}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
to_concatenate_->Unref();
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Concatenate")}));
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- i_(0),
- input_impl_(params.dataset->input_->MakeIterator(
- strings::StrCat(params.prefix, "[0]"))) {}
+ : DatasetIterator<Dataset>(params), i_(0) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(
+ ctx, strings::StrCat(prefix(), "[0]"), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
return Status::OK();
}
if (++i_ < 2) {
- input_impl_ = dataset()->to_concatenate_->MakeIterator(
- strings::StrCat(prefix(), "[1]"));
+ TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
+ ctx, strings::StrCat(prefix(), "[1]"), &input_impl_));
}
}
*end_of_sequence = true;
if (!TF_PREDICT_TRUE(i_ >= 0 && i_ <= 2))
return errors::InvalidArgument("i_ must be in range [0, 2].");
if (i_ == 1) {
- input_impl_ = dataset()->to_concatenate_->MakeIterator(
- strings::StrCat(prefix(), "[1]"));
+ TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator(
+ ctx, strings::StrCat(prefix(), "[1]"), &input_impl_));
} else if (i_ == 2) {
input_impl_.reset();
}
GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
// Create an iterator for the dataset that was returned by `f`.
- *out_iterator = returned_dataset->MakeIterator(
- strings::StrCat(prefix, "[", thread_index, "]"));
- return Status::OK();
+ return returned_dataset->MakeIterator(
+ ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator);
}
} // namespace dataset
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::DenseToSparseBatch")}));
class Iterator : public DatasetIterator<Dataset<T>> {
public:
explicit Iterator(const typename Iterator::Params& params)
- : DatasetIterator<Dataset<T>>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset<T>>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return DatasetIterator<Dataset<T>>::dataset()->input_->MakeIterator(
+ ctx, DatasetIterator<Dataset<T>>::prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
~FilterDatasetBase() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Filter")}));
class Iterator : public DatasetIterator<FilterDatasetBase> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<FilterDatasetBase>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<FilterDatasetBase>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::FlatMap")}));
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
current_element_iterator_.reset();
captured_func_inputs_.clear();
if (!reader->Contains(full_name("exhausted"))) {
- input_impl_ = dataset()->input_->MakeIterator(prefix());
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
{
int64 temp;
output_types_(output_types),
output_shapes_(output_shapes) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Generator")}));
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::GroupByReducer")}));
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::GroupByWindow")}));
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
// Create an iterator for the dataset that was returned by `f`.
- current_group_iterator_ = returned_dataset->MakeIterator(prefix());
- return Status::OK();
+ return returned_dataset->MakeIterator(ctx, prefix(),
+ ¤t_group_iterator_);
}
mutex mu_;
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Interleave")}));
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
current_elements_(params.dataset->cycle_length_),
args_list_(params.dataset->cycle_length_) {}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
block_index_ = 0;
cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
}
mutex mu_;
- const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
std::vector<std::unique_ptr<IteratorBase>> current_elements_
GUARDED_BY(mu_);
std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_);
graph_runner.Run(&graph, lib, {}, {output_node}, &outputs));
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
- TF_RETURN_IF_ERROR(set_iterator(dataset->MakeIterator("Iterator")));
+ IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+ std::unique_ptr<IteratorBase> iterator;
+ TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iterator));
+ TF_RETURN_IF_ERROR(set_iterator(std::move(iterator)));
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
if (captured_iterator) {
OP_REQUIRES_OK(
ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
core::ScopedUnref unref(iterator_resource);
- OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(
- dataset->MakeIterator("Iterator")));
+
+ IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+ std::unique_ptr<IteratorBase> iterator;
+ OP_REQUIRES_OK(ctx,
+ dataset->MakeIterator(&iter_ctx, "Iterator", &iterator));
+ OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator)));
}
};
DatasetBase* dataset;
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
- auto iterator = dataset->MakeIterator("SingleElementIterator");
-
IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+ std::unique_ptr<IteratorBase> iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator),
+ done);
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());
bool end_of_sequence;
// factory function.
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
- TF_RETURN_IF_ERROR(
- (*iterator)->set_iterator(dataset->MakeIterator("Iterator")));
+ IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+ std::unique_ptr<IteratorBase> iter;
+ TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iter));
+ TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter)));
(*iterator)->Ref();
return Status::OK();
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::MapAndBatch")}));
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
batch_results_((params.dataset->num_parallel_calls_ +
params.dataset->batch_size_ - 1) /
params.dataset->batch_size_) {
}
}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
int64 num_calls_ GUARDED_BY(mu_) = 0;
// Counts the total number of calls.
int64 call_counter_ GUARDED_BY(mu_) = 0;
- const std::unique_ptr<IteratorBase> input_impl_;
+ std::unique_ptr<IteratorBase> input_impl_;
// Identifies the next batch to be read by the caller.
int64 input_batch_ GUARDED_BY(mu_) = 0;
// Identifies the next batch to create.
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Map")}));
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
}
private:
- const std::unique_ptr<IteratorBase> input_impl_;
+ std::unique_ptr<IteratorBase> input_impl_;
};
const DatasetBase* const input_;
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::PaddedBatch")}));
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
if (reader->Contains(full_name("exhausted"))) {
input_impl_.reset();
} else {
- input_impl_ = dataset()->input_->MakeIterator(prefix());
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
}
return Status::OK();
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::ParallelInterleave")}));
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
workers_(dataset()->num_threads()),
worker_thread_states_(dataset()->num_threads()) {}
}
}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
// It is implemented so that it matches the deterministic interleave
// unless getting the next element would block and we are allowed to be
// sloppy.
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::ParallelMap")}));
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
invocation_results_(params.dataset->num_parallel_calls_) {}
~Iterator() override {
}
}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Prefetch")}));
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
auto_tuner_(params.dataset->buffer_size_) {}
~Iterator() override {
}
}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
// accessing the parent iterator. We keep this separate from `mu_` to
// allow prefetching to run in parallel with GetNext calls.
mutex parent_mu_ ACQUIRED_BEFORE(mu_);
- const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_);
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_);
condition_variable cond_var_;
PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_);
std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
Dataset(OpKernelContext* ctx, int64 seed, int64 seed2)
: GraphDatasetBase(ctx), seed_(seed), seed2_(seed2) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Random")}));
Dataset(OpKernelContext* ctx, int64 start, int64 stop, int64 step)
: GraphDatasetBase(ctx), start_(start), stop_(stop), step_(step) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Range")}));
use_compression_(!compression_type.empty()),
options_(options) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::TextLine")}));
footer_bytes_(footer_bytes),
buffer_size_(buffer_size) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::FixedLengthRecord")}));
}
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::TFRecord")}));
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
if (count_ < 0) {
return std::unique_ptr<IteratorBase>(new ForeverIterator(
class FiniteIterator : public DatasetIterator<Dataset> {
public:
explicit FiniteIterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- i_(0),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params), i_(0) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
return Status::OK();
}
++i_;
- input_impl_ = dataset()->input_->MakeIterator(prefix());
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
}
*end_of_sequence = true;
input_impl_.reset();
bool first_call = false;
if (!input_impl_) {
first_call = true;
- input_impl_ = dataset()->input_->MakeIterator(prefix());
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
}
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
if (reader->Contains(full_name("uninitialized"))) {
input_impl_.reset();
} else {
- input_impl_ = dataset()->input_->MakeIterator(prefix());
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
}
return Status::OK();
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Scan")}));
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
state_(params.dataset->initial_state_) {}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
private:
mutex mu_;
- const std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
std::vector<Tensor> state_ GUARDED_BY(mu_);
};
bool first_call = false;
if (!input_impl_ && epoch_ == 0) {
first_call = true;
- input_impl_ = dataset()->input_->MakeIterator(prefix());
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
}
while (input_impl_ && num_elements_ < dataset()->buffer_size_) {
if (ctx->env()->NowMicros() >
epoch_++;
int64 n = slices_.back()->end;
slices_.emplace_back(new Slice{n, n});
- input_impl_ = dataset()->input_->MakeIterator(prefix());
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
}
if (!end_of_input_sequence) {
buffer_[slices_.back()->end % dataset()->buffer_size_] =
// Restore the input iterator if it wasn't already exhausted.
if (!reader->Contains(full_name("end_of_input_sequence"))) {
- input_impl_ = dataset()->input_->MakeIterator(prefix());
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
} else {
input_impl_.reset();
", ", seed2_, ")::ReshufflingDataset");
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
int64 iterator_seed;
int64 iterator_seed2;
", ", seed2_, ")::FixedSeedDataset");
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
{this, strings::StrCat(prefix, "::Shuffle")}, seed_, seed2_));
seed_, ", ", seed2_, ", ", count_, ")::Dataset");
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new ShuffleDatasetBase::Iterator(
{this, strings::StrCat(prefix, "::ShuffleAndRepeat")}, seed_,
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
if (count_ < 0) {
return std::unique_ptr<IteratorBase>(
new EmptyIterator({this, strings::StrCat(prefix, "::EmptySkip")}));
- } else if (count_ == 0) {
- // Pass through.
- return input_->MakeIterator(prefix);
} else {
return std::unique_ptr<IteratorBase>(new FiniteIterator(
{this, strings::StrCat(prefix, "::FiniteSkip")}));
class FiniteIterator : public DatasetIterator<Dataset> {
public:
explicit FiniteIterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- i_(0),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params), i_(0) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
DatasetBase** output) override {
int64 window_size = 0;
int64 stride = 1;
- OP_REQUIRES_OK(ctx,
- ParseScalarArgument<int64>(ctx, "window_size", &window_size));
- OP_REQUIRES_OK(ctx,
- ParseScalarArgument<int64>(ctx, "stride", &stride));
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "stride", &stride));
OP_REQUIRES(
ctx, window_size > 0,
errors::InvalidArgument("Window size must be greater than zero."));
private:
class Dataset : public GraphDatasetBase {
public:
- Dataset(OpKernelContext* ctx, int64 window_size, int64 stride, const DatasetBase* input)
- : GraphDatasetBase(ctx), window_size_(window_size), stride_(stride), input_(input) {
+ Dataset(OpKernelContext* ctx, int64 window_size, int64 stride,
+ const DatasetBase* input)
+ : GraphDatasetBase(ctx),
+ window_size_(window_size),
+ stride_(stride),
+ input_(input) {
input_->Ref();
const auto& input_shapes = input_->output_shapes();
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
Iterator::Params{this, strings::StrCat(prefix, "::Slide")}));
}
string DebugString() override {
- return strings::StrCat("SlideDatasetOp(", window_size_, ", ", stride_, ")::Dataset");
+ return strings::StrCat("SlideDatasetOp(", window_size_, ", ", stride_,
+ ")::Dataset");
}
protected:
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
{-1},
{sparse_tensor.dims() - 1}}) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::SparseTensorSlice")}));
output_types_(output_types),
output_shapes_(output_shapes) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Sql")}));
stats_aggregator_resource_->Unref();
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::SetStatsAggregator")}));
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::LatencyStats")}));
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::BytesProducedStats")}));
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
~Dataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- if (count_ < 0) {
- // Pass through
- return input_->MakeIterator(prefix);
- } else if (count_ == 0) {
+ if (count_ == 0) {
return std::unique_ptr<IteratorBase>(
new EmptyIterator({this, strings::StrCat(prefix, "::EmptyTake")}));
} else {
class FiniteIterator : public DatasetIterator<Dataset> {
public:
explicit FiniteIterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- i_(0),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {}
+ : DatasetIterator<Dataset>(params), i_(0) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
*end_of_sequence = true;
return Status::OK();
}
- while (i_ < dataset()->count_) {
+ while (dataset()->count_ < 0 || i_ < dataset()->count_) {
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
if (!*end_of_sequence) {
}
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::FromTensor")}));
~PrependFromQueueAndPaddedBatchDataset() override { input_->Unref(); }
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new Iterator(
{this, strings::StrCat(prefix, "::PrependFromQueueAndPaddedBatch")}));
: public DatasetIterator<PrependFromQueueAndPaddedBatchDataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<PrependFromQueueAndPaddedBatchDataset>(params),
- queue_(new TensorQueue(/*input_impl*/
- params.dataset->input_->MakeIterator(
- params.prefix),
- params.dataset->dtypes_,
- params.dataset->shapes_)) {}
+ : DatasetIterator<PrependFromQueueAndPaddedBatchDataset>(params) {}
~Iterator() override { queue_->Unref(); }
+ Status Initialize(IteratorContext* ctx) override {
+ std::unique_ptr<IteratorBase> iterator;
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &iterator));
+ queue_ = new TensorQueue(std::move(iterator), dataset()->dtypes_,
+ dataset()->shapes_);
+ return Status::OK();
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
if (reader->Contains(iter->full_name("input_exhausted"))) {
input_impl_.reset();
} else {
- input_impl_ = iter->dataset_input()->MakeIterator(iter->prefix());
+ TF_RETURN_IF_ERROR(iter->dataset_input()->MakeIterator(
+ ctx, iter->prefix(), &input_impl_));
TF_RETURN_IF_ERROR(iter->RestoreParent(ctx, reader, input_impl_));
}
entries_.clear();
};
private:
- TensorQueue* const queue_;
+ TensorQueue* queue_;
};
private:
}
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::TensorSlice")}));
}
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Unbatch")}));
: DatasetIterator<Dataset>(params),
current_index_(0),
current_batch_size_(0),
- input_impl_(params.dataset->input_->MakeIterator(params.prefix)),
shapes_(params.dataset->output_shapes().size()) {}
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
output_types_(std::move(output_types)),
output_shapes_(std::move(output_shapes)) {}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Window")}));
DatasetBase* dataset;
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
- auto iterator = dataset->MakeIterator("ToTFRecordOpIterator");
-
IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+ std::unique_ptr<IteratorBase> iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ dataset->MakeIterator(&iter_ctx, "ToTFRecordOpIterator", &iterator),
+ done);
+
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());
bool end_of_sequence;
}
}
- std::unique_ptr<IteratorBase> MakeIterator(
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
new Iterator({this, strings::StrCat(prefix, "::Zip")}));
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {
- input_impls_.reserve(params.dataset->inputs_.size());
- size_t idx = 0;
- for (const auto& input : params.dataset->inputs_) {
- input_impls_.emplace_back(input->MakeIterator(
- strings::StrCat(params.prefix, "[", idx++, "]")));
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ input_impls_.resize(dataset()->inputs_.size());
+ for (size_t i = 0; i < input_impls_.size(); ++i) {
+ TF_RETURN_IF_ERROR(dataset()->inputs_[i]->MakeIterator(
+ ctx, strings::StrCat(prefix(), "[", i, "]"), &input_impls_[i]));
}
+ return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,