fixes multiprocessing serialization for integer nn.Parameter (#18639)
authorSoumith Chintala <soumith@gmail.com>
Tue, 2 Apr 2019 00:07:27 +0000 (17:07 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 2 Apr 2019 00:15:42 +0000 (17:15 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/17345
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18639

Differential Revision: D14711565

Pulled By: soumith

fbshipit-source-id: 0063ed138a215b95d6571dcd68b18569714abe19

test/test_multiprocessing.py
torch/multiprocessing/reductions.py

index ce764b0..c28165c 100644 (file)
@@ -120,6 +120,10 @@ def requires_grad_variable_sharing(queue, ready):
     queue.put(var.requires_grad)
 
 
+def integer_parameter_serialization(iparam):
+    iparam + 1
+
+
 def autograd_sharing(queue, ready, master_modified, device, is_parameter):
     var = queue.get()
     ready.set()
@@ -751,6 +755,16 @@ class TestMultiprocessing(TestCase):
         param = Parameter(torch.arange(1., 26, device='cuda').view(5, 5))
         self._test_autograd_sharing(param, mp.get_context('spawn'), is_parameter=True)
 
+    @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
+                     don't support multiprocessing with spawn start method")
+    def test_integer_parameter_serialization(self):
+        iparam = torch.nn.Parameter(torch.tensor(0, dtype=torch.int64), requires_grad=False)
+
+        ctx = mp.get_context('spawn')
+        p = ctx.Process(target=integer_parameter_serialization, args=(iparam,))
+        p.start()
+        p.join()
+
     def test_empty_shared(self):
         t = torch.Tensor()
         t.share_memory_()
index f93054c..6cfdba0 100644 (file)
@@ -79,8 +79,12 @@ def rebuild_tensor(cls, storage, metadata):
     storage_offset, size, stride, requires_grad = metadata
     t = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
     if cls == torch.nn.parameter.Parameter:
-        t = torch.nn.parameter.Parameter(t)
-    t.requires_grad = requires_grad
+        # we have to pass requires_grad into constructor, rather than set it as an
+        # attribute later, because it's an important check for Integer Tensors to
+        # have requires_grad=False (or else they raise an error)
+        t = torch.nn.parameter.Parameter(t, requires_grad=requires_grad)
+    else:
+        t.requires_grad = requires_grad
     return t