From a918f1d9af8a0a405d84dff5fae64e269770bf49 Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Fri, 4 Jan 2019 22:47:35 -0800 Subject: [PATCH] Adding a hook (wrapper) for non-std stream reader in PyTorchStreamReader (#15551) Summary: To implement a stream is very annoying, since it is closely defined with the underlying storage streambuffer. So in this PR, we add ReadAdapterInterface and PyTorchStreamReader will use it. We implement IStreamAdapter as a wrapper of std::istream. And keep the user interface unchanged. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15551 Reviewed By: zrphercule Differential Revision: D13568907 Pulled By: houseroad fbshipit-source-id: 93708cb801248a6c101f35cb14d1631029365c3c --- caffe2/serialize/CMakeLists.txt | 5 ++- caffe2/serialize/file_adapter.cc | 28 ++++++++++++ caffe2/serialize/file_adapter.h | 28 ++++++++++++ caffe2/serialize/inline_container.cc | 71 ++++++++++++++++++------------ caffe2/serialize/inline_container.h | 33 ++++++++------ caffe2/serialize/inline_container_test.cc | 10 +++-- caffe2/serialize/istream_adapter.cc | 39 ++++++++++++++++ caffe2/serialize/istream_adapter.h | 28 ++++++++++++ caffe2/serialize/read_adapter_interface.cc | 9 ++++ caffe2/serialize/read_adapter_interface.h | 21 +++++++++ torch/csrc/jit/export.cpp | 2 +- torch/csrc/jit/import.cpp | 2 +- torch/csrc/jit/init.cpp | 3 ++ 13 files changed, 229 insertions(+), 50 deletions(-) create mode 100644 caffe2/serialize/file_adapter.cc create mode 100644 caffe2/serialize/file_adapter.h create mode 100644 caffe2/serialize/istream_adapter.cc create mode 100644 caffe2/serialize/istream_adapter.h create mode 100644 caffe2/serialize/read_adapter_interface.cc create mode 100644 caffe2/serialize/read_adapter_interface.h diff --git a/caffe2/serialize/CMakeLists.txt b/caffe2/serialize/CMakeLists.txt index b0ed70d..bcda33c 100644 --- a/caffe2/serialize/CMakeLists.txt +++ b/caffe2/serialize/CMakeLists.txt @@ -3,7 +3,10 @@ file(GLOB tmp *_test.cc) set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} ${tmp}) list(APPEND Caffe2_CPU_SRCS ${PROJECT_SOURCE_DIR}/third_party/miniz-2.0.8/miniz.c - ${CMAKE_CURRENT_SOURCE_DIR}/inline_container.cc) + ${CMAKE_CURRENT_SOURCE_DIR}/inline_container.cc + ${CMAKE_CURRENT_SOURCE_DIR}/istream_adapter.cc + ${CMAKE_CURRENT_SOURCE_DIR}/file_adapter.cc + ${CMAKE_CURRENT_SOURCE_DIR}/read_adapter_interface.cc) list(APPEND Caffe2_CPU_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/miniz-2.0.8) set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE) diff --git a/caffe2/serialize/file_adapter.cc b/caffe2/serialize/file_adapter.cc new file mode 100644 index 0000000..dd30ce3 --- /dev/null +++ b/caffe2/serialize/file_adapter.cc @@ -0,0 +1,28 @@ +#include "caffe2/serialize/file_adapter.h" +#include +#include "caffe2/core/common.h" + +namespace caffe2 { +namespace serialize { + +FileAdapter::FileAdapter(const std::string& file_name) { + file_stream_.open(file_name, std::ifstream::in | std::ifstream::binary); + if (!file_stream_) { + AT_ERROR("open file failed, file path: ", file_name); + } + istream_adapter_ = caffe2::make_unique(&file_stream_); +} + +size_t FileAdapter::size() const { + return istream_adapter_->size(); +} + +size_t FileAdapter::read(uint64_t pos, void* buf, size_t n, const char* what) + const { + return istream_adapter_->read(pos, buf, n, what); +} + +FileAdapter::~FileAdapter() {} + +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/file_adapter.h b/caffe2/serialize/file_adapter.h new file mode 100644 index 0000000..cc05839 --- /dev/null +++ b/caffe2/serialize/file_adapter.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +#include +#include "caffe2/serialize/istream_adapter.h" +#include "caffe2/serialize/read_adapter_interface.h" + +namespace caffe2 { +namespace serialize { + +class FileAdapter final : public ReadAdapterInterface { + public: + C10_DISABLE_COPY_AND_ASSIGN(FileAdapter); + explicit FileAdapter(const std::string& file_name); + size_t size() const override; + size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") + const override; + ~FileAdapter(); + + private: + std::ifstream file_stream_; + std::unique_ptr istream_adapter_; +}; + +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index 667359d..9d1c426 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -8,12 +8,17 @@ #include #include +#include "caffe2/core/common.h" #include "caffe2/core/logging.h" +#include "caffe2/serialize/file_adapter.h" #include "caffe2/serialize/inline_container.h" +#include "caffe2/serialize/istream_adapter.h" +#include "caffe2/serialize/read_adapter_interface.h" #include "miniz.h" -namespace torch { namespace jit { +namespace caffe2 { +namespace serialize { size_t istream_read_func(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n) { auto self = static_cast(pOpaque); @@ -42,27 +47,33 @@ static std::string basename(const std::string& name) { } size_t PyTorchStreamReader::read(uint64_t pos, char* buf, size_t n) { - in_->seekg(pos); - if(!*in_) - return 0; - in_->read(static_cast(buf), n); - if(!*in_) - return 0; - return n; + return in_->read(pos, buf, n, "reading file"); } -PyTorchStreamReader::PyTorchStreamReader(std::string file_name, std::istream* in) -: ar_(new mz_zip_archive), in_(in) { - memset(ar_.get(), 0, sizeof(mz_zip_archive)); +PyTorchStreamReader::PyTorchStreamReader(const std::string& file_name) + : ar_(caffe2::make_unique()), + in_(caffe2::make_unique(file_name)) { + init(); +} - if (!in_) { - file_stream_.open(file_name, std::ifstream::in | std::ifstream::binary); - in_ = &file_stream_; - valid("opening archive"); - } +PyTorchStreamReader::PyTorchStreamReader(std::istream* in) + : ar_(caffe2::make_unique()), + in_(caffe2::make_unique(in)) { + init(); +} + +PyTorchStreamReader::PyTorchStreamReader( + std::unique_ptr in) + : ar_(caffe2::make_unique()), in_(std::move(in)) { + init(); +} - in_->seekg(0, in_->end); - size_t size = in_->tellg(); +void PyTorchStreamReader::init() { + AT_ASSERT(in_ != nullptr); + AT_ASSERT(ar_ != nullptr); + memset(ar_.get(), 0, sizeof(mz_zip_archive)); + + size_t size = in_->size(); // check for the old magic number, constexpr size_t kMagicValueLength = 8; @@ -81,7 +92,6 @@ PyTorchStreamReader::PyTorchStreamReader(std::string file_name, std::istream* in mz_zip_reader_init(ar_.get(), size, 0); valid("reading zip archive"); - // figure out the archive_name (i.e. the zip folder all the other files are in) // all lookups to getRecord will be prefixed by this folder int n = mz_zip_reader_get_num_files(ar_.get()); @@ -126,9 +136,6 @@ void PyTorchStreamReader::valid(const char* what) { if (err != MZ_ZIP_NO_ERROR) { CAFFE_THROW("PytorchStreamReader failed ", what, ": ", mz_zip_get_error_string(err)); } - if (!*in_) { - CAFFE_THROW("PytorchStreamReader failed ", what, "."); - } } constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30; @@ -191,11 +198,12 @@ size_t PyTorchStreamReader::getRecordOffset(const std::string& name) { mz_zip_archive_file_stat stat; mz_zip_reader_file_stat(ar_.get(), getFileID(name), &stat); valid("retriving file meta-data"); - in_->seekg(stat.m_local_header_ofs); - valid("seeking to file header"); uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE]; - in_->read(reinterpret_cast(local_header), MZ_ZIP_LOCAL_DIR_HEADER_SIZE); - valid("reading file header"); + in_->read( + stat.m_local_header_ofs, + local_header, + MZ_ZIP_LOCAL_DIR_HEADER_SIZE, + "reading file header"); size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS); size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS); return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len; @@ -226,8 +234,12 @@ size_t ostream_write_func(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, s return n; } -PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name, std::ostream* out) -: ar_(new mz_zip_archive), archive_name_(basename(file_name)), out_(out) { +PyTorchStreamWriter::PyTorchStreamWriter( + std::string file_name, + std::ostream* out) + : ar_(caffe2::make_unique()), + archive_name_(basename(file_name)), + out_(out) { memset(ar_.get(), 0, sizeof(mz_zip_archive)); if (archive_name_.size() == 0) { @@ -302,4 +314,5 @@ PyTorchStreamWriter::~PyTorchStreamWriter() { } } -}} // namespace torch::jit +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index 8df0c03..28f7149 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -11,6 +11,8 @@ #include #include "caffe2/core/logging.h" +#include "caffe2/serialize/istream_adapter.h" +#include "caffe2/serialize/read_adapter_interface.h" extern "C" { typedef struct mz_zip_archive mz_zip_archive; @@ -84,7 +86,8 @@ typedef struct mz_zip_archive mz_zip_archive; // model.json as the last file when writing after we have accumulated all // other information. -namespace torch { namespace jit { +namespace caffe2 { +namespace serialize { constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L; constexpr uint64_t kMaxSupportedFileFormatVersion = 0x1L; @@ -97,9 +100,9 @@ constexpr uint64_t kFieldAlignment = 64; class CAFFE2_API PyTorchStreamReader final { public: - PyTorchStreamReader(std::string archive_name, std::istream* in=nullptr); - PyTorchStreamReader(std::istream* in) - : PyTorchStreamReader("archive", in) {} + explicit PyTorchStreamReader(const std::string& file_name); + explicit PyTorchStreamReader(std::istream* in); + explicit PyTorchStreamReader(std::unique_ptr in); // return dataptr, size std::tuple getRecord(const std::string& name); @@ -109,15 +112,16 @@ class CAFFE2_API PyTorchStreamReader final { ~PyTorchStreamReader(); private: - size_t read(uint64_t pos, char* buf, size_t n); - void valid(const char* what); - size_t getFileID(const std::string& name); - - friend size_t istream_read_func(void *pOpaque, uint64_t file_ofs, void *pBuf, size_t n); - std::unique_ptr ar_; - std::string archive_name_; - std::istream* in_; - std::ifstream file_stream_; + void init(); + size_t read(uint64_t pos, char* buf, size_t n); + void valid(const char* what); + size_t getFileID(const std::string& name); + + friend size_t + istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n); + std::unique_ptr ar_; + std::string archive_name_; + std::unique_ptr in_; }; class CAFFE2_API PyTorchStreamWriter final { @@ -150,4 +154,5 @@ class CAFFE2_API PyTorchStreamWriter final { friend size_t ostream_write_func(void *pOpaque, uint64_t file_ofs, const void *pBuf, size_t n); }; -}} // namespace torch::jit +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/inline_container_test.cc b/caffe2/serialize/inline_container_test.cc index 341e6f9..6b4c969 100644 --- a/caffe2/serialize/inline_container_test.cc +++ b/caffe2/serialize/inline_container_test.cc @@ -6,7 +6,8 @@ #include "caffe2/serialize/inline_container.h" -namespace at { +namespace caffe2 { +namespace serialize { namespace { TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { @@ -14,7 +15,7 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { std::ostringstream oss; // write records through writers - torch::jit::PyTorchStreamWriter writer(&oss); + PyTorchStreamWriter writer(&oss); std::array data1; for (int i = 0; i < data1.size(); ++i) { @@ -37,7 +38,7 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { std::istringstream iss(the_file); // read records through readers - torch::jit::PyTorchStreamReader reader(&iss); + PyTorchStreamReader reader(&iss); at::DataPtr data_ptr; int64_t size; std::tie(data_ptr, size) = reader.getRecord("key1"); @@ -58,4 +59,5 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { } } // namespace -} // namespace at +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/istream_adapter.cc b/caffe2/serialize/istream_adapter.cc new file mode 100644 index 0000000..8561253 --- /dev/null +++ b/caffe2/serialize/istream_adapter.cc @@ -0,0 +1,39 @@ +#include "caffe2/serialize/istream_adapter.h" +#include + +namespace caffe2 { +namespace serialize { + +IStreamAdapter::IStreamAdapter(std::istream* istream) : istream_(istream) {} + +size_t IStreamAdapter::size() const { + auto prev_pos = istream_->tellg(); + validate("getting the current position"); + istream_->seekg(0, istream_->end); + validate("seeking to end"); + auto result = istream_->tellg(); + validate("getting size"); + istream_->seekg(prev_pos); + validate("seeking to the original position"); + return result; +} + +size_t IStreamAdapter::read(uint64_t pos, void* buf, size_t n, const char* what) + const { + istream_->seekg(pos); + validate(what); + istream_->read(static_cast(buf), n); + validate(what); + return n; +} + +void IStreamAdapter::validate(const char* what) const { + if (!*istream_) { + AT_ERROR("istream reader failed: ", what, "."); + } +} + +IStreamAdapter::~IStreamAdapter() {} + +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/istream_adapter.h b/caffe2/serialize/istream_adapter.h new file mode 100644 index 0000000..4d597e1 --- /dev/null +++ b/caffe2/serialize/istream_adapter.h @@ -0,0 +1,28 @@ +#pragma once + +#include + +#include + +#include "caffe2/serialize/read_adapter_interface.h" + +namespace caffe2 { +namespace serialize { + +// this is a reader implemented by std::istream +class IStreamAdapter final : public ReadAdapterInterface { + public: + C10_DISABLE_COPY_AND_ASSIGN(IStreamAdapter); + explicit IStreamAdapter(std::istream* istream); + size_t size() const override; + size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") + const override; + ~IStreamAdapter(); + + private: + std::istream* istream_; + void validate(const char* what) const; +}; + +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/read_adapter_interface.cc b/caffe2/serialize/read_adapter_interface.cc new file mode 100644 index 0000000..739b24f --- /dev/null +++ b/caffe2/serialize/read_adapter_interface.cc @@ -0,0 +1,9 @@ +#include "caffe2/serialize/read_adapter_interface.h" + +namespace caffe2 { +namespace serialize { + +ReadAdapterInterface::~ReadAdapterInterface() {} + +} // namespace serialize +} // namespace caffe2 diff --git a/caffe2/serialize/read_adapter_interface.h b/caffe2/serialize/read_adapter_interface.h new file mode 100644 index 0000000..e153f81 --- /dev/null +++ b/caffe2/serialize/read_adapter_interface.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include + +namespace caffe2 { +namespace serialize { + +// this is the interface for the (file/stream/memory) reader in +// PyTorchStreamReader. with this interface, we can extend the support +// besides standard istream +class ReadAdapterInterface { + public: + virtual size_t size() const = 0; + virtual size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") + const = 0; + virtual ~ReadAdapterInterface(); +}; + +} // namespace serialize +} // namespace caffe2 diff --git a/torch/csrc/jit/export.cpp b/torch/csrc/jit/export.cpp index d6cf849..44be035 100644 --- a/torch/csrc/jit/export.cpp +++ b/torch/csrc/jit/export.cpp @@ -497,7 +497,7 @@ class ScriptModuleSerializer final { torch::ParameterDef* param_def); std::ofstream ofs_; - PyTorchStreamWriter writer_; + caffe2::serialize::PyTorchStreamWriter writer_; // all tensors that will be stored std::vector tensor_table_; diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp index 56c5859..437236c 100644 --- a/torch/csrc/jit/import.cpp +++ b/torch/csrc/jit/import.cpp @@ -50,7 +50,7 @@ class ScriptModuleDeserializer final { void loadTensorTable(torch::ModelDef* model_def); - PyTorchStreamReader reader_; + caffe2::serialize::PyTorchStreamReader reader_; // this is a hack to make sure the script module created in C++ is the // same as created in Python ModuleLookup moduleLookup_; diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 38dbe69..f9b4e88 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -55,6 +55,9 @@ namespace torch { namespace jit { +using caffe2::serialize::PyTorchStreamReader; +using caffe2::serialize::PyTorchStreamWriter; + // TODO: make a fake future for python namespace detail { class Future {}; -- 2.7.4