[Support] Add WriteThroughMemoryBuffer.
authorZachary Turner <zturner@google.com>
Thu, 8 Mar 2018 20:34:47 +0000 (20:34 +0000)
committerZachary Turner <zturner@google.com>
Thu, 8 Mar 2018 20:34:47 +0000 (20:34 +0000)
This is like MemoryBuffer (read-only) and WritableMemoryBuffer
(writable private), but where the underlying file can be modified
after writing.  This is useful when you want to open a file, make
some targeted edits, and then write it back out.

Differential Revision: https://reviews.llvm.org/D44230

llvm-svn: 327057

llvm/include/llvm/Support/FileSystem.h
llvm/include/llvm/Support/MemoryBuffer.h
llvm/lib/Support/MemoryBuffer.cpp
llvm/lib/Support/Unix/Path.inc
llvm/lib/Support/Windows/Path.inc
llvm/unittests/Support/MemoryBufferTest.cpp

index f84e281..499682c 100644 (file)
@@ -678,15 +678,20 @@ enum OpenFlags : unsigned {
   /// with F_Excl.
   F_Append = 2,
 
+  /// F_NoTrunc - When opening a file, if it already exists don't truncate
+  /// the file contents.  F_Append implies F_NoTrunc, but F_Append seeks to
+  /// the end of the file, which F_NoTrunc doesn't.
+  F_NoTrunc = 4,
+
   /// The file should be opened in text mode on platforms that make this
   /// distinction.
-  F_Text = 4,
+  F_Text = 8,
 
   /// Open the file for read and write.
-  F_RW = 8,
+  F_RW = 16,
 
   /// Delete the file on close. Only makes a difference on windows.
-  F_Delete = 16
+  F_Delete = 32
 };
 
 /// @brief Create a uniquely named file.
index 9e13715..2997ae4 100644 (file)
@@ -20,6 +20,7 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/CBindingWrapping.h"
 #include "llvm/Support/ErrorOr.h"
+#include "llvm/Support/FileSystem.h"
 #include <cstddef>
 #include <cstdint>
 #include <memory>
@@ -49,7 +50,8 @@ protected:
   void init(const char *BufStart, const char *BufEnd,
             bool RequiresNullTerminator);
 
-  static constexpr bool Writable = false;
+  static constexpr sys::fs::mapped_file_region::mapmode Mapmode =
+      sys::fs::mapped_file_region::readonly;
 
 public:
   MemoryBuffer(const MemoryBuffer &) = delete;
@@ -148,15 +150,16 @@ public:
   MemoryBufferRef getMemBufferRef() const;
 };
 
-/// This class is an extension of MemoryBuffer, which allows writing to the
-/// underlying contents.  It only supports creation methods that are guaranteed
-/// to produce a writable buffer.  For example, mapping a file read-only is not
-/// supported.
+/// This class is an extension of MemoryBuffer, which allows copy-on-write
+/// access to the underlying contents.  It only supports creation methods that
+/// are guaranteed to produce a writable buffer.  For example, mapping a file
+/// read-only is not supported.
 class WritableMemoryBuffer : public MemoryBuffer {
 protected:
   WritableMemoryBuffer() = default;
 
-  static constexpr bool Writable = true;
+  static constexpr sys::fs::mapped_file_region::mapmode Mapmode =
+      sys::fs::mapped_file_region::priv;
 
 public:
   using MemoryBuffer::getBuffer;
@@ -209,6 +212,54 @@ private:
   using MemoryBuffer::getSTDIN;
 };
 
+/// This class is an extension of MemoryBuffer, which allows write access to
+/// the underlying contents and committing those changes to the original source.
+/// It only supports creation methods that are guaranteed to produce a writable
+/// buffer.  For example, mapping a file read-only is not supported.
+class WriteThroughMemoryBuffer : public MemoryBuffer {
+protected:
+  WriteThroughMemoryBuffer() = default;
+
+  static constexpr sys::fs::mapped_file_region::mapmode Mapmode =
+      sys::fs::mapped_file_region::readwrite;
+
+public:
+  using MemoryBuffer::getBuffer;
+  using MemoryBuffer::getBufferEnd;
+  using MemoryBuffer::getBufferStart;
+
+  // const_cast is well-defined here, because the underlying buffer is
+  // guaranteed to have been initialized with a mutable buffer.
+  char *getBufferStart() {
+    return const_cast<char *>(MemoryBuffer::getBufferStart());
+  }
+  char *getBufferEnd() {
+    return const_cast<char *>(MemoryBuffer::getBufferEnd());
+  }
+  MutableArrayRef<char> getBuffer() {
+    return {getBufferStart(), getBufferEnd()};
+  }
+
+  static ErrorOr<std::unique_ptr<WriteThroughMemoryBuffer>>
+  getFile(const Twine &Filename, int64_t FileSize = -1);
+
+  /// Map a subrange of the specified file as a ReadWriteMemoryBuffer.
+  static ErrorOr<std::unique_ptr<WriteThroughMemoryBuffer>>
+  getFileSlice(const Twine &Filename, uint64_t MapSize, uint64_t Offset);
+
+private:
+  // Hide these base class factory function so one can't write
+  //   WritableMemoryBuffer::getXXX()
+  // and be surprised that he got a read-only Buffer.
+  using MemoryBuffer::getFileAsStream;
+  using MemoryBuffer::getFileOrSTDIN;
+  using MemoryBuffer::getMemBuffer;
+  using MemoryBuffer::getMemBufferCopy;
+  using MemoryBuffer::getOpenFile;
+  using MemoryBuffer::getOpenFileSlice;
+  using MemoryBuffer::getSTDIN;
+};
+
 class MemoryBufferRef {
   StringRef Buffer;
   StringRef Identifier;
index 9cea9a2..9f9987b 100644 (file)
@@ -184,10 +184,8 @@ class MemoryBufferMMapFile : public MB {
 public:
   MemoryBufferMMapFile(bool RequiresNullTerminator, int FD, uint64_t Len,
                        uint64_t Offset, std::error_code &EC)
-      : MFR(FD,
-            MB::Writable ? sys::fs::mapped_file_region::priv
-                         : sys::fs::mapped_file_region::readonly,
-            getLegalMapSize(Len, Offset), getLegalMapOffset(Offset), EC) {
+      : MFR(FD, MB::Mapmode, getLegalMapSize(Len, Offset),
+            getLegalMapOffset(Offset), EC) {
     if (!EC) {
       const char *Start = getStart(Len, Offset);
       MemoryBuffer::init(Start, Start + Len, RequiresNullTerminator);
@@ -361,6 +359,59 @@ static bool shouldUseMmap(int FD,
   return true;
 }
 
+static ErrorOr<std::unique_ptr<WriteThroughMemoryBuffer>>
+getReadWriteFile(const Twine &Filename, int64_t FileSize, uint64_t MapSize,
+                 uint64_t Offset) {
+  int FD;
+  std::error_code EC = sys::fs::openFileForWrite(
+      Filename, FD, sys::fs::F_RW | sys::fs::F_NoTrunc);
+
+  if (EC)
+    return EC;
+
+  // Default is to map the full file.
+  if (MapSize == uint64_t(-1)) {
+    // If we don't know the file size, use fstat to find out.  fstat on an open
+    // file descriptor is cheaper than stat on a random path.
+    if (FileSize == uint64_t(-1)) {
+      sys::fs::file_status Status;
+      std::error_code EC = sys::fs::status(FD, Status);
+      if (EC)
+        return EC;
+
+      // If this not a file or a block device (e.g. it's a named pipe
+      // or character device), we can't mmap it, so error out.
+      sys::fs::file_type Type = Status.type();
+      if (Type != sys::fs::file_type::regular_file &&
+          Type != sys::fs::file_type::block_file)
+        return make_error_code(errc::invalid_argument);
+
+      FileSize = Status.getSize();
+    }
+    MapSize = FileSize;
+  }
+
+  std::unique_ptr<WriteThroughMemoryBuffer> Result(
+      new (NamedBufferAlloc(Filename))
+          MemoryBufferMMapFile<WriteThroughMemoryBuffer>(false, FD, MapSize,
+                                                         Offset, EC));
+  if (EC)
+    return EC;
+  return std::move(Result);
+}
+
+ErrorOr<std::unique_ptr<WriteThroughMemoryBuffer>>
+WriteThroughMemoryBuffer::getFile(const Twine &Filename, int64_t FileSize) {
+  return getReadWriteFile(Filename, FileSize, FileSize, 0);
+}
+
+/// Map a subrange of the specified file as a WritableMemoryBuffer.
+ErrorOr<std::unique_ptr<WriteThroughMemoryBuffer>>
+WriteThroughMemoryBuffer::getFileSlice(const Twine &Filename, uint64_t MapSize,
+                                       uint64_t Offset) {
+  return getReadWriteFile(Filename, -1, MapSize, Offset);
+}
+
 template <typename MB>
 static ErrorOr<std::unique_ptr<MB>>
 getOpenFileImpl(int FD, const Twine &Filename, uint64_t FileSize,
index 088774a..1d5e56d 100644 (file)
@@ -792,7 +792,7 @@ std::error_code openFileForWrite(const Twine &Name, int &ResultFD,
 
   if (Flags & F_Append)
     OpenFlags |= O_APPEND;
-  else
+  else if (!(Flags & F_NoTrunc))
     OpenFlags |= O_TRUNC;
 
   if (Flags & F_Excl)
index 58c555d..1fcb759 100644 (file)
@@ -1101,7 +1101,7 @@ std::error_code openFileForWrite(const Twine &Name, int &ResultFD,
   DWORD CreationDisposition;
   if (Flags & F_Excl)
     CreationDisposition = CREATE_NEW;
-  else if (Flags & F_Append)
+  else if ((Flags & F_Append) || (Flags & F_NoTrunc))
     CreationDisposition = OPEN_ALWAYS;
   else
     CreationDisposition = CREATE_ALWAYS;
index 64a7bb6..bafdffe 100644 (file)
@@ -260,4 +260,33 @@ TEST_F(MemoryBufferTest, writableSlice) {
   for (size_t i = 0; i < MB.getBufferSize(); i += 0x10)
     EXPECT_EQ("0123456789abcdef", MB.getBuffer().substr(i, 0x10)) << "i: " << i;
 }
+
+TEST_F(MemoryBufferTest, writeThroughFile) {
+  // Create a file initialized with some data
+  int FD;
+  SmallString<64> TestPath;
+  sys::fs::createTemporaryFile("MemoryBufferTest_WriteThrough", "temp", FD,
+                               TestPath);
+  FileRemover Cleanup(TestPath);
+  raw_fd_ostream OF(FD, true);
+  OF << "0123456789abcdef";
+  OF.close();
+  {
+    auto MBOrError = WriteThroughMemoryBuffer::getFile(TestPath);
+    ASSERT_FALSE(MBOrError.getError());
+    // Write some data.  It should be mapped readwrite, so that upon completion
+    // the original file contents are modified.
+    WriteThroughMemoryBuffer &MB = **MBOrError;
+    ASSERT_EQ(16, MB.getBufferSize());
+    char *Start = MB.getBufferStart();
+    ASSERT_EQ(MB.getBufferEnd(), MB.getBufferStart() + MB.getBufferSize());
+    ::memset(Start, 'x', MB.getBufferSize());
+  }
+
+  auto MBOrError = MemoryBuffer::getFile(TestPath);
+  ASSERT_FALSE(MBOrError.getError());
+  auto &MB = **MBOrError;
+  ASSERT_EQ(16, MB.getBufferSize());
+  EXPECT_EQ("xxxxxxxxxxxxxxxx", MB.getBuffer());
+}
 }