From 24db1667daeb84a9eb351ec6b6264de22213a910 Mon Sep 17 00:00:00 2001 From: David Riazati Date: Fri, 29 Mar 2019 19:06:06 -0700 Subject: [PATCH] Attribute serialization improvements (#18188) Summary: * adds attributes to `ScriptModule.__getattr__` so they can be accessed in Python after re-importing * full support for all the possible values for an `int64_t` * this necessitated a bunch more `pushWhatever` functions, so re-introduced a templated version to cut down on duplicate code * tests to validate references / value sharing works * adds `torch.jit.Unpickler` which people can use to de-serialize the pickle files into Python / have a quick reference on how to do this without PyTorch Pull Request resolved: https://github.com/pytorch/pytorch/pull/18188 Differential Revision: D14527490 Pulled By: driazati fbshipit-source-id: efd15579cc04aa2e28c4b2c9490d82d849dee559 --- test/test_jit.py | 70 ++++++++++++++++- torch/_six.py | 7 ++ torch/csrc/jit/passes/python_print.cpp | 6 +- torch/csrc/jit/pickler.cpp | 134 +++++++++++++++++---------------- torch/csrc/jit/pickler.h | 14 +++- torch/csrc/jit/script/init.cpp | 6 ++ torch/csrc/jit/script/module.h | 3 + torch/jit/__init__.py | 3 + torch/jit/_pickle.py | 26 +++++++ 9 files changed, 194 insertions(+), 75 deletions(-) create mode 100644 torch/jit/_pickle.py diff --git a/test/test_jit.py b/test/test_jit.py index 4599c96..3f33683 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -16,7 +16,7 @@ from torch.nn import Module from torch.autograd.function import traceable from torch.testing import assert_allclose from torch.onnx import OperatorExportTypes -from torch._six import inf, PY2, builtins +from torch._six import inf, PY2, builtins, StringIO from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \ skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \ freeze_rng_state, set_rng_seed, slowTest @@ -37,7 +37,10 @@ import warnings import math import types import pickle +import pickletools import copy +import zipfile + from common_methods_invocations import method_tests as autograd_method_tests from common_methods_invocations import create_input, unpack_variables, \ @@ -10488,8 +10491,6 @@ a") @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle") def test_attribute_unpickling(self): - import zipfile - class M(torch.jit.ScriptModule): def __init__(self): super(M, self).__init__() @@ -10557,6 +10558,69 @@ a") imported_m = self.getExportImportCopy(m) self.assertEqual(m(), imported_m()) + def test_serialization_big_ints(self): + class M(torch.jit.ScriptModule): + def __init__(self): + super(M, self).__init__() + self.int32_max = torch.jit.Attribute(2**31 - 1, int) + self.int32_min = torch.jit.Attribute(-2**31, int) + self.uint32_max = torch.jit.Attribute(2**32, int) + + self.int64_max = torch.jit.Attribute(2**63 - 1, int) + self.int64_min = torch.jit.Attribute(-2**63, int) + + self.tensor = torch.nn.Parameter(torch.ones(2, 2)) + + @torch.jit.script_method + def forward(self, x): + # type: (int) -> (int) + return x + (self.int32_max + self.int32_min) + (self.int64_max + self.int64_min) + + m = M() + imported = self.getExportImportCopy(m) + self.assertEqual(m(10), imported(10)) + + self.assertEqual(m.int32_max, imported.int32_max) + self.assertEqual(m.int32_min, imported.int32_min) + self.assertEqual(m.uint32_max, imported.uint32_max) + self.assertEqual(m.int64_max, imported.int64_max) + self.assertEqual(m.int64_min, imported.int64_min) + + def test_serialization_sharing(self): + class M(torch.jit.ScriptModule): + def __init__(self): + super(M, self).__init__() + self.list = torch.jit.Attribute([], List[str]) + + @torch.jit.script_method + def forward(self, key): + # type: (str) -> List[str] + self.list.append(key) + self.list.append(key) + self.list.append(key) + return self.list + + # the text of the string should only appear once in the pickling + m = M() + s1 = "a long string" + s2 = "a different, even longer string" + self.assertEqual(m(s1), [s1] * 3) + self.assertEqual(m(s2), [s1] * 3 + [s2] * 3) + with TemporaryFileName() as fname: + m.save(fname) + archive_name = os.path.basename(os.path.normpath(fname)) + archive = zipfile.ZipFile(fname, 'r') + pickled_data = archive.read(os.path.join(archive_name, 'attributes.pkl')) + + out = StringIO() + pickletools.dis(pickled_data, out=out) + disassembled = out.getvalue() + + FileCheck().check_count(s1, 1, exactly=True) \ + .check_count("BINGET", 2, exactly=True) \ + .check_count(s2, 1, exactly=True) \ + .check_count("BINGET", 2, exactly=True).run(out.getvalue()) + def test_optional_tuple(self): def fn(x=None): # type: (Optional[Tuple[int, int]]) -> Tuple[int, int] diff --git a/torch/_six.py b/torch/_six.py index 5eeb18a..b062114 100644 --- a/torch/_six.py +++ b/torch/_six.py @@ -137,6 +137,13 @@ if PY2: elif PY3: import builtins +if PY2: + import StringIO + StringIO = StringIO.StringIO +elif PY3: + import io + StringIO = io.StringIO + # The codes below is not copied from the six package, so the copyright # declaration at the beginning does not apply. diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index 3e37054..d515014 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -827,9 +827,9 @@ struct PythonPrintPass { if (enforce_importable_) { throw script::ErrorReport(node->getSourceLocation()) << "could not export python function call " << value->name() - << ". Remove calls to Python functions before export." - << "Did you forget add @script annotation? " - << "If this is a modulelist, add it to __constants__."; + << ". Remove calls to Python functions before export. " + << "Did you forget add @script or @script_method annotation? " + << "If this is a nn.ModuleList, add it to __constants__."; } stmt << "^" << value->name(); diff --git a/torch/csrc/jit/pickler.cpp b/torch/csrc/jit/pickler.cpp index 3da3132..df95bc4 100644 --- a/torch/csrc/jit/pickler.cpp +++ b/torch/csrc/jit/pickler.cpp @@ -37,18 +37,18 @@ const std::vector& Pickler::stack() { } void Pickler::start() { - pushOpCode(OpCode::PROTO); - pushUint8(2); + push(OpCode::PROTO); + push(2); // All attributes get pushed into a list and their indices saved in the // module def - pushOpCode(OpCode::EMPTY_LIST); - pushOpCode(OpCode::MARK); + push(OpCode::EMPTY_LIST); + push(OpCode::MARK); } void Pickler::finish() { - pushOpCode(OpCode::APPENDS); - pushOpCode(OpCode::STOP); + push(OpCode::APPENDS); + push(OpCode::STOP); } void Pickler::addIValue(const IValue& ivalue) { @@ -70,17 +70,12 @@ void Pickler::addIValue(const IValue& ivalue) { } else if (ivalue.isDouble()) { pushDouble(ivalue); } else if (ivalue.isInt()) { - // TODO: use BININT1/BININT2/LONG if possible/necessary - AT_ASSERT( - ivalue.toInt() <= std::numeric_limits::max() && - ivalue.toInt() >= std::numeric_limits::min()); - pushOpCode(OpCode::BININT); - pushInt32(ivalue.toInt()); + pushInt(ivalue); } else if (ivalue.isBool()) { if (ivalue.toBool()) { - pushOpCode(OpCode::NEWTRUE); + push(OpCode::NEWTRUE); } else { - pushOpCode(OpCode::NEWFALSE); + push(OpCode::NEWFALSE); } } else if (ivalue.isString()) { pushMemoizedString(ivalue); @@ -89,7 +84,7 @@ void Pickler::addIValue(const IValue& ivalue) { } else if (ivalue.isGenericDict()) { pushDict(ivalue); } else if (ivalue.isNone()) { - pushOpCode(OpCode::NONE); + push(OpCode::NONE); } else if (ivalue.isIntList()) { pushIntList(ivalue); } else { @@ -113,22 +108,41 @@ const void* Pickler::getPointer(const IValue& ivalue) { return nullptr; } +void Pickler::pushInt(const IValue& ivalue) { + auto n = ivalue.toInt(); + if (n >= std::numeric_limits::min() && + n <= std::numeric_limits::max()) { + push(OpCode::BININT1); + push(n); + } else if ( + n >= std::numeric_limits::min() && + n <= std::numeric_limits::max()) { + push(OpCode::BININT); + push(n); + } else { + // Push 8 byte integer + push(OpCode::LONG1); + push(8); + push(n); + } +} + void Pickler::pushBinGet(uint32_t memo_id) { if (memo_id <= std::numeric_limits::max()) { - pushOpCode(OpCode::BINGET); - pushUint8(memo_id); + push(OpCode::BINGET); + push(memo_id); } else { // Memoized too many items, issue a LONG_BINGET instead - pushOpCode(OpCode::LONG_BINGET); - pushUint32(memo_id); + push(OpCode::LONG_BINGET); + push(memo_id); } } void Pickler::pushMemoizedString(const IValue& ivalue) { const auto& string = ivalue.toStringRef(); - pushOpCode(OpCode::BINUNICODE); - pushUint32(string.size()); + push(OpCode::BINUNICODE); + push(string.size()); pushString(string); pushMemoization(ivalue); } @@ -142,7 +156,7 @@ void Pickler::pushClass(PicklerClass cls) { // Write it to the tensor table auto memo_entry = memo_.find(&name); if (memo_entry == memo_.end()) { - pushOpCode(OpCode::GLOBAL); + push(OpCode::GLOBAL); // Module name + "\n" pushString(getModuleName()); // Class name + "\n" @@ -152,8 +166,8 @@ void Pickler::pushClass(PicklerClass cls) { pushBinGet(memo_entry->second); } - pushOpCode(OpCode::EMPTY_TUPLE); - pushOpCode(OpCode::NEWOBJ); + push(OpCode::EMPTY_TUPLE); + push(OpCode::NEWOBJ); } void Pickler::pushTensor(const IValue& ivalue) { @@ -161,25 +175,25 @@ void Pickler::pushTensor(const IValue& ivalue) { tensor_table_->push_back(ivalue.toTensor()); auto tensor_id = tensor_table_->size() - 1; - pushOpCode(OpCode::BININT); - pushUint32(tensor_id); + push(OpCode::BININT); + push(tensor_id); - pushOpCode(OpCode::BUILD); + push(OpCode::BUILD); } void Pickler::pushIntList(const IValue& ivalue) { pushClass(PicklerClass::INTLIST); - pushOpCode(OpCode::EMPTY_LIST); + push(OpCode::EMPTY_LIST); pushMemoization(ivalue); - pushOpCode(OpCode::MARK); + push(OpCode::MARK); for (const auto& item : ivalue.toIntListRef()) { addIValue(item); } - pushOpCode(OpCode::APPENDS); - pushOpCode(OpCode::BUILD); + push(OpCode::APPENDS); + push(OpCode::BUILD); } void Pickler::pushDouble(const IValue& ivalue) { @@ -187,9 +201,9 @@ void Pickler::pushDouble(const IValue& ivalue) { AT_ASSERT(sizeof(double) == 8); char* bytes = reinterpret_cast(&value); - pushOpCode(OpCode::BINFLOAT); + push(OpCode::BINFLOAT); for (size_t i = 0; i < 8; ++i) { - pushUint8(bytes[8 - i - 1]); + push(bytes[8 - i - 1]); } } @@ -213,10 +227,10 @@ struct IValuePairComparator { void Pickler::pushDict(const IValue& ivalue) { auto dict = ivalue.toGenericDictRef(); - pushOpCode(OpCode::EMPTY_DICT); + push(OpCode::EMPTY_DICT); pushMemoization(ivalue); - pushOpCode(OpCode::MARK); + push(OpCode::MARK); // Sort the dict for deterministic keys std::vector> dict_items(dict.begin(), dict.end()); @@ -227,18 +241,18 @@ void Pickler::pushDict(const IValue& ivalue) { addIValue(pair.second); } - pushOpCode(OpCode::SETITEMS); + push(OpCode::SETITEMS); } void Pickler::pushMemoization(const void* item) { AT_ASSERT(item != nullptr); if (memo_id <= std::numeric_limits::max()) { - pushOpCode(OpCode::BINPUT); - pushUint8(memo_id); + push(OpCode::BINPUT); + push(memo_id); } else { // Memoized too many items, issue a LONG_BINPUT instead - pushOpCode(OpCode::LONG_BINPUT); - pushUint32(memo_id); + push(OpCode::LONG_BINPUT); + push(memo_id); } memo_[item] = memo_id; AT_ASSERT(memo_id <= std::numeric_limits::max()); @@ -251,51 +265,31 @@ void Pickler::pushMemoization(const IValue& ivalue) { void Pickler::pushList(const IValue& ivalue) { auto list = ivalue.toGenericListRef(); - pushOpCode(OpCode::EMPTY_LIST); + push(OpCode::EMPTY_LIST); pushMemoization(ivalue); - pushOpCode(OpCode::MARK); + push(OpCode::MARK); for (const auto& item : list) { addIValue(item); } - pushOpCode(OpCode::APPENDS); + push(OpCode::APPENDS); } void Pickler::pushTuple(const IValue& ivalue) { // TODO: Small tuple unrolling (e.g. TUPLE3) - pushOpCode(OpCode::MARK); + push(OpCode::MARK); auto tuple = ivalue.toTuple()->elements(); for (const auto& item : tuple) { addIValue(item); } - pushOpCode(OpCode::TUPLE); + push(OpCode::TUPLE); pushMemoization(ivalue); } -void Pickler::pushUint8(uint8_t value) { - const char* begin = reinterpret_cast(&value); - stack_.insert(stack_.end(), begin, begin + sizeof(uint8_t)); -} - -void Pickler::pushOpCode(OpCode value) { - const char* begin = reinterpret_cast(&value); - stack_.insert(stack_.end(), begin, begin + sizeof(OpCode)); -} - -void Pickler::pushUint32(uint32_t value) { - const char* begin = reinterpret_cast(&value); - stack_.insert(stack_.end(), begin, begin + sizeof(uint32_t)); -} - -void Pickler::pushInt32(int32_t value) { - const char* begin = reinterpret_cast(&value); - stack_.insert(stack_.end(), begin, begin + sizeof(int32_t)); -} - std::vector Unpickler::parse_ivalue_list() { run(); AT_ASSERT(stack_.size() == 1); @@ -367,10 +361,20 @@ OpCode Unpickler::readInstruction() { // Mark location of the container ivalue in the stack marks_.push_back(stack_.size()); } break; + case OpCode::BININT1: { + int8_t value = read(); + stack_.emplace_back(int64_t(value)); + } break; case OpCode::BININT: { int32_t value = read(); stack_.emplace_back(int64_t(value)); } break; + case OpCode::LONG1: { + // Only read LONG1s with 8 as the length + uint8_t length = read(); + AT_ASSERT(length == 8); + stack_.emplace_back(int64_t(read())); + } break; case OpCode::BINUNICODE: { uint32_t length = read(); const char* characters = reinterpret_cast(bytes_); diff --git a/torch/csrc/jit/pickler.h b/torch/csrc/jit/pickler.h index 2439aa6..ab35a70 100644 --- a/torch/csrc/jit/pickler.h +++ b/torch/csrc/jit/pickler.h @@ -112,12 +112,18 @@ class Pickler { void pushTuple(const IValue& ivalue); void pushDict(const IValue& ivalue); void pushClass(PicklerClass cls); + void pushInt(const IValue& ivalue); const void* getPointer(const IValue& ivalue); - void pushUint8(uint8_t value); - void pushOpCode(OpCode value); - void pushUint32(uint32_t value); - void pushInt32(int32_t value); + // These convert values to bytes and add them to the stack (NB: since T is to + // the left of a '::', its type cannot be deduced by the compiler so one must + // explicitly instantiate the template, i.e. push(int) works, push(int) + // does not) + template + void push(typename std::common_type::type value) { + const char* begin = reinterpret_cast(&value); + stack_.insert(stack_.end(), begin, begin + sizeof(T)); + } // Stack of opcodes/data std::vector stack_; diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 8175f35..de678e6 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -761,6 +761,7 @@ void initJitScriptBindings(PyObject* module) { .def("_set_parameter", &Module::set_parameter) .def("_get_parameter", &Module::get_parameter) .def("_get_buffer", &Module::get_buffer) + .def("_get_attribute", &Module::get_attribute) .def("_get_module", &Module::get_module) .def( "_get_modules", @@ -801,6 +802,11 @@ void initJitScriptBindings(PyObject* module) { return result; }) .def( + "_has_attribute", + [](Module& self, const std::string& name) -> bool { + return self.find_attribute(name); + }) + .def( "_has_parameter", [](Module& self, const std::string& name) -> bool { return self.find_parameter(name); diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index 6d4d746..d4417a7 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -480,6 +480,9 @@ struct Module { autograd::Variable get_buffer(const std::string& name) const { return autograd::as_variable_ref(attributes.find(name)->slot()->toTensor()); } + IValue get_attribute(const std::string& name) const { + return *attributes.find(name)->slot(); + } // each module owns its method. The reference returned here // is guarenteed to stay valid until this module has been destroyed diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index ec895cf..2671432 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -10,6 +10,7 @@ import torch._jit_internal as _jit_internal from torch._six import raise_from, with_metaclass, get_function_from_type, \ string_classes from torch._jit_internal import ignore +from torch.jit._pickle import Unpickler from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \ _list_with_default import torch.testing @@ -1131,6 +1132,8 @@ if _enabled: return self._get_method(attr) if attr == 'graph' and self._has_method('forward'): return self.__getattr__('forward').graph + if self._has_attribute(attr): + return self._get_attribute(attr) return Module.__getattr__(self, attr) def __setattr__(self, attr, value): diff --git a/torch/jit/_pickle.py b/torch/jit/_pickle.py new file mode 100644 index 0000000..24e7bf3 --- /dev/null +++ b/torch/jit/_pickle.py @@ -0,0 +1,26 @@ +import torch +import functools +import pickle + + +class TensorID(object): + def __setstate__(self, id): + self.id = id + + +class IntList(object): + def __setstate__(self, data): + self.data = data + + +class Unpickler(pickle.Unpickler): + def find_class(self, module, name): + if not module == '__main__': + return None + + if name == 'TensorID': + return TensorID + elif name == 'IntList': + return IntList + elif name == 'LiteralTensor': + return LiteralTensor -- 2.7.4