Better tests/support for Python/C++ inter-op (#15193)
authorPeter Goldsborough <psag@fb.com>
Fri, 14 Dec 2018 16:29:15 +0000 (08:29 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 14 Dec 2018 16:42:10 +0000 (08:42 -0800)
Summary:
Methods like `module.named_modules()` returns a container of `shared_ptr<nn::Module>`. Currently the `nn::Module` base class does  not have Python bindings. This PR fixes this, and adds more unit tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15193

Differential Revision: D13458713

Pulled By: goldsborough

fbshipit-source-id: 4091fe1b96a1be8db14c6a4307fbacc2b41ff6fe

test/cpp_extensions/cpp_frontend_extension.cpp
test/test_cpp_extensions.py
torch/csrc/api/include/torch/python.h
torch/csrc/api/src/python/init.cpp

index 2b11ffa..9c768b5 100644 (file)
@@ -3,9 +3,13 @@
 #include <cstddef>
 #include <string>
 
-struct Net : torch::nn::Module {
-  Net(int64_t in, int64_t out) : fc(in, out) {
-    register_module("fc", fc);
+struct Net : torch::nn::Cloneable<Net> {
+  Net(int64_t in, int64_t out) : in_(in), out_(out) {
+    reset();
+  }
+
+  void reset() override {
+    fc = register_module("fc", torch::nn::Linear(in_, out_));
     buffer = register_buffer("buf", torch::eye(5));
   }
 
@@ -34,7 +38,8 @@ struct Net : torch::nn::Module {
     register_module(name, torch::nn::Linear(fc->options));
   }
 
-  torch::nn::Linear fc;
+  int64_t in_, out_;
+  torch::nn::Linear fc{nullptr};
   torch::Tensor buffer;
 };
 
index b06541f..a30483c 100755 (executable)
@@ -422,14 +422,14 @@ class TestCppExtension(common.TestCase):
                 super(M, self).__init__()
                 self.x = torch.nn.Parameter(torch.tensor(1.0))
                 self.net = extension.Net(3, 5)
-                self.net.to(torch.float64)
 
             def forward(self, input):
                 return self.net.forward(input) + self.x
 
         net = extension.Net(5, 2)
-        self.assertEqual(str(net), "Net")
         net.double()
+        net.to(torch.get_default_dtype())
+        self.assertEqual(str(net), "Net")
 
         # Further embed the torch.nn.Module into a Sequential, and also add the
         # C++ module as an element of the Sequential.
@@ -442,6 +442,25 @@ class TestCppExtension(common.TestCase):
         self.assertEqual(output, sequential(input))
         self.assertEqual(list(output.shape), [2, 2])
 
+        # Do changes on the module hierarchy.
+        old_dtype = torch.get_default_dtype()
+        sequential.to(torch.float64)
+        sequential.to(torch.float32)
+        sequential.to(old_dtype)
+        self.assertEqual(sequential[2].parameters()[0].dtype, old_dtype)
+
+        # Make sure we can access these method recursively.
+        self.assertEqual(len(list(sequential.parameters())), len(net.parameters()) * 2 + 1)
+        self.assertEqual(len(list(sequential.named_parameters())), len(net.named_parameters()) * 2 + 1)
+        self.assertEqual(len(list(sequential.buffers())), len(net.buffers()) * 2)
+        self.assertEqual(len(list(sequential.modules())), 8)
+
+        # Test clone()
+        net2 = net.clone()
+        self.assertEqual(len(net.parameters()), len(net2.parameters()))
+        self.assertEqual(len(net.buffers()), len(net2.buffers()))
+        self.assertEqual(len(net.modules()), len(net2.modules()))
+
         # Try differentiating through the whole module.
         for parameter in net.parameters():
             self.assertIsNone(parameter.grad)
index 1da0c8f..eefc627 100644 (file)
@@ -139,7 +139,7 @@ py::class_<ModuleType, Extra...> add_module_bindings(
           },
           py::arg("recurse") = true)
       .def_property_readonly(
-          "_modules", [](ModuleType& module) { return module.named_children(); })
+        "_modules", [](ModuleType& module) { return module.named_children(); })
       .def("modules", [](ModuleType& module) { return module.modules(); })
       .def("named_modules",
           [](ModuleType& module, py::object /* unused */, std::string prefix) {
@@ -166,10 +166,16 @@ py::class_<ModuleType, Extra...> add_module_bindings(
              py::object device,
              py::object dtype,
              bool non_blocking) {
-            module.to(
-                detail::py_object_to_device(device),
-                detail::py_object_to_dtype(dtype),
-                non_blocking);
+              if (device.is_none()) {
+                module.to(detail::py_object_to_dtype(dtype), non_blocking);
+              } else if (dtype.is_none()) {
+                module.to(detail::py_object_to_device(device), non_blocking);
+              } else {
+                module.to(
+                    detail::py_object_to_device(device),
+                    detail::py_object_to_dtype(dtype),
+                    non_blocking);
+              }
           },
           py::arg("device"),
           py::arg("dtype"),
index a3a1386..6e1d939 100644 (file)
@@ -1,4 +1,5 @@
 #include <torch/python/init.h>
+#include <torch/python.h>
 
 #include <torch/nn/module.h>
 #include <torch/ordered_dict.h>
@@ -35,6 +36,7 @@ ITEM_TYPE_CASTER(std::shared_ptr<torch::nn::Module>, Module);
 
 namespace torch {
 namespace python {
+namespace {
 template <typename T>
 void bind_ordered_dict(py::module module, const char* dict_name) {
   using ODict = OrderedDict<std::string, T>;
@@ -56,6 +58,7 @@ void bind_ordered_dict(py::module module, const char* dict_name) {
       });
   // clang-format on
 }
+} // namespace
 
 void init_bindings(PyObject* module) {
   py::module m = py::handle(module).cast<py::module>();
@@ -65,7 +68,8 @@ void init_bindings(PyObject* module) {
   bind_ordered_dict<std::shared_ptr<nn::Module>>(cpp, "OrderedModuleDict");
 
   py::module nn = cpp.def_submodule("nn");
-  py::class_<nn::Module, std::shared_ptr<nn::Module>>(nn, "Module");
+  add_module_bindings(
+      py::class_<nn::Module, std::shared_ptr<nn::Module>>(nn, "Module"));
 }
 } // namespace python
 } // namespace torch