From 3a98462f2c0887d3d268250d3e650f8684c65a92 Mon Sep 17 00:00:00 2001 From: Zachary DeVito Date: Mon, 17 Dec 2018 21:11:30 -0800 Subject: [PATCH] improve script/no script save error (#15321) Summary: Improves the error message for #15116 Pull Request resolved: https://github.com/pytorch/pytorch/pull/15321 Differential Revision: D13499379 Pulled By: zdevito fbshipit-source-id: b8dc0a83efabff74199f4aab2ee98aa41c42608b --- test/test_jit.py | 12 ++++++++++++ torch/jit/__init__.py | 8 +++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/test/test_jit.py b/test/test_jit.py index 86d3410..8c1acf0 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -30,6 +30,7 @@ import shutil import warnings import math import types +import pickle from common_methods_invocations import method_tests as autograd_method_tests from common_methods_invocations import create_input, unpack_variables, \ @@ -497,6 +498,13 @@ class JitTestCase(TestCase): return results +# has to be at top level or Pickle complains +class FooToPickle(torch.nn.Module): + def __init__(self): + super(FooToPickle, self).__init__() + self.bar = torch.jit.ScriptModule() + + class TestJit(JitTestCase): @unittest.skip("Requires a lot of RAM") @@ -541,6 +549,10 @@ class TestJit(JitTestCase): self.assertFalse(m2.p0.is_cuda) self.assertFalse(m2.b0.is_cuda) + def test_model_save_error(self): + with self.assertRaisesRegex(pickle.PickleError, "not supported"): + torch.save(FooToPickle(), "will_fail") + @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA") def test_restore_device_cuda(self): class MyModule(torch.jit.ScriptModule): diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 03dc011..dc1499f 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -30,6 +30,7 @@ import numbers import collections import re import inspect +import pickle if sys.version_info[0] > 2: import pathlib @@ -1154,6 +1155,12 @@ if _enabled: self._copy_into(module_lookup, []) return m + def __getstate__(self): + raise pickle.PickleError( + "ScriptModules cannot be saved using torch.save. " + + "Mixed serialization of script and non-script modules is not supported. " + + "For purely script modules use my_script_module.save() instead.") + class WeakScriptModuleProxy(ScriptModule): def __init__(self, original, stubs): # Guards behavior of __setattr__ and __getattr__ so ScriptModule @@ -1210,7 +1217,6 @@ if _enabled: raise AttributeError("Cannot set new attribute '{}' on " "weak script module once it has been " "created".format(attr)) - else: ScriptModule = torch.nn.Module -- 2.7.4