[PyTorch] Add OpCode cache in ByteCodeDeserializer (#64110)
authorScott Wolchok <swolchok@fb.com>
Tue, 14 Sep 2021 21:18:55 +0000 (14:18 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 14 Sep 2021 21:22:10 +0000 (14:22 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64110

As the code comment says, we can exploit pickler string interning to accelerate OpCode parsing. No more strcmp!
ghstack-source-id: 137978946

Test Plan:
Pixel 3 before: https://www.internalfb.com/intern/aibench/details/591414145082422
Pixel 3 after: https://www.internalfb.com/intern/aibench/details/484557404703261

new mean is 292 ms, down from 302 ms.

Reviewed By: dhruvbird

Differential Revision: D30615052

fbshipit-source-id: 9707625e778388a7920ab72704d71ad57ddaac17

torch/csrc/jit/mobile/parse_bytecode.cpp

index f43f657..1cb7bfe 100644 (file)
@@ -1,6 +1,7 @@
 #include <ATen/core/ivalue.h>
 #include <torch/csrc/jit/mobile/parse_bytecode.h>
 #include <torch/csrc/jit/mobile/type_parser.h>
+#include <torch/csrc/jit/runtime/instruction.h>
 #include <torch/csrc/jit/serialization/import_export_constants.h>
 #include <torch/csrc/jit/serialization/import_export_functions.h>
 #include <torch/custom_class_detail.h>
@@ -26,7 +27,44 @@ IValue expect_field(
 
 namespace mobile {
 
-namespace {} // namespace
+namespace {
+#define COUNT_OPCODE(_, _a) 1 +
+constexpr size_t numOpcodes = FORALL_OPCODES(COUNT_OPCODE) 0;
+#undef COUNT_OPCODE
+
+// Pickled strings are memoized, so we can cache a mapping from
+// pointers to parsed OpCodes to speed up parsing.
+class OpCodeCache {
+ private:
+  // We store as void* to emphasize that we care only about the
+  // address and should not be dereferencing these pointers.
+  std::array<const void*, numOpcodes> keys_{};
+  std::array<OpCode, numOpcodes> values_{};
+  size_t usedEntries_ = 0;
+
+ public:
+  OpCodeCache() {
+    memset(keys_.data(), 0, keys_.size() * sizeof(keys_[0]));
+  }
+
+  OpCode parse(const c10::ivalue::ConstantString& s) {
+    const auto endIt = keys_.begin() + usedEntries_;
+    auto it = std::find_if(
+        keys_.begin(), endIt, [&s](const void* k) { return k == &s; });
+    if (it == endIt) {
+      OpCode result = parseOpCode(s.string().c_str());
+      if (usedEntries_ < numOpcodes) {
+        keys_[usedEntries_] = &s;
+        values_[usedEntries_++] = result;
+      }
+      return result;
+    }
+    // NOTE: I tried implementing the transpose heuristic here to
+    // speed up the search, but it removed the benefit of this cache.
+    return values_[it - keys_.begin()];
+  }
+};
+} // namespace
 
 void parseInstructions(
     const std::string& function_name,
@@ -54,6 +92,11 @@ void parseInstructions(
         "The numbers of instructions and debug handles strings do not match.");
   }
 
+  // NOTE: this won't perform particularly well if the ins_list IValue
+  // didn't come from unpickler and thus have its strings
+  // interned. Consider adding a flag to bypass the cache if that
+  // becomes an important use case.
+  OpCodeCache opCodeCache;
   for (const auto j : c10::irange(ins_list.size())) {
     std::vector<IValue> ins_item =
         std::move(*std::move(ins_list[j]).toTuple()).elements();
@@ -61,7 +104,7 @@ void parseInstructions(
         ins_item.size() == 3,
         "There should be three parts in an instruction. The function name is ",
         function_name);
-    OpCode op_code = parseOpCode(ins_item[0].toString()->string().c_str());
+    OpCode op_code = opCodeCache.parse(*ins_item[0].toString());
     int X = ins_item[1].toInt();
     int N = ins_item[2].toInt();
     if (!debug_handles_list.empty()) {