improve script/no script save error (#15321)
authorZachary DeVito <zdevito@fb.com>
Tue, 18 Dec 2018 05:11:30 +0000 (21:11 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 18 Dec 2018 05:13:58 +0000 (21:13 -0800)
Summary:
Improves the error message for #15116
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15321

Differential Revision: D13499379

Pulled By: zdevito

fbshipit-source-id: b8dc0a83efabff74199f4aab2ee98aa41c42608b

test/test_jit.py
torch/jit/__init__.py

index 86d3410..8c1acf0 100644 (file)
@@ -30,6 +30,7 @@ import shutil
 import warnings
 import math
 import types
+import pickle
 
 from common_methods_invocations import method_tests as autograd_method_tests
 from common_methods_invocations import create_input, unpack_variables, \
@@ -497,6 +498,13 @@ class JitTestCase(TestCase):
         return results
 
 
+# has to be at top level or Pickle complains
+class FooToPickle(torch.nn.Module):
+    def __init__(self):
+        super(FooToPickle, self).__init__()
+        self.bar = torch.jit.ScriptModule()
+
+
 class TestJit(JitTestCase):
 
     @unittest.skip("Requires a lot of RAM")
@@ -541,6 +549,10 @@ class TestJit(JitTestCase):
         self.assertFalse(m2.p0.is_cuda)
         self.assertFalse(m2.b0.is_cuda)
 
+    def test_model_save_error(self):
+        with self.assertRaisesRegex(pickle.PickleError, "not supported"):
+            torch.save(FooToPickle(), "will_fail")
+
     @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
     def test_restore_device_cuda(self):
         class MyModule(torch.jit.ScriptModule):
index 03dc011..dc1499f 100644 (file)
@@ -30,6 +30,7 @@ import numbers
 import collections
 import re
 import inspect
+import pickle
 if sys.version_info[0] > 2:
     import pathlib
 
@@ -1154,6 +1155,12 @@ if _enabled:
             self._copy_into(module_lookup, [])
             return m
 
+        def __getstate__(self):
+            raise pickle.PickleError(
+                "ScriptModules cannot be saved using torch.save. " +
+                "Mixed serialization of script and non-script modules is not supported. " +
+                "For purely script modules use my_script_module.save(<filename>) instead.")
+
     class WeakScriptModuleProxy(ScriptModule):
         def __init__(self, original, stubs):
             # Guards behavior of __setattr__ and __getattr__ so ScriptModule
@@ -1210,7 +1217,6 @@ if _enabled:
                 raise AttributeError("Cannot set new attribute '{}' on "
                                      "weak script module once it has been "
                                      "created".format(attr))
-
 else:
     ScriptModule = torch.nn.Module