#include <cstddef>
#include <string>
-struct Net : torch::nn::Module {
- Net(int64_t in, int64_t out) : fc(in, out) {
- register_module("fc", fc);
+struct Net : torch::nn::Cloneable<Net> {
+ Net(int64_t in, int64_t out) : in_(in), out_(out) {
+ reset();
+ }
+
+ void reset() override {
+ fc = register_module("fc", torch::nn::Linear(in_, out_));
buffer = register_buffer("buf", torch::eye(5));
}
register_module(name, torch::nn::Linear(fc->options));
}
- torch::nn::Linear fc;
+ int64_t in_, out_;
+ torch::nn::Linear fc{nullptr};
torch::Tensor buffer;
};
super(M, self).__init__()
self.x = torch.nn.Parameter(torch.tensor(1.0))
self.net = extension.Net(3, 5)
- self.net.to(torch.float64)
def forward(self, input):
return self.net.forward(input) + self.x
net = extension.Net(5, 2)
- self.assertEqual(str(net), "Net")
net.double()
+ net.to(torch.get_default_dtype())
+ self.assertEqual(str(net), "Net")
# Further embed the torch.nn.Module into a Sequential, and also add the
# C++ module as an element of the Sequential.
self.assertEqual(output, sequential(input))
self.assertEqual(list(output.shape), [2, 2])
+ # Do changes on the module hierarchy.
+ old_dtype = torch.get_default_dtype()
+ sequential.to(torch.float64)
+ sequential.to(torch.float32)
+ sequential.to(old_dtype)
+ self.assertEqual(sequential[2].parameters()[0].dtype, old_dtype)
+
+ # Make sure we can access these method recursively.
+ self.assertEqual(len(list(sequential.parameters())), len(net.parameters()) * 2 + 1)
+ self.assertEqual(len(list(sequential.named_parameters())), len(net.named_parameters()) * 2 + 1)
+ self.assertEqual(len(list(sequential.buffers())), len(net.buffers()) * 2)
+ self.assertEqual(len(list(sequential.modules())), 8)
+
+ # Test clone()
+ net2 = net.clone()
+ self.assertEqual(len(net.parameters()), len(net2.parameters()))
+ self.assertEqual(len(net.buffers()), len(net2.buffers()))
+ self.assertEqual(len(net.modules()), len(net2.modules()))
+
# Try differentiating through the whole module.
for parameter in net.parameters():
self.assertIsNone(parameter.grad)
},
py::arg("recurse") = true)
.def_property_readonly(
- "_modules", [](ModuleType& module) { return module.named_children(); })
+ "_modules", [](ModuleType& module) { return module.named_children(); })
.def("modules", [](ModuleType& module) { return module.modules(); })
.def("named_modules",
[](ModuleType& module, py::object /* unused */, std::string prefix) {
py::object device,
py::object dtype,
bool non_blocking) {
- module.to(
- detail::py_object_to_device(device),
- detail::py_object_to_dtype(dtype),
- non_blocking);
+ if (device.is_none()) {
+ module.to(detail::py_object_to_dtype(dtype), non_blocking);
+ } else if (dtype.is_none()) {
+ module.to(detail::py_object_to_device(device), non_blocking);
+ } else {
+ module.to(
+ detail::py_object_to_device(device),
+ detail::py_object_to_dtype(dtype),
+ non_blocking);
+ }
},
py::arg("device"),
py::arg("dtype"),
#include <torch/python/init.h>
+#include <torch/python.h>
#include <torch/nn/module.h>
#include <torch/ordered_dict.h>
namespace torch {
namespace python {
+namespace {
template <typename T>
void bind_ordered_dict(py::module module, const char* dict_name) {
using ODict = OrderedDict<std::string, T>;
});
// clang-format on
}
+} // namespace
void init_bindings(PyObject* module) {
py::module m = py::handle(module).cast<py::module>();
bind_ordered_dict<std::shared_ptr<nn::Module>>(cpp, "OrderedModuleDict");
py::module nn = cpp.def_submodule("nn");
- py::class_<nn::Module, std::shared_ptr<nn::Module>>(nn, "Module");
+ add_module_bindings(
+ py::class_<nn::Module, std::shared_ptr<nn::Module>>(nn, "Module"));
}
} // namespace python
} // namespace torch