From aec9fdf0a41fc304e1bd5424b9baf83101834083 Mon Sep 17 00:00:00 2001 From: Peter Goldsborough Date: Mon, 17 Dec 2018 16:08:05 -0800 Subject: [PATCH] Fix _apply in nn.Module (#15305) Summary: Fixes an issue that arose from https://github.com/pytorch/pytorch/pull/13481 where `.shared_memory()` couldn't be called. Effectively undoes all changes to `nn.Module` from that PR and solve the relevant problem in a different way (the goal was to be able to call `._apply()` on the Python wrapper for a C++ module). soumith Pull Request resolved: https://github.com/pytorch/pytorch/pull/15305 Differential Revision: D13493937 Pulled By: goldsborough fbshipit-source-id: 4cb8687f90fc8709a536c5e7eacd0dc8edf6f750 --- test/test_cpp_extensions.py | 18 +++++++++++++++++- test/test_nn.py | 22 ++++++++++++++++++++++ torch/nn/cpp.py | 13 +++++++++++++ torch/nn/modules/module.py | 24 +++++------------------- 4 files changed, 57 insertions(+), 20 deletions(-) diff --git a/test/test_cpp_extensions.py b/test/test_cpp_extensions.py index 42dec14..73e60a0 100755 --- a/test/test_cpp_extensions.py +++ b/test/test_cpp_extensions.py @@ -453,7 +453,7 @@ class TestCppExtension(common.TestCase): sequential.to(old_dtype) self.assertEqual(sequential[2].parameters()[0].dtype, old_dtype) - # Make sure we can access these method recursively. + # Make sure we can access these methods recursively. self.assertEqual(len(list(sequential.parameters())), len(net.parameters()) * 2 + 1) self.assertEqual(len(list(sequential.named_parameters())), len(net.named_parameters()) * 2 + 1) self.assertEqual(len(list(sequential.buffers())), len(net.buffers()) * 2) @@ -556,6 +556,22 @@ class TestCppExtension(common.TestCase): self.assertTrue(p.device.index == 0) self.assertEqual(cpu_parameters[i], p) + net.cpu() + net.add_new_parameter("a", torch.eye(5)) + net.add_new_parameter("b", torch.eye(5)) + net.add_new_buffer("c", torch.eye(5)) + net.add_new_buffer("d", torch.eye(5)) + net.add_new_submodule("fc2") + net.add_new_submodule("fc3") + + for p in net.parameters(): + self.assertTrue(p.device.type == "cpu") + + net.cuda() + + for p in net.parameters(): + self.assertTrue(p.device.type == "cuda") + def test_returns_shared_library_path_when_is_python_module_is_true(self): source = """ #include diff --git a/test/test_nn.py b/test/test_nn.py index 0283db2..de31d73 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -570,6 +570,28 @@ class TestNN(NNTestCase): input = torch.randn(2, 3, dtype=torch.float) self.assertEqual(m(input).size(), (2, 5)) + def test_share_memory(self): + class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.p = nn.Parameter(torch.eye(5)) + self.par = nn.ParameterList() + self.par.append(nn.Parameter(torch.randn(10))) + + def forward(inp): + return inp.clone() + + net = Net() + for p in net.parameters(): + self.assertFalse(p.storage().is_shared()) + for b in net.buffers(): + self.assertFalse(b.storage().is_shared()) + net.share_memory() + for p in net.parameters(): + self.assertTrue(p.storage().is_shared()) + for b in net.buffers(): + self.assertTrue(b.storage().is_shared()) + def test_hooks(self): module = nn.Sigmoid() input = torch.ones(5, 5, requires_grad=True) diff --git a/torch/nn/cpp.py b/torch/nn/cpp.py index 854d488..194c17b 100644 --- a/torch/nn/cpp.py +++ b/torch/nn/cpp.py @@ -65,6 +65,19 @@ class ModuleWrapper(nn.Module): if not attr.startswith("_"): setattr(self, attr, getattr(self.cpp_module, attr)) + def _apply(self, fn): + for param in self.parameters(): + # Tensors stored in modules are graph leaves, and we don't + # want to create copy nodes, so we have to unpack the data. + param.data = fn(param.data) + if param._grad is not None: + param._grad.data = fn(param._grad.data) + + for buf in self.buffers(): + buf.data = fn(buf.data) + + return self + @property def training(self): return self.cpp_module.training diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 832dff0..4de506d 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -20,18 +20,6 @@ def _addindent(s_, numSpaces): return s -def _if_float_tensor(fn): - ''' - Calls `fn` on a value `t` only if `t` is a float tensor, or not a tensor (in - which case it's a module, as part of a recursive call to apply()). - ''' - def apply(t): - if not isinstance(t, torch.Tensor) or t.is_floating_point(): - return fn(t) - return t - return apply - - class Module(object): r"""Base class for all neural network modules. @@ -196,7 +184,7 @@ class Module(object): def _apply(self, fn): for module in self.children(): - fn(module) + module._apply(fn) for param in self._parameters.values(): if param is not None: @@ -296,7 +284,7 @@ class Module(object): Returns: Module: self """ - return self._apply(_if_float_tensor(lambda t: t.float())) + return self._apply(lambda t: t.float() if t.is_floating_point() else t) def double(self): r"""Casts all floating point parameters and buffers to ``double`` datatype. @@ -304,7 +292,7 @@ class Module(object): Returns: Module: self """ - return self._apply(_if_float_tensor(lambda t: t.double())) + return self._apply(lambda t: t.double() if t.is_floating_point() else t) def half(self): r"""Casts all floating point parameters and buffers to ``half`` datatype. @@ -312,7 +300,7 @@ class Module(object): Returns: Module: self """ - return self._apply(_if_float_tensor(lambda t: t.half())) + return self._apply(lambda t: t.half() if t.is_floating_point() else t) def to(self, *args, **kwargs): r"""Moves and/or casts the parameters and buffers. @@ -388,9 +376,7 @@ class Module(object): 'dtypes, but got desired dtype={}'.format(dtype)) def convert(t): - if isinstance(t, torch.Tensor): - return t.to(device, dtype if t.is_floating_point() else None, non_blocking) - return t.to(device, dtype, non_blocking) + return t.to(device, dtype if t.is_floating_point() else None, non_blocking) return self._apply(convert) -- 2.7.4