From 0d0d2f2ac596cda6bb785bbfe49447bbbe545d73 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Wed, 8 Sep 2021 18:30:14 -0700 Subject: [PATCH] [PyTorch] move from input ivalues in ByteCodeDeserializer (#64029) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64029 This should save us a separate pass over the data structure to destroy it. ghstack-source-id: 137566821 Test Plan: Pixel3 before: https://www.internalfb.com/intern/aibench/details/503337445067962 after: https://our.intern.facebook.com/intern/aibench/details/320277034999340 overall mean time decreased from 373 ms to 358 ms. In flame graph, we can see that some time spent destroying a vector of IValues was moved into parseMethods, and the new parseMethods time is less than the old time plus the recursive destruction time. Reviewed By: dhruvbird Differential Revision: D30559530 fbshipit-source-id: d080295a846745ea03ac50f08f4f6c95f4eaf3d8 --- torch/csrc/jit/mobile/import.cpp | 148 +++++++++++++++++++++++---------------- 1 file changed, 86 insertions(+), 62 deletions(-) diff --git a/torch/csrc/jit/mobile/import.cpp b/torch/csrc/jit/mobile/import.cpp index 99be225..e438bb7 100644 --- a/torch/csrc/jit/mobile/import.cpp +++ b/torch/csrc/jit/mobile/import.cpp @@ -85,18 +85,18 @@ using caffe2::serialize::ReadAdapterInterface; OpCode parseOpCode(const char* str); -const IValue& expect_field( - const IValue& tup, +IValue expect_field( + std::vector& elements, const std::string& expected_name, size_t entry) { - auto row = tup.toTuple()->elements().at(entry).toTuple(); + auto row = std::move(elements.at(entry)).toTuple(); TORCH_INTERNAL_ASSERT( row->elements().at(0).toStringRef() == expected_name, "Expected ", expected_name, " found ", row->elements().at(0).toStringRef()); - return row->elements().at(1); + return std::move(row->elements().at(1)); } std::string operator_str( @@ -224,8 +224,8 @@ class BytecodeDeserializer final { private: TypePtr resolveTypeName(const c10::QualifiedName& qn); void parseMethods( - const std::vector& vals, - const c10::optional>& debug_handles, + std::vector&& vals, + c10::optional>&& debug_handles, mobile::CompilationUnit& mcu); c10::IValue readArchive( const std::string& archive_name, @@ -299,8 +299,8 @@ TypePtr BytecodeDeserializer::resolveTypeName(const c10::QualifiedName& qn) { } void BytecodeDeserializer::parseMethods( - const std::vector& vals, - const c10::optional>& debug_handles, + std::vector&& vals, + c10::optional>&& debug_handles, mobile::CompilationUnit& mcu) { TORCH_CHECK(vals.size() > 0, "Bytecode has no elements. "); // Initialized with the version number when kProducedBytecodeVersion was @@ -336,62 +336,69 @@ void BytecodeDeserializer::parseMethods( // Process all methods in this mobile module. for (const auto i : c10::irange(method_i_start, vals.size())) { - const auto& element = vals[i]; - const auto& m_tuple = element.toTuple()->elements(); + auto element = std::move(vals[i]); + auto m_tuple = std::move(*element.toTuple()).elements(); const std::string& function_name = m_tuple[0].toStringRef(); - const IValue& codeTable = m_tuple[1]; - const IValue* schemaTable = // older files do not store function schema + auto codeTableElements = + std::move(*std::move(m_tuple[1]).toTuple()).elements(); + IValue* schemaTable = // older files do not store function schema (model_version > 0x4L || (model_version == 0x4L && m_tuple.size() >= 3)) ? &m_tuple[2] : nullptr; auto function = std::make_unique(c10::QualifiedName(function_name)); - const auto& ins_list = - expect_field(codeTable, "instructions", BYTECODE_INDEX_INSTRUCTION) - .toTuple() - ->elements(); - const auto& ops_list = - expect_field(codeTable, "operators", BYTECODE_INDEX_OPERATOR) - .toTuple() - ->elements(); - const auto& consts_list = - expect_field(codeTable, "constants", BYTECODE_INDEX_CONSTANT) - .toTuple() - ->elements(); - const auto& types_list = - expect_field(codeTable, "types", BYTECODE_INDEX_TYPE) - .toTuple() - ->elements(); - const auto& register_size = - expect_field(codeTable, "register_size", BYTECODE_INDEX_REGISTER_SIZE) + std::vector ins_list = + std::move( + *expect_field( + codeTableElements, "instructions", BYTECODE_INDEX_INSTRUCTION) + .toTuple()) + .elements(); + std::vector ops_list = + std::move(*expect_field( + codeTableElements, "operators", BYTECODE_INDEX_OPERATOR) + .toTuple()) + .elements(); + std::vector consts_list = + std::move(*expect_field( + codeTableElements, "constants", BYTECODE_INDEX_CONSTANT) + .toTuple()) + .elements(); + std::vector types_list = + std::move(*expect_field(codeTableElements, "types", BYTECODE_INDEX_TYPE) + .toTuple()) + .elements(); + int64_t register_size = + expect_field( + codeTableElements, "register_size", BYTECODE_INDEX_REGISTER_SIZE) .toInt(); c10::List debug_handles_list; if (debug_handles) { - const auto& debug_handles_element = (*debug_handles)[i]; - const auto& debug_handles_m_tuple = - debug_handles_element.toTuple()->elements(); + auto debug_handles_m_tuple = + std::move(*std::move((*debug_handles)[i]).toTuple()).elements(); const std::string& debug_info_function_name = debug_handles_m_tuple[0].toStringRef(); TORCH_CHECK( debug_info_function_name == function_name, "The function names in the bytecode table and the debug info table do not match."); - const IValue& debug_handles_table = debug_handles_m_tuple[1]; - debug_handles_list = (expect_field( - debug_handles_table, - "function_debug_handles", - BYTECODE_INDEX_MODULE_DEBUG_HANDLES) - .toTuple() - ->elements())[0] - .toIntList(); + IValue& debug_handles_table = debug_handles_m_tuple[1]; + debug_handles_list = + (expect_field( + std::move(debug_handles_table).toTuple()->elements(), + "function_debug_handles", + BYTECODE_INDEX_MODULE_DEBUG_HANDLES) + .toTuple() + ->elements())[0] + .toIntList(); TORCH_CHECK( debug_handles_list.size() == ins_list.size(), "The numbers of instructions and debug handles strings do not match."); } for (const auto j : c10::irange(ins_list.size())) { - const auto& ins_item = ins_list[j].toTuple()->elements(); + std::vector ins_item = + std::move(*std::move(ins_list[j]).toTuple()).elements(); TORCH_CHECK( ins_item.size() == 3, "There should be three parts in an instruction. The function name is ", @@ -439,35 +446,52 @@ void BytecodeDeserializer::parseMethods( // function schema if (schemaTable) { // (schema is optional for back compat) - auto parseArgList = [this](const std::vector& argTables) { + auto parseArgList = [this](std::vector&& argTables) { std::vector args; - for (auto&& argTable : argTables) { + for (auto&& argTable : std::move(argTables)) { + auto argTableElements = + std::move(*std::move(argTable).toTuple()).elements(); auto name = - expect_field(argTable, "name", BYTECODE_INDEX_ARGUMENT_NAME) + expect_field( + argTableElements, "name", BYTECODE_INDEX_ARGUMENT_NAME) .toStringRef(); - const auto& type = resolveTypeName( - (expect_field(argTable, "type", BYTECODE_INDEX_ARGUMENT_TYPE)) + c10::TypePtr type = resolveTypeName( + (expect_field( + argTableElements, "type", BYTECODE_INDEX_ARGUMENT_TYPE)) .toStringRef()); - const IValue& default_value = expect_field( - argTable, "default_value", BYTECODE_INDEX_ARGUMENT_DEFAULT_VALUE); - args.emplace_back(name, type, c10::nullopt /*N*/, default_value); + IValue default_value = expect_field( + argTableElements, + "default_value", + BYTECODE_INDEX_ARGUMENT_DEFAULT_VALUE); + args.emplace_back( + name, + std::move(type), + c10::nullopt /*N*/, + std::move(default_value)); } return args; }; - const auto& arg_list = - expect_field( - *schemaTable, "arguments", BYTECODE_INDEX_SCHEMA_ARGUMENTS) - .toTuple() - ->elements(); - const auto& ret_list = - expect_field(*schemaTable, "returns", BYTECODE_INDEX_SCHEMA_RETURNS) - .toTuple() - ->elements(); + auto schemaTableElements = + std::move(*std::move(*schemaTable).toTuple()).elements(); + std::vector arg_list = + std::move(*expect_field( + schemaTableElements, + "arguments", + BYTECODE_INDEX_SCHEMA_ARGUMENTS) + .toTuple()) + .elements(); + std::vector ret_list = + std::move(*expect_field( + schemaTableElements, + "returns", + BYTECODE_INDEX_SCHEMA_RETURNS) + .toTuple()) + .elements(); c10::FunctionSchema schema( function_name, "" /*overload_name*/, - parseArgList(arg_list), - parseArgList(ret_list), + parseArgList(std::move(arg_list)), + parseArgList(std::move(ret_list)), false /*is_varargs*/, false /*is_varret*/); function->setSchema(std::move(schema)); @@ -523,7 +547,7 @@ mobile::Module BytecodeDeserializer::deserialize( readArchive("mobile_debug_handles", mcu).toTuple()->elements(); has_debug_handles = true; } - parseMethods(bvals, debug_handles, *mcu); + parseMethods(std::move(bvals), std::move(debug_handles), *mcu); auto m = mobile::Module(readArchive("data", mcu).toObject(), mcu); m.setHasDebugHandles(has_debug_handles); #if defined(SYMBOLICATE_MOBILE_DEBUG_HANDLE) -- 2.7.4