virtual Status status() const = 0;
// Returns the total number of elements that the iterator will produce.
+ // It might return -1 in case of error.
virtual int64 total_size() const = 0;
private:
// number of expected elements.
virtual Status DoPrepare(size_t expected_num_elements) = 0;
+ // Same as DoPrepare() but derived implementations might choose to skip
+ // calling get_expected_num_elements if size is not needed for DoPrepare.
+ virtual Status DoLazyPrepare(
+ std::function<int64(void)> get_expected_num_elements) {
+ int64 expected_num_elements = get_expected_num_elements();
+ if (expected_num_elements < 0) {
+ return errors::FailedPrecondition("Got negative expected_num_elements.");
+ }
+ return DoPrepare(expected_num_elements);
+ }
+
// Populates the table in batches given keys and values as tensors into the
// underlying data structure.
virtual Status DoInsert(const Tensor& keys, const Tensor& values) = 0;
Status Init(const string& filename, int64 vocab_size, char delimiter,
DataType key_dtype, int64 key_index, DataType value_dtype,
int64 value_index, Env* env) {
- if (vocab_size == -1) {
- TF_RETURN_IF_ERROR(GetNumLinesInTextFile(env, filename, &vocab_size));
- }
filename_ = filename;
vocab_size_ = vocab_size;
delimiter_ = delimiter;
value_ = Tensor(value_dtype, TensorShape({}));
key_index_ = key_index;
value_index_ = value_index;
+ env_ = env;
status_ = env->NewRandomAccessFile(filename_, &file_);
if (!status_.ok()) return status_;
string line;
status_ = input_buffer_->ReadLine(&line);
if (!status_.ok()) {
- if (errors::IsOutOfRange(status_) && next_id_ != vocab_size_) {
+ if (errors::IsOutOfRange(status_) && next_id_ != total_size()) {
status_ = errors::InvalidArgument("Invalid vocab_size in ", filename_,
- ": expected ", vocab_size_,
+ ": expected ", total_size(),
" but got ", next_id_);
}
valid_ = false;
return;
}
- if (next_id_ >= vocab_size_) {
+ if (vocab_size_ != -1 && next_id_ >= vocab_size_) {
LOG(WARNING) << "Truncated " << filename_ << " before its end at "
<< vocab_size_ << " records.";
LOG(WARNING) << "next_id_ : " << next_id_;
Status status() const override { return status_; }
- int64 total_size() const override { return vocab_size_; }
+ int64 total_size() const override {
+ if (vocab_size_ == -1) {
+ int64 new_size;
+ Status status = GetNumLinesInTextFile(env_, filename_, &new_size);
+ if (!status.ok()) {
+ LOG(WARNING) << "Unable to get line count: " << status;
+ new_size = -1;
+ }
+ *const_cast<int64*>(&vocab_size_) = new_size;
+ }
+ return vocab_size_;
+ }
private:
Tensor key_;
bool valid_; // true if the iterator points to an existing range.
int64 key_index_;
int64 value_index_;
+ Env* env_;
int64 next_id_;
int64 vocab_size_;
string filename_;