import warnings
import math
import types
+import pickle
from common_methods_invocations import method_tests as autograd_method_tests
from common_methods_invocations import create_input, unpack_variables, \
return results
+# has to be at top level or Pickle complains
+class FooToPickle(torch.nn.Module):
+ def __init__(self):
+ super(FooToPickle, self).__init__()
+ self.bar = torch.jit.ScriptModule()
+
+
class TestJit(JitTestCase):
@unittest.skip("Requires a lot of RAM")
self.assertFalse(m2.p0.is_cuda)
self.assertFalse(m2.b0.is_cuda)
+ def test_model_save_error(self):
+ with self.assertRaisesRegex(pickle.PickleError, "not supported"):
+ torch.save(FooToPickle(), "will_fail")
+
@unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
def test_restore_device_cuda(self):
class MyModule(torch.jit.ScriptModule):
import collections
import re
import inspect
+import pickle
if sys.version_info[0] > 2:
import pathlib
self._copy_into(module_lookup, [])
return m
+ def __getstate__(self):
+ raise pickle.PickleError(
+ "ScriptModules cannot be saved using torch.save. " +
+ "Mixed serialization of script and non-script modules is not supported. " +
+ "For purely script modules use my_script_module.save(<filename>) instead.")
+
class WeakScriptModuleProxy(ScriptModule):
def __init__(self, original, stubs):
# Guards behavior of __setattr__ and __getattr__ so ScriptModule
raise AttributeError("Cannot set new attribute '{}' on "
"weak script module once it has been "
"created".format(attr))
-
else:
ScriptModule = torch.nn.Module