internal
authorLukasz Kaiser <lukaszkaiser@google.com>
Fri, 20 Apr 2018 00:39:09 +0000 (17:39 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 20 Apr 2018 00:41:31 +0000 (17:41 -0700)
END_PUBLIC

BEGIN_PUBLIC
Automated g4 rollback of changelist 193571934

PiperOrigin-RevId: 193602050

tensorflow/core/lib/io/record_reader.cc
tensorflow/core/lib/io/record_reader.h
tensorflow/core/lib/io/recordio_test.cc
tensorflow/core/lib/io/zlib_inputstream.cc
tensorflow/core/lib/io/zlib_inputstream.h

index c24628b..6de850b 100644 (file)
@@ -56,55 +56,110 @@ RecordReaderOptions RecordReaderOptions::CreateRecordReaderOptions(
 
 RecordReader::RecordReader(RandomAccessFile* file,
                            const RecordReaderOptions& options)
-    : options_(options),
-      input_stream_(new RandomAccessInputStream(file)),
-      last_read_failed_(false) {
+    : src_(file), options_(options) {
   if (options.buffer_size > 0) {
-    input_stream_.reset(new BufferedInputStream(input_stream_.release(),
-                                                options.buffer_size, true));
+    input_stream_.reset(new BufferedInputStream(file, options.buffer_size));
+  } else {
+    input_stream_.reset(new RandomAccessInputStream(file));
   }
   if (options.compression_type == RecordReaderOptions::ZLIB_COMPRESSION) {
 // We don't have zlib available on all embedded platforms, so fail.
 #if defined(IS_SLIM_BUILD)
     LOG(FATAL) << "Zlib compression is unsupported on mobile platforms.";
 #else   // IS_SLIM_BUILD
-    input_stream_.reset(new ZlibInputStream(
-        input_stream_.release(), options.zlib_options.input_buffer_size,
-        options.zlib_options.output_buffer_size, options.zlib_options, true));
+    zlib_input_stream_.reset(new ZlibInputStream(
+        input_stream_.get(), options.zlib_options.input_buffer_size,
+        options.zlib_options.output_buffer_size, options.zlib_options));
 #endif  // IS_SLIM_BUILD
   } else if (options.compression_type == RecordReaderOptions::NONE) {
     // Nothing to do.
   } else {
-    LOG(FATAL) << "Unrecognized compression type :" << options.compression_type;
+    LOG(FATAL) << "Unspecified compression type :" << options.compression_type;
   }
 }
 
 // Read n+4 bytes from file, verify that checksum of first n bytes is
 // stored in the last 4 bytes and store the first n bytes in *result.
-//
-// offset corresponds to the user-provided value to ReadRecord()
-// and is used only in error messages.
-Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) {
+// May use *storage as backing store.
+Status RecordReader::ReadChecksummed(uint64 offset, size_t n,
+                                     StringPiece* result, string* storage) {
   if (n >= SIZE_MAX - sizeof(uint32)) {
     return errors::DataLoss("record size too large");
   }
 
   const size_t expected = n + sizeof(uint32);
-  TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(expected, result));
+  storage->resize(expected);
+
+#if !defined(IS_SLIM_BUILD)
+  if (zlib_input_stream_) {
+    // If we have a zlib compressed buffer, we assume that the
+    // file is being read sequentially, and we use the underlying
+    // implementation to read the data.
+    //
+    // No checks are done to validate that the file is being read
+    // sequentially.  At some point the zlib input buffer may support
+    // seeking, possibly inefficiently.
+    TF_RETURN_IF_ERROR(zlib_input_stream_->ReadNBytes(expected, storage));
+
+    if (storage->size() != expected) {
+      if (storage->empty()) {
+        return errors::OutOfRange("eof");
+      } else {
+        return errors::DataLoss("truncated record at ", offset);
+      }
+    }
 
-  if (result->size() != expected) {
-    if (result->empty()) {
-      return errors::OutOfRange("eof");
+    uint32 masked_crc = core::DecodeFixed32(storage->data() + n);
+    if (crc32c::Unmask(masked_crc) != crc32c::Value(storage->data(), n)) {
+      return errors::DataLoss("corrupted record at ", offset);
+    }
+    *result = StringPiece(storage->data(), n);
+  } else {
+#endif  // IS_SLIM_BUILD
+    if (options_.buffer_size > 0) {
+      // If we have a buffer, we assume that the file is being read
+      // sequentially, and we use the underlying implementation to read the
+      // data.
+      //
+      // No checks are done to validate that the file is being read
+      // sequentially.
+      TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(expected, storage));
+
+      if (storage->size() != expected) {
+        if (storage->empty()) {
+          return errors::OutOfRange("eof");
+        } else {
+          return errors::DataLoss("truncated record at ", offset);
+        }
+      }
+
+      const uint32 masked_crc = core::DecodeFixed32(storage->data() + n);
+      if (crc32c::Unmask(masked_crc) != crc32c::Value(storage->data(), n)) {
+        return errors::DataLoss("corrupted record at ", offset);
+      }
+      *result = StringPiece(storage->data(), n);
     } else {
-      return errors::DataLoss("truncated record at ", offset);
+      // This version supports reading from arbitrary offsets
+      // since we are accessing the random access file directly.
+      StringPiece data;
+      TF_RETURN_IF_ERROR(src_->Read(offset, expected, &data, &(*storage)[0]));
+      if (data.size() != expected) {
+        if (data.empty()) {
+          return errors::OutOfRange("eof");
+        } else {
+          return errors::DataLoss("truncated record at ", offset);
+        }
+      }
+      const uint32 masked_crc = core::DecodeFixed32(data.data() + n);
+      if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) {
+        return errors::DataLoss("corrupted record at ", offset);
+      }
+      *result = StringPiece(data.data(), n);
     }
+#if !defined(IS_SLIM_BUILD)
   }
+#endif  // IS_SLIM_BUILD
 
-  const uint32 masked_crc = core::DecodeFixed32(result->data() + n);
-  if (crc32c::Unmask(masked_crc) != crc32c::Value(result->data(), n)) {
-    return errors::DataLoss("corrupted record at ", offset);
-  }
-  result->resize(n);
   return Status::OK();
 }
 
@@ -112,42 +167,50 @@ Status RecordReader::ReadRecord(uint64* offset, string* record) {
   static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
   static const size_t kFooterSize = sizeof(uint32);
 
-  // Position the input stream.
-  int64 curr_pos = input_stream_->Tell();
-  int64 desired_pos = static_cast<int64>(*offset);
-  if (curr_pos > desired_pos || curr_pos < 0 /* EOF */ ||
-      (curr_pos == desired_pos && last_read_failed_)) {
-    last_read_failed_ = false;
-    TF_RETURN_IF_ERROR(input_stream_->Reset());
-    TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos));
-  } else if (curr_pos < desired_pos) {
-    TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos - curr_pos));
-  }
-  DCHECK_EQ(desired_pos, input_stream_->Tell());
-
   // Read header data.
-  Status s = ReadChecksummed(*offset, sizeof(uint64), record);
+  StringPiece lbuf;
+  Status s = ReadChecksummed(*offset, sizeof(uint64), &lbuf, record);
   if (!s.ok()) {
-    last_read_failed_ = true;
     return s;
   }
-  const uint64 length = core::DecodeFixed64(record->data());
+  const uint64 length = core::DecodeFixed64(lbuf.data());
 
   // Read data
-  s = ReadChecksummed(*offset + kHeaderSize, length, record);
+  StringPiece data;
+  s = ReadChecksummed(*offset + kHeaderSize, length, &data, record);
   if (!s.ok()) {
-    last_read_failed_ = true;
     if (errors::IsOutOfRange(s)) {
       s = errors::DataLoss("truncated record at ", *offset);
     }
     return s;
   }
 
+  if (record->data() != data.data()) {
+    // RandomAccessFile placed the data in some other location.
+    memmove(&(*record)[0], data.data(), data.size());
+  }
+
+  record->resize(data.size());
+
   *offset += kHeaderSize + length + kFooterSize;
-  DCHECK_EQ(*offset, input_stream_->Tell());
   return Status::OK();
 }
 
+Status RecordReader::SkipNBytes(uint64 offset) {
+#if !defined(IS_SLIM_BUILD)
+  if (zlib_input_stream_) {
+    TF_RETURN_IF_ERROR(zlib_input_stream_->SkipNBytes(offset));
+  } else {
+#endif
+    if (options_.buffer_size > 0) {
+      TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(offset));
+    }
+#if !defined(IS_SLIM_BUILD)
+  }
+#endif
+  return Status::OK();
+}  // namespace io
+
 SequentialRecordReader::SequentialRecordReader(
     RandomAccessFile* file, const RecordReaderOptions& options)
     : underlying_(file, options), offset_(0) {}
index f6d587d..26278e0 100644 (file)
@@ -69,14 +69,25 @@ class RecordReader {
   // Read the record at "*offset" into *record and update *offset to
   // point to the offset of the next record.  Returns OK on success,
   // OUT_OF_RANGE for end of file, or something else for an error.
+  //
+  // Note: if buffering is used (with or without compression), access must be
+  // sequential.
   Status ReadRecord(uint64* offset, string* record);
 
+  // Skip the records till "offset". Returns OK on success,
+  // OUT_OF_RANGE for end of file, or something else for an error.
+  Status SkipNBytes(uint64 offset);
+
  private:
-  Status ReadChecksummed(uint64 offset, size_t n, string* result);
+  Status ReadChecksummed(uint64 offset, size_t n, StringPiece* result,
+                         string* storage);
 
+  RandomAccessFile* src_;
   RecordReaderOptions options_;
   std::unique_ptr<InputStreamInterface> input_stream_;
-  bool last_read_failed_;
+#if !defined(IS_SLIM_BUILD)
+  std::unique_ptr<ZlibInputStream> zlib_input_stream_;
+#endif  // IS_SLIM_BUILD
 
   TF_DISALLOW_COPY_AND_ASSIGN(RecordReader);
 };
@@ -110,6 +121,7 @@ class SequentialRecordReader {
       return errors::InvalidArgument(
           "Trying to seek offset: ", offset,
           " which is less than the current offset: ", offset_);
+    TF_RETURN_IF_ERROR(underlying_.SkipNBytes(offset - offset_));
     offset_ = offset;
     return Status::OK();
   }
index da514bd..6323576 100644 (file)
@@ -26,11 +26,10 @@ limitations under the License.
 
 namespace tensorflow {
 namespace io {
-namespace {
 
 // Construct a string of the specified length made out of the supplied
 // partial string.
-string BigString(const string& partial_string, size_t n) {
+static string BigString(const string& partial_string, size_t n) {
   string result;
   while (result.size() < n) {
     result.append(partial_string);
@@ -40,66 +39,62 @@ string BigString(const string& partial_string, size_t n) {
 }
 
 // Construct a string from a number
-string NumberString(int n) {
+static string NumberString(int n) {
   char buf[50];
   snprintf(buf, sizeof(buf), "%d.", n);
   return string(buf);
 }
 
 // Return a skewed potentially long string
-string RandomSkewedString(int i, random::SimplePhilox* rnd) {
+static string RandomSkewedString(int i, random::SimplePhilox* rnd) {
   return BigString(NumberString(i), rnd->Skewed(17));
 }
 
-class StringDest : public WritableFile {
- public:
-  explicit StringDest(string* contents) : contents_(contents) {}
-
-  Status Close() override { return Status::OK(); }
-  Status Flush() override { return Status::OK(); }
-  Status Sync() override { return Status::OK(); }
-  Status Append(const StringPiece& slice) override {
-    contents_->append(slice.data(), slice.size());
-    return Status::OK();
-  }
-
+class RecordioTest : public ::testing::Test {
  private:
-  string* contents_;
-};
-
-class StringSource : public RandomAccessFile {
- public:
-  explicit StringSource(string* contents)
-      : contents_(contents), force_error_(false) {}
-
-  Status Read(uint64 offset, size_t n, StringPiece* result,
-              char* scratch) const override {
-    if (force_error_) {
-      force_error_ = false;
-      return errors::DataLoss("read error");
+  class StringDest : public WritableFile {
+   public:
+    string contents_;
+
+    Status Close() override { return Status::OK(); }
+    Status Flush() override { return Status::OK(); }
+    Status Sync() override { return Status::OK(); }
+    Status Append(const StringPiece& slice) override {
+      contents_.append(slice.data(), slice.size());
+      return Status::OK();
     }
-
-    if (offset >= contents_->size()) {
-      return errors::OutOfRange("end of file");
-    }
-
-    if (contents_->size() < offset + n) {
-      n = contents_->size() - offset;
+  };
+
+  class StringSource : public RandomAccessFile {
+   public:
+    StringPiece contents_;
+    mutable bool force_error_;
+    mutable bool returned_partial_;
+    StringSource() : force_error_(false), returned_partial_(false) {}
+
+    Status Read(uint64 offset, size_t n, StringPiece* result,
+                char* scratch) const override {
+      EXPECT_FALSE(returned_partial_) << "must not Read() after eof/error";
+
+      if (force_error_) {
+        force_error_ = false;
+        returned_partial_ = true;
+        return errors::DataLoss("read error");
+      }
+
+      if (offset >= contents_.size()) {
+        return errors::OutOfRange("end of file");
+      }
+
+      if (contents_.size() < offset + n) {
+        n = contents_.size() - offset;
+        returned_partial_ = true;
+      }
+      *result = StringPiece(contents_.data() + offset, n);
+      return Status::OK();
     }
-    *result = StringPiece(contents_->data() + offset, n);
-    return Status::OK();
-  }
-
-  void force_error() { force_error_ = true; }
-
- private:
-  string* contents_;
-  mutable bool force_error_;
-};
+  };
 
-class RecordioTest : public ::testing::Test {
- private:
-  string contents_;
   StringDest dest_;
   StringSource source_;
   bool reading_;
@@ -109,9 +104,7 @@ class RecordioTest : public ::testing::Test {
 
  public:
   RecordioTest()
-      : dest_(&contents_),
-        source_(&contents_),
-        reading_(false),
+      : reading_(false),
         readpos_(0),
         writer_(new RecordWriter(&dest_)),
         reader_(new RecordReader(&source_)) {}
@@ -126,11 +119,12 @@ class RecordioTest : public ::testing::Test {
     TF_ASSERT_OK(writer_->WriteRecord(StringPiece(msg)));
   }
 
-  size_t WrittenBytes() const { return contents_.size(); }
+  size_t WrittenBytes() const { return dest_.contents_.size(); }
 
   string Read() {
     if (!reading_) {
       reading_ = true;
+      source_.contents_ = StringPiece(dest_.contents_);
     }
     string record;
     Status s = reader_->ReadRecord(&readpos_, &record);
@@ -143,20 +137,26 @@ class RecordioTest : public ::testing::Test {
     }
   }
 
-  void IncrementByte(int offset, int delta) { contents_[offset] += delta; }
+  void IncrementByte(int offset, int delta) {
+    dest_.contents_[offset] += delta;
+  }
 
-  void SetByte(int offset, char new_byte) { contents_[offset] = new_byte; }
+  void SetByte(int offset, char new_byte) {
+    dest_.contents_[offset] = new_byte;
+  }
 
-  void ShrinkSize(int bytes) { contents_.resize(contents_.size() - bytes); }
+  void ShrinkSize(int bytes) {
+    dest_.contents_.resize(dest_.contents_.size() - bytes);
+  }
 
   void FixChecksum(int header_offset, int len) {
     // Compute crc of type/len/data
-    uint32_t crc = crc32c::Value(&contents_[header_offset + 6], 1 + len);
+    uint32_t crc = crc32c::Value(&dest_.contents_[header_offset + 6], 1 + len);
     crc = crc32c::Mask(crc);
-    core::EncodeFixed32(&contents_[header_offset], crc);
+    core::EncodeFixed32(&dest_.contents_[header_offset], crc);
   }
 
-  void ForceError() { source_.force_error(); }
+  void ForceError() { source_.force_error_ = true; }
 
   void StartReadingAt(uint64_t initial_offset) { readpos_ = initial_offset; }
 
@@ -165,6 +165,7 @@ class RecordioTest : public ::testing::Test {
     Write("bar");
     Write(BigString("x", 10000));
     reading_ = true;
+    source_.contents_ = StringPiece(dest_.contents_);
     uint64 offset = WrittenBytes() + offset_past_end;
     string record;
     Status s = reader_->ReadRecord(&offset, &record);
@@ -216,100 +217,16 @@ TEST_F(RecordioTest, RandomRead) {
   ASSERT_EQ("EOF", Read());
 }
 
-void TestNonSequentialReads(const RecordWriterOptions& writer_options,
-                            const RecordReaderOptions& reader_options) {
-  string contents;
-  StringDest dst(&contents);
-  RecordWriter writer(&dst, writer_options);
-  for (int i = 0; i < 10; ++i) {
-    TF_ASSERT_OK(writer.WriteRecord(NumberString(i))) << i;
-  }
-  TF_ASSERT_OK(writer.Close());
-
-  StringSource file(&contents);
-  RecordReader reader(&file, reader_options);
-
-  string record;
-  // First read sequentially to fill in the offsets table.
-  uint64 offsets[10] = {0};
-  uint64 offset = 0;
-  for (int i = 0; i < 10; ++i) {
-    offsets[i] = offset;
-    TF_ASSERT_OK(reader.ReadRecord(&offset, &record)) << i;
-  }
-
-  // Read randomly: First go back to record #3 then forward to #8.
-  offset = offsets[3];
-  TF_ASSERT_OK(reader.ReadRecord(&offset, &record));
-  EXPECT_EQ("3.", record);
-  EXPECT_EQ(offsets[4], offset);
-
-  offset = offsets[8];
-  TF_ASSERT_OK(reader.ReadRecord(&offset, &record));
-  EXPECT_EQ("8.", record);
-  EXPECT_EQ(offsets[9], offset);
-}
-
-TEST_F(RecordioTest, NonSequentialReads) {
-  TestNonSequentialReads(RecordWriterOptions(), RecordReaderOptions());
-}
-
-TEST_F(RecordioTest, NonSequentialReadsWithReadBuffer) {
-  RecordReaderOptions options;
-  options.buffer_size = 1 << 10;
-  TestNonSequentialReads(RecordWriterOptions(), options);
-}
-
-TEST_F(RecordioTest, NonSequentialReadsWithCompression) {
-  TestNonSequentialReads(
-      RecordWriterOptions::CreateRecordWriterOptions("ZLIB"),
-      RecordReaderOptions::CreateRecordReaderOptions("ZLIB"));
-}
-
 // Tests of all the error paths in log_reader.cc follow:
-void AssertHasSubstr(StringPiece s, StringPiece expected) {
+static void AssertHasSubstr(StringPiece s, StringPiece expected) {
   EXPECT_TRUE(str_util::StrContains(s, expected))
       << s << " does not contain " << expected;
 }
 
-void TestReadError(const RecordWriterOptions& writer_options,
-                   const RecordReaderOptions& reader_options) {
-  const string wrote = BigString("well hello there!", 100);
-  string contents;
-  StringDest dst(&contents);
-  TF_ASSERT_OK(RecordWriter(&dst, writer_options).WriteRecord(wrote));
-
-  StringSource file(&contents);
-  RecordReader reader(&file, reader_options);
-
-  uint64 offset = 0;
-  string read;
-  file.force_error();
-  Status status = reader.ReadRecord(&offset, &read);
-  ASSERT_TRUE(errors::IsDataLoss(status));
-  ASSERT_EQ(0, offset);
-
-  // A failed Read() shouldn't update the offset, and thus a retry shouldn't
-  // lose the record.
-  status = reader.ReadRecord(&offset, &read);
-  ASSERT_TRUE(status.ok()) << status;
-  EXPECT_GT(offset, 0);
-  EXPECT_EQ(wrote, read);
-}
-
 TEST_F(RecordioTest, ReadError) {
-  TestReadError(RecordWriterOptions(), RecordReaderOptions());
-}
-
-TEST_F(RecordioTest, ReadErrorWithBuffering) {
-  RecordReaderOptions options;
-  options.buffer_size = 1 << 20;
-  TestReadError(RecordWriterOptions(), options);
-}
-
-TEST_F(RecordioTest, ReadErrorWithCompression) {
-  TestReadError(RecordWriterOptions::CreateRecordWriterOptions("ZLIB"),
-                RecordReaderOptions::CreateRecordReaderOptions("ZLIB"));
+  Write("foo");
+  ForceError();
+  AssertHasSubstr(Read(), "Data loss");
 }
 
 TEST_F(RecordioTest, CorruptLength) {
@@ -340,6 +257,5 @@ TEST_F(RecordioTest, ReadEnd) { CheckOffsetPastEndReturnsNoRecords(0); }
 
 TEST_F(RecordioTest, ReadPastEnd) { CheckOffsetPastEndReturnsNoRecords(5); }
 
-}  // namespace
 }  // namespace io
 }  // namespace tensorflow
index bf8dcf0..984fbc2 100644 (file)
@@ -25,9 +25,8 @@ ZlibInputStream::ZlibInputStream(
     InputStreamInterface* input_stream,
     size_t input_buffer_bytes,   // size of z_stream.next_in buffer
     size_t output_buffer_bytes,  // size of z_stream.next_out buffer
-    const ZlibCompressionOptions& zlib_options, bool owns_input_stream)
-    : owns_input_stream_(owns_input_stream),
-      input_stream_(input_stream),
+    const ZlibCompressionOptions& zlib_options)
+    : input_stream_(input_stream),
       input_buffer_capacity_(input_buffer_bytes),
       output_buffer_capacity_(output_buffer_bytes),
       z_stream_input_(new Bytef[input_buffer_capacity_]),
@@ -42,14 +41,10 @@ ZlibInputStream::~ZlibInputStream() {
   if (z_stream_) {
     inflateEnd(z_stream_.get());
   }
-  if (owns_input_stream_) {
-    delete input_stream_;
-  }
 }
 
 Status ZlibInputStream::Reset() {
   TF_RETURN_IF_ERROR(input_stream_->Reset());
-  inflateEnd(z_stream_.get());
   InitZlibBuffer();
   bytes_read_ = 0;
   return Status::OK();
index 6099e24..9c7e144 100644 (file)
@@ -40,13 +40,10 @@ class ZlibInputStream : public InputStreamInterface {
   // Create a ZlibInputStream for `input_stream` with a buffer of size
   // `input_buffer_bytes` bytes for reading contents from `input_stream` and
   // another buffer with size `output_buffer_bytes` for caching decompressed
-  // contents.
-  //
-  // Takes ownership of `input_stream` iff `owns_input_stream` is true.
+  // contents. Does *not* take ownership of "input_stream".
   ZlibInputStream(InputStreamInterface* input_stream, size_t input_buffer_bytes,
                   size_t output_buffer_bytes,
-                  const ZlibCompressionOptions& zlib_options,
-                  bool owns_input_stream = false);
+                  const ZlibCompressionOptions& zlib_options);
 
   ~ZlibInputStream();
 
@@ -68,8 +65,7 @@ class ZlibInputStream : public InputStreamInterface {
  private:
   void InitZlibBuffer();
 
-  const bool owns_input_stream_;
-  InputStreamInterface* input_stream_;
+  InputStreamInterface* input_stream_;  // Not owned
   size_t input_buffer_capacity_;        // Size of z_stream_input_
   size_t output_buffer_capacity_;       // Size of z_stream_output_
   char* next_unread_byte_;              // Next unread byte in z_stream_output_