[RFC] Modularize functions of parsing bytecode (#61862)
authorMartin Yuan <myuan@fb.com>
Sun, 12 Sep 2021 05:22:28 +0000 (22:22 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Sun, 12 Sep 2021 05:24:05 +0000 (22:24 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61862

Modularize functions of parsing bytecode tables so that they can be used as needed in situations other than mobile lite interpreter.
* The decoupled functions are re-used by current lite interpreter loader.
* The bytecode can be serialized/deserialized from other formats.
* The decoupled functions have minimum dependencies on other PyTorch components.

Next:
Build a driver binary to include the parser and interpreter, but only has necessary dependency on other PyTorch components.
ghstack-source-id: 137867287

Test Plan:
As an example, a simple bytecode is parsed to a mobile function, and directly run in the added unit test, `RunTimeTest:ParseBytecode`. It contains basic control flow (if, else) and basic data orchestration (list construction).
CI

Reviewed By: larryliu0820

Differential Revision: D29798382

Pulled By: iseeyuan

fbshipit-source-id: 1c173a5f5d37097e3a97baec3f3e48e1eea1400f

12 files changed:
caffe2/CMakeLists.txt
test/cpp/jit/test_lite_interpreter.cpp
tools/build_variables.bzl
torch/csrc/jit/mobile/function.cpp
torch/csrc/jit/mobile/function.h
torch/csrc/jit/mobile/import.cpp
torch/csrc/jit/mobile/parse_bytecode.cpp [new file with mode: 0644]
torch/csrc/jit/mobile/parse_bytecode.h [new file with mode: 0644]
torch/csrc/jit/serialization/export.h
torch/csrc/jit/serialization/export_module.cpp
torch/csrc/jit/serialization/import_export_functions.h
torch/csrc/jit/serialization/import_source.cpp

index 8b403a7..3c2fb83 100644 (file)
@@ -536,10 +536,11 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
        ${TORCH_SRC_DIR}/csrc/jit/mobile/function.cpp
        ${TORCH_SRC_DIR}/csrc/jit/mobile/import.cpp
        ${TORCH_SRC_DIR}/csrc/jit/mobile/import_data.cpp
+       ${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp
        ${TORCH_SRC_DIR}/csrc/jit/mobile/model_compatibility.cpp
        ${TORCH_SRC_DIR}/csrc/jit/mobile/module.cpp
        ${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp
-       ${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp
+       ${TORCH_SRC_DIR}/csrc/jit/mobile/parse_bytecode.cpp
        ${TORCH_SRC_DIR}/csrc/jit/mobile/train/export_data.cpp
        ${TORCH_SRC_DIR}/csrc/jit/mobile/train/optim/sgd.cpp
        ${TORCH_SRC_DIR}/csrc/jit/mobile/train/random.cpp
index b362c8a..583d899 100644 (file)
@@ -11,6 +11,7 @@
 #include <torch/csrc/jit/mobile/import.h>
 #include <torch/csrc/jit/mobile/model_compatibility.h>
 #include <torch/csrc/jit/mobile/module.h>
+#include <torch/csrc/jit/mobile/parse_bytecode.h>
 #include <torch/csrc/jit/mobile/runtime_compatibility.h>
 #include <torch/csrc/jit/serialization/export.h>
 #include <torch/csrc/jit/serialization/import.h>
@@ -886,6 +887,75 @@ TEST(LiteInterpreterTest, DefaultArgsConv) {
   AT_ASSERT(output.equal(outputref));
 }
 
+TEST(RunTimeTest, ParseBytecode) {
+  // A simple example to show a simple bytecode that can be used independent of
+  // PyTorch TorchScript serialization (unpickler, etc) and operator library.
+  // It has basic control flow (if, else) and basic data orchestration (list
+  // construction). The original PyTorch program:
+
+  //  class Module(torch.nn.Module):
+  //
+  //    def __init__(self):
+  //      super().__init__()
+  //
+  //    def forward(self, x: int, h: int, xfirst: bool):
+  //      if xfirst:
+  //        return [x, h]
+  //      else:
+  //        return [h, x]
+
+  // 1. Prepare for the bytecode. In reality it can be from a customized
+  // deserializer.
+  std::vector<IValue> instructions{
+      to_tuple({"STOREN", 1, 4}),
+      to_tuple({"DROPR", 1, 0}),
+      to_tuple({"MOVE", 4, 0}),
+      to_tuple({"JF", 5, 0}),
+      to_tuple({"LOAD", 2, 0}),
+      to_tuple({"LOAD", 3, 0}),
+      to_tuple({"LIST_CONSTRUCT", 0, 2}),
+      to_tuple({"JMP", 4, 0}),
+      to_tuple({"LOAD", 3, 0}),
+      to_tuple({"LOAD", 2, 0}),
+      to_tuple({"LIST_CONSTRUCT", 1, 2}),
+      to_tuple({"STORE", 5, 0}),
+      to_tuple({"DROPR", 3, 0}),
+      to_tuple({"DROPR", 2, 0}),
+      to_tuple({"MOVE", 5, 0}),
+      to_tuple({"RET", 0, 0}),
+  };
+  std::vector<IValue> operators; // empty for this example
+  std::vector<IValue> constants; // empty for this example
+
+  std::vector<IValue> types{"List[int]", "List[int]"};
+  // 2. Parse the function
+  std::string function_name("test_function");
+  auto function = std::unique_ptr<mobile::Function>(
+      new mobile::Function(c10::QualifiedName(function_name)));
+  std::vector<IValue> debug_handles_m_tuple;
+  parseInstructions(
+      function_name, instructions, debug_handles_m_tuple, function.get());
+  parseTypes(types, function.get());
+  const size_t rsize = 5;
+  parseRegisterSize(rsize, function.get());
+
+  // 3. Prepare for inputs and run the function
+  // Note that the first input is reserved for Module object.
+  // Since this is a function test and Module object is not required,
+  // a dummy IValue (0) is added here.
+  std::vector<IValue> inputs{0, 1, 2, true};
+  function->run(inputs);
+  auto output = inputs[0].toList();
+  ASSERT_EQ(output[0], 1);
+  ASSERT_EQ(output[1], 2);
+
+  std::vector<IValue> inputs1{0, 1, 2, false};
+  function->run(inputs1);
+  auto output1 = inputs1[0].toList();
+  ASSERT_EQ(output1[0], 2);
+  ASSERT_EQ(output1[1], 1);
+}
+
 namespace {
 void testLiteModuleCompareResultTensors(
     Module& m,
index ee3ae5a..a139515 100644 (file)
@@ -434,6 +434,7 @@ torch_mobile_core = [
     "torch/csrc/jit/mobile/model_compatibility.cpp",
     "torch/csrc/jit/mobile/module.cpp",
     "torch/csrc/jit/mobile/observer.cpp",
+    "torch/csrc/jit/mobile/parse_bytecode.cpp",
     "torch/csrc/jit/runtime/register_prim_ops.cpp",
     "torch/csrc/jit/runtime/register_special_ops.cpp",
 ]
@@ -474,6 +475,7 @@ libtorch_extra_sources = libtorch_core_jit_sources + [
     "torch/csrc/jit/mobile/model_compatibility.cpp",
     "torch/csrc/jit/mobile/module.cpp",
     "torch/csrc/jit/mobile/observer.cpp",
+    "torch/csrc/jit/mobile/parse_bytecode.cpp",
     "torch/csrc/jit/mobile/train/export_data.cpp",
     "torch/csrc/jit/mobile/train/optim/sgd.cpp",
     "torch/csrc/jit/mobile/train/random.cpp",
index 434bb8d..aff7d69 100644 (file)
@@ -4,7 +4,6 @@
 #include <torch/csrc/jit/mobile/interpreter.h>
 #include <torch/csrc/jit/runtime/instruction.h>
 #include <torch/csrc/jit/runtime/operator.h>
-#include <torch/custom_class_detail.h>
 
 namespace torch {
 namespace jit {
index e194569..3057429 100644 (file)
@@ -41,8 +41,8 @@ class Function {
   using OperatorCacheType =
       std::unordered_map<c10::OperatorName, OperatorFunctionWithSchema>;
 
-  Function(c10::QualifiedName name);
-  bool run(Stack& stack) const;
+  TORCH_API Function(c10::QualifiedName name);
+  TORCH_API bool run(Stack& stack) const;
   c10::IValue operator()(Stack& stack) const;
   const std::string& name() const;
   TORCH_API const c10::QualifiedName& qualname() const;
index e438bb7..9a32c51 100644 (file)
@@ -1,4 +1,5 @@
 #include <torch/csrc/jit/mobile/import.h>
+#include <torch/csrc/jit/mobile/parse_bytecode.h>
 
 #include <ATen/core/ivalue.h>
 #include <c10/util/ScopeExit.h>
@@ -9,6 +10,7 @@
 #include <torch/csrc/jit/mobile/observer.h>
 #include <torch/csrc/jit/runtime/instruction.h>
 #include <torch/csrc/jit/serialization/import_export_constants.h>
+#include <torch/csrc/jit/serialization/import_export_functions.h>
 #include <torch/csrc/jit/serialization/import_read.h>
 #include <torch/custom_class.h>
 
@@ -85,20 +87,6 @@ using caffe2::serialize::ReadAdapterInterface;
 
 OpCode parseOpCode(const char* str);
 
-IValue expect_field(
-    std::vector<IValue>& elements,
-    const std::string& expected_name,
-    size_t entry) {
-  auto row = std::move(elements.at(entry)).toTuple();
-  TORCH_INTERNAL_ASSERT(
-      row->elements().at(0).toStringRef() == expected_name,
-      "Expected ",
-      expected_name,
-      " found ",
-      row->elements().at(0).toStringRef());
-  return std::move(row->elements().at(1));
-}
-
 std::string operator_str(
     const std::string& name,
     const std::string& overloadname) {
@@ -230,6 +218,11 @@ class BytecodeDeserializer final {
   c10::IValue readArchive(
       const std::string& archive_name,
       std::shared_ptr<mobile::CompilationUnit> mcu);
+  void parseFunctionSchema(
+      const std::string& function_name,
+      IValue* schemaTable,
+      const int64_t& model_version,
+      mobile::Function* function);
   /**
    * Loads operators by looking them up in the Dispatcher and returns
    * the set of operator names (with overload) that are not supported
@@ -241,11 +234,6 @@ class BytecodeDeserializer final {
    * even if the key is the same. You need to call has_same_arg_num()
    * on the value to ensure that the number of arguments are the same.
    */
-  std::unordered_set<std::string> load_and_find_unsupported_operator_names(
-      const std::vector<IValue>& ops_list,
-      mobile::Function* function,
-      int64_t model_version,
-      mobile::Function::OperatorCacheType& operator_cache) const;
   std::shared_ptr<CompilationUnit> compilation_unit_;
   std::unordered_set<std::string> imported_libs_;
   std::unique_ptr<PyTorchStreamReader> reader_{};
@@ -260,12 +248,11 @@ BytecodeDeserializer::BytecodeDeserializer(
       reader_(std::move(reader)),
       module_load_options_(module_load_options) {}
 
-std::unordered_set<std::string> BytecodeDeserializer::
-    load_and_find_unsupported_operator_names(
-        const std::vector<IValue>& ops_list,
-        mobile::Function* function,
-        int64_t model_version,
-        mobile::Function::OperatorCacheType& operator_cache) const {
+std::unordered_set<std::string> load_and_find_unsupported_operator_names(
+    const std::vector<IValue>& ops_list,
+    mobile::Function* function,
+    int64_t model_version,
+    mobile::Function::OperatorCacheType& operator_cache) {
   std::unordered_set<std::string> unsupported_op_names;
   // ops_list is the list of operator names that were read in from
   // bytecode.plk for the method that is currently being processed.
@@ -298,6 +285,81 @@ TypePtr BytecodeDeserializer::resolveTypeName(const c10::QualifiedName& qn) {
   return resolveTypeNameMobile(qn, compilation_unit_);
 }
 
+// It requires compilation_unit_ when parsing function schema. Keep it in
+// BytecodeDeserializer. It may be refacotred later to make it independent
+// of the specific BytecodeDeserializer, like parsing other tables
+void BytecodeDeserializer::parseFunctionSchema(
+    const std::string& function_name,
+    IValue* schemaTable,
+    const int64_t& model_version,
+    mobile::Function* function) {
+  // function schema
+  if (schemaTable) { // (schema is optional for back compat)
+    auto parseArgList = [this](std::vector<IValue>&& argTables) {
+      std::vector<c10::Argument> args;
+      for (auto&& argTable : std::move(argTables)) {
+        auto argTableElements =
+            std::move(*std::move(argTable).toTuple()).elements();
+        auto name =
+            expect_field(argTableElements, "name", BYTECODE_INDEX_ARGUMENT_NAME)
+                .toStringRef();
+        c10::TypePtr type = resolveTypeName(
+            (expect_field(
+                 argTableElements, "type", BYTECODE_INDEX_ARGUMENT_TYPE))
+                .toStringRef());
+        IValue default_value = expect_field(
+            argTableElements,
+            "default_value",
+            BYTECODE_INDEX_ARGUMENT_DEFAULT_VALUE);
+        args.emplace_back(
+            name,
+            std::move(type),
+            c10::nullopt /*N*/,
+            std::move(default_value));
+      }
+      return args;
+    };
+    auto schemaTableElements =
+        std::move(*std::move(*schemaTable).toTuple()).elements();
+    std::vector<IValue> arg_list =
+        std::move(*expect_field(
+                       schemaTableElements,
+                       "arguments",
+                       BYTECODE_INDEX_SCHEMA_ARGUMENTS)
+                       .toTuple())
+            .elements();
+    std::vector<IValue> ret_list =
+        std::move(
+            *expect_field(
+                 schemaTableElements, "returns", BYTECODE_INDEX_SCHEMA_RETURNS)
+                 .toTuple())
+            .elements();
+    c10::FunctionSchema schema(
+        function_name,
+        "" /*overload_name*/,
+        parseArgList(std::move(arg_list)),
+        parseArgList(std::move(ret_list)),
+        false /*is_varargs*/,
+        false /*is_varret*/);
+    function->setSchema(std::move(schema));
+  }
+}
+
+void parseOperators(
+    const std::vector<IValue>& ops_list,
+    const int64_t& model_version,
+    const uint64_t& module_load_options,
+    mobile::Function* function,
+    mobile::Function::OperatorCacheType& operator_cache) {
+  std::unordered_set<std::string> unsupported_op_names =
+      load_and_find_unsupported_operator_names(
+          ops_list, function, model_version, operator_cache);
+  if ((module_load_options & MobileModuleLoadOptions::OPERATOR_CHECK) &&
+      !unsupported_op_names.empty()) {
+    print_unsupported_ops_and_throw(unsupported_op_names);
+  }
+}
+
 void BytecodeDeserializer::parseMethods(
     std::vector<IValue>&& vals,
     c10::optional<std::vector<IValue>>&& debug_handles,
@@ -373,129 +435,30 @@ void BytecodeDeserializer::parseMethods(
             codeTableElements, "register_size", BYTECODE_INDEX_REGISTER_SIZE)
             .toInt();
 
-    c10::List<int64_t> debug_handles_list;
+    std::vector<IValue> debug_handles_m_tuple;
     if (debug_handles) {
-      auto debug_handles_m_tuple =
+      debug_handles_m_tuple =
           std::move(*std::move((*debug_handles)[i]).toTuple()).elements();
-      const std::string& debug_info_function_name =
-          debug_handles_m_tuple[0].toStringRef();
-      TORCH_CHECK(
-          debug_info_function_name == function_name,
-          "The function names in the bytecode table and the debug info table do not match.");
-      IValue& debug_handles_table = debug_handles_m_tuple[1];
-      debug_handles_list =
-          (expect_field(
-               std::move(debug_handles_table).toTuple()->elements(),
-               "function_debug_handles",
-               BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
-               .toTuple()
-               ->elements())[0]
-              .toIntList();
-      TORCH_CHECK(
-          debug_handles_list.size() == ins_list.size(),
-          "The numbers of instructions and debug handles strings do not match.");
     }
 
-    for (const auto j : c10::irange(ins_list.size())) {
-      std::vector<IValue> ins_item =
-          std::move(*std::move(ins_list[j]).toTuple()).elements();
-      TORCH_CHECK(
-          ins_item.size() == 3,
-          "There should be three parts in an instruction. The function name is ",
-          function_name);
-      OpCode op_code = parseOpCode(ins_item[0].toString()->string().c_str());
-      int X = ins_item[1].toInt();
-      int N = ins_item[2].toInt();
-      if (debug_handles) {
-        int64_t debug_handle = debug_handles_list[j];
-        function->append_instruction(op_code, X, N, debug_handle);
-      } else {
-        function->append_instruction(op_code, X, N);
-      }
-    }
+    parseInstructions(
+        function_name, ins_list, debug_handles_m_tuple, function.get());
 
-    std::unordered_set<std::string> unsupported_op_names =
-        load_and_find_unsupported_operator_names(
-            ops_list, function.get(), model_version, operator_cache);
-    if ((module_load_options_ & MobileModuleLoadOptions::OPERATOR_CHECK) &&
-        !unsupported_op_names.empty()) {
-      print_unsupported_ops_and_throw(unsupported_op_names);
-    }
+    parseOperators(
+        ops_list,
+        model_version,
+        module_load_options_,
+        function.get(),
+        operator_cache);
 
-    for (const auto& constant : consts_list) {
-      function->append_constant(constant);
-    }
+    parseConstants(consts_list, function.get());
 
-    static const c10::QualifiedName classPrefix = "__torch__.torch.classes";
-    for (const auto& t : types_list) {
-      c10::QualifiedName qn(t.toStringRef());
-      if (classPrefix.isPrefixOf(qn)) {
-        auto classType = getCustomClass(qn.qualifiedName());
-        TORCH_CHECK(
-            classType,
-            "The implementation of class ",
-            qn.qualifiedName(),
-            " cannot be found.");
-        function->append_type(classType);
-      } else {
-        function->append_type(c10::parseType(t.toStringRef()));
-      }
-    }
+    parseTypes(types_list, function.get());
 
     function->set_register_size(register_size);
 
-    // function schema
-    if (schemaTable) { // (schema is optional for back compat)
-      auto parseArgList = [this](std::vector<IValue>&& argTables) {
-        std::vector<c10::Argument> args;
-        for (auto&& argTable : std::move(argTables)) {
-          auto argTableElements =
-              std::move(*std::move(argTable).toTuple()).elements();
-          auto name =
-              expect_field(
-                  argTableElements, "name", BYTECODE_INDEX_ARGUMENT_NAME)
-                  .toStringRef();
-          c10::TypePtr type = resolveTypeName(
-              (expect_field(
-                   argTableElements, "type", BYTECODE_INDEX_ARGUMENT_TYPE))
-                  .toStringRef());
-          IValue default_value = expect_field(
-              argTableElements,
-              "default_value",
-              BYTECODE_INDEX_ARGUMENT_DEFAULT_VALUE);
-          args.emplace_back(
-              name,
-              std::move(type),
-              c10::nullopt /*N*/,
-              std::move(default_value));
-        }
-        return args;
-      };
-      auto schemaTableElements =
-          std::move(*std::move(*schemaTable).toTuple()).elements();
-      std::vector<IValue> arg_list =
-          std::move(*expect_field(
-                         schemaTableElements,
-                         "arguments",
-                         BYTECODE_INDEX_SCHEMA_ARGUMENTS)
-                         .toTuple())
-              .elements();
-      std::vector<IValue> ret_list =
-          std::move(*expect_field(
-                         schemaTableElements,
-                         "returns",
-                         BYTECODE_INDEX_SCHEMA_RETURNS)
-                         .toTuple())
-              .elements();
-      c10::FunctionSchema schema(
-          function_name,
-          "" /*overload_name*/,
-          parseArgList(std::move(arg_list)),
-          parseArgList(std::move(ret_list)),
-          false /*is_varargs*/,
-          false /*is_varret*/);
-      function->setSchema(std::move(schema));
-    }
+    parseFunctionSchema(
+        function_name, schemaTable, model_version, function.get());
 
     mcu.register_function(std::move(function));
   }
diff --git a/torch/csrc/jit/mobile/parse_bytecode.cpp b/torch/csrc/jit/mobile/parse_bytecode.cpp
new file mode 100644 (file)
index 0000000..f43f657
--- /dev/null
@@ -0,0 +1,110 @@
+#include <ATen/core/ivalue.h>
+#include <torch/csrc/jit/mobile/parse_bytecode.h>
+#include <torch/csrc/jit/mobile/type_parser.h>
+#include <torch/csrc/jit/serialization/import_export_constants.h>
+#include <torch/csrc/jit/serialization/import_export_functions.h>
+#include <torch/custom_class_detail.h>
+
+namespace torch {
+namespace jit {
+OpCode parseOpCode(const char* str);
+using c10::IValue;
+
+IValue expect_field(
+    std::vector<IValue>& elements,
+    const std::string& expected_name,
+    size_t entry) {
+  auto row = std::move(elements.at(entry)).toTuple();
+  TORCH_INTERNAL_ASSERT(
+      row->elements().at(0).toStringRef() == expected_name,
+      "Expected ",
+      expected_name,
+      " found ",
+      row->elements().at(0).toStringRef());
+  return std::move(row->elements().at(1));
+}
+
+namespace mobile {
+
+namespace {} // namespace
+
+void parseInstructions(
+    const std::string& function_name,
+    const std::vector<IValue>& ins_list,
+    std::vector<IValue>& debug_handles_m_tuple,
+    mobile::Function* function) {
+  c10::List<int64_t> debug_handles_list;
+  if (!debug_handles_m_tuple.empty()) {
+    const std::string& debug_info_function_name =
+        debug_handles_m_tuple[0].toStringRef();
+    TORCH_CHECK(
+        debug_info_function_name == function_name,
+        "The function names in the bytecode table and the debug info table do not match.");
+    IValue& debug_handles_table = debug_handles_m_tuple[1];
+    debug_handles_list =
+        (expect_field(
+             std::move(debug_handles_table).toTuple()->elements(),
+             "function_debug_handles",
+             BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
+             .toTuple()
+             ->elements())[0]
+            .toIntList();
+    TORCH_CHECK(
+        debug_handles_list.size() == ins_list.size(),
+        "The numbers of instructions and debug handles strings do not match.");
+  }
+
+  for (const auto j : c10::irange(ins_list.size())) {
+    std::vector<IValue> ins_item =
+        std::move(*std::move(ins_list[j]).toTuple()).elements();
+    TORCH_CHECK(
+        ins_item.size() == 3,
+        "There should be three parts in an instruction. The function name is ",
+        function_name);
+    OpCode op_code = parseOpCode(ins_item[0].toString()->string().c_str());
+    int X = ins_item[1].toInt();
+    int N = ins_item[2].toInt();
+    if (!debug_handles_list.empty()) {
+      int64_t debug_handle = debug_handles_list[j];
+      function->append_instruction(op_code, X, N, debug_handle);
+    } else {
+      function->append_instruction(op_code, X, N);
+    }
+  }
+}
+
+void parseConstants(
+    const std::vector<IValue>& consts_list,
+    mobile::Function* function) {
+  for (const auto& constant : consts_list) {
+    function->append_constant(constant);
+  }
+}
+
+void parseTypes(
+    const std::vector<IValue>& types_list,
+    mobile::Function* function) {
+  static const c10::QualifiedName classPrefix = "__torch__.torch.classes";
+  for (const auto& t : types_list) {
+    c10::QualifiedName qn(t.toStringRef());
+    if (classPrefix.isPrefixOf(qn)) {
+      auto classType = getCustomClass(qn.qualifiedName());
+      TORCH_CHECK(
+          classType,
+          "The implementation of class ",
+          qn.qualifiedName(),
+          " cannot be found.");
+      function->append_type(classType);
+    } else {
+      function->append_type(c10::parseType(t.toStringRef()));
+    }
+  }
+}
+
+void parseRegisterSize(size_t rsize, mobile::Function* function) {
+  function->set_register_size(rsize);
+}
+
+} // namespace mobile
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/mobile/parse_bytecode.h b/torch/csrc/jit/mobile/parse_bytecode.h
new file mode 100644 (file)
index 0000000..b87542a
--- /dev/null
@@ -0,0 +1,22 @@
+#pragma once
+#include <torch/csrc/jit/mobile/function.h>
+
+namespace torch {
+namespace jit {
+namespace mobile {
+using c10::IValue;
+TORCH_API void parseInstructions(
+    const std::string& function_name,
+    const std::vector<IValue>& ins_list,
+    std::vector<IValue>& debug_handles_m_tuple,
+    mobile::Function* function);
+TORCH_API void parseConstants(
+    const std::vector<IValue>& consts_list,
+    mobile::Function* function);
+TORCH_API void parseTypes(
+    const std::vector<IValue>& types_list,
+    mobile::Function* function);
+TORCH_API void parseRegisterSize(size_t rsize, mobile::Function* function);
+} // namespace mobile
+} // namespace jit
+} // namespace torch
index dbe0ea2..5545f14 100644 (file)
@@ -208,5 +208,8 @@ struct TORCH_API BytecodeEmitDefaultInputsGuard {
   bool prev_mode;
 };
 
+TORCH_API IValue to_tuple(std::vector<IValue> ivalues);
+TORCH_API IValue
+Table(const std::vector<std::pair<std::string, IValue>>& entries);
 } // namespace jit
 } // namespace torch
index a3cae77..8af15a3 100644 (file)
@@ -15,6 +15,7 @@
 #include <torch/csrc/jit/runtime/instruction.h>
 #include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
 #include <torch/csrc/jit/serialization/import_export_constants.h>
+#include <torch/csrc/jit/serialization/import_export_functions.h>
 #include <torch/csrc/jit/serialization/import_export_helpers.h>
 #include <torch/csrc/jit/serialization/pickle.h>
 #include <torch/csrc/jit/serialization/python_print.h>
@@ -36,25 +37,24 @@ namespace jit {
 
 char const* toString(OpCode op);
 
-namespace {
-
-ExportModuleExtraFilesHook& GetExtraFilesHook() {
-  static ExportModuleExtraFilesHook func = nullptr;
-  return func;
-}
-
-static IValue Tup(std::vector<IValue> ivalues) {
+IValue to_tuple(std::vector<IValue> ivalues) {
   return c10::ivalue::Tuple::create(std::move(ivalues));
 }
 
-static IValue Table(
-    const std::vector<std::pair<std::string, IValue>>& entries) {
+IValue Table(const std::vector<std::pair<std::string, IValue>>& entries) {
   std::vector<IValue> ivalue_entries;
   ivalue_entries.reserve(entries.size());
   for (const auto& e : entries) {
-    ivalue_entries.push_back(Tup({e.first, e.second}));
+    ivalue_entries.push_back(to_tuple({e.first, e.second}));
   }
-  return Tup(std::move(ivalue_entries));
+  return to_tuple(std::move(ivalue_entries));
+}
+
+namespace {
+
+ExportModuleExtraFilesHook& GetExtraFilesHook() {
+  static ExportModuleExtraFilesHook func = nullptr;
+  return func;
 }
 
 std::pair<IValue, IValue> getFunctionTuple(
@@ -150,7 +150,7 @@ std::pair<IValue, IValue> getFunctionTuple(
   std::vector<IValue> instructions;
   instructions.reserve(instructions_copy.size());
   for (Instruction ins : instructions_copy) {
-    instructions.emplace_back(Tup({toString(ins.op), ins.X, ins.N}));
+    instructions.emplace_back(to_tuple({toString(ins.op), ins.X, ins.N}));
   }
 
   // operators
@@ -169,10 +169,10 @@ std::pair<IValue, IValue> getFunctionTuple(
       num_args = it->second;
     }
     if (BytecodeEmitDefaultValueForUnspecifiedArgMode::is_enabled()) {
-      operators.emplace_back(Tup({opname.name, opname.overload_name}));
+      operators.emplace_back(to_tuple({opname.name, opname.overload_name}));
     } else {
       operators.emplace_back(
-          Tup({opname.name, opname.overload_name, num_args}));
+          to_tuple({opname.name, opname.overload_name, num_args}));
     }
   }
 
@@ -208,10 +208,10 @@ std::pair<IValue, IValue> getFunctionTuple(
   auto register_size = static_cast<int>(code->register_size());
 
   auto codeTable = Table(
-      {{"instructions", Tup(instructions)},
-       {"operators", Tup(operators)},
-       {"constants", Tup(constants)},
-       {"types", Tup(types)},
+      {{"instructions", to_tuple(instructions)},
+       {"operators", to_tuple(operators)},
+       {"constants", to_tuple(constants)},
+       {"types", to_tuple(types)},
        {"register_size", register_size}});
 
   // schema
@@ -258,7 +258,7 @@ std::pair<IValue, IValue> getFunctionTuple(
           {"default_value", arg.default_value()},
       }));
     }
-    return Tup(argTables);
+    return to_tuple(argTables);
   };
   auto schemaTable = Table({
       {"arguments", makeArgTuple(schema.arguments())},
@@ -266,7 +266,7 @@ std::pair<IValue, IValue> getFunctionTuple(
   });
 
   // function tuple
-  auto bytecode_vals = Tup({qn, codeTable, schemaTable});
+  auto bytecode_vals = to_tuple({qn, codeTable, schemaTable});
 
   c10::optional<IValue> debug_info_vals;
   // module debug info
@@ -278,7 +278,7 @@ std::pair<IValue, IValue> getFunctionTuple(
   IValue module_debug_tuple = c10::ivalue::Tuple::create(op_debug_handles);
   auto function_debug_info =
       Table({{"function_debug_handles", module_debug_tuple}});
-  debug_info_vals = Tup({qn, function_debug_info});
+  debug_info_vals = to_tuple({qn, function_debug_info});
   return std::make_pair(bytecode_vals, debug_info_vals);
 }
 
@@ -648,7 +648,7 @@ void ScriptModuleSerializer::writeByteCode(
       debug_info_elements,
       debug_info_recorder,
       type_name_uniquer_);
-  auto telements = Tup(std::move(elements));
+  auto telements = to_tuple(std::move(elements));
   writeArchive(
       telements,
       /*archive_name=*/"bytecode",
@@ -656,7 +656,7 @@ void ScriptModuleSerializer::writeByteCode(
       /*tensor_dir=*/"constants/",
       /*use_storage_context=*/true);
 
-  auto debug_info_telements = Tup(std::move(debug_info_elements));
+  auto debug_info_telements = to_tuple(std::move(debug_info_elements));
 
   // At the moment keeping this feature experimental
   // since we have not evaluated how this affect model size
index 4cb48d6..c512c34 100644 (file)
@@ -1,12 +1,14 @@
 #pragma once
+#include <ATen/core/ivalue.h>
 
 // Functions that are used in both import and export processes
 namespace torch {
 namespace jit {
-void moduleMethodsTuple(
-    const Module& module,
-    std::vector<c10::IValue>& elements);
-IValue expect_field(IValue tup, const std::string& expected_name, size_t entry);
+using c10::IValue;
+IValue expect_field(
+    std::vector<IValue>& elements,
+    const std::string& expected_name,
+    size_t entry);
 std::string operator_str(
     const std::string& name,
     const std::string& overloadname);
index 918b0d4..05a5389 100644 (file)
@@ -5,7 +5,6 @@
 #include <torch/csrc/jit/frontend/parser.h>
 #include <torch/csrc/jit/frontend/resolver.h>
 #include <torch/csrc/jit/frontend/script_type_parser.h>
-#include <torch/csrc/jit/serialization/export.h>
 #include <torch/custom_class.h>
 
 #include <regex>