From 1c055f0679ea6cdae28b3c78c3bf98cb40f00e13 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 27 Mar 2018 03:23:58 -0700 Subject: [PATCH] Avoid reading the input file twice for InitializableLookupTable in combination with HashTable. Before this cl, TextFileLineIterator::total_size() was called for HashTable::DoPrepare, even though HashTable::DoPrepare ignores the size parameter. In order to have a result ready for TextFileLineIterator::total_size(), Init() called GetNumLinesInTextFile(), which read the whole file. Just to throw away the result :-/ This cl: - adds a DoLazyPrepare, that gets a functor to get the size, only if needed. - add HashTable::DoLazyPrepare which does not call this functor. - modify TextFileLineIterator::Init() to not call GetNumLinesInTextFile() anymore, when vocab_size was given as -1. - modify TextFileLineIterator::total_size() to call GetNumLinesInTextFile() lazily on the first call, if vocab_size_ was passed as -1. PiperOrigin-RevId: 190593744 --- .../core/kernels/initializable_lookup_table.cc | 2 +- .../core/kernels/initializable_lookup_table.h | 12 +++++++++++ tensorflow/core/kernels/lookup_table_op.h | 5 +++++ tensorflow/core/kernels/lookup_util.cc | 24 +++++++++++++++------- 4 files changed, 35 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/kernels/initializable_lookup_table.cc b/tensorflow/core/kernels/initializable_lookup_table.cc index 9c428cd..06d53eb 100644 --- a/tensorflow/core/kernels/initializable_lookup_table.cc +++ b/tensorflow/core/kernels/initializable_lookup_table.cc @@ -44,7 +44,7 @@ Status InitializableLookupTable::Initialize(InitTableIterator& iter) { return errors::FailedPrecondition("Table already initialized."); } - TF_RETURN_IF_ERROR(DoPrepare(iter.total_size())); + TF_RETURN_IF_ERROR(DoLazyPrepare([&iter]() { return iter.total_size(); })); while (iter.Valid()) { TF_RETURN_IF_ERROR(DoInsert(iter.keys(), iter.values())); iter.Next(); diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h index e9eae9f..b16c76d 100644 --- a/tensorflow/core/kernels/initializable_lookup_table.h +++ b/tensorflow/core/kernels/initializable_lookup_table.h @@ -114,6 +114,7 @@ class InitializableLookupTable : public LookupInterface { virtual Status status() const = 0; // Returns the total number of elements that the iterator will produce. + // It might return -1 in case of error. virtual int64 total_size() const = 0; private: @@ -129,6 +130,17 @@ class InitializableLookupTable : public LookupInterface { // number of expected elements. virtual Status DoPrepare(size_t expected_num_elements) = 0; + // Same as DoPrepare() but derived implementations might choose to skip + // calling get_expected_num_elements if size is not needed for DoPrepare. + virtual Status DoLazyPrepare( + std::function get_expected_num_elements) { + int64 expected_num_elements = get_expected_num_elements(); + if (expected_num_elements < 0) { + return errors::FailedPrecondition("Got negative expected_num_elements."); + } + return DoPrepare(expected_num_elements); + } + // Populates the table in batches given keys and values as tensors into the // underlying data structure. virtual Status DoInsert(const Tensor& keys, const Tensor& values) = 0; diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h index 5ba9b93..3657fd5 100644 --- a/tensorflow/core/kernels/lookup_table_op.h +++ b/tensorflow/core/kernels/lookup_table_op.h @@ -191,6 +191,11 @@ class HashTable : public InitializableLookupTable { return Status::OK(); }; + Status DoLazyPrepare(std::function unused) override { + constexpr size_t kUnusedSize = 0; + return DoPrepare(kUnusedSize); + } + Status DoInsert(const Tensor& keys, const Tensor& values) override { if (!table_) { return errors::FailedPrecondition("HashTable is not prepared."); diff --git a/tensorflow/core/kernels/lookup_util.cc b/tensorflow/core/kernels/lookup_util.cc index c7ce1c3..27031d9 100644 --- a/tensorflow/core/kernels/lookup_util.cc +++ b/tensorflow/core/kernels/lookup_util.cc @@ -75,9 +75,6 @@ class TextFileLineIterator Status Init(const string& filename, int64 vocab_size, char delimiter, DataType key_dtype, int64 key_index, DataType value_dtype, int64 value_index, Env* env) { - if (vocab_size == -1) { - TF_RETURN_IF_ERROR(GetNumLinesInTextFile(env, filename, &vocab_size)); - } filename_ = filename; vocab_size_ = vocab_size; delimiter_ = delimiter; @@ -85,6 +82,7 @@ class TextFileLineIterator value_ = Tensor(value_dtype, TensorShape({})); key_index_ = key_index; value_index_ = value_index; + env_ = env; status_ = env->NewRandomAccessFile(filename_, &file_); if (!status_.ok()) return status_; @@ -103,15 +101,15 @@ class TextFileLineIterator string line; status_ = input_buffer_->ReadLine(&line); if (!status_.ok()) { - if (errors::IsOutOfRange(status_) && next_id_ != vocab_size_) { + if (errors::IsOutOfRange(status_) && next_id_ != total_size()) { status_ = errors::InvalidArgument("Invalid vocab_size in ", filename_, - ": expected ", vocab_size_, + ": expected ", total_size(), " but got ", next_id_); } valid_ = false; return; } - if (next_id_ >= vocab_size_) { + if (vocab_size_ != -1 && next_id_ >= vocab_size_) { LOG(WARNING) << "Truncated " << filename_ << " before its end at " << vocab_size_ << " records."; LOG(WARNING) << "next_id_ : " << next_id_; @@ -162,7 +160,18 @@ class TextFileLineIterator Status status() const override { return status_; } - int64 total_size() const override { return vocab_size_; } + int64 total_size() const override { + if (vocab_size_ == -1) { + int64 new_size; + Status status = GetNumLinesInTextFile(env_, filename_, &new_size); + if (!status.ok()) { + LOG(WARNING) << "Unable to get line count: " << status; + new_size = -1; + } + *const_cast(&vocab_size_) = new_size; + } + return vocab_size_; + } private: Tensor key_; @@ -170,6 +179,7 @@ class TextFileLineIterator bool valid_; // true if the iterator points to an existing range. int64 key_index_; int64 value_index_; + Env* env_; int64 next_id_; int64 vocab_size_; string filename_; -- 2.7.4