Adding a hook (wrapper) for non-std stream reader in PyTorchStreamReader (#15551)
authorLu Fang <lufang@fb.com>
Sat, 5 Jan 2019 06:47:35 +0000 (22:47 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 5 Jan 2019 06:50:07 +0000 (22:50 -0800)
Summary:
To implement a stream is very annoying, since it is closely defined with the underlying storage streambuffer.

So in this PR, we add ReadAdapterInterface and PyTorchStreamReader will use it. We implement IStreamAdapter as a wrapper of std::istream. And keep the user interface unchanged.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15551

Reviewed By: zrphercule

Differential Revision: D13568907

Pulled By: houseroad

fbshipit-source-id: 93708cb801248a6c101f35cb14d1631029365c3c

13 files changed:
caffe2/serialize/CMakeLists.txt
caffe2/serialize/file_adapter.cc [new file with mode: 0644]
caffe2/serialize/file_adapter.h [new file with mode: 0644]
caffe2/serialize/inline_container.cc
caffe2/serialize/inline_container.h
caffe2/serialize/inline_container_test.cc
caffe2/serialize/istream_adapter.cc [new file with mode: 0644]
caffe2/serialize/istream_adapter.h [new file with mode: 0644]
caffe2/serialize/read_adapter_interface.cc [new file with mode: 0644]
caffe2/serialize/read_adapter_interface.h [new file with mode: 0644]
torch/csrc/jit/export.cpp
torch/csrc/jit/import.cpp
torch/csrc/jit/init.cpp

index b0ed70d..bcda33c 100644 (file)
@@ -3,7 +3,10 @@ file(GLOB tmp *_test.cc)
 set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} ${tmp})
 list(APPEND Caffe2_CPU_SRCS
   ${PROJECT_SOURCE_DIR}/third_party/miniz-2.0.8/miniz.c
-  ${CMAKE_CURRENT_SOURCE_DIR}/inline_container.cc)
+  ${CMAKE_CURRENT_SOURCE_DIR}/inline_container.cc
+  ${CMAKE_CURRENT_SOURCE_DIR}/istream_adapter.cc
+  ${CMAKE_CURRENT_SOURCE_DIR}/file_adapter.cc
+  ${CMAKE_CURRENT_SOURCE_DIR}/read_adapter_interface.cc)
 list(APPEND Caffe2_CPU_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/miniz-2.0.8)
 
 set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE)
diff --git a/caffe2/serialize/file_adapter.cc b/caffe2/serialize/file_adapter.cc
new file mode 100644 (file)
index 0000000..dd30ce3
--- /dev/null
@@ -0,0 +1,28 @@
+#include "caffe2/serialize/file_adapter.h"
+#include <c10/util/Exception.h>
+#include "caffe2/core/common.h"
+
+namespace caffe2 {
+namespace serialize {
+
+FileAdapter::FileAdapter(const std::string& file_name) {
+  file_stream_.open(file_name, std::ifstream::in | std::ifstream::binary);
+  if (!file_stream_) {
+    AT_ERROR("open file failed, file path: ", file_name);
+  }
+  istream_adapter_ = caffe2::make_unique<IStreamAdapter>(&file_stream_);
+}
+
+size_t FileAdapter::size() const {
+  return istream_adapter_->size();
+}
+
+size_t FileAdapter::read(uint64_t pos, void* buf, size_t n, const char* what)
+    const {
+  return istream_adapter_->read(pos, buf, n, what);
+}
+
+FileAdapter::~FileAdapter() {}
+
+} // namespace serialize
+} // namespace caffe2
diff --git a/caffe2/serialize/file_adapter.h b/caffe2/serialize/file_adapter.h
new file mode 100644 (file)
index 0000000..cc05839
--- /dev/null
@@ -0,0 +1,28 @@
+#pragma once
+
+#include <fstream>
+#include <memory>
+
+#include <c10/macros/Macros.h>
+#include "caffe2/serialize/istream_adapter.h"
+#include "caffe2/serialize/read_adapter_interface.h"
+
+namespace caffe2 {
+namespace serialize {
+
+class FileAdapter final : public ReadAdapterInterface {
+ public:
+  C10_DISABLE_COPY_AND_ASSIGN(FileAdapter);
+  explicit FileAdapter(const std::string& file_name);
+  size_t size() const override;
+  size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
+      const override;
+  ~FileAdapter();
+
+ private:
+  std::ifstream file_stream_;
+  std::unique_ptr<IStreamAdapter> istream_adapter_;
+};
+
+} // namespace serialize
+} // namespace caffe2
index 667359d..9d1c426 100644 (file)
@@ -8,12 +8,17 @@
 #include <c10/core/Allocator.h>
 #include <c10/core/Backend.h>
 
+#include "caffe2/core/common.h"
 #include "caffe2/core/logging.h"
+#include "caffe2/serialize/file_adapter.h"
 #include "caffe2/serialize/inline_container.h"
+#include "caffe2/serialize/istream_adapter.h"
+#include "caffe2/serialize/read_adapter_interface.h"
 
 #include "miniz.h"
 
-namespace torch { namespace jit {
+namespace caffe2 {
+namespace serialize {
 
 size_t istream_read_func(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n) {
   auto self = static_cast<PyTorchStreamReader*>(pOpaque);
@@ -42,27 +47,33 @@ static std::string basename(const std::string& name) {
 }
 
 size_t PyTorchStreamReader::read(uint64_t pos, char* buf, size_t n) {
-  in_->seekg(pos);
-  if(!*in_)
-    return 0;
-  in_->read(static_cast<char*>(buf), n);
-  if(!*in_)
-    return 0;
-  return n;
+  return in_->read(pos, buf, n, "reading file");
 }
 
-PyTorchStreamReader::PyTorchStreamReader(std::string file_name, std::istream* in)
-: ar_(new mz_zip_archive), in_(in) {
-  memset(ar_.get(), 0, sizeof(mz_zip_archive));
+PyTorchStreamReader::PyTorchStreamReader(const std::string& file_name)
+    : ar_(caffe2::make_unique<mz_zip_archive>()),
+      in_(caffe2::make_unique<FileAdapter>(file_name)) {
+  init();
+}
 
-  if (!in_) {
-    file_stream_.open(file_name, std::ifstream::in | std::ifstream::binary);
-    in_ = &file_stream_;
-    valid("opening archive");
-  }
+PyTorchStreamReader::PyTorchStreamReader(std::istream* in)
+    : ar_(caffe2::make_unique<mz_zip_archive>()),
+      in_(caffe2::make_unique<IStreamAdapter>(in)) {
+  init();
+}
+
+PyTorchStreamReader::PyTorchStreamReader(
+    std::unique_ptr<ReadAdapterInterface> in)
+    : ar_(caffe2::make_unique<mz_zip_archive>()), in_(std::move(in)) {
+  init();
+}
 
-  in_->seekg(0, in_->end);
-  size_t size = in_->tellg();
+void PyTorchStreamReader::init() {
+  AT_ASSERT(in_ != nullptr);
+  AT_ASSERT(ar_ != nullptr);
+  memset(ar_.get(), 0, sizeof(mz_zip_archive));
+
+  size_t size = in_->size();
 
   // check for the old magic number,
   constexpr size_t kMagicValueLength = 8;
@@ -81,7 +92,6 @@ PyTorchStreamReader::PyTorchStreamReader(std::string file_name, std::istream* in
   mz_zip_reader_init(ar_.get(), size, 0);
   valid("reading zip archive");
 
-
   // figure out the archive_name (i.e. the zip folder all the other files are in)
   // all lookups to getRecord will be prefixed by this folder
   int n = mz_zip_reader_get_num_files(ar_.get());
@@ -126,9 +136,6 @@ void PyTorchStreamReader::valid(const char* what) {
   if (err != MZ_ZIP_NO_ERROR) {
     CAFFE_THROW("PytorchStreamReader failed ", what, ": ", mz_zip_get_error_string(err));
   }
-  if (!*in_) {
-    CAFFE_THROW("PytorchStreamReader failed ", what, ".");
-  }
 }
 
 constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
@@ -191,11 +198,12 @@ size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
   mz_zip_archive_file_stat stat;
   mz_zip_reader_file_stat(ar_.get(), getFileID(name), &stat);
   valid("retriving file meta-data");
-  in_->seekg(stat.m_local_header_ofs);
-  valid("seeking to file header");
   uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
-  in_->read(reinterpret_cast<char*>(local_header), MZ_ZIP_LOCAL_DIR_HEADER_SIZE);
-  valid("reading file header");
+  in_->read(
+      stat.m_local_header_ofs,
+      local_header,
+      MZ_ZIP_LOCAL_DIR_HEADER_SIZE,
+      "reading file header");
   size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
   size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
   return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
@@ -226,8 +234,12 @@ size_t ostream_write_func(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, s
   return n;
 }
 
-PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name, std::ostream* out)
-: ar_(new mz_zip_archive), archive_name_(basename(file_name)), out_(out) {
+PyTorchStreamWriter::PyTorchStreamWriter(
+    std::string file_name,
+    std::ostream* out)
+    : ar_(caffe2::make_unique<mz_zip_archive>()),
+      archive_name_(basename(file_name)),
+      out_(out) {
   memset(ar_.get(), 0, sizeof(mz_zip_archive));
 
   if (archive_name_.size() == 0) {
@@ -302,4 +314,5 @@ PyTorchStreamWriter::~PyTorchStreamWriter() {
   }
 }
 
-}}  // namespace torch::jit
+} // namespace serialize
+} // namespace caffe2
index 8df0c03..28f7149 100644 (file)
@@ -11,6 +11,8 @@
 #include <c10/core/Backend.h>
 
 #include "caffe2/core/logging.h"
+#include "caffe2/serialize/istream_adapter.h"
+#include "caffe2/serialize/read_adapter_interface.h"
 
 extern "C" {
 typedef struct mz_zip_archive mz_zip_archive;
@@ -84,7 +86,8 @@ typedef struct mz_zip_archive mz_zip_archive;
 // model.json as the last file when writing after we have accumulated all
 // other information.
 
-namespace torch { namespace jit {
+namespace caffe2 {
+namespace serialize {
 
 constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
 constexpr uint64_t kMaxSupportedFileFormatVersion = 0x1L;
@@ -97,9 +100,9 @@ constexpr uint64_t kFieldAlignment = 64;
 
 class CAFFE2_API PyTorchStreamReader final {
  public:
-  PyTorchStreamReader(std::string archive_name, std::istream* in=nullptr);
-  PyTorchStreamReader(std::istream* in)
-  : PyTorchStreamReader("archive", in) {}
+  explicit PyTorchStreamReader(const std::string& file_name);
+  explicit PyTorchStreamReader(std::istream* in);
+  explicit PyTorchStreamReader(std::unique_ptr<ReadAdapterInterface> in);
 
   // return dataptr, size
   std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
@@ -109,15 +112,16 @@ class CAFFE2_API PyTorchStreamReader final {
   ~PyTorchStreamReader();
 
  private:
-   size_t read(uint64_t pos, char* buf, size_t n);
-   void valid(const char* what);
-   size_t getFileID(const std::string& name);
-
-   friend size_t istream_read_func(void *pOpaque, uint64_t file_ofs, void *pBuf, size_t n);
-   std::unique_ptr<mz_zip_archive> ar_;
-   std::string archive_name_;
-   std::istream* in_;
-   std::ifstream file_stream_;
+  void init();
+  size_t read(uint64_t pos, char* buf, size_t n);
+  void valid(const char* what);
+  size_t getFileID(const std::string& name);
+
+  friend size_t
+  istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n);
+  std::unique_ptr<mz_zip_archive> ar_;
+  std::string archive_name_;
+  std::unique_ptr<ReadAdapterInterface> in_;
 };
 
 class CAFFE2_API PyTorchStreamWriter final {
@@ -150,4 +154,5 @@ class CAFFE2_API PyTorchStreamWriter final {
    friend size_t ostream_write_func(void *pOpaque, uint64_t file_ofs, const void *pBuf, size_t n);
 };
 
-}}  // namespace torch::jit
+} // namespace serialize
+} // namespace caffe2
index 341e6f9..6b4c969 100644 (file)
@@ -6,7 +6,8 @@
 
 #include "caffe2/serialize/inline_container.h"
 
-namespace at {
+namespace caffe2 {
+namespace serialize {
 namespace {
 
 TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
@@ -14,7 +15,7 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
 
   std::ostringstream oss;
   // write records through writers
-  torch::jit::PyTorchStreamWriter writer(&oss);
+  PyTorchStreamWriter writer(&oss);
   std::array<char, 127> data1;
 
   for (int i = 0; i < data1.size(); ++i) {
@@ -37,7 +38,7 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
   std::istringstream iss(the_file);
 
   // read records through readers
-  torch::jit::PyTorchStreamReader reader(&iss);
+  PyTorchStreamReader reader(&iss);
   at::DataPtr data_ptr;
   int64_t size;
   std::tie(data_ptr, size) = reader.getRecord("key1");
@@ -58,4 +59,5 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
 }
 
 } // namespace
-} // namespace at
+} // namespace serialize
+} // namespace caffe2
diff --git a/caffe2/serialize/istream_adapter.cc b/caffe2/serialize/istream_adapter.cc
new file mode 100644 (file)
index 0000000..8561253
--- /dev/null
@@ -0,0 +1,39 @@
+#include "caffe2/serialize/istream_adapter.h"
+#include <c10/util/Exception.h>
+
+namespace caffe2 {
+namespace serialize {
+
+IStreamAdapter::IStreamAdapter(std::istream* istream) : istream_(istream) {}
+
+size_t IStreamAdapter::size() const {
+  auto prev_pos = istream_->tellg();
+  validate("getting the current position");
+  istream_->seekg(0, istream_->end);
+  validate("seeking to end");
+  auto result = istream_->tellg();
+  validate("getting size");
+  istream_->seekg(prev_pos);
+  validate("seeking to the original position");
+  return result;
+}
+
+size_t IStreamAdapter::read(uint64_t pos, void* buf, size_t n, const char* what)
+    const {
+  istream_->seekg(pos);
+  validate(what);
+  istream_->read(static_cast<char*>(buf), n);
+  validate(what);
+  return n;
+}
+
+void IStreamAdapter::validate(const char* what) const {
+  if (!*istream_) {
+    AT_ERROR("istream reader failed: ", what, ".");
+  }
+}
+
+IStreamAdapter::~IStreamAdapter() {}
+
+} // namespace serialize
+} // namespace caffe2
diff --git a/caffe2/serialize/istream_adapter.h b/caffe2/serialize/istream_adapter.h
new file mode 100644 (file)
index 0000000..4d597e1
--- /dev/null
@@ -0,0 +1,28 @@
+#pragma once
+
+#include <istream>
+
+#include <c10/macros/Macros.h>
+
+#include "caffe2/serialize/read_adapter_interface.h"
+
+namespace caffe2 {
+namespace serialize {
+
+// this is a reader implemented by std::istream
+class IStreamAdapter final : public ReadAdapterInterface {
+ public:
+  C10_DISABLE_COPY_AND_ASSIGN(IStreamAdapter);
+  explicit IStreamAdapter(std::istream* istream);
+  size_t size() const override;
+  size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
+      const override;
+  ~IStreamAdapter();
+
+ private:
+  std::istream* istream_;
+  void validate(const char* what) const;
+};
+
+} // namespace serialize
+} // namespace caffe2
diff --git a/caffe2/serialize/read_adapter_interface.cc b/caffe2/serialize/read_adapter_interface.cc
new file mode 100644 (file)
index 0000000..739b24f
--- /dev/null
@@ -0,0 +1,9 @@
+#include "caffe2/serialize/read_adapter_interface.h"
+
+namespace caffe2 {
+namespace serialize {
+
+ReadAdapterInterface::~ReadAdapterInterface() {}
+
+} // namespace serialize
+} // namespace caffe2
diff --git a/caffe2/serialize/read_adapter_interface.h b/caffe2/serialize/read_adapter_interface.h
new file mode 100644 (file)
index 0000000..e153f81
--- /dev/null
@@ -0,0 +1,21 @@
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+
+namespace caffe2 {
+namespace serialize {
+
+// this is the interface for the (file/stream/memory) reader in
+// PyTorchStreamReader. with this interface, we can extend the support
+// besides standard istream
+class ReadAdapterInterface {
+ public:
+  virtual size_t size() const = 0;
+  virtual size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
+      const = 0;
+  virtual ~ReadAdapterInterface();
+};
+
+} // namespace serialize
+} // namespace caffe2
index d6cf849..44be035 100644 (file)
@@ -497,7 +497,7 @@ class ScriptModuleSerializer final {
       torch::ParameterDef* param_def);
 
   std::ofstream ofs_;
-  PyTorchStreamWriter writer_;
+  caffe2::serialize::PyTorchStreamWriter writer_;
 
   // all tensors that will be stored
   std::vector<at::Tensor> tensor_table_;
index 56c5859..437236c 100644 (file)
@@ -50,7 +50,7 @@ class ScriptModuleDeserializer final {
 
   void loadTensorTable(torch::ModelDef* model_def);
 
-  PyTorchStreamReader reader_;
+  caffe2::serialize::PyTorchStreamReader reader_;
   // this is a hack to make sure the script module created in C++ is the
   // same as created in Python
   ModuleLookup moduleLookup_;
index 38dbe69..f9b4e88 100644 (file)
@@ -55,6 +55,9 @@
 namespace torch {
 namespace jit {
 
+using caffe2::serialize::PyTorchStreamReader;
+using caffe2::serialize::PyTorchStreamWriter;
+
 // TODO: make a fake future for python
 namespace detail {
 class Future {};