Attribute serialization improvements (#18188)
authorDavid Riazati <davidriazati@fb.com>
Sat, 30 Mar 2019 02:06:06 +0000 (19:06 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 30 Mar 2019 02:10:12 +0000 (19:10 -0700)
Summary:
* adds attributes to `ScriptModule.__getattr__` so they can be accessed in Python after re-importing
* full support for all the possible values for an `int64_t`
    * this necessitated a bunch more `pushWhatever` functions, so re-introduced a templated version to cut down on duplicate code
* tests to validate references / value sharing works
* adds `torch.jit.Unpickler` which people can use to de-serialize the pickle files into Python / have a quick reference on how to do this without PyTorch
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18188

Differential Revision: D14527490

Pulled By: driazati

fbshipit-source-id: efd15579cc04aa2e28c4b2c9490d82d849dee559

test/test_jit.py
torch/_six.py
torch/csrc/jit/passes/python_print.cpp
torch/csrc/jit/pickler.cpp
torch/csrc/jit/pickler.h
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/module.h
torch/jit/__init__.py
torch/jit/_pickle.py [new file with mode: 0644]

index 4599c96..3f33683 100644 (file)
@@ -16,7 +16,7 @@ from torch.nn import Module
 from torch.autograd.function import traceable
 from torch.testing import assert_allclose
 from torch.onnx import OperatorExportTypes
-from torch._six import inf, PY2, builtins
+from torch._six import inf, PY2, builtins, StringIO
 from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
     skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
     freeze_rng_state, set_rng_seed, slowTest
@@ -37,7 +37,10 @@ import warnings
 import math
 import types
 import pickle
+import pickletools
 import copy
+import zipfile
+
 
 from common_methods_invocations import method_tests as autograd_method_tests
 from common_methods_invocations import create_input, unpack_variables, \
@@ -10488,8 +10491,6 @@ a")
 
     @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
     def test_attribute_unpickling(self):
-        import zipfile
-
         class M(torch.jit.ScriptModule):
             def __init__(self):
                 super(M, self).__init__()
@@ -10557,6 +10558,69 @@ a")
         imported_m = self.getExportImportCopy(m)
         self.assertEqual(m(), imported_m())
 
+    def test_serialization_big_ints(self):
+        class M(torch.jit.ScriptModule):
+            def __init__(self):
+                super(M, self).__init__()
+                self.int32_max = torch.jit.Attribute(2**31 - 1, int)
+                self.int32_min = torch.jit.Attribute(-2**31, int)
+                self.uint32_max = torch.jit.Attribute(2**32, int)
+
+                self.int64_max = torch.jit.Attribute(2**63 - 1, int)
+                self.int64_min = torch.jit.Attribute(-2**63, int)
+
+                self.tensor = torch.nn.Parameter(torch.ones(2, 2))
+
+            @torch.jit.script_method
+            def forward(self, x):
+                # type: (int) -> (int)
+                return x + (self.int32_max + self.int32_min) + (self.int64_max + self.int64_min)
+
+        m = M()
+        imported = self.getExportImportCopy(m)
+        self.assertEqual(m(10), imported(10))
+
+        self.assertEqual(m.int32_max, imported.int32_max)
+        self.assertEqual(m.int32_min, imported.int32_min)
+        self.assertEqual(m.uint32_max, imported.uint32_max)
+        self.assertEqual(m.int64_max, imported.int64_max)
+        self.assertEqual(m.int64_min, imported.int64_min)
+
+    def test_serialization_sharing(self):
+        class M(torch.jit.ScriptModule):
+            def __init__(self):
+                super(M, self).__init__()
+                self.list = torch.jit.Attribute([], List[str])
+
+            @torch.jit.script_method
+            def forward(self, key):
+                # type: (str) -> List[str]
+                self.list.append(key)
+                self.list.append(key)
+                self.list.append(key)
+                return self.list
+
+        # the text of the string should only appear once in the pickling
+        m = M()
+        s1 = "a long string"
+        s2 = "a different, even longer string"
+        self.assertEqual(m(s1), [s1] * 3)
+        self.assertEqual(m(s2), [s1] * 3 + [s2] * 3)
+        with TemporaryFileName() as fname:
+            m.save(fname)
+            archive_name = os.path.basename(os.path.normpath(fname))
+            archive = zipfile.ZipFile(fname, 'r')
+            pickled_data = archive.read(os.path.join(archive_name, 'attributes.pkl'))
+
+            out = StringIO()
+            pickletools.dis(pickled_data, out=out)
+            disassembled = out.getvalue()
+
+            FileCheck().check_count(s1, 1, exactly=True) \
+                .check_count("BINGET", 2, exactly=True) \
+                .check_count(s2, 1, exactly=True) \
+                .check_count("BINGET", 2, exactly=True).run(out.getvalue())
+
     def test_optional_tuple(self):
         def fn(x=None):
             # type: (Optional[Tuple[int, int]]) -> Tuple[int, int]
index 5eeb18a..b062114 100644 (file)
@@ -137,6 +137,13 @@ if PY2:
 elif PY3:
     import builtins
 
+if PY2:
+    import StringIO
+    StringIO = StringIO.StringIO
+elif PY3:
+    import io
+    StringIO = io.StringIO
+
 
 # The codes below is not copied from the six package, so the copyright
 # declaration at the beginning does not apply.
index 3e37054..d515014 100644 (file)
@@ -827,9 +827,9 @@ struct PythonPrintPass {
         if (enforce_importable_) {
           throw script::ErrorReport(node->getSourceLocation())
               << "could not export python function call " << value->name()
-              << ". Remove calls to Python functions before export."
-              << "Did you forget add @script annotation? "
-              << "If this is a modulelist, add it to __constants__.";
+              << ". Remove calls to Python functions before export. "
+              << "Did you forget add @script or @script_method annotation? "
+              << "If this is a nn.ModuleList, add it to __constants__.";
         }
 
         stmt << "^" << value->name();
index 3da3132..df95bc4 100644 (file)
@@ -37,18 +37,18 @@ const std::vector<char>& Pickler::stack() {
 }
 
 void Pickler::start() {
-  pushOpCode(OpCode::PROTO);
-  pushUint8(2);
+  push<OpCode>(OpCode::PROTO);
+  push<uint8_t>(2);
 
   // All attributes get pushed into a list and their indices saved in the
   // module def
-  pushOpCode(OpCode::EMPTY_LIST);
-  pushOpCode(OpCode::MARK);
+  push<OpCode>(OpCode::EMPTY_LIST);
+  push<OpCode>(OpCode::MARK);
 }
 
 void Pickler::finish() {
-  pushOpCode(OpCode::APPENDS);
-  pushOpCode(OpCode::STOP);
+  push<OpCode>(OpCode::APPENDS);
+  push<OpCode>(OpCode::STOP);
 }
 
 void Pickler::addIValue(const IValue& ivalue) {
@@ -70,17 +70,12 @@ void Pickler::addIValue(const IValue& ivalue) {
   } else if (ivalue.isDouble()) {
     pushDouble(ivalue);
   } else if (ivalue.isInt()) {
-    // TODO: use BININT1/BININT2/LONG if possible/necessary
-    AT_ASSERT(
-        ivalue.toInt() <= std::numeric_limits<int32_t>::max() &&
-        ivalue.toInt() >= std::numeric_limits<int32_t>::min());
-    pushOpCode(OpCode::BININT);
-    pushInt32(ivalue.toInt());
+    pushInt(ivalue);
   } else if (ivalue.isBool()) {
     if (ivalue.toBool()) {
-      pushOpCode(OpCode::NEWTRUE);
+      push<OpCode>(OpCode::NEWTRUE);
     } else {
-      pushOpCode(OpCode::NEWFALSE);
+      push<OpCode>(OpCode::NEWFALSE);
     }
   } else if (ivalue.isString()) {
     pushMemoizedString(ivalue);
@@ -89,7 +84,7 @@ void Pickler::addIValue(const IValue& ivalue) {
   } else if (ivalue.isGenericDict()) {
     pushDict(ivalue);
   } else if (ivalue.isNone()) {
-    pushOpCode(OpCode::NONE);
+    push<OpCode>(OpCode::NONE);
   } else if (ivalue.isIntList()) {
     pushIntList(ivalue);
   } else {
@@ -113,22 +108,41 @@ const void* Pickler::getPointer(const IValue& ivalue) {
   return nullptr;
 }
 
+void Pickler::pushInt(const IValue& ivalue) {
+  auto n = ivalue.toInt();
+  if (n >= std::numeric_limits<int8_t>::min() &&
+      n <= std::numeric_limits<int8_t>::max()) {
+    push<OpCode>(OpCode::BININT1);
+    push<int8_t>(n);
+  } else if (
+      n >= std::numeric_limits<int32_t>::min() &&
+      n <= std::numeric_limits<int32_t>::max()) {
+    push<OpCode>(OpCode::BININT);
+    push<int32_t>(n);
+  } else {
+    // Push 8 byte integer
+    push<OpCode>(OpCode::LONG1);
+    push<uint8_t>(8);
+    push<int64_t>(n);
+  }
+}
+
 void Pickler::pushBinGet(uint32_t memo_id) {
   if (memo_id <= std::numeric_limits<uint8_t>::max()) {
-    pushOpCode(OpCode::BINGET);
-    pushUint8(memo_id);
+    push<OpCode>(OpCode::BINGET);
+    push<uint8_t>(memo_id);
   } else {
     // Memoized too many items, issue a LONG_BINGET instead
-    pushOpCode(OpCode::LONG_BINGET);
-    pushUint32(memo_id);
+    push<OpCode>(OpCode::LONG_BINGET);
+    push<uint32_t>(memo_id);
   }
 }
 
 void Pickler::pushMemoizedString(const IValue& ivalue) {
   const auto& string = ivalue.toStringRef();
 
-  pushOpCode(OpCode::BINUNICODE);
-  pushUint32(string.size());
+  push<OpCode>(OpCode::BINUNICODE);
+  push<uint32_t>(string.size());
   pushString(string);
   pushMemoization(ivalue);
 }
@@ -142,7 +156,7 @@ void Pickler::pushClass(PicklerClass cls) {
   // Write it to the tensor table
   auto memo_entry = memo_.find(&name);
   if (memo_entry == memo_.end()) {
-    pushOpCode(OpCode::GLOBAL);
+    push<OpCode>(OpCode::GLOBAL);
     // Module name + "\n"
     pushString(getModuleName());
     // Class name + "\n"
@@ -152,8 +166,8 @@ void Pickler::pushClass(PicklerClass cls) {
     pushBinGet(memo_entry->second);
   }
 
-  pushOpCode(OpCode::EMPTY_TUPLE);
-  pushOpCode(OpCode::NEWOBJ);
+  push<OpCode>(OpCode::EMPTY_TUPLE);
+  push<OpCode>(OpCode::NEWOBJ);
 }
 
 void Pickler::pushTensor(const IValue& ivalue) {
@@ -161,25 +175,25 @@ void Pickler::pushTensor(const IValue& ivalue) {
 
   tensor_table_->push_back(ivalue.toTensor());
   auto tensor_id = tensor_table_->size() - 1;
-  pushOpCode(OpCode::BININT);
-  pushUint32(tensor_id);
+  push<OpCode>(OpCode::BININT);
+  push<uint32_t>(tensor_id);
 
-  pushOpCode(OpCode::BUILD);
+  push<OpCode>(OpCode::BUILD);
 }
 
 void Pickler::pushIntList(const IValue& ivalue) {
   pushClass(PicklerClass::INTLIST);
 
-  pushOpCode(OpCode::EMPTY_LIST);
+  push<OpCode>(OpCode::EMPTY_LIST);
   pushMemoization(ivalue);
-  pushOpCode(OpCode::MARK);
+  push<OpCode>(OpCode::MARK);
 
   for (const auto& item : ivalue.toIntListRef()) {
     addIValue(item);
   }
 
-  pushOpCode(OpCode::APPENDS);
-  pushOpCode(OpCode::BUILD);
+  push<OpCode>(OpCode::APPENDS);
+  push<OpCode>(OpCode::BUILD);
 }
 
 void Pickler::pushDouble(const IValue& ivalue) {
@@ -187,9 +201,9 @@ void Pickler::pushDouble(const IValue& ivalue) {
   AT_ASSERT(sizeof(double) == 8);
   char* bytes = reinterpret_cast<char*>(&value);
 
-  pushOpCode(OpCode::BINFLOAT);
+  push<OpCode>(OpCode::BINFLOAT);
   for (size_t i = 0; i < 8; ++i) {
-    pushUint8(bytes[8 - i - 1]);
+    push<uint8_t>(bytes[8 - i - 1]);
   }
 }
 
@@ -213,10 +227,10 @@ struct IValuePairComparator {
 void Pickler::pushDict(const IValue& ivalue) {
   auto dict = ivalue.toGenericDictRef();
 
-  pushOpCode(OpCode::EMPTY_DICT);
+  push<OpCode>(OpCode::EMPTY_DICT);
   pushMemoization(ivalue);
 
-  pushOpCode(OpCode::MARK);
+  push<OpCode>(OpCode::MARK);
 
   // Sort the dict for deterministic keys
   std::vector<std::pair<IValue, IValue>> dict_items(dict.begin(), dict.end());
@@ -227,18 +241,18 @@ void Pickler::pushDict(const IValue& ivalue) {
     addIValue(pair.second);
   }
 
-  pushOpCode(OpCode::SETITEMS);
+  push<OpCode>(OpCode::SETITEMS);
 }
 
 void Pickler::pushMemoization(const void* item) {
   AT_ASSERT(item != nullptr);
   if (memo_id <= std::numeric_limits<uint8_t>::max()) {
-    pushOpCode(OpCode::BINPUT);
-    pushUint8(memo_id);
+    push<OpCode>(OpCode::BINPUT);
+    push<uint8_t>(memo_id);
   } else {
     // Memoized too many items, issue a LONG_BINPUT instead
-    pushOpCode(OpCode::LONG_BINPUT);
-    pushUint32(memo_id);
+    push<OpCode>(OpCode::LONG_BINPUT);
+    push<uint32_t>(memo_id);
   }
   memo_[item] = memo_id;
   AT_ASSERT(memo_id <= std::numeric_limits<uint32_t>::max());
@@ -251,51 +265,31 @@ void Pickler::pushMemoization(const IValue& ivalue) {
 
 void Pickler::pushList(const IValue& ivalue) {
   auto list = ivalue.toGenericListRef();
-  pushOpCode(OpCode::EMPTY_LIST);
+  push<OpCode>(OpCode::EMPTY_LIST);
   pushMemoization(ivalue);
 
-  pushOpCode(OpCode::MARK);
+  push<OpCode>(OpCode::MARK);
 
   for (const auto& item : list) {
     addIValue(item);
   }
 
-  pushOpCode(OpCode::APPENDS);
+  push<OpCode>(OpCode::APPENDS);
 }
 
 void Pickler::pushTuple(const IValue& ivalue) {
   // TODO: Small tuple unrolling (e.g. TUPLE3)
-  pushOpCode(OpCode::MARK);
+  push<OpCode>(OpCode::MARK);
   auto tuple = ivalue.toTuple()->elements();
 
   for (const auto& item : tuple) {
     addIValue(item);
   }
 
-  pushOpCode(OpCode::TUPLE);
+  push<OpCode>(OpCode::TUPLE);
   pushMemoization(ivalue);
 }
 
-void Pickler::pushUint8(uint8_t value) {
-  const char* begin = reinterpret_cast<const char*>(&value);
-  stack_.insert(stack_.end(), begin, begin + sizeof(uint8_t));
-}
-
-void Pickler::pushOpCode(OpCode value) {
-  const char* begin = reinterpret_cast<const char*>(&value);
-  stack_.insert(stack_.end(), begin, begin + sizeof(OpCode));
-}
-
-void Pickler::pushUint32(uint32_t value) {
-  const char* begin = reinterpret_cast<const char*>(&value);
-  stack_.insert(stack_.end(), begin, begin + sizeof(uint32_t));
-}
-
-void Pickler::pushInt32(int32_t value) {
-  const char* begin = reinterpret_cast<const char*>(&value);
-  stack_.insert(stack_.end(), begin, begin + sizeof(int32_t));
-}
-
 std::vector<IValue> Unpickler::parse_ivalue_list() {
   run();
   AT_ASSERT(stack_.size() == 1);
@@ -367,10 +361,20 @@ OpCode Unpickler::readInstruction() {
       // Mark location of the container ivalue in the stack
       marks_.push_back(stack_.size());
     } break;
+    case OpCode::BININT1: {
+      int8_t value = read<int8_t>();
+      stack_.emplace_back(int64_t(value));
+    } break;
     case OpCode::BININT: {
       int32_t value = read<int32_t>();
       stack_.emplace_back(int64_t(value));
     } break;
+    case OpCode::LONG1: {
+      // Only read LONG1s with 8 as the length
+      uint8_t length = read<uint8_t>();
+      AT_ASSERT(length == 8);
+      stack_.emplace_back(int64_t(read<int64_t>()));
+    } break;
     case OpCode::BINUNICODE: {
       uint32_t length = read<uint32_t>();
       const char* characters = reinterpret_cast<const char*>(bytes_);
index 2439aa6..ab35a70 100644 (file)
@@ -112,12 +112,18 @@ class Pickler {
   void pushTuple(const IValue& ivalue);
   void pushDict(const IValue& ivalue);
   void pushClass(PicklerClass cls);
+  void pushInt(const IValue& ivalue);
   const void* getPointer(const IValue& ivalue);
 
-  void pushUint8(uint8_t value);
-  void pushOpCode(OpCode value);
-  void pushUint32(uint32_t value);
-  void pushInt32(int32_t value);
+  // These convert values to bytes and add them to the stack (NB: since T is to
+  // the left of a '::', its type cannot be deduced by the compiler so one must
+  // explicitly instantiate the template, i.e. push<int>(int) works, push(int)
+  // does not)
+  template<typename T>
+  void push(typename std::common_type<T>::type value) {
+    const char* begin = reinterpret_cast<const char*>(&value);
+    stack_.insert(stack_.end(), begin, begin + sizeof(T));
+  }
 
   // Stack of opcodes/data
   std::vector<char> stack_;
index 8175f35..de678e6 100644 (file)
@@ -761,6 +761,7 @@ void initJitScriptBindings(PyObject* module) {
       .def("_set_parameter", &Module::set_parameter)
       .def("_get_parameter", &Module::get_parameter)
       .def("_get_buffer", &Module::get_buffer)
+      .def("_get_attribute", &Module::get_attribute)
       .def("_get_module", &Module::get_module)
       .def(
           "_get_modules",
@@ -801,6 +802,11 @@ void initJitScriptBindings(PyObject* module) {
             return result;
           })
       .def(
+          "_has_attribute",
+          [](Module& self, const std::string& name) -> bool {
+            return self.find_attribute(name);
+          })
+      .def(
           "_has_parameter",
           [](Module& self, const std::string& name) -> bool {
             return self.find_parameter(name);
index 6d4d746..d4417a7 100644 (file)
@@ -480,6 +480,9 @@ struct Module {
   autograd::Variable get_buffer(const std::string& name) const {
     return autograd::as_variable_ref(attributes.find(name)->slot()->toTensor());
   }
+  IValue get_attribute(const std::string& name) const {
+    return *attributes.find(name)->slot();
+  }
 
   // each module owns its method. The reference returned here
   // is guarenteed to stay valid until this module has been destroyed
index ec895cf..2671432 100644 (file)
@@ -10,6 +10,7 @@ import torch._jit_internal as _jit_internal
 from torch._six import raise_from, with_metaclass, get_function_from_type, \
     string_classes
 from torch._jit_internal import ignore
+from torch.jit._pickle import Unpickler
 from ..nn.modules.utils import _single, _pair, _triple, _quadruple, \
     _list_with_default
 import torch.testing
@@ -1131,6 +1132,8 @@ if _enabled:
                     return self._get_method(attr)
             if attr == 'graph' and self._has_method('forward'):
                 return self.__getattr__('forward').graph
+            if self._has_attribute(attr):
+                return self._get_attribute(attr)
             return Module.__getattr__(self, attr)
 
         def __setattr__(self, attr, value):
diff --git a/torch/jit/_pickle.py b/torch/jit/_pickle.py
new file mode 100644 (file)
index 0000000..24e7bf3
--- /dev/null
@@ -0,0 +1,26 @@
+import torch
+import functools
+import pickle
+
+
+class TensorID(object):
+    def __setstate__(self, id):
+        self.id = id
+
+
+class IntList(object):
+    def __setstate__(self, data):
+        self.data = data
+
+
+class Unpickler(pickle.Unpickler):
+    def find_class(self, module, name):
+        if not module == '__main__':
+            return None
+
+        if name == 'TensorID':
+            return TensorID
+        elif name == 'IntList':
+            return IntList
+        elif name == 'LiteralTensor':
+            return LiteralTensor