RecordReader::RecordReader(RandomAccessFile* file,
const RecordReaderOptions& options)
- : options_(options),
- input_stream_(new RandomAccessInputStream(file)),
- last_read_failed_(false) {
+ : src_(file), options_(options) {
if (options.buffer_size > 0) {
- input_stream_.reset(new BufferedInputStream(input_stream_.release(),
- options.buffer_size, true));
+ input_stream_.reset(new BufferedInputStream(file, options.buffer_size));
+ } else {
+ input_stream_.reset(new RandomAccessInputStream(file));
}
if (options.compression_type == RecordReaderOptions::ZLIB_COMPRESSION) {
// We don't have zlib available on all embedded platforms, so fail.
#if defined(IS_SLIM_BUILD)
LOG(FATAL) << "Zlib compression is unsupported on mobile platforms.";
#else // IS_SLIM_BUILD
- input_stream_.reset(new ZlibInputStream(
- input_stream_.release(), options.zlib_options.input_buffer_size,
- options.zlib_options.output_buffer_size, options.zlib_options, true));
+ zlib_input_stream_.reset(new ZlibInputStream(
+ input_stream_.get(), options.zlib_options.input_buffer_size,
+ options.zlib_options.output_buffer_size, options.zlib_options));
#endif // IS_SLIM_BUILD
} else if (options.compression_type == RecordReaderOptions::NONE) {
// Nothing to do.
} else {
- LOG(FATAL) << "Unrecognized compression type :" << options.compression_type;
+ LOG(FATAL) << "Unspecified compression type :" << options.compression_type;
}
}
// Read n+4 bytes from file, verify that checksum of first n bytes is
// stored in the last 4 bytes and store the first n bytes in *result.
-//
-// offset corresponds to the user-provided value to ReadRecord()
-// and is used only in error messages.
-Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) {
+// May use *storage as backing store.
+Status RecordReader::ReadChecksummed(uint64 offset, size_t n,
+ StringPiece* result, string* storage) {
if (n >= SIZE_MAX - sizeof(uint32)) {
return errors::DataLoss("record size too large");
}
const size_t expected = n + sizeof(uint32);
- TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(expected, result));
+ storage->resize(expected);
+
+#if !defined(IS_SLIM_BUILD)
+ if (zlib_input_stream_) {
+ // If we have a zlib compressed buffer, we assume that the
+ // file is being read sequentially, and we use the underlying
+ // implementation to read the data.
+ //
+ // No checks are done to validate that the file is being read
+ // sequentially. At some point the zlib input buffer may support
+ // seeking, possibly inefficiently.
+ TF_RETURN_IF_ERROR(zlib_input_stream_->ReadNBytes(expected, storage));
+
+ if (storage->size() != expected) {
+ if (storage->empty()) {
+ return errors::OutOfRange("eof");
+ } else {
+ return errors::DataLoss("truncated record at ", offset);
+ }
+ }
- if (result->size() != expected) {
- if (result->empty()) {
- return errors::OutOfRange("eof");
+ uint32 masked_crc = core::DecodeFixed32(storage->data() + n);
+ if (crc32c::Unmask(masked_crc) != crc32c::Value(storage->data(), n)) {
+ return errors::DataLoss("corrupted record at ", offset);
+ }
+ *result = StringPiece(storage->data(), n);
+ } else {
+#endif // IS_SLIM_BUILD
+ if (options_.buffer_size > 0) {
+ // If we have a buffer, we assume that the file is being read
+ // sequentially, and we use the underlying implementation to read the
+ // data.
+ //
+ // No checks are done to validate that the file is being read
+ // sequentially.
+ TF_RETURN_IF_ERROR(input_stream_->ReadNBytes(expected, storage));
+
+ if (storage->size() != expected) {
+ if (storage->empty()) {
+ return errors::OutOfRange("eof");
+ } else {
+ return errors::DataLoss("truncated record at ", offset);
+ }
+ }
+
+ const uint32 masked_crc = core::DecodeFixed32(storage->data() + n);
+ if (crc32c::Unmask(masked_crc) != crc32c::Value(storage->data(), n)) {
+ return errors::DataLoss("corrupted record at ", offset);
+ }
+ *result = StringPiece(storage->data(), n);
} else {
- return errors::DataLoss("truncated record at ", offset);
+ // This version supports reading from arbitrary offsets
+ // since we are accessing the random access file directly.
+ StringPiece data;
+ TF_RETURN_IF_ERROR(src_->Read(offset, expected, &data, &(*storage)[0]));
+ if (data.size() != expected) {
+ if (data.empty()) {
+ return errors::OutOfRange("eof");
+ } else {
+ return errors::DataLoss("truncated record at ", offset);
+ }
+ }
+ const uint32 masked_crc = core::DecodeFixed32(data.data() + n);
+ if (crc32c::Unmask(masked_crc) != crc32c::Value(data.data(), n)) {
+ return errors::DataLoss("corrupted record at ", offset);
+ }
+ *result = StringPiece(data.data(), n);
}
+#if !defined(IS_SLIM_BUILD)
}
+#endif // IS_SLIM_BUILD
- const uint32 masked_crc = core::DecodeFixed32(result->data() + n);
- if (crc32c::Unmask(masked_crc) != crc32c::Value(result->data(), n)) {
- return errors::DataLoss("corrupted record at ", offset);
- }
- result->resize(n);
return Status::OK();
}
static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
static const size_t kFooterSize = sizeof(uint32);
- // Position the input stream.
- int64 curr_pos = input_stream_->Tell();
- int64 desired_pos = static_cast<int64>(*offset);
- if (curr_pos > desired_pos || curr_pos < 0 /* EOF */ ||
- (curr_pos == desired_pos && last_read_failed_)) {
- last_read_failed_ = false;
- TF_RETURN_IF_ERROR(input_stream_->Reset());
- TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos));
- } else if (curr_pos < desired_pos) {
- TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(desired_pos - curr_pos));
- }
- DCHECK_EQ(desired_pos, input_stream_->Tell());
-
// Read header data.
- Status s = ReadChecksummed(*offset, sizeof(uint64), record);
+ StringPiece lbuf;
+ Status s = ReadChecksummed(*offset, sizeof(uint64), &lbuf, record);
if (!s.ok()) {
- last_read_failed_ = true;
return s;
}
- const uint64 length = core::DecodeFixed64(record->data());
+ const uint64 length = core::DecodeFixed64(lbuf.data());
// Read data
- s = ReadChecksummed(*offset + kHeaderSize, length, record);
+ StringPiece data;
+ s = ReadChecksummed(*offset + kHeaderSize, length, &data, record);
if (!s.ok()) {
- last_read_failed_ = true;
if (errors::IsOutOfRange(s)) {
s = errors::DataLoss("truncated record at ", *offset);
}
return s;
}
+ if (record->data() != data.data()) {
+ // RandomAccessFile placed the data in some other location.
+ memmove(&(*record)[0], data.data(), data.size());
+ }
+
+ record->resize(data.size());
+
*offset += kHeaderSize + length + kFooterSize;
- DCHECK_EQ(*offset, input_stream_->Tell());
return Status::OK();
}
+Status RecordReader::SkipNBytes(uint64 offset) {
+#if !defined(IS_SLIM_BUILD)
+ if (zlib_input_stream_) {
+ TF_RETURN_IF_ERROR(zlib_input_stream_->SkipNBytes(offset));
+ } else {
+#endif
+ if (options_.buffer_size > 0) {
+ TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(offset));
+ }
+#if !defined(IS_SLIM_BUILD)
+ }
+#endif
+ return Status::OK();
+} // namespace io
+
SequentialRecordReader::SequentialRecordReader(
RandomAccessFile* file, const RecordReaderOptions& options)
: underlying_(file, options), offset_(0) {}
// Read the record at "*offset" into *record and update *offset to
// point to the offset of the next record. Returns OK on success,
// OUT_OF_RANGE for end of file, or something else for an error.
+ //
+ // Note: if buffering is used (with or without compression), access must be
+ // sequential.
Status ReadRecord(uint64* offset, string* record);
+ // Skip the records till "offset". Returns OK on success,
+ // OUT_OF_RANGE for end of file, or something else for an error.
+ Status SkipNBytes(uint64 offset);
+
private:
- Status ReadChecksummed(uint64 offset, size_t n, string* result);
+ Status ReadChecksummed(uint64 offset, size_t n, StringPiece* result,
+ string* storage);
+ RandomAccessFile* src_;
RecordReaderOptions options_;
std::unique_ptr<InputStreamInterface> input_stream_;
- bool last_read_failed_;
+#if !defined(IS_SLIM_BUILD)
+ std::unique_ptr<ZlibInputStream> zlib_input_stream_;
+#endif // IS_SLIM_BUILD
TF_DISALLOW_COPY_AND_ASSIGN(RecordReader);
};
return errors::InvalidArgument(
"Trying to seek offset: ", offset,
" which is less than the current offset: ", offset_);
+ TF_RETURN_IF_ERROR(underlying_.SkipNBytes(offset - offset_));
offset_ = offset;
return Status::OK();
}
namespace tensorflow {
namespace io {
-namespace {
// Construct a string of the specified length made out of the supplied
// partial string.
-string BigString(const string& partial_string, size_t n) {
+static string BigString(const string& partial_string, size_t n) {
string result;
while (result.size() < n) {
result.append(partial_string);
}
// Construct a string from a number
-string NumberString(int n) {
+static string NumberString(int n) {
char buf[50];
snprintf(buf, sizeof(buf), "%d.", n);
return string(buf);
}
// Return a skewed potentially long string
-string RandomSkewedString(int i, random::SimplePhilox* rnd) {
+static string RandomSkewedString(int i, random::SimplePhilox* rnd) {
return BigString(NumberString(i), rnd->Skewed(17));
}
-class StringDest : public WritableFile {
- public:
- explicit StringDest(string* contents) : contents_(contents) {}
-
- Status Close() override { return Status::OK(); }
- Status Flush() override { return Status::OK(); }
- Status Sync() override { return Status::OK(); }
- Status Append(const StringPiece& slice) override {
- contents_->append(slice.data(), slice.size());
- return Status::OK();
- }
-
+class RecordioTest : public ::testing::Test {
private:
- string* contents_;
-};
-
-class StringSource : public RandomAccessFile {
- public:
- explicit StringSource(string* contents)
- : contents_(contents), force_error_(false) {}
-
- Status Read(uint64 offset, size_t n, StringPiece* result,
- char* scratch) const override {
- if (force_error_) {
- force_error_ = false;
- return errors::DataLoss("read error");
+ class StringDest : public WritableFile {
+ public:
+ string contents_;
+
+ Status Close() override { return Status::OK(); }
+ Status Flush() override { return Status::OK(); }
+ Status Sync() override { return Status::OK(); }
+ Status Append(const StringPiece& slice) override {
+ contents_.append(slice.data(), slice.size());
+ return Status::OK();
}
-
- if (offset >= contents_->size()) {
- return errors::OutOfRange("end of file");
- }
-
- if (contents_->size() < offset + n) {
- n = contents_->size() - offset;
+ };
+
+ class StringSource : public RandomAccessFile {
+ public:
+ StringPiece contents_;
+ mutable bool force_error_;
+ mutable bool returned_partial_;
+ StringSource() : force_error_(false), returned_partial_(false) {}
+
+ Status Read(uint64 offset, size_t n, StringPiece* result,
+ char* scratch) const override {
+ EXPECT_FALSE(returned_partial_) << "must not Read() after eof/error";
+
+ if (force_error_) {
+ force_error_ = false;
+ returned_partial_ = true;
+ return errors::DataLoss("read error");
+ }
+
+ if (offset >= contents_.size()) {
+ return errors::OutOfRange("end of file");
+ }
+
+ if (contents_.size() < offset + n) {
+ n = contents_.size() - offset;
+ returned_partial_ = true;
+ }
+ *result = StringPiece(contents_.data() + offset, n);
+ return Status::OK();
}
- *result = StringPiece(contents_->data() + offset, n);
- return Status::OK();
- }
-
- void force_error() { force_error_ = true; }
-
- private:
- string* contents_;
- mutable bool force_error_;
-};
+ };
-class RecordioTest : public ::testing::Test {
- private:
- string contents_;
StringDest dest_;
StringSource source_;
bool reading_;
public:
RecordioTest()
- : dest_(&contents_),
- source_(&contents_),
- reading_(false),
+ : reading_(false),
readpos_(0),
writer_(new RecordWriter(&dest_)),
reader_(new RecordReader(&source_)) {}
TF_ASSERT_OK(writer_->WriteRecord(StringPiece(msg)));
}
- size_t WrittenBytes() const { return contents_.size(); }
+ size_t WrittenBytes() const { return dest_.contents_.size(); }
string Read() {
if (!reading_) {
reading_ = true;
+ source_.contents_ = StringPiece(dest_.contents_);
}
string record;
Status s = reader_->ReadRecord(&readpos_, &record);
}
}
- void IncrementByte(int offset, int delta) { contents_[offset] += delta; }
+ void IncrementByte(int offset, int delta) {
+ dest_.contents_[offset] += delta;
+ }
- void SetByte(int offset, char new_byte) { contents_[offset] = new_byte; }
+ void SetByte(int offset, char new_byte) {
+ dest_.contents_[offset] = new_byte;
+ }
- void ShrinkSize(int bytes) { contents_.resize(contents_.size() - bytes); }
+ void ShrinkSize(int bytes) {
+ dest_.contents_.resize(dest_.contents_.size() - bytes);
+ }
void FixChecksum(int header_offset, int len) {
// Compute crc of type/len/data
- uint32_t crc = crc32c::Value(&contents_[header_offset + 6], 1 + len);
+ uint32_t crc = crc32c::Value(&dest_.contents_[header_offset + 6], 1 + len);
crc = crc32c::Mask(crc);
- core::EncodeFixed32(&contents_[header_offset], crc);
+ core::EncodeFixed32(&dest_.contents_[header_offset], crc);
}
- void ForceError() { source_.force_error(); }
+ void ForceError() { source_.force_error_ = true; }
void StartReadingAt(uint64_t initial_offset) { readpos_ = initial_offset; }
Write("bar");
Write(BigString("x", 10000));
reading_ = true;
+ source_.contents_ = StringPiece(dest_.contents_);
uint64 offset = WrittenBytes() + offset_past_end;
string record;
Status s = reader_->ReadRecord(&offset, &record);
ASSERT_EQ("EOF", Read());
}
-void TestNonSequentialReads(const RecordWriterOptions& writer_options,
- const RecordReaderOptions& reader_options) {
- string contents;
- StringDest dst(&contents);
- RecordWriter writer(&dst, writer_options);
- for (int i = 0; i < 10; ++i) {
- TF_ASSERT_OK(writer.WriteRecord(NumberString(i))) << i;
- }
- TF_ASSERT_OK(writer.Close());
-
- StringSource file(&contents);
- RecordReader reader(&file, reader_options);
-
- string record;
- // First read sequentially to fill in the offsets table.
- uint64 offsets[10] = {0};
- uint64 offset = 0;
- for (int i = 0; i < 10; ++i) {
- offsets[i] = offset;
- TF_ASSERT_OK(reader.ReadRecord(&offset, &record)) << i;
- }
-
- // Read randomly: First go back to record #3 then forward to #8.
- offset = offsets[3];
- TF_ASSERT_OK(reader.ReadRecord(&offset, &record));
- EXPECT_EQ("3.", record);
- EXPECT_EQ(offsets[4], offset);
-
- offset = offsets[8];
- TF_ASSERT_OK(reader.ReadRecord(&offset, &record));
- EXPECT_EQ("8.", record);
- EXPECT_EQ(offsets[9], offset);
-}
-
-TEST_F(RecordioTest, NonSequentialReads) {
- TestNonSequentialReads(RecordWriterOptions(), RecordReaderOptions());
-}
-
-TEST_F(RecordioTest, NonSequentialReadsWithReadBuffer) {
- RecordReaderOptions options;
- options.buffer_size = 1 << 10;
- TestNonSequentialReads(RecordWriterOptions(), options);
-}
-
-TEST_F(RecordioTest, NonSequentialReadsWithCompression) {
- TestNonSequentialReads(
- RecordWriterOptions::CreateRecordWriterOptions("ZLIB"),
- RecordReaderOptions::CreateRecordReaderOptions("ZLIB"));
-}
-
// Tests of all the error paths in log_reader.cc follow:
-void AssertHasSubstr(StringPiece s, StringPiece expected) {
+static void AssertHasSubstr(StringPiece s, StringPiece expected) {
EXPECT_TRUE(str_util::StrContains(s, expected))
<< s << " does not contain " << expected;
}
-void TestReadError(const RecordWriterOptions& writer_options,
- const RecordReaderOptions& reader_options) {
- const string wrote = BigString("well hello there!", 100);
- string contents;
- StringDest dst(&contents);
- TF_ASSERT_OK(RecordWriter(&dst, writer_options).WriteRecord(wrote));
-
- StringSource file(&contents);
- RecordReader reader(&file, reader_options);
-
- uint64 offset = 0;
- string read;
- file.force_error();
- Status status = reader.ReadRecord(&offset, &read);
- ASSERT_TRUE(errors::IsDataLoss(status));
- ASSERT_EQ(0, offset);
-
- // A failed Read() shouldn't update the offset, and thus a retry shouldn't
- // lose the record.
- status = reader.ReadRecord(&offset, &read);
- ASSERT_TRUE(status.ok()) << status;
- EXPECT_GT(offset, 0);
- EXPECT_EQ(wrote, read);
-}
-
TEST_F(RecordioTest, ReadError) {
- TestReadError(RecordWriterOptions(), RecordReaderOptions());
-}
-
-TEST_F(RecordioTest, ReadErrorWithBuffering) {
- RecordReaderOptions options;
- options.buffer_size = 1 << 20;
- TestReadError(RecordWriterOptions(), options);
-}
-
-TEST_F(RecordioTest, ReadErrorWithCompression) {
- TestReadError(RecordWriterOptions::CreateRecordWriterOptions("ZLIB"),
- RecordReaderOptions::CreateRecordReaderOptions("ZLIB"));
+ Write("foo");
+ ForceError();
+ AssertHasSubstr(Read(), "Data loss");
}
TEST_F(RecordioTest, CorruptLength) {
TEST_F(RecordioTest, ReadPastEnd) { CheckOffsetPastEndReturnsNoRecords(5); }
-} // namespace
} // namespace io
} // namespace tensorflow
InputStreamInterface* input_stream,
size_t input_buffer_bytes, // size of z_stream.next_in buffer
size_t output_buffer_bytes, // size of z_stream.next_out buffer
- const ZlibCompressionOptions& zlib_options, bool owns_input_stream)
- : owns_input_stream_(owns_input_stream),
- input_stream_(input_stream),
+ const ZlibCompressionOptions& zlib_options)
+ : input_stream_(input_stream),
input_buffer_capacity_(input_buffer_bytes),
output_buffer_capacity_(output_buffer_bytes),
z_stream_input_(new Bytef[input_buffer_capacity_]),
if (z_stream_) {
inflateEnd(z_stream_.get());
}
- if (owns_input_stream_) {
- delete input_stream_;
- }
}
Status ZlibInputStream::Reset() {
TF_RETURN_IF_ERROR(input_stream_->Reset());
- inflateEnd(z_stream_.get());
InitZlibBuffer();
bytes_read_ = 0;
return Status::OK();
// Create a ZlibInputStream for `input_stream` with a buffer of size
// `input_buffer_bytes` bytes for reading contents from `input_stream` and
// another buffer with size `output_buffer_bytes` for caching decompressed
- // contents.
- //
- // Takes ownership of `input_stream` iff `owns_input_stream` is true.
+ // contents. Does *not* take ownership of "input_stream".
ZlibInputStream(InputStreamInterface* input_stream, size_t input_buffer_bytes,
size_t output_buffer_bytes,
- const ZlibCompressionOptions& zlib_options,
- bool owns_input_stream = false);
+ const ZlibCompressionOptions& zlib_options);
~ZlibInputStream();
private:
void InitZlibBuffer();
- const bool owns_input_stream_;
- InputStreamInterface* input_stream_;
+ InputStreamInterface* input_stream_; // Not owned
size_t input_buffer_capacity_; // Size of z_stream_input_
size_t output_buffer_capacity_; // Size of z_stream_output_
char* next_unread_byte_; // Next unread byte in z_stream_output_