From 4327a2d70afba3cbb099ddca002fb2a2949f4579 Mon Sep 17 00:00:00 2001 From: Peter Goldsborough Date: Fri, 14 Dec 2018 08:29:15 -0800 Subject: [PATCH] Better tests/support for Python/C++ inter-op (#15193) Summary: Methods like `module.named_modules()` returns a container of `shared_ptr`. Currently the `nn::Module` base class does not have Python bindings. This PR fixes this, and adds more unit tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15193 Differential Revision: D13458713 Pulled By: goldsborough fbshipit-source-id: 4091fe1b96a1be8db14c6a4307fbacc2b41ff6fe --- test/cpp_extensions/cpp_frontend_extension.cpp | 13 +++++++++---- test/test_cpp_extensions.py | 23 +++++++++++++++++++++-- torch/csrc/api/include/torch/python.h | 16 +++++++++++----- torch/csrc/api/src/python/init.cpp | 6 +++++- 4 files changed, 46 insertions(+), 12 deletions(-) diff --git a/test/cpp_extensions/cpp_frontend_extension.cpp b/test/cpp_extensions/cpp_frontend_extension.cpp index 2b11ffa..9c768b5 100644 --- a/test/cpp_extensions/cpp_frontend_extension.cpp +++ b/test/cpp_extensions/cpp_frontend_extension.cpp @@ -3,9 +3,13 @@ #include #include -struct Net : torch::nn::Module { - Net(int64_t in, int64_t out) : fc(in, out) { - register_module("fc", fc); +struct Net : torch::nn::Cloneable { + Net(int64_t in, int64_t out) : in_(in), out_(out) { + reset(); + } + + void reset() override { + fc = register_module("fc", torch::nn::Linear(in_, out_)); buffer = register_buffer("buf", torch::eye(5)); } @@ -34,7 +38,8 @@ struct Net : torch::nn::Module { register_module(name, torch::nn::Linear(fc->options)); } - torch::nn::Linear fc; + int64_t in_, out_; + torch::nn::Linear fc{nullptr}; torch::Tensor buffer; }; diff --git a/test/test_cpp_extensions.py b/test/test_cpp_extensions.py index b06541f..a30483c 100755 --- a/test/test_cpp_extensions.py +++ b/test/test_cpp_extensions.py @@ -422,14 +422,14 @@ class TestCppExtension(common.TestCase): super(M, self).__init__() self.x = torch.nn.Parameter(torch.tensor(1.0)) self.net = extension.Net(3, 5) - self.net.to(torch.float64) def forward(self, input): return self.net.forward(input) + self.x net = extension.Net(5, 2) - self.assertEqual(str(net), "Net") net.double() + net.to(torch.get_default_dtype()) + self.assertEqual(str(net), "Net") # Further embed the torch.nn.Module into a Sequential, and also add the # C++ module as an element of the Sequential. @@ -442,6 +442,25 @@ class TestCppExtension(common.TestCase): self.assertEqual(output, sequential(input)) self.assertEqual(list(output.shape), [2, 2]) + # Do changes on the module hierarchy. + old_dtype = torch.get_default_dtype() + sequential.to(torch.float64) + sequential.to(torch.float32) + sequential.to(old_dtype) + self.assertEqual(sequential[2].parameters()[0].dtype, old_dtype) + + # Make sure we can access these method recursively. + self.assertEqual(len(list(sequential.parameters())), len(net.parameters()) * 2 + 1) + self.assertEqual(len(list(sequential.named_parameters())), len(net.named_parameters()) * 2 + 1) + self.assertEqual(len(list(sequential.buffers())), len(net.buffers()) * 2) + self.assertEqual(len(list(sequential.modules())), 8) + + # Test clone() + net2 = net.clone() + self.assertEqual(len(net.parameters()), len(net2.parameters())) + self.assertEqual(len(net.buffers()), len(net2.buffers())) + self.assertEqual(len(net.modules()), len(net2.modules())) + # Try differentiating through the whole module. for parameter in net.parameters(): self.assertIsNone(parameter.grad) diff --git a/torch/csrc/api/include/torch/python.h b/torch/csrc/api/include/torch/python.h index 1da0c8f..eefc627 100644 --- a/torch/csrc/api/include/torch/python.h +++ b/torch/csrc/api/include/torch/python.h @@ -139,7 +139,7 @@ py::class_ add_module_bindings( }, py::arg("recurse") = true) .def_property_readonly( - "_modules", [](ModuleType& module) { return module.named_children(); }) + "_modules", [](ModuleType& module) { return module.named_children(); }) .def("modules", [](ModuleType& module) { return module.modules(); }) .def("named_modules", [](ModuleType& module, py::object /* unused */, std::string prefix) { @@ -166,10 +166,16 @@ py::class_ add_module_bindings( py::object device, py::object dtype, bool non_blocking) { - module.to( - detail::py_object_to_device(device), - detail::py_object_to_dtype(dtype), - non_blocking); + if (device.is_none()) { + module.to(detail::py_object_to_dtype(dtype), non_blocking); + } else if (dtype.is_none()) { + module.to(detail::py_object_to_device(device), non_blocking); + } else { + module.to( + detail::py_object_to_device(device), + detail::py_object_to_dtype(dtype), + non_blocking); + } }, py::arg("device"), py::arg("dtype"), diff --git a/torch/csrc/api/src/python/init.cpp b/torch/csrc/api/src/python/init.cpp index a3a1386..6e1d939 100644 --- a/torch/csrc/api/src/python/init.cpp +++ b/torch/csrc/api/src/python/init.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -35,6 +36,7 @@ ITEM_TYPE_CASTER(std::shared_ptr, Module); namespace torch { namespace python { +namespace { template void bind_ordered_dict(py::module module, const char* dict_name) { using ODict = OrderedDict; @@ -56,6 +58,7 @@ void bind_ordered_dict(py::module module, const char* dict_name) { }); // clang-format on } +} // namespace void init_bindings(PyObject* module) { py::module m = py::handle(module).cast(); @@ -65,7 +68,8 @@ void init_bindings(PyObject* module) { bind_ordered_dict>(cpp, "OrderedModuleDict"); py::module nn = cpp.def_submodule("nn"); - py::class_>(nn, "Module"); + add_module_bindings( + py::class_>(nn, "Module")); } } // namespace python } // namespace torch -- 2.7.4