TEST(SerializeTest, XOR_CUDA) {
torch::manual_seed(0);
// We better be able to save and load a XOR model!
- auto getLoss = [](Sequential model, uint32_t batch_size) {
+ auto getLoss = [](Sequential model, uint32_t batch_size, bool is_cuda=false) {
auto inputs = torch::empty({batch_size, 2});
auto labels = torch::empty({batch_size});
+ if (is_cuda) {
+ inputs = inputs.cuda();
+ labels = labels.cuda();
+ }
for (size_t i = 0; i < batch_size; i++) {
inputs[i] = torch::randint(2, {2}, torch::kInt64);
labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
ASSERT_LT(loss.item<float>(), 0.1);
model2->to(torch::kCUDA);
+ loss = getLoss(model2, 100, true);
+ ASSERT_LT(loss.item<float>(), 0.1);
+
auto tempfile2 = torch::utils::make_tempfile();
torch::save(model2, tempfile2.name);
torch::load(model3, tempfile2.name);
- loss = getLoss(model3, 100);
+ loss = getLoss(model3, 100, true);
ASSERT_LT(loss.item<float>(), 0.1);
}
#pragma once
+#include <c10/util/Optional.h>
+#include <c10/Device.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/types.h>
#include <iosfwd>
#include <memory>
void read(const std::string& key, InputArchive& archive);
/// Loads the `InputArchive` from a serialized representation stored in the
- /// file at `filename`.
- void load_from(const std::string& filename);
+ /// file at `filename`. Storage are remapped using device option. If device
+ /// is not specified, the module is loaded to the original device.
+ void load_from(const std::string& filename,
+ c10::optional<torch::Device> device = c10::nullopt);
/// Loads the `InputArchive` from a serialized representation stored in the
- /// given `stream`.
- void load_from(std::istream& stream);
+ /// given `stream`. Storage are remapped using device option. If device
+ /// is not specified, the module is loaded to the original device.
+ void load_from(std::istream& stream,
+ c10::optional<torch::Device> device = c10::nullopt);
/// Forwards all arguments to `read()`.
/// Useful for generic code that can be re-used for both `InputArchive` and
// clang-format on
if (tensor.defined()) {
torch::NoGradGuard guard;
- tensor.set_(*read_tensor->slot());
+ if (tensor.device() != read_tensor->slot()->device()) {
+ tensor.set_data(autograd::Variable(*read_tensor->slot()).data());
+ } else {
+ tensor.set_(*read_tensor->slot());
+ }
} else {
tensor = std::move(*read_tensor->slot());
}
}
}
-void InputArchive::load_from(const std::string& filename) {
- module_ = torch::jit::load(filename);
+void InputArchive::load_from(const std::string& filename,
+ c10::optional<torch::Device> device /*= c10::nullopt*/) {
+ module_ = torch::jit::load(filename, device);
}
-void InputArchive::load_from(std::istream& stream) {
- module_ = torch::jit::load(stream);
+void InputArchive::load_from(std::istream& stream,
+ c10::optional<torch::Device> device /*= c10::nullopt*/) {
+ module_ = torch::jit::load(stream, device);
}
} // namespace serialize
} // namespace torch
deserializer.deserialize(module_lookup, device);
}
-std::shared_ptr<script::Module> load(std::istream& in) {
+std::shared_ptr<script::Module> load(std::istream& in,
+ c10::optional<at::Device> device) {
auto module = std::make_shared<script::Module>();
auto module_lookup = [&](const std::vector<std::string>& qualified_name) {
};
ScriptModuleDeserializer deserializer(&in);
- // TODO: add device support in C++ API
- deserializer.deserialize(module_lookup, c10::optional<at::Device>(at::Device("cpu")));
+ deserializer.deserialize(module_lookup, device);
return module;
}
-std::shared_ptr<script::Module> load(const std::string& filename) {
+std::shared_ptr<script::Module> load(const std::string& filename,
+ c10::optional<at::Device> device) {
std::ifstream in(filename, std::ios_base::binary);
AT_CHECK(! in.fail(), "load: could not open file ", filename);
- auto module = load(in);
+ auto module = load(in, device);
return module;
}
TORCH_API void import_ir_module(
ModuleLookup module_lookup,
const std::string& filename,
- c10::optional<c10::Device> device);
+ c10::optional<c10::Device> device = c10::nullopt);
TORCH_API void import_ir_module(
ModuleLookup module_lookup,
std::istream& in,
- c10::optional<c10::Device> device);
+ c10::optional<c10::Device> device = c10::nullopt);
/// Loads a serialized `script::Module` from the given `istream`.
///
/// The istream must contain a serialized `script::Module`, exported via
/// `torch::jit::ExportModule` in C++.
-TORCH_API std::shared_ptr<script::Module> load(std::istream& in);
+TORCH_API std::shared_ptr<script::Module> load(std::istream& in,
+ c10::optional<c10::Device> device = c10::nullopt);
/// Loads a serialized `script::Module` from the given `filename`.
///
/// The file stored at the location given in `filename` must contain a
/// serialized `script::Module`, exported either via `ScriptModule.save()` in
/// Python or `torch::jit::ExportModule` in C++.
-TORCH_API std::shared_ptr<script::Module> load(const std::string& filename);
+TORCH_API std::shared_ptr<script::Module> load(const std::string& filename,
+ c10::optional<c10::Device> device = c10::nullopt);
} // namespace jit
} // namespace torch