From 2e029db2f9a70959c0ba831d9d13e344ddab3c57 Mon Sep 17 00:00:00 2001 From: Soumith Chintala Date: Mon, 1 Apr 2019 17:07:27 -0700 Subject: [PATCH] fixes multiprocessing serialization for integer nn.Parameter (#18639) Summary: Fixes https://github.com/pytorch/pytorch/issues/17345 Pull Request resolved: https://github.com/pytorch/pytorch/pull/18639 Differential Revision: D14711565 Pulled By: soumith fbshipit-source-id: 0063ed138a215b95d6571dcd68b18569714abe19 --- test/test_multiprocessing.py | 14 ++++++++++++++ torch/multiprocessing/reductions.py | 8 ++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/test/test_multiprocessing.py b/test/test_multiprocessing.py index ce764b0..c28165c 100644 --- a/test/test_multiprocessing.py +++ b/test/test_multiprocessing.py @@ -120,6 +120,10 @@ def requires_grad_variable_sharing(queue, ready): queue.put(var.requires_grad) +def integer_parameter_serialization(iparam): + iparam + 1 + + def autograd_sharing(queue, ready, master_modified, device, is_parameter): var = queue.get() ready.set() @@ -751,6 +755,16 @@ class TestMultiprocessing(TestCase): param = Parameter(torch.arange(1., 26, device='cuda').view(5, 5)) self._test_autograd_sharing(param, mp.get_context('spawn'), is_parameter=True) + @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \ + don't support multiprocessing with spawn start method") + def test_integer_parameter_serialization(self): + iparam = torch.nn.Parameter(torch.tensor(0, dtype=torch.int64), requires_grad=False) + + ctx = mp.get_context('spawn') + p = ctx.Process(target=integer_parameter_serialization, args=(iparam,)) + p.start() + p.join() + def test_empty_shared(self): t = torch.Tensor() t.share_memory_() diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index f93054c..6cfdba0 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -79,8 +79,12 @@ def rebuild_tensor(cls, storage, metadata): storage_offset, size, stride, requires_grad = metadata t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride) if cls == torch.nn.parameter.Parameter: - t = torch.nn.parameter.Parameter(t) - t.requires_grad = requires_grad + # we have to pass requires_grad into constructor, rather than set it as an + # attribute later, because it's an important check for Integer Tensors to + # have requires_grad=False (or else they raise an error) + t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad) + else: + t.requires_grad = requires_grad return t -- 2.7.4