Add PyTorchPredictorContainer (#15899)
authorLu Fang <lufang@fb.com>
Tue, 15 Jan 2019 17:13:16 +0000 (09:13 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 15 Jan 2019 17:18:18 +0000 (09:18 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15899

Add PyTorchPredictorContainer to support multiple jit script modules

Reviewed By: pritamdamania87

Differential Revision: D13596139

fbshipit-source-id: 3ce0bdf2f4dbba7aa1d20e824d03e5ac98f5d887

caffe2/serialize/file_adapter.h
caffe2/serialize/istream_adapter.h
caffe2/serialize/read_adapter_interface.h
torch/csrc/jit/import.cpp
torch/csrc/jit/import.h

index cc05839..416208e 100644 (file)
@@ -3,14 +3,14 @@
 #include <fstream>
 #include <memory>
 
-#include <c10/macros/Macros.h>
+#include "c10/macros/Macros.h"
 #include "caffe2/serialize/istream_adapter.h"
 #include "caffe2/serialize/read_adapter_interface.h"
 
 namespace caffe2 {
 namespace serialize {
 
-class FileAdapter final : public ReadAdapterInterface {
+class CAFFE2_API FileAdapter final : public ReadAdapterInterface {
  public:
   C10_DISABLE_COPY_AND_ASSIGN(FileAdapter);
   explicit FileAdapter(const std::string& file_name);
index 4d597e1..b7a0444 100644 (file)
@@ -2,15 +2,14 @@
 
 #include <istream>
 
-#include <c10/macros/Macros.h>
-
+#include "c10/macros/Macros.h"
 #include "caffe2/serialize/read_adapter_interface.h"
 
 namespace caffe2 {
 namespace serialize {
 
 // this is a reader implemented by std::istream
-class IStreamAdapter final : public ReadAdapterInterface {
+class CAFFE2_API IStreamAdapter final : public ReadAdapterInterface {
  public:
   C10_DISABLE_COPY_AND_ASSIGN(IStreamAdapter);
   explicit IStreamAdapter(std::istream* istream);
index e153f81..556c005 100644 (file)
@@ -3,13 +3,15 @@
 #include <cstddef>
 #include <cstdint>
 
+#include "c10/macros/Macros.h"
+
 namespace caffe2 {
 namespace serialize {
 
 // this is the interface for the (file/stream/memory) reader in
 // PyTorchStreamReader. with this interface, we can extend the support
 // besides standard istream
-class ReadAdapterInterface {
+class CAFFE2_API ReadAdapterInterface {
  public:
   virtual size_t size() const = 0;
   virtual size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
index 437236c..5d129f2 100644 (file)
@@ -8,10 +8,13 @@
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/utils/functional.h>
 
-#include <caffe2/core/types.h>
-#include <caffe2/proto/caffe2_pb.h>
-#include <caffe2/proto/torch_pb.h>
-#include <caffe2/serialize/inline_container.h>
+#include "caffe2/core/common.h"
+#include "caffe2/core/types.h"
+#include "caffe2/proto/caffe2_pb.h"
+#include "caffe2/proto/torch_pb.h"
+#include "caffe2/serialize/file_adapter.h"
+#include "caffe2/serialize/inline_container.h"
+#include "caffe2/serialize/istream_adapter.h"
 
 #include <ATen/ATen.h>
 
 namespace torch {
 namespace jit {
 
+using caffe2::serialize::ReadAdapterInterface;
+using caffe2::serialize::IStreamAdapter;
+using caffe2::serialize::FileAdapter;
+
 namespace {
 
 // this is a deserializer class which loads script modules from pt files. the
@@ -34,9 +41,8 @@ namespace {
 class ScriptModuleDeserializer final {
  public:
   ScriptModuleDeserializer(const std::string& filename);
-
   ScriptModuleDeserializer(std::istream* is);
-
+  explicit ScriptModuleDeserializer(std::unique_ptr<ReadAdapterInterface> rai);
   void deserialize(
       ModuleLookup module_lookup,
       c10::optional<at::Device> device);
@@ -68,6 +74,9 @@ ScriptModuleDeserializer::ScriptModuleDeserializer(const std::string& filename)
 ScriptModuleDeserializer::ScriptModuleDeserializer(std::istream* is)
     : reader_(is) {}
 
+ScriptModuleDeserializer::ScriptModuleDeserializer(std::unique_ptr<ReadAdapterInterface> rai)
+    : reader_(std::move(rai)) {}
+
 void ScriptModuleDeserializer::deserialize(
     ModuleLookup module_lookup,
     c10::optional<at::Device> device) {
@@ -229,9 +238,34 @@ void import_ir_module(
   deserializer.deserialize(module_lookup, device);
 }
 
+void import_ir_module(
+    ModuleLookup module_lookup,
+    std::unique_ptr<ReadAdapterInterface> rai,
+    c10::optional<at::Device> device) {
+  ScriptModuleDeserializer deserializer(std::move(rai));
+  deserializer.deserialize(module_lookup, device);
+}
+
 std::shared_ptr<script::Module> load(
     std::istream& in,
     c10::optional<at::Device> device) {
+  std::unique_ptr<IStreamAdapter> rai =
+    caffe2::make_unique<IStreamAdapter>(&in);
+  auto module = load(std::move(rai), device);
+  return module;
+}
+
+std::shared_ptr<script::Module> load(
+    const std::string& filename,
+    c10::optional<at::Device> device) {
+  std::unique_ptr<FileAdapter> rai = caffe2::make_unique<FileAdapter>(filename);
+  auto module = load(std::move(rai), device);
+  return module;
+}
+
+std::shared_ptr<script::Module> load(
+    std::unique_ptr<ReadAdapterInterface> rai,
+    c10::optional<c10::Device> device) {
   auto module = std::make_shared<script::Module>();
 
   auto module_lookup = [&](const std::vector<std::string>& qualified_name) {
@@ -245,23 +279,11 @@ std::shared_ptr<script::Module> load(
     return curr;
   };
 
-  ScriptModuleDeserializer deserializer(&in);
+  ScriptModuleDeserializer deserializer(std::move(rai));
   deserializer.deserialize(module_lookup, device);
 
   return module;
 }
 
-std::shared_ptr<script::Module> load(
-    const std::string& filename,
-    c10::optional<at::Device> device) {
-  std::ifstream in(filename, std::ios_base::binary);
-
-  AT_CHECK(!in.fail(), "load: could not open file ", filename);
-
-  auto module = load(in, device);
-
-  return module;
-}
-
 } // namespace jit
 } // namespace torch
index 2252ba4..a765726 100644 (file)
@@ -5,6 +5,12 @@
 
 #include <istream>
 
+namespace caffe2 {
+namespace serialize {
+class ReadAdapterInterface;
+} // namespace serialize
+} // namespace caffe2
+
 namespace torch {
 namespace jit {
 
@@ -21,6 +27,11 @@ TORCH_API void import_ir_module(
     std::istream& in,
     c10::optional<c10::Device> device = c10::nullopt);
 
+TORCH_API void import_ir_module(
+    ModuleLookup module_lookup,
+    std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
+    c10::optional<c10::Device> device = c10::nullopt);
+
 /// Loads a serialized `script::Module` from the given `istream`.
 ///
 /// The istream must contain a serialized `script::Module`, exported via
@@ -38,5 +49,15 @@ TORCH_API std::shared_ptr<script::Module> load(
     const std::string& filename,
     c10::optional<c10::Device> device = c10::nullopt);
 
+/// Loads a serialized `script::Module` from the given `rai`.
+///
+/// The reader adapter, which is for customized input stream, must contain a
+/// serialized `script::Module`, exported either via `ScriptModule.save()` in
+/// Python or `torch::jit::ExportModule` in C++.
+TORCH_API std::shared_ptr<script::Module> load(
+    std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
+    c10::optional<c10::Device> device = c10::nullopt);
+
+
 } // namespace jit
 } // namespace torch