#include <torch/csrc/jit/operator.h>
#include <torch/csrc/utils/functional.h>
-#include <caffe2/core/types.h>
-#include <caffe2/proto/caffe2_pb.h>
-#include <caffe2/proto/torch_pb.h>
-#include <caffe2/serialize/inline_container.h>
+#include "caffe2/core/common.h"
+#include "caffe2/core/types.h"
+#include "caffe2/proto/caffe2_pb.h"
+#include "caffe2/proto/torch_pb.h"
+#include "caffe2/serialize/file_adapter.h"
+#include "caffe2/serialize/inline_container.h"
+#include "caffe2/serialize/istream_adapter.h"
#include <ATen/ATen.h>
namespace torch {
namespace jit {
+using caffe2::serialize::ReadAdapterInterface;
+using caffe2::serialize::IStreamAdapter;
+using caffe2::serialize::FileAdapter;
+
namespace {
// this is a deserializer class which loads script modules from pt files. the
class ScriptModuleDeserializer final {
public:
ScriptModuleDeserializer(const std::string& filename);
-
ScriptModuleDeserializer(std::istream* is);
-
+ explicit ScriptModuleDeserializer(std::unique_ptr<ReadAdapterInterface> rai);
void deserialize(
ModuleLookup module_lookup,
c10::optional<at::Device> device);
ScriptModuleDeserializer::ScriptModuleDeserializer(std::istream* is)
: reader_(is) {}
+ScriptModuleDeserializer::ScriptModuleDeserializer(std::unique_ptr<ReadAdapterInterface> rai)
+ : reader_(std::move(rai)) {}
+
void ScriptModuleDeserializer::deserialize(
ModuleLookup module_lookup,
c10::optional<at::Device> device) {
deserializer.deserialize(module_lookup, device);
}
+void import_ir_module(
+ ModuleLookup module_lookup,
+ std::unique_ptr<ReadAdapterInterface> rai,
+ c10::optional<at::Device> device) {
+ ScriptModuleDeserializer deserializer(std::move(rai));
+ deserializer.deserialize(module_lookup, device);
+}
+
std::shared_ptr<script::Module> load(
std::istream& in,
c10::optional<at::Device> device) {
+ std::unique_ptr<IStreamAdapter> rai =
+ caffe2::make_unique<IStreamAdapter>(&in);
+ auto module = load(std::move(rai), device);
+ return module;
+}
+
+std::shared_ptr<script::Module> load(
+ const std::string& filename,
+ c10::optional<at::Device> device) {
+ std::unique_ptr<FileAdapter> rai = caffe2::make_unique<FileAdapter>(filename);
+ auto module = load(std::move(rai), device);
+ return module;
+}
+
+std::shared_ptr<script::Module> load(
+ std::unique_ptr<ReadAdapterInterface> rai,
+ c10::optional<c10::Device> device) {
auto module = std::make_shared<script::Module>();
auto module_lookup = [&](const std::vector<std::string>& qualified_name) {
return curr;
};
- ScriptModuleDeserializer deserializer(&in);
+ ScriptModuleDeserializer deserializer(std::move(rai));
deserializer.deserialize(module_lookup, device);
return module;
}
-std::shared_ptr<script::Module> load(
- const std::string& filename,
- c10::optional<at::Device> device) {
- std::ifstream in(filename, std::ios_base::binary);
-
- AT_CHECK(!in.fail(), "load: could not open file ", filename);
-
- auto module = load(in, device);
-
- return module;
-}
-
} // namespace jit
} // namespace torch
#include <istream>
+namespace caffe2 {
+namespace serialize {
+class ReadAdapterInterface;
+} // namespace serialize
+} // namespace caffe2
+
namespace torch {
namespace jit {
std::istream& in,
c10::optional<c10::Device> device = c10::nullopt);
+TORCH_API void import_ir_module(
+ ModuleLookup module_lookup,
+ std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
+ c10::optional<c10::Device> device = c10::nullopt);
+
/// Loads a serialized `script::Module` from the given `istream`.
///
/// The istream must contain a serialized `script::Module`, exported via
const std::string& filename,
c10::optional<c10::Device> device = c10::nullopt);
+/// Loads a serialized `script::Module` from the given `rai`.
+///
+/// The reader adapter, which is for customized input stream, must contain a
+/// serialized `script::Module`, exported either via `ScriptModule.save()` in
+/// Python or `torch::jit::ExportModule` in C++.
+TORCH_API std::shared_ptr<script::Module> load(
+ std::unique_ptr<caffe2::serialize::ReadAdapterInterface> rai,
+ c10::optional<c10::Device> device = c10::nullopt);
+
+
} // namespace jit
} // namespace torch