From 079d63d59b75bdfd25f7371efda25ec5f6739b78 Mon Sep 17 00:00:00 2001 From: Sourabh Bajaj Date: Wed, 11 Apr 2018 15:20:11 -0700 Subject: [PATCH] GCS Filesystem should not cache checkpoint file as we need to read the updated checkpoints from the contents. PiperOrigin-RevId: 192517819 --- tensorflow/core/platform/cloud/gcs_file_system.cc | 8 ++++ .../core/platform/cloud/gcs_file_system_test.cc | 48 ++++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 3c0dc13..6ed1d5d 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -301,6 +301,14 @@ class GcsRandomAccessFile : public RandomAccessFile { TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, scratch, &bytes_transferred)); *result = StringPiece(scratch, bytes_transferred); + string checkpoint_ending = "/checkpoint"; + // Check if the file is the checkpoint file as we should not be caching + // that. As it's contents are updated and used for iterating checkpoints. + if (std::equal(checkpoint_ending.rbegin(), checkpoint_ending.rend(), + filename_.rbegin())) { + // Remove the checkpoint file from the cache + file_block_cache_->RemoveFile(filename_); + } if (bytes_transferred < n) { // This is not an error per se. The RandomAccessFile interface expects // that Read returns OutOfRange if fewer bytes were read than requested. diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc index 2fbde9b..e9eca04 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc @@ -198,6 +198,54 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache) { EXPECT_EQ("0123", result); } +TEST(GcsFileSystemTest, NewRandomAccessFile_CheckpointFile_WithBlockCache) { + // Our underlying file in this test changes as new data comes in + std::vector requests( + {new FakeHttpRequest( + "Uri: https://storage.googleapis.com/bucket/checkpoint\n" + "Auth Token: fake_token\n" + "Range: 0-8\n" + "Timeouts: 5 1 20\n", + "012345678"), + new FakeHttpRequest( + "Uri: https://storage.googleapis.com/bucket/checkpoint\n" + "Auth Token: fake_token\n" + "Range: 0-8\n" + "Timeouts: 5 1 20\n", + "abcdefghi")}); + GcsFileSystem fs( + std::unique_ptr(new FakeAuthProvider), + std::unique_ptr( + new FakeHttpRequestFactory(&requests)), + 9 /* block size */, 18 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + kTestTimeoutConfig, nullptr /* gcs additional header */); + + char scratch[100]; + StringPiece result; + { + // We are instantiating this in an enclosed scope to make sure after the + // unique ptr goes out of scope, we can still access result. + std::unique_ptr file; + TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/checkpoint", &file)); + + // Read the first chunk. The cache will be populated with the first block of + // 9 bytes. + scratch[5] = 'x'; + TF_EXPECT_OK(file->Read(0, 4, &result, scratch)); + EXPECT_EQ("0123", result); + EXPECT_EQ(scratch[5], 'x'); // Make sure we only copied 4 bytes. + + // The second chunk should not be in cache so we make a new request + // As the checkpoint file should not be cached + TF_EXPECT_OK(file->Read(0, 4, &result, scratch)); + EXPECT_EQ("abcd", result); + EXPECT_EQ(scratch[5], 'x'); // Make sure we only copied 4 bytes. + } +} + TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_Flush) { // Our underlying file in this test is a 15 byte file with contents // "0123456789abcde". -- 2.7.4