Fix _apply in nn.Module (#15305)
authorPeter Goldsborough <psag@fb.com>
Tue, 18 Dec 2018 00:08:05 +0000 (16:08 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 18 Dec 2018 00:22:21 +0000 (16:22 -0800)
Summary:
Fixes an issue that arose from https://github.com/pytorch/pytorch/pull/13481 where `.shared_memory()` couldn't be called. Effectively undoes all changes to `nn.Module` from that PR and solve the relevant problem in a different way (the goal was to be able to call `._apply()` on the Python wrapper for a C++ module).

soumith
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15305

Differential Revision: D13493937

Pulled By: goldsborough

fbshipit-source-id: 4cb8687f90fc8709a536c5e7eacd0dc8edf6f750

test/test_cpp_extensions.py
test/test_nn.py
torch/nn/cpp.py
torch/nn/modules/module.py

index 42dec14..73e60a0 100755 (executable)
@@ -453,7 +453,7 @@ class TestCppExtension(common.TestCase):
         sequential.to(old_dtype)
         self.assertEqual(sequential[2].parameters()[0].dtype, old_dtype)
 
-        # Make sure we can access these method recursively.
+        # Make sure we can access these methods recursively.
         self.assertEqual(len(list(sequential.parameters())), len(net.parameters()) * 2 + 1)
         self.assertEqual(len(list(sequential.named_parameters())), len(net.named_parameters()) * 2 + 1)
         self.assertEqual(len(list(sequential.buffers())), len(net.buffers()) * 2)
@@ -556,6 +556,22 @@ class TestCppExtension(common.TestCase):
             self.assertTrue(p.device.index == 0)
             self.assertEqual(cpu_parameters[i], p)
 
+        net.cpu()
+        net.add_new_parameter("a", torch.eye(5))
+        net.add_new_parameter("b", torch.eye(5))
+        net.add_new_buffer("c", torch.eye(5))
+        net.add_new_buffer("d", torch.eye(5))
+        net.add_new_submodule("fc2")
+        net.add_new_submodule("fc3")
+
+        for p in net.parameters():
+            self.assertTrue(p.device.type == "cpu")
+
+        net.cuda()
+
+        for p in net.parameters():
+            self.assertTrue(p.device.type == "cuda")
+
     def test_returns_shared_library_path_when_is_python_module_is_true(self):
         source = """
         #include <torch/script.h>
index 0283db2..de31d73 100644 (file)
@@ -570,6 +570,28 @@ class TestNN(NNTestCase):
         input = torch.randn(2, 3, dtype=torch.float)
         self.assertEqual(m(input).size(), (2, 5))
 
+    def test_share_memory(self):
+        class Net(nn.Module):
+            def __init__(self):
+                super(Net, self).__init__()
+                self.p = nn.Parameter(torch.eye(5))
+                self.par = nn.ParameterList()
+                self.par.append(nn.Parameter(torch.randn(10)))
+
+            def forward(inp):
+                return inp.clone()
+
+        net = Net()
+        for p in net.parameters():
+            self.assertFalse(p.storage().is_shared())
+        for b in net.buffers():
+            self.assertFalse(b.storage().is_shared())
+        net.share_memory()
+        for p in net.parameters():
+            self.assertTrue(p.storage().is_shared())
+        for b in net.buffers():
+            self.assertTrue(b.storage().is_shared())
+
     def test_hooks(self):
         module = nn.Sigmoid()
         input = torch.ones(5, 5, requires_grad=True)
index 854d488..194c17b 100644 (file)
@@ -65,6 +65,19 @@ class ModuleWrapper(nn.Module):
             if not attr.startswith("_"):
                 setattr(self, attr, getattr(self.cpp_module, attr))
 
+    def _apply(self, fn):
+        for param in self.parameters():
+            # Tensors stored in modules are graph leaves, and we don't
+            # want to create copy nodes, so we have to unpack the data.
+            param.data = fn(param.data)
+            if param._grad is not None:
+                param._grad.data = fn(param._grad.data)
+
+        for buf in self.buffers():
+            buf.data = fn(buf.data)
+
+        return self
+
     @property
     def training(self):
         return self.cpp_module.training
index 832dff0..4de506d 100644 (file)
@@ -20,18 +20,6 @@ def _addindent(s_, numSpaces):
     return s
 
 
-def _if_float_tensor(fn):
-    '''
-    Calls `fn` on a value `t` only if `t` is a float tensor, or not a tensor (in
-    which case it's a module, as part of a recursive call to apply()).
-    '''
-    def apply(t):
-        if not isinstance(t, torch.Tensor) or t.is_floating_point():
-            return fn(t)
-        return t
-    return apply
-
-
 class Module(object):
     r"""Base class for all neural network modules.
 
@@ -196,7 +184,7 @@ class Module(object):
 
     def _apply(self, fn):
         for module in self.children():
-            fn(module)
+            module._apply(fn)
 
         for param in self._parameters.values():
             if param is not None:
@@ -296,7 +284,7 @@ class Module(object):
         Returns:
             Module: self
         """
-        return self._apply(_if_float_tensor(lambda t: t.float()))
+        return self._apply(lambda t: t.float() if t.is_floating_point() else t)
 
     def double(self):
         r"""Casts all floating point parameters and buffers to ``double`` datatype.
@@ -304,7 +292,7 @@ class Module(object):
         Returns:
             Module: self
         """
-        return self._apply(_if_float_tensor(lambda t: t.double()))
+        return self._apply(lambda t: t.double() if t.is_floating_point() else t)
 
     def half(self):
         r"""Casts all floating point parameters and buffers to ``half`` datatype.
@@ -312,7 +300,7 @@ class Module(object):
         Returns:
             Module: self
         """
-        return self._apply(_if_float_tensor(lambda t: t.half()))
+        return self._apply(lambda t: t.half() if t.is_floating_point() else t)
 
     def to(self, *args, **kwargs):
         r"""Moves and/or casts the parameters and buffers.
@@ -388,9 +376,7 @@ class Module(object):
                                 'dtypes, but got desired dtype={}'.format(dtype))
 
         def convert(t):
-            if isinstance(t, torch.Tensor):
-                return t.to(device, dtype if t.is_floating_point() else None, non_blocking)
-            return t.to(device, dtype, non_blocking)
+            return t.to(device, dtype if t.is_floating_point() else None, non_blocking)
 
         return self._apply(convert)