set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} ${tmp})
list(APPEND Caffe2_CPU_SRCS
${PROJECT_SOURCE_DIR}/third_party/miniz-2.0.8/miniz.c
- ${CMAKE_CURRENT_SOURCE_DIR}/inline_container.cc)
+ ${CMAKE_CURRENT_SOURCE_DIR}/inline_container.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/istream_adapter.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/file_adapter.cc
+ ${CMAKE_CURRENT_SOURCE_DIR}/read_adapter_interface.cc)
list(APPEND Caffe2_CPU_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/miniz-2.0.8)
set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE)
--- /dev/null
+#include "caffe2/serialize/file_adapter.h"
+#include <c10/util/Exception.h>
+#include "caffe2/core/common.h"
+
+namespace caffe2 {
+namespace serialize {
+
+FileAdapter::FileAdapter(const std::string& file_name) {
+ file_stream_.open(file_name, std::ifstream::in | std::ifstream::binary);
+ if (!file_stream_) {
+ AT_ERROR("open file failed, file path: ", file_name);
+ }
+ istream_adapter_ = caffe2::make_unique<IStreamAdapter>(&file_stream_);
+}
+
+size_t FileAdapter::size() const {
+ return istream_adapter_->size();
+}
+
+size_t FileAdapter::read(uint64_t pos, void* buf, size_t n, const char* what)
+ const {
+ return istream_adapter_->read(pos, buf, n, what);
+}
+
+FileAdapter::~FileAdapter() {}
+
+} // namespace serialize
+} // namespace caffe2
--- /dev/null
+#pragma once
+
+#include <fstream>
+#include <memory>
+
+#include <c10/macros/Macros.h>
+#include "caffe2/serialize/istream_adapter.h"
+#include "caffe2/serialize/read_adapter_interface.h"
+
+namespace caffe2 {
+namespace serialize {
+
+class FileAdapter final : public ReadAdapterInterface {
+ public:
+ C10_DISABLE_COPY_AND_ASSIGN(FileAdapter);
+ explicit FileAdapter(const std::string& file_name);
+ size_t size() const override;
+ size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
+ const override;
+ ~FileAdapter();
+
+ private:
+ std::ifstream file_stream_;
+ std::unique_ptr<IStreamAdapter> istream_adapter_;
+};
+
+} // namespace serialize
+} // namespace caffe2
#include <c10/core/Allocator.h>
#include <c10/core/Backend.h>
+#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
+#include "caffe2/serialize/file_adapter.h"
#include "caffe2/serialize/inline_container.h"
+#include "caffe2/serialize/istream_adapter.h"
+#include "caffe2/serialize/read_adapter_interface.h"
#include "miniz.h"
-namespace torch { namespace jit {
+namespace caffe2 {
+namespace serialize {
size_t istream_read_func(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n) {
auto self = static_cast<PyTorchStreamReader*>(pOpaque);
}
size_t PyTorchStreamReader::read(uint64_t pos, char* buf, size_t n) {
- in_->seekg(pos);
- if(!*in_)
- return 0;
- in_->read(static_cast<char*>(buf), n);
- if(!*in_)
- return 0;
- return n;
+ return in_->read(pos, buf, n, "reading file");
}
-PyTorchStreamReader::PyTorchStreamReader(std::string file_name, std::istream* in)
-: ar_(new mz_zip_archive), in_(in) {
- memset(ar_.get(), 0, sizeof(mz_zip_archive));
+PyTorchStreamReader::PyTorchStreamReader(const std::string& file_name)
+ : ar_(caffe2::make_unique<mz_zip_archive>()),
+ in_(caffe2::make_unique<FileAdapter>(file_name)) {
+ init();
+}
- if (!in_) {
- file_stream_.open(file_name, std::ifstream::in | std::ifstream::binary);
- in_ = &file_stream_;
- valid("opening archive");
- }
+PyTorchStreamReader::PyTorchStreamReader(std::istream* in)
+ : ar_(caffe2::make_unique<mz_zip_archive>()),
+ in_(caffe2::make_unique<IStreamAdapter>(in)) {
+ init();
+}
+
+PyTorchStreamReader::PyTorchStreamReader(
+ std::unique_ptr<ReadAdapterInterface> in)
+ : ar_(caffe2::make_unique<mz_zip_archive>()), in_(std::move(in)) {
+ init();
+}
- in_->seekg(0, in_->end);
- size_t size = in_->tellg();
+void PyTorchStreamReader::init() {
+ AT_ASSERT(in_ != nullptr);
+ AT_ASSERT(ar_ != nullptr);
+ memset(ar_.get(), 0, sizeof(mz_zip_archive));
+
+ size_t size = in_->size();
// check for the old magic number,
constexpr size_t kMagicValueLength = 8;
mz_zip_reader_init(ar_.get(), size, 0);
valid("reading zip archive");
-
// figure out the archive_name (i.e. the zip folder all the other files are in)
// all lookups to getRecord will be prefixed by this folder
int n = mz_zip_reader_get_num_files(ar_.get());
if (err != MZ_ZIP_NO_ERROR) {
CAFFE_THROW("PytorchStreamReader failed ", what, ": ", mz_zip_get_error_string(err));
}
- if (!*in_) {
- CAFFE_THROW("PytorchStreamReader failed ", what, ".");
- }
}
constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
mz_zip_archive_file_stat stat;
mz_zip_reader_file_stat(ar_.get(), getFileID(name), &stat);
valid("retriving file meta-data");
- in_->seekg(stat.m_local_header_ofs);
- valid("seeking to file header");
uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
- in_->read(reinterpret_cast<char*>(local_header), MZ_ZIP_LOCAL_DIR_HEADER_SIZE);
- valid("reading file header");
+ in_->read(
+ stat.m_local_header_ofs,
+ local_header,
+ MZ_ZIP_LOCAL_DIR_HEADER_SIZE,
+ "reading file header");
size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
return n;
}
-PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name, std::ostream* out)
-: ar_(new mz_zip_archive), archive_name_(basename(file_name)), out_(out) {
+PyTorchStreamWriter::PyTorchStreamWriter(
+ std::string file_name,
+ std::ostream* out)
+ : ar_(caffe2::make_unique<mz_zip_archive>()),
+ archive_name_(basename(file_name)),
+ out_(out) {
memset(ar_.get(), 0, sizeof(mz_zip_archive));
if (archive_name_.size() == 0) {
}
}
-}} // namespace torch::jit
+} // namespace serialize
+} // namespace caffe2
#include <c10/core/Backend.h>
#include "caffe2/core/logging.h"
+#include "caffe2/serialize/istream_adapter.h"
+#include "caffe2/serialize/read_adapter_interface.h"
extern "C" {
typedef struct mz_zip_archive mz_zip_archive;
// model.json as the last file when writing after we have accumulated all
// other information.
-namespace torch { namespace jit {
+namespace caffe2 {
+namespace serialize {
constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x1L;
class CAFFE2_API PyTorchStreamReader final {
public:
- PyTorchStreamReader(std::string archive_name, std::istream* in=nullptr);
- PyTorchStreamReader(std::istream* in)
- : PyTorchStreamReader("archive", in) {}
+ explicit PyTorchStreamReader(const std::string& file_name);
+ explicit PyTorchStreamReader(std::istream* in);
+ explicit PyTorchStreamReader(std::unique_ptr<ReadAdapterInterface> in);
// return dataptr, size
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
~PyTorchStreamReader();
private:
- size_t read(uint64_t pos, char* buf, size_t n);
- void valid(const char* what);
- size_t getFileID(const std::string& name);
-
- friend size_t istream_read_func(void *pOpaque, uint64_t file_ofs, void *pBuf, size_t n);
- std::unique_ptr<mz_zip_archive> ar_;
- std::string archive_name_;
- std::istream* in_;
- std::ifstream file_stream_;
+ void init();
+ size_t read(uint64_t pos, char* buf, size_t n);
+ void valid(const char* what);
+ size_t getFileID(const std::string& name);
+
+ friend size_t
+ istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n);
+ std::unique_ptr<mz_zip_archive> ar_;
+ std::string archive_name_;
+ std::unique_ptr<ReadAdapterInterface> in_;
};
class CAFFE2_API PyTorchStreamWriter final {
friend size_t ostream_write_func(void *pOpaque, uint64_t file_ofs, const void *pBuf, size_t n);
};
-}} // namespace torch::jit
+} // namespace serialize
+} // namespace caffe2
#include "caffe2/serialize/inline_container.h"
-namespace at {
+namespace caffe2 {
+namespace serialize {
namespace {
TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
std::ostringstream oss;
// write records through writers
- torch::jit::PyTorchStreamWriter writer(&oss);
+ PyTorchStreamWriter writer(&oss);
std::array<char, 127> data1;
for (int i = 0; i < data1.size(); ++i) {
std::istringstream iss(the_file);
// read records through readers
- torch::jit::PyTorchStreamReader reader(&iss);
+ PyTorchStreamReader reader(&iss);
at::DataPtr data_ptr;
int64_t size;
std::tie(data_ptr, size) = reader.getRecord("key1");
}
} // namespace
-} // namespace at
+} // namespace serialize
+} // namespace caffe2
--- /dev/null
+#include "caffe2/serialize/istream_adapter.h"
+#include <c10/util/Exception.h>
+
+namespace caffe2 {
+namespace serialize {
+
+IStreamAdapter::IStreamAdapter(std::istream* istream) : istream_(istream) {}
+
+size_t IStreamAdapter::size() const {
+ auto prev_pos = istream_->tellg();
+ validate("getting the current position");
+ istream_->seekg(0, istream_->end);
+ validate("seeking to end");
+ auto result = istream_->tellg();
+ validate("getting size");
+ istream_->seekg(prev_pos);
+ validate("seeking to the original position");
+ return result;
+}
+
+size_t IStreamAdapter::read(uint64_t pos, void* buf, size_t n, const char* what)
+ const {
+ istream_->seekg(pos);
+ validate(what);
+ istream_->read(static_cast<char*>(buf), n);
+ validate(what);
+ return n;
+}
+
+void IStreamAdapter::validate(const char* what) const {
+ if (!*istream_) {
+ AT_ERROR("istream reader failed: ", what, ".");
+ }
+}
+
+IStreamAdapter::~IStreamAdapter() {}
+
+} // namespace serialize
+} // namespace caffe2
--- /dev/null
+#pragma once
+
+#include <istream>
+
+#include <c10/macros/Macros.h>
+
+#include "caffe2/serialize/read_adapter_interface.h"
+
+namespace caffe2 {
+namespace serialize {
+
+// this is a reader implemented by std::istream
+class IStreamAdapter final : public ReadAdapterInterface {
+ public:
+ C10_DISABLE_COPY_AND_ASSIGN(IStreamAdapter);
+ explicit IStreamAdapter(std::istream* istream);
+ size_t size() const override;
+ size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
+ const override;
+ ~IStreamAdapter();
+
+ private:
+ std::istream* istream_;
+ void validate(const char* what) const;
+};
+
+} // namespace serialize
+} // namespace caffe2
--- /dev/null
+#include "caffe2/serialize/read_adapter_interface.h"
+
+namespace caffe2 {
+namespace serialize {
+
+ReadAdapterInterface::~ReadAdapterInterface() {}
+
+} // namespace serialize
+} // namespace caffe2
--- /dev/null
+#pragma once
+
+#include <cstddef>
+#include <cstdint>
+
+namespace caffe2 {
+namespace serialize {
+
+// this is the interface for the (file/stream/memory) reader in
+// PyTorchStreamReader. with this interface, we can extend the support
+// besides standard istream
+class ReadAdapterInterface {
+ public:
+ virtual size_t size() const = 0;
+ virtual size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
+ const = 0;
+ virtual ~ReadAdapterInterface();
+};
+
+} // namespace serialize
+} // namespace caffe2
torch::ParameterDef* param_def);
std::ofstream ofs_;
- PyTorchStreamWriter writer_;
+ caffe2::serialize::PyTorchStreamWriter writer_;
// all tensors that will be stored
std::vector<at::Tensor> tensor_table_;
void loadTensorTable(torch::ModelDef* model_def);
- PyTorchStreamReader reader_;
+ caffe2::serialize::PyTorchStreamReader reader_;
// this is a hack to make sure the script module created in C++ is the
// same as created in Python
ModuleLookup moduleLookup_;
namespace torch {
namespace jit {
+using caffe2::serialize::PyTorchStreamReader;
+using caffe2::serialize::PyTorchStreamWriter;
+
// TODO: make a fake future for python
namespace detail {
class Future {};