[PyTorch] move from input ivalues in ByteCodeDeserializer (#64029)
authorScott Wolchok <swolchok@fb.com>
Thu, 9 Sep 2021 01:30:14 +0000 (18:30 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 01:32:48 +0000 (18:32 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64029

This should save us a separate pass over the data structure to destroy it.
ghstack-source-id: 137566821

Test Plan:
Pixel3
before:
https://www.internalfb.com/intern/aibench/details/503337445067962
after:
https://our.intern.facebook.com/intern/aibench/details/320277034999340

overall mean time decreased from 373 ms to 358 ms. In flame graph, we
can see that some time spent destroying a vector of IValues was moved
into parseMethods, and the new parseMethods time is less than the old
time plus the recursive destruction time.

Reviewed By: dhruvbird

Differential Revision: D30559530

fbshipit-source-id: d080295a846745ea03ac50f08f4f6c95f4eaf3d8

torch/csrc/jit/mobile/import.cpp

index 99be225..e438bb7 100644 (file)
@@ -85,18 +85,18 @@ using caffe2::serialize::ReadAdapterInterface;
 
 OpCode parseOpCode(const char* str);
 
-const IValue& expect_field(
-    const IValue& tup,
+IValue expect_field(
+    std::vector<IValue>& elements,
     const std::string& expected_name,
     size_t entry) {
-  auto row = tup.toTuple()->elements().at(entry).toTuple();
+  auto row = std::move(elements.at(entry)).toTuple();
   TORCH_INTERNAL_ASSERT(
       row->elements().at(0).toStringRef() == expected_name,
       "Expected ",
       expected_name,
       " found ",
       row->elements().at(0).toStringRef());
-  return row->elements().at(1);
+  return std::move(row->elements().at(1));
 }
 
 std::string operator_str(
@@ -224,8 +224,8 @@ class BytecodeDeserializer final {
  private:
   TypePtr resolveTypeName(const c10::QualifiedName& qn);
   void parseMethods(
-      const std::vector<IValue>& vals,
-      const c10::optional<std::vector<IValue>>& debug_handles,
+      std::vector<IValue>&& vals,
+      c10::optional<std::vector<IValue>>&& debug_handles,
       mobile::CompilationUnit& mcu);
   c10::IValue readArchive(
       const std::string& archive_name,
@@ -299,8 +299,8 @@ TypePtr BytecodeDeserializer::resolveTypeName(const c10::QualifiedName& qn) {
 }
 
 void BytecodeDeserializer::parseMethods(
-    const std::vector<IValue>& vals,
-    const c10::optional<std::vector<IValue>>& debug_handles,
+    std::vector<IValue>&& vals,
+    c10::optional<std::vector<IValue>>&& debug_handles,
     mobile::CompilationUnit& mcu) {
   TORCH_CHECK(vals.size() > 0, "Bytecode has no elements. ");
   // Initialized with the version number when kProducedBytecodeVersion was
@@ -336,62 +336,69 @@ void BytecodeDeserializer::parseMethods(
 
   // Process all methods in this mobile module.
   for (const auto i : c10::irange(method_i_start, vals.size())) {
-    const auto& element = vals[i];
-    const auto& m_tuple = element.toTuple()->elements();
+    auto element = std::move(vals[i]);
+    auto m_tuple = std::move(*element.toTuple()).elements();
     const std::string& function_name = m_tuple[0].toStringRef();
-    const IValue& codeTable = m_tuple[1];
-    const IValue* schemaTable = // older files do not store function schema
+    auto codeTableElements =
+        std::move(*std::move(m_tuple[1]).toTuple()).elements();
+    IValue* schemaTable = // older files do not store function schema
         (model_version > 0x4L || (model_version == 0x4L && m_tuple.size() >= 3))
         ? &m_tuple[2]
         : nullptr;
     auto function =
         std::make_unique<mobile::Function>(c10::QualifiedName(function_name));
 
-    const auto& ins_list =
-        expect_field(codeTable, "instructions", BYTECODE_INDEX_INSTRUCTION)
-            .toTuple()
-            ->elements();
-    const auto& ops_list =
-        expect_field(codeTable, "operators", BYTECODE_INDEX_OPERATOR)
-            .toTuple()
-            ->elements();
-    const auto& consts_list =
-        expect_field(codeTable, "constants", BYTECODE_INDEX_CONSTANT)
-            .toTuple()
-            ->elements();
-    const auto& types_list =
-        expect_field(codeTable, "types", BYTECODE_INDEX_TYPE)
-            .toTuple()
-            ->elements();
-    const auto& register_size =
-        expect_field(codeTable, "register_size", BYTECODE_INDEX_REGISTER_SIZE)
+    std::vector<IValue> ins_list =
+        std::move(
+            *expect_field(
+                 codeTableElements, "instructions", BYTECODE_INDEX_INSTRUCTION)
+                 .toTuple())
+            .elements();
+    std::vector<IValue> ops_list =
+        std::move(*expect_field(
+                       codeTableElements, "operators", BYTECODE_INDEX_OPERATOR)
+                       .toTuple())
+            .elements();
+    std::vector<IValue> consts_list =
+        std::move(*expect_field(
+                       codeTableElements, "constants", BYTECODE_INDEX_CONSTANT)
+                       .toTuple())
+            .elements();
+    std::vector<IValue> types_list =
+        std::move(*expect_field(codeTableElements, "types", BYTECODE_INDEX_TYPE)
+                       .toTuple())
+            .elements();
+    int64_t register_size =
+        expect_field(
+            codeTableElements, "register_size", BYTECODE_INDEX_REGISTER_SIZE)
             .toInt();
 
     c10::List<int64_t> debug_handles_list;
     if (debug_handles) {
-      const auto& debug_handles_element = (*debug_handles)[i];
-      const auto& debug_handles_m_tuple =
-          debug_handles_element.toTuple()->elements();
+      auto debug_handles_m_tuple =
+          std::move(*std::move((*debug_handles)[i]).toTuple()).elements();
       const std::string& debug_info_function_name =
           debug_handles_m_tuple[0].toStringRef();
       TORCH_CHECK(
           debug_info_function_name == function_name,
           "The function names in the bytecode table and the debug info table do not match.");
-      const IValue& debug_handles_table = debug_handles_m_tuple[1];
-      debug_handles_list = (expect_field(
-                                debug_handles_table,
-                                "function_debug_handles",
-                                BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
-                                .toTuple()
-                                ->elements())[0]
-                               .toIntList();
+      IValue& debug_handles_table = debug_handles_m_tuple[1];
+      debug_handles_list =
+          (expect_field(
+               std::move(debug_handles_table).toTuple()->elements(),
+               "function_debug_handles",
+               BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
+               .toTuple()
+               ->elements())[0]
+              .toIntList();
       TORCH_CHECK(
           debug_handles_list.size() == ins_list.size(),
           "The numbers of instructions and debug handles strings do not match.");
     }
 
     for (const auto j : c10::irange(ins_list.size())) {
-      const auto& ins_item = ins_list[j].toTuple()->elements();
+      std::vector<IValue> ins_item =
+          std::move(*std::move(ins_list[j]).toTuple()).elements();
       TORCH_CHECK(
           ins_item.size() == 3,
           "There should be three parts in an instruction. The function name is ",
@@ -439,35 +446,52 @@ void BytecodeDeserializer::parseMethods(
 
     // function schema
     if (schemaTable) { // (schema is optional for back compat)
-      auto parseArgList = [this](const std::vector<IValue>& argTables) {
+      auto parseArgList = [this](std::vector<IValue>&& argTables) {
         std::vector<c10::Argument> args;
-        for (auto&& argTable : argTables) {
+        for (auto&& argTable : std::move(argTables)) {
+          auto argTableElements =
+              std::move(*std::move(argTable).toTuple()).elements();
           auto name =
-              expect_field(argTable, "name", BYTECODE_INDEX_ARGUMENT_NAME)
+              expect_field(
+                  argTableElements, "name", BYTECODE_INDEX_ARGUMENT_NAME)
                   .toStringRef();
-          const auto& type = resolveTypeName(
-              (expect_field(argTable, "type", BYTECODE_INDEX_ARGUMENT_TYPE))
+          c10::TypePtr type = resolveTypeName(
+              (expect_field(
+                   argTableElements, "type", BYTECODE_INDEX_ARGUMENT_TYPE))
                   .toStringRef());
-          const IValue& default_value = expect_field(
-              argTable, "default_value", BYTECODE_INDEX_ARGUMENT_DEFAULT_VALUE);
-          args.emplace_back(name, type, c10::nullopt /*N*/, default_value);
+          IValue default_value = expect_field(
+              argTableElements,
+              "default_value",
+              BYTECODE_INDEX_ARGUMENT_DEFAULT_VALUE);
+          args.emplace_back(
+              name,
+              std::move(type),
+              c10::nullopt /*N*/,
+              std::move(default_value));
         }
         return args;
       };
-      const auto& arg_list =
-          expect_field(
-              *schemaTable, "arguments", BYTECODE_INDEX_SCHEMA_ARGUMENTS)
-              .toTuple()
-              ->elements();
-      const auto& ret_list =
-          expect_field(*schemaTable, "returns", BYTECODE_INDEX_SCHEMA_RETURNS)
-              .toTuple()
-              ->elements();
+      auto schemaTableElements =
+          std::move(*std::move(*schemaTable).toTuple()).elements();
+      std::vector<IValue> arg_list =
+          std::move(*expect_field(
+                         schemaTableElements,
+                         "arguments",
+                         BYTECODE_INDEX_SCHEMA_ARGUMENTS)
+                         .toTuple())
+              .elements();
+      std::vector<IValue> ret_list =
+          std::move(*expect_field(
+                         schemaTableElements,
+                         "returns",
+                         BYTECODE_INDEX_SCHEMA_RETURNS)
+                         .toTuple())
+              .elements();
       c10::FunctionSchema schema(
           function_name,
           "" /*overload_name*/,
-          parseArgList(arg_list),
-          parseArgList(ret_list),
+          parseArgList(std::move(arg_list)),
+          parseArgList(std::move(ret_list)),
           false /*is_varargs*/,
           false /*is_varret*/);
       function->setSchema(std::move(schema));
@@ -523,7 +547,7 @@ mobile::Module BytecodeDeserializer::deserialize(
         readArchive("mobile_debug_handles", mcu).toTuple()->elements();
     has_debug_handles = true;
   }
-  parseMethods(bvals, debug_handles, *mcu);
+  parseMethods(std::move(bvals), std::move(debug_handles), *mcu);
   auto m = mobile::Module(readArchive("data", mcu).toObject(), mcu);
   m.setHasDebugHandles(has_debug_handles);
 #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE)