From: Salil Desai Date: Tue, 14 Sep 2021 19:09:45 +0000 (-0700) Subject: [PyTorch Edge][Model Loading] Operator Call De-dup at TorchScript Serialization Level... X-Git-Tag: accepted/tizen/8.0/unified/20231005.095509~220 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=86e6bed0d447d0156b4a6745a7f0b90496d71e9a;p=platform%2Fupstream%2Fpytorch.git [PyTorch Edge][Model Loading] Operator Call De-dup at TorchScript Serialization Level [1/2] (#64268) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64268 If the same pair of operator name and num inputs have been used to add an instruction to the operator table previously (and the operator's schema is not vararg), use the same index as that instruction rather than creating a new one. ghstack-source-id: 138014905 Test Plan: Phabricator tests, and test performance changes in next diff Reviewed By: iseeyuan, tugsbayasgalan Differential Revision: D30615434 fbshipit-source-id: f442f557f12412693a73004ce44733ccef063b82 --- diff --git a/torch/csrc/jit/runtime/interpreter/code_impl.h b/torch/csrc/jit/runtime/interpreter/code_impl.h index 15ba0ce..cef28ce 100644 --- a/torch/csrc/jit/runtime/interpreter/code_impl.h +++ b/torch/csrc/jit/runtime/interpreter/code_impl.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include @@ -69,6 +70,13 @@ struct CodeImpl { std::vector constant_table_; std::vector operator_table_; + // map<(op name, num inputs), index in operator table>, to avoid duplicates, + // not including vararg operators + std::unordered_map< + std::pair, + int, + std::function& p)>> + operator_table_inv_; std::vector function_table_; std::vector> forked_functions_; std::vector type_table_; @@ -126,7 +134,13 @@ struct CodeImpl { std::string function_name, size_t remaining_bailout_depth, bool emit_instructions = true) - : function_name_(std::move(function_name)), + : operator_table_inv_( + 0, + [](const std::pair& p) { + return std::hash()(p.first) ^ + std::hash()(p.second); + }), + function_name_(std::move(function_name)), preprocess_(*graph), current_node_(preprocess_.graph->return_node()), remaining_bailout_depth_(remaining_bailout_depth) { @@ -303,12 +317,20 @@ struct CodeImpl { virtual void emitOperator(Node* node) { emitLoadInputs(node->inputs()); const Operator& op = node->getOperator(); - if (op.hasOperation() && op.schema().is_vararg()) { - insertInstruction(OPN, operator_table_.size(), node->inputs().size()); + int num_inputs = node->inputs().size(); + bool is_vararg = op.schema().is_vararg(); + + int operation_index = add_to_operator_table( + op.getOperation(node), + c10::toString(op.schema().operator_name()), + num_inputs, + is_vararg); + + if (op.hasOperation() && is_vararg) { + insertInstruction(OPN, operation_index, num_inputs); } else { - insertInstruction(OP, operator_table_.size()); + insertInstruction(OP, operation_index); } - operator_table_.emplace_back(op.getOperation(node)); } void emitWait(Node* node) { @@ -714,6 +736,31 @@ struct CodeImpl { dump(out, i); } } + + /** + * Add an operation to operator_table_ if not a duplicate and return its index + */ + int add_to_operator_table( + const Operation& oper, + const std::string& op_name, + const int num_inputs, + const bool is_vararg) { + int size = operator_table_.size(); + + if (!is_vararg) { + std::pair key(op_name, num_inputs); + auto found = operator_table_inv_.find(key); + + if (found != operator_table_inv_.end()) { + return found->second; + } + + operator_table_inv_.emplace(key, size); + } + + operator_table_.emplace_back(oper); + return size; + } }; struct MobileCodeImpl : CodeImpl { @@ -775,12 +822,20 @@ struct MobileCodeImpl : CodeImpl { CodeImpl::emitOperator(node); } else { const Operator& op = node->getOperator(); - if (op.hasOperation() && op.schema().is_vararg()) { + std::string unique_op_name = c10::toString(op.schema().operator_name()); + int num_inputs = node->inputs().size(); + bool is_vararg = op.schema().is_vararg(); + + if (op.hasOperation() && is_vararg) { emitLoadInputs(node->inputs()); - insertInstruction(OPN, operator_table_.size(), node->inputs().size()); + int operation_index = add_to_operator_table( + op.getOperation(node), + unique_op_name, + num_inputs, + /* is_vararg */ true); + insertInstruction(OPN, operation_index, num_inputs); } else { - auto unique_op_name = c10::toString(op.schema().operator_name()); - auto num_include = node->inputs().size(); + auto num_include = num_inputs; auto it = op_to_num_specified_args_.find(unique_op_name); if (it != op_to_num_specified_args_.end()) { num_include = it->second; @@ -796,9 +851,10 @@ struct MobileCodeImpl : CodeImpl { } else { emitLoadInputs(node->inputs(), num_include); } - insertInstruction(OP, operator_table_.size()); + int operation_index = add_to_operator_table( + op.getOperation(node), unique_op_name, num_inputs, is_vararg); + insertInstruction(OP, operation_index); } - operator_table_.emplace_back(op.getOperation(node)); } } diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index 8af15a3..d06be70 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -79,11 +79,14 @@ std::pair getFunctionTuple( std::vector opnames; std::vector method_names; std::vector op_debug_handles; + int next_new_op_index = 0; for (size_t i = 0; i < instructions_copy.size(); ++i) { Instruction ins = instructions_copy[i]; - if (ins.op == OP || ins.op == OPN) { + if ((ins.op == OP || ins.op == OPN) && ins.X == next_new_op_index) { + // Found a new op (assumes new operators ordered by ascending ins.X) auto node = code->instructions_source()[i]; opnames.emplace_back(node->schema().operator_name()); + next_new_op_index++; } // CALL nodes at this point represent built-in (i.e. non-Graph) // functions that were not inlined. Here we convert the CALL