[PyTorch Edge][Model Loading] Operator Call De-dup at TorchScript Serialization Level...
authorSalil Desai <salilsdesai@fb.com>
Tue, 14 Sep 2021 19:09:45 +0000 (12:09 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Tue, 14 Sep 2021 19:11:32 +0000 (12:11 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64268

If the same pair of operator name and num inputs have been used to add an instruction to the operator table previously (and the operator's schema is not vararg), use the same index as that instruction rather than creating a new one.
ghstack-source-id: 138014905

Test Plan: Phabricator tests, and test performance changes in next diff

Reviewed By: iseeyuan, tugsbayasgalan

Differential Revision: D30615434

fbshipit-source-id: f442f557f12412693a73004ce44733ccef063b82

torch/csrc/jit/runtime/interpreter/code_impl.h
torch/csrc/jit/serialization/export_module.cpp

index 15ba0ce..cef28ce 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
 #include <memory>
+#include <unordered_map>
 #include <vector>
 
 #include <c10/util/irange.h>
@@ -69,6 +70,13 @@ struct CodeImpl {
 
   std::vector<IValue> constant_table_;
   std::vector<Operation> operator_table_;
+  // map<(op name, num inputs), index in operator table>, to avoid duplicates,
+  // not including vararg operators
+  std::unordered_map<
+      std::pair<std::string, int>,
+      int,
+      std::function<size_t(const std::pair<std::string, int>& p)>>
+      operator_table_inv_;
   std::vector<Function*> function_table_;
   std::vector<std::unique_ptr<GraphFunction>> forked_functions_;
   std::vector<TypePtr> type_table_;
@@ -126,7 +134,13 @@ struct CodeImpl {
       std::string function_name,
       size_t remaining_bailout_depth,
       bool emit_instructions = true)
-      : function_name_(std::move(function_name)),
+      : operator_table_inv_(
+            0,
+            [](const std::pair<std::string, int>& p) {
+              return std::hash<std::string>()(p.first) ^
+                  std::hash<int>()(p.second);
+            }),
+        function_name_(std::move(function_name)),
         preprocess_(*graph),
         current_node_(preprocess_.graph->return_node()),
         remaining_bailout_depth_(remaining_bailout_depth) {
@@ -303,12 +317,20 @@ struct CodeImpl {
   virtual void emitOperator(Node* node) {
     emitLoadInputs(node->inputs());
     const Operator& op = node->getOperator();
-    if (op.hasOperation() && op.schema().is_vararg()) {
-      insertInstruction(OPN, operator_table_.size(), node->inputs().size());
+    int num_inputs = node->inputs().size();
+    bool is_vararg = op.schema().is_vararg();
+
+    int operation_index = add_to_operator_table(
+        op.getOperation(node),
+        c10::toString(op.schema().operator_name()),
+        num_inputs,
+        is_vararg);
+
+    if (op.hasOperation() && is_vararg) {
+      insertInstruction(OPN, operation_index, num_inputs);
     } else {
-      insertInstruction(OP, operator_table_.size());
+      insertInstruction(OP, operation_index);
     }
-    operator_table_.emplace_back(op.getOperation(node));
   }
 
   void emitWait(Node* node) {
@@ -714,6 +736,31 @@ struct CodeImpl {
       dump(out, i);
     }
   }
+
+  /**
+   * Add an operation to operator_table_ if not a duplicate and return its index
+   */
+  int add_to_operator_table(
+      const Operation& oper,
+      const std::string& op_name,
+      const int num_inputs,
+      const bool is_vararg) {
+    int size = operator_table_.size();
+
+    if (!is_vararg) {
+      std::pair<std::string, int> key(op_name, num_inputs);
+      auto found = operator_table_inv_.find(key);
+
+      if (found != operator_table_inv_.end()) {
+        return found->second;
+      }
+
+      operator_table_inv_.emplace(key, size);
+    }
+
+    operator_table_.emplace_back(oper);
+    return size;
+  }
 };
 
 struct MobileCodeImpl : CodeImpl {
@@ -775,12 +822,20 @@ struct MobileCodeImpl : CodeImpl {
       CodeImpl::emitOperator(node);
     } else {
       const Operator& op = node->getOperator();
-      if (op.hasOperation() && op.schema().is_vararg()) {
+      std::string unique_op_name = c10::toString(op.schema().operator_name());
+      int num_inputs = node->inputs().size();
+      bool is_vararg = op.schema().is_vararg();
+
+      if (op.hasOperation() && is_vararg) {
         emitLoadInputs(node->inputs());
-        insertInstruction(OPN, operator_table_.size(), node->inputs().size());
+        int operation_index = add_to_operator_table(
+            op.getOperation(node),
+            unique_op_name,
+            num_inputs,
+            /* is_vararg */ true);
+        insertInstruction(OPN, operation_index, num_inputs);
       } else {
-        auto unique_op_name = c10::toString(op.schema().operator_name());
-        auto num_include = node->inputs().size();
+        auto num_include = num_inputs;
         auto it = op_to_num_specified_args_.find(unique_op_name);
         if (it != op_to_num_specified_args_.end()) {
           num_include = it->second;
@@ -796,9 +851,10 @@ struct MobileCodeImpl : CodeImpl {
         } else {
           emitLoadInputs(node->inputs(), num_include);
         }
-        insertInstruction(OP, operator_table_.size());
+        int operation_index = add_to_operator_table(
+            op.getOperation(node), unique_op_name, num_inputs, is_vararg);
+        insertInstruction(OP, operation_index);
       }
-      operator_table_.emplace_back(op.getOperation(node));
     }
   }
 
index 8af15a3..d06be70 100644 (file)
@@ -79,11 +79,14 @@ std::pair<IValue, IValue> getFunctionTuple(
   std::vector<c10::OperatorName> opnames;
   std::vector<std::string> method_names;
   std::vector<int64_t> op_debug_handles;
+  int next_new_op_index = 0;
   for (size_t i = 0; i < instructions_copy.size(); ++i) {
     Instruction ins = instructions_copy[i];
-    if (ins.op == OP || ins.op == OPN) {
+    if ((ins.op == OP || ins.op == OPN) && ins.X == next_new_op_index) {
+      // Found a new op (assumes new operators ordered by ascending ins.X)
       auto node = code->instructions_source()[i];
       opnames.emplace_back(node->schema().operator_name());
+      next_new_op_index++;
     }
     // CALL nodes at this point represent built-in (i.e. non-Graph)
     // functions that were not inlined. Here we convert the CALL