From b329e03684b2c965b96e6e0ae5e60e340b1593c4 Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Tue, 15 Jan 2019 09:13:16 -0800 Subject: [PATCH] Add PyTorchPredictorContainer (#15899) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15899 Add PyTorchPredictorContainer to support multiple jit script modules Reviewed By: pritamdamania87 Differential Revision: D13596139 fbshipit-source-id: 3ce0bdf2f4dbba7aa1d20e824d03e5ac98f5d887 --- caffe2/serialize/file_adapter.h | 4 +-- caffe2/serialize/istream_adapter.h | 5 ++- caffe2/serialize/read_adapter_interface.h | 4 ++- torch/csrc/jit/import.cpp | 60 +++++++++++++++++++++---------- torch/csrc/jit/import.h | 21 +++++++++++ 5 files changed, 69 insertions(+), 25 deletions(-) diff --git a/caffe2/serialize/file_adapter.h b/caffe2/serialize/file_adapter.h index cc05839..416208e 100644 --- a/caffe2/serialize/file_adapter.h +++ b/caffe2/serialize/file_adapter.h @@ -3,14 +3,14 @@ #include #include -#include +#include "c10/macros/Macros.h" #include "caffe2/serialize/istream_adapter.h" #include "caffe2/serialize/read_adapter_interface.h" namespace caffe2 { namespace serialize { -class FileAdapter final : public ReadAdapterInterface { +class CAFFE2_API FileAdapter final : public ReadAdapterInterface { public: C10_DISABLE_COPY_AND_ASSIGN(FileAdapter); explicit FileAdapter(const std::string& file_name); diff --git a/caffe2/serialize/istream_adapter.h b/caffe2/serialize/istream_adapter.h index 4d597e1..b7a0444 100644 --- a/caffe2/serialize/istream_adapter.h +++ b/caffe2/serialize/istream_adapter.h @@ -2,15 +2,14 @@ #include -#include - +#include "c10/macros/Macros.h" #include "caffe2/serialize/read_adapter_interface.h" namespace caffe2 { namespace serialize { // this is a reader implemented by std::istream -class IStreamAdapter final : public ReadAdapterInterface { +class CAFFE2_API IStreamAdapter final : public ReadAdapterInterface { public: C10_DISABLE_COPY_AND_ASSIGN(IStreamAdapter); explicit IStreamAdapter(std::istream* istream); diff --git a/caffe2/serialize/read_adapter_interface.h b/caffe2/serialize/read_adapter_interface.h index e153f81..556c005 100644 --- a/caffe2/serialize/read_adapter_interface.h +++ b/caffe2/serialize/read_adapter_interface.h @@ -3,13 +3,15 @@ #include #include +#include "c10/macros/Macros.h" + namespace caffe2 { namespace serialize { // this is the interface for the (file/stream/memory) reader in // PyTorchStreamReader. with this interface, we can extend the support // besides standard istream -class ReadAdapterInterface { +class CAFFE2_API ReadAdapterInterface { public: virtual size_t size() const = 0; virtual size_t read(uint64_t pos, void* buf, size_t n, const char* what = "") diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp index 437236c..5d129f2 100644 --- a/torch/csrc/jit/import.cpp +++ b/torch/csrc/jit/import.cpp @@ -8,10 +8,13 @@ #include #include -#include -#include -#include -#include +#include "caffe2/core/common.h" +#include "caffe2/core/types.h" +#include "caffe2/proto/caffe2_pb.h" +#include "caffe2/proto/torch_pb.h" +#include "caffe2/serialize/file_adapter.h" +#include "caffe2/serialize/inline_container.h" +#include "caffe2/serialize/istream_adapter.h" #include @@ -23,6 +26,10 @@ namespace torch { namespace jit { +using caffe2::serialize::ReadAdapterInterface; +using caffe2::serialize::IStreamAdapter; +using caffe2::serialize::FileAdapter; + namespace { // this is a deserializer class which loads script modules from pt files. the @@ -34,9 +41,8 @@ namespace { class ScriptModuleDeserializer final { public: ScriptModuleDeserializer(const std::string& filename); - ScriptModuleDeserializer(std::istream* is); - + explicit ScriptModuleDeserializer(std::unique_ptr rai); void deserialize( ModuleLookup module_lookup, c10::optional device); @@ -68,6 +74,9 @@ ScriptModuleDeserializer::ScriptModuleDeserializer(const std::string& filename) ScriptModuleDeserializer::ScriptModuleDeserializer(std::istream* is) : reader_(is) {} +ScriptModuleDeserializer::ScriptModuleDeserializer(std::unique_ptr rai) + : reader_(std::move(rai)) {} + void ScriptModuleDeserializer::deserialize( ModuleLookup module_lookup, c10::optional device) { @@ -229,9 +238,34 @@ void import_ir_module( deserializer.deserialize(module_lookup, device); } +void import_ir_module( + ModuleLookup module_lookup, + std::unique_ptr rai, + c10::optional device) { + ScriptModuleDeserializer deserializer(std::move(rai)); + deserializer.deserialize(module_lookup, device); +} + std::shared_ptr load( std::istream& in, c10::optional device) { + std::unique_ptr rai = + caffe2::make_unique(&in); + auto module = load(std::move(rai), device); + return module; +} + +std::shared_ptr load( + const std::string& filename, + c10::optional device) { + std::unique_ptr rai = caffe2::make_unique(filename); + auto module = load(std::move(rai), device); + return module; +} + +std::shared_ptr load( + std::unique_ptr rai, + c10::optional device) { auto module = std::make_shared(); auto module_lookup = [&](const std::vector& qualified_name) { @@ -245,23 +279,11 @@ std::shared_ptr load( return curr; }; - ScriptModuleDeserializer deserializer(&in); + ScriptModuleDeserializer deserializer(std::move(rai)); deserializer.deserialize(module_lookup, device); return module; } -std::shared_ptr load( - const std::string& filename, - c10::optional device) { - std::ifstream in(filename, std::ios_base::binary); - - AT_CHECK(!in.fail(), "load: could not open file ", filename); - - auto module = load(in, device); - - return module; -} - } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/import.h b/torch/csrc/jit/import.h index 2252ba4..a765726 100644 --- a/torch/csrc/jit/import.h +++ b/torch/csrc/jit/import.h @@ -5,6 +5,12 @@ #include +namespace caffe2 { +namespace serialize { +class ReadAdapterInterface; +} // namespace serialize +} // namespace caffe2 + namespace torch { namespace jit { @@ -21,6 +27,11 @@ TORCH_API void import_ir_module( std::istream& in, c10::optional device = c10::nullopt); +TORCH_API void import_ir_module( + ModuleLookup module_lookup, + std::unique_ptr rai, + c10::optional device = c10::nullopt); + /// Loads a serialized `script::Module` from the given `istream`. /// /// The istream must contain a serialized `script::Module`, exported via @@ -38,5 +49,15 @@ TORCH_API std::shared_ptr load( const std::string& filename, c10::optional device = c10::nullopt); +/// Loads a serialized `script::Module` from the given `rai`. +/// +/// The reader adapter, which is for customized input stream, must contain a +/// serialized `script::Module`, exported either via `ScriptModule.save()` in +/// Python or `torch::jit::ExportModule` in C++. +TORCH_API std::shared_ptr load( + std::unique_ptr rai, + c10::optional device = c10::nullopt); + + } // namespace jit } // namespace torch -- 2.7.4