Remove outdated warning about RecursiveScriptModule not being copiable (#64085)
authorgmagogsfm <gmagogsfm@gmail.com>
Wed, 1 Sep 2021 04:27:46 +0000 (21:27 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 04:31:32 +0000 (21:31 -0700)
Summary:
RecursiveScriptModule has its customized `__copy__` and `__deepcopy__` defined. The warning/error  that says it is not copiable is outdated

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64085

Reviewed By: rohan-varma

Differential Revision: D30598623

Pulled By: gmagogsfm

fbshipit-source-id: 0701d8617f42d818bc7b88244caee4cd47fbe976

test/test_jit.py
torch/distributed/nn/api/remote_module.py
torch/jit/_script.py
torch/testing/_internal/distributed/nn/api/remote_module_test.py

index d1a170d..e94ed8d 100644 (file)
@@ -391,11 +391,6 @@ class TestJit(JitTestCase):
         self.assertFalse(m2.p0.is_cuda)
         self.assertFalse(m2.b0.is_cuda)
 
-    def test_model_save_error(self):
-        with TemporaryFileName() as fname:
-            with self.assertRaisesRegex(pickle.PickleError, "not supported"):
-                torch.save(FooToPickle(), fname)
-
     @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
     def test_restore_device_cuda(self):
         class MyModule(torch.jit.ScriptModule):
index ef26db6..fb3b160 100644 (file)
@@ -288,11 +288,13 @@ class _RemoteModule(nn.Module):
         """
         return self.module_rref
 
+    @torch.jit.export
     def __getstate__(self):
         raise RuntimeError(
             "Cannot pickle RemoteModule in python pickler. RemoteModule can only be pickled when using RPC"
         )
 
+    @torch.jit.export
     def __setstate__(self, state):
         raise RuntimeError(
             "Cannot unpickle RemoteModule in python pickler. RemoteModule can only be unpickled when using RPC"
index 09801ba..de32e1a 100644 (file)
@@ -785,13 +785,6 @@ if _enabled:
                 # It's fairly trivial to save enough info to warn in this case.
                 return super(RecursiveScriptModule, self).__setattr__(attr, value)
 
-        def __getstate__(self):
-            raise pickle.PickleError(
-                "ScriptModules cannot be deepcopied using copy.deepcopy or saved using torch.save. "
-                + "Mixed serialization of script and non-script modules is not supported. "
-                + "For purely script modules use my_script_module.save(<filename>) instead."
-            )
-
         def __copy__(self):
             return torch.jit._recursive.wrap_cpp_module(copy.copy(self._c))
 
index fb1d5fb..9970063 100644 (file)
@@ -1,6 +1,5 @@
 #!/usr/bin/python3
 import enum
-import pickle
 from typing import Tuple
 
 import torch
@@ -467,7 +466,7 @@ class RemoteModuleTest(CommonRemoteModuleTest):
             dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
         ):
             with TemporaryFileName() as fname:
-                with self.assertRaises(pickle.PickleError):
+                with self.assertRaisesRegex(torch.jit.Error, "can only be pickled when using RPC"):
                     torch.save(remote_module, fname)