Refactor pickler (#19035)
authorDavid Riazati <davidriazati@fb.com>
Wed, 10 Apr 2019 18:20:44 +0000 (11:20 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 10 Apr 2019 18:26:07 +0000 (11:26 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19035
ghimport-source-id: 553977b9963d4877e5066a61702f887e81706598

Differential Revision: D14839341

Pulled By: driazati

fbshipit-source-id: d6e4f21b2df28e2a0a21b26bf08d9905599119ad

torch/csrc/jit/pickler.cpp
torch/csrc/jit/pickler.h

index 3da6b7e..e679f2b 100644 (file)
@@ -55,8 +55,8 @@ void Pickler::addIValue(const IValue& ivalue) {
   // Check if reference ivalue has been saved before
   const void* ivalue_ptr = getPointer(ivalue);
   if (ivalue_ptr) {
-    auto memo_entry = memo_.find(ivalue_ptr);
-    if (memo_entry != memo_.end()) {
+    auto memo_entry = memo_map_.find(ivalue_ptr);
+    if (memo_entry != memo_map_.end()) {
       // This value has already been pushed, just do a BINGET
       pushBinGet(memo_entry->second);
       return;
@@ -154,8 +154,8 @@ void Pickler::pushString(const std::string& string) {
 void Pickler::pushClass(PicklerClass cls) {
   const auto& name = getClassName(cls);
   // Write it to the tensor table
-  auto memo_entry = memo_.find(&name);
-  if (memo_entry == memo_.end()) {
+  auto memo_entry = memo_map_.find(&name);
+  if (memo_entry == memo_map_.end()) {
     push<OpCode>(OpCode::GLOBAL);
     // Module name + "\n"
     pushString(getModuleName());
@@ -254,7 +254,7 @@ void Pickler::pushMemoization(const void* item) {
     push<OpCode>(OpCode::LONG_BINPUT);
     push<uint32_t>(memo_id);
   }
-  memo_[item] = memo_id;
+  memo_map_[item] = memo_id;
   AT_ASSERT(memo_id <= std::numeric_limits<uint32_t>::max());
   ++memo_id;
 }
@@ -359,10 +359,10 @@ OpCode Unpickler::readInstruction() {
     } break;
     case OpCode::BINPUT: {
       size_t memo_id = read<uint8_t>();
-      if (memo_.size() <= memo_id) {
-        memo_.reserve(1 + 2 * memo_id);
+      if (memo_table_.size() <= memo_id) {
+        memo_table_.reserve(1 + 2 * memo_id);
       }
-      memo_.push_back(stack_.back());
+      memo_table_.push_back(stack_.back());
     } break;
     case OpCode::MARK: {
       // Mark location of the container ivalue in the stack
@@ -416,7 +416,7 @@ OpCode Unpickler::readInstruction() {
       stack_.resize(start);
     } break;
     case OpCode::BINGET: {
-      stack_.push_back(memo_.at(read<uint8_t>()));
+      stack_.push_back(memo_table_.at(read<uint8_t>()));
     } break;
     case OpCode::STOP:
       break;
@@ -476,6 +476,7 @@ void Unpickler::readList() {
 // Read a newline terminated string
 std::string Unpickler::readString() {
   const char* chars = reinterpret_cast<const char*>(bytes_);
+  const char* char_end_ptr = reinterpret_cast<const char*>(end_ptr_);
   size_t n = 0;
   while (true) {
     char c = chars[n];
@@ -488,6 +489,9 @@ std::string Unpickler::readString() {
 
     // Increment after to exclude newline from string
     ++n;
+    AT_CHECK(
+        chars + n < char_end_ptr,
+        "Unpickler overran buffer while reading a string (expected a newline)");
   }
 
   // Increment by string length + newline char
index 22061c9..8e76d5b 100644 (file)
@@ -130,7 +130,7 @@ class Pickler {
 
   // Memoization of IValues that have been written (index in table is used for
   // BINPUT opcodes) to enable shared references
-  std::unordered_map<const void*, uint32_t> memo_;
+  std::unordered_map<const void*, uint32_t> memo_map_;
 
   // External table of tensors to serialize
   std::vector<at::Tensor>* tensor_table_;
@@ -177,7 +177,7 @@ class Unpickler {
   void readList();
 
   std::vector<IValue> stack_;
-  std::vector<IValue> memo_;
+  std::vector<IValue> memo_table_;
   std::vector<size_t> marks_;
   const uint8_t* bytes_;
   const uint8_t* end_ptr_;