Avoid reading the input file twice for InitializableLookupTable in combination with...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 27 Mar 2018 10:23:58 +0000 (03:23 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 27 Mar 2018 10:26:34 +0000 (03:26 -0700)
Before this cl, TextFileLineIterator::total_size() was called for HashTable::DoPrepare, even though HashTable::DoPrepare ignores the size parameter.
In order to have a result ready for TextFileLineIterator::total_size(), Init() called GetNumLinesInTextFile(), which read the whole file. Just to throw away the result :-/

This cl:
- adds a DoLazyPrepare, that gets a functor to get the size, only if needed.
- add HashTable::DoLazyPrepare which does not call this functor.
- modify TextFileLineIterator::Init() to not call GetNumLinesInTextFile() anymore, when vocab_size was given as -1.
- modify TextFileLineIterator::total_size() to call GetNumLinesInTextFile() lazily on the first call, if vocab_size_ was passed as -1.

PiperOrigin-RevId: 190593744

tensorflow/core/kernels/initializable_lookup_table.cc
tensorflow/core/kernels/initializable_lookup_table.h
tensorflow/core/kernels/lookup_table_op.h
tensorflow/core/kernels/lookup_util.cc

index 9c428cd..06d53eb 100644 (file)
@@ -44,7 +44,7 @@ Status InitializableLookupTable::Initialize(InitTableIterator& iter) {
     return errors::FailedPrecondition("Table already initialized.");
   }
 
-  TF_RETURN_IF_ERROR(DoPrepare(iter.total_size()));
+  TF_RETURN_IF_ERROR(DoLazyPrepare([&iter]() { return iter.total_size(); }));
   while (iter.Valid()) {
     TF_RETURN_IF_ERROR(DoInsert(iter.keys(), iter.values()));
     iter.Next();
index e9eae9f..b16c76d 100644 (file)
@@ -114,6 +114,7 @@ class InitializableLookupTable : public LookupInterface {
     virtual Status status() const = 0;
 
     // Returns the total number of elements that the iterator will produce.
+    // It might return -1 in case of error.
     virtual int64 total_size() const = 0;
 
    private:
@@ -129,6 +130,17 @@ class InitializableLookupTable : public LookupInterface {
   // number of expected elements.
   virtual Status DoPrepare(size_t expected_num_elements) = 0;
 
+  // Same as DoPrepare() but derived implementations might choose to skip
+  // calling get_expected_num_elements if size is not needed for DoPrepare.
+  virtual Status DoLazyPrepare(
+      std::function<int64(void)> get_expected_num_elements) {
+    int64 expected_num_elements = get_expected_num_elements();
+    if (expected_num_elements < 0) {
+      return errors::FailedPrecondition("Got negative expected_num_elements.");
+    }
+    return DoPrepare(expected_num_elements);
+  }
+
   // Populates the table in batches given keys and values as tensors into the
   // underlying data structure.
   virtual Status DoInsert(const Tensor& keys, const Tensor& values) = 0;
index 5ba9b93..3657fd5 100644 (file)
@@ -191,6 +191,11 @@ class HashTable : public InitializableLookupTable {
     return Status::OK();
   };
 
+  Status DoLazyPrepare(std::function<int64(void)> unused) override {
+    constexpr size_t kUnusedSize = 0;
+    return DoPrepare(kUnusedSize);
+  }
+
   Status DoInsert(const Tensor& keys, const Tensor& values) override {
     if (!table_) {
       return errors::FailedPrecondition("HashTable is not prepared.");
index c7ce1c3..27031d9 100644 (file)
@@ -75,9 +75,6 @@ class TextFileLineIterator
   Status Init(const string& filename, int64 vocab_size, char delimiter,
               DataType key_dtype, int64 key_index, DataType value_dtype,
               int64 value_index, Env* env) {
-    if (vocab_size == -1) {
-      TF_RETURN_IF_ERROR(GetNumLinesInTextFile(env, filename, &vocab_size));
-    }
     filename_ = filename;
     vocab_size_ = vocab_size;
     delimiter_ = delimiter;
@@ -85,6 +82,7 @@ class TextFileLineIterator
     value_ = Tensor(value_dtype, TensorShape({}));
     key_index_ = key_index;
     value_index_ = value_index;
+    env_ = env;
 
     status_ = env->NewRandomAccessFile(filename_, &file_);
     if (!status_.ok()) return status_;
@@ -103,15 +101,15 @@ class TextFileLineIterator
     string line;
     status_ = input_buffer_->ReadLine(&line);
     if (!status_.ok()) {
-      if (errors::IsOutOfRange(status_) && next_id_ != vocab_size_) {
+      if (errors::IsOutOfRange(status_) && next_id_ != total_size()) {
         status_ = errors::InvalidArgument("Invalid vocab_size in ", filename_,
-                                          ": expected ", vocab_size_,
+                                          ": expected ", total_size(),
                                           " but got ", next_id_);
       }
       valid_ = false;
       return;
     }
-    if (next_id_ >= vocab_size_) {
+    if (vocab_size_ != -1 && next_id_ >= vocab_size_) {
       LOG(WARNING) << "Truncated " << filename_ << " before its end at "
                    << vocab_size_ << " records.";
       LOG(WARNING) << "next_id_  : " << next_id_;
@@ -162,7 +160,18 @@ class TextFileLineIterator
 
   Status status() const override { return status_; }
 
-  int64 total_size() const override { return vocab_size_; }
+  int64 total_size() const override {
+    if (vocab_size_ == -1) {
+      int64 new_size;
+      Status status = GetNumLinesInTextFile(env_, filename_, &new_size);
+      if (!status.ok()) {
+        LOG(WARNING) << "Unable to get line count: " << status;
+        new_size = -1;
+      }
+      *const_cast<int64*>(&vocab_size_) = new_size;
+    }
+    return vocab_size_;
+  }
 
  private:
   Tensor key_;
@@ -170,6 +179,7 @@ class TextFileLineIterator
   bool valid_;  // true if the iterator points to an existing range.
   int64 key_index_;
   int64 value_index_;
+  Env* env_;
   int64 next_id_;
   int64 vocab_size_;
   string filename_;