${TORCH_SRC_DIR}/csrc/jit/mobile/function.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/import.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/import_data.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/model_compatibility.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/module.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp
- ${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/mobile/parse_bytecode.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/train/export_data.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/train/optim/sgd.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/train/random.cpp
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/model_compatibility.h>
#include <torch/csrc/jit/mobile/module.h>
+#include <torch/csrc/jit/mobile/parse_bytecode.h>
#include <torch/csrc/jit/mobile/runtime_compatibility.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/import.h>
AT_ASSERT(output.equal(outputref));
}
+TEST(RunTimeTest, ParseBytecode) {
+ // A simple example to show a simple bytecode that can be used independent of
+ // PyTorch TorchScript serialization (unpickler, etc) and operator library.
+ // It has basic control flow (if, else) and basic data orchestration (list
+ // construction). The original PyTorch program:
+
+ // class Module(torch.nn.Module):
+ //
+ // def __init__(self):
+ // super().__init__()
+ //
+ // def forward(self, x: int, h: int, xfirst: bool):
+ // if xfirst:
+ // return [x, h]
+ // else:
+ // return [h, x]
+
+ // 1. Prepare for the bytecode. In reality it can be from a customized
+ // deserializer.
+ std::vector<IValue> instructions{
+ to_tuple({"STOREN", 1, 4}),
+ to_tuple({"DROPR", 1, 0}),
+ to_tuple({"MOVE", 4, 0}),
+ to_tuple({"JF", 5, 0}),
+ to_tuple({"LOAD", 2, 0}),
+ to_tuple({"LOAD", 3, 0}),
+ to_tuple({"LIST_CONSTRUCT", 0, 2}),
+ to_tuple({"JMP", 4, 0}),
+ to_tuple({"LOAD", 3, 0}),
+ to_tuple({"LOAD", 2, 0}),
+ to_tuple({"LIST_CONSTRUCT", 1, 2}),
+ to_tuple({"STORE", 5, 0}),
+ to_tuple({"DROPR", 3, 0}),
+ to_tuple({"DROPR", 2, 0}),
+ to_tuple({"MOVE", 5, 0}),
+ to_tuple({"RET", 0, 0}),
+ };
+ std::vector<IValue> operators; // empty for this example
+ std::vector<IValue> constants; // empty for this example
+
+ std::vector<IValue> types{"List[int]", "List[int]"};
+ // 2. Parse the function
+ std::string function_name("test_function");
+ auto function = std::unique_ptr<mobile::Function>(
+ new mobile::Function(c10::QualifiedName(function_name)));
+ std::vector<IValue> debug_handles_m_tuple;
+ parseInstructions(
+ function_name, instructions, debug_handles_m_tuple, function.get());
+ parseTypes(types, function.get());
+ const size_t rsize = 5;
+ parseRegisterSize(rsize, function.get());
+
+ // 3. Prepare for inputs and run the function
+ // Note that the first input is reserved for Module object.
+ // Since this is a function test and Module object is not required,
+ // a dummy IValue (0) is added here.
+ std::vector<IValue> inputs{0, 1, 2, true};
+ function->run(inputs);
+ auto output = inputs[0].toList();
+ ASSERT_EQ(output[0], 1);
+ ASSERT_EQ(output[1], 2);
+
+ std::vector<IValue> inputs1{0, 1, 2, false};
+ function->run(inputs1);
+ auto output1 = inputs1[0].toList();
+ ASSERT_EQ(output1[0], 2);
+ ASSERT_EQ(output1[1], 1);
+}
+
namespace {
void testLiteModuleCompareResultTensors(
Module& m,
"torch/csrc/jit/mobile/model_compatibility.cpp",
"torch/csrc/jit/mobile/module.cpp",
"torch/csrc/jit/mobile/observer.cpp",
+ "torch/csrc/jit/mobile/parse_bytecode.cpp",
"torch/csrc/jit/runtime/register_prim_ops.cpp",
"torch/csrc/jit/runtime/register_special_ops.cpp",
]
"torch/csrc/jit/mobile/model_compatibility.cpp",
"torch/csrc/jit/mobile/module.cpp",
"torch/csrc/jit/mobile/observer.cpp",
+ "torch/csrc/jit/mobile/parse_bytecode.cpp",
"torch/csrc/jit/mobile/train/export_data.cpp",
"torch/csrc/jit/mobile/train/optim/sgd.cpp",
"torch/csrc/jit/mobile/train/random.cpp",
#include <torch/csrc/jit/mobile/interpreter.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/runtime/operator.h>
-#include <torch/custom_class_detail.h>
namespace torch {
namespace jit {
using OperatorCacheType =
std::unordered_map<c10::OperatorName, OperatorFunctionWithSchema>;
- Function(c10::QualifiedName name);
- bool run(Stack& stack) const;
+ TORCH_API Function(c10::QualifiedName name);
+ TORCH_API bool run(Stack& stack) const;
c10::IValue operator()(Stack& stack) const;
const std::string& name() const;
TORCH_API const c10::QualifiedName& qualname() const;
#include <torch/csrc/jit/mobile/import.h>
+#include <torch/csrc/jit/mobile/parse_bytecode.h>
#include <ATen/core/ivalue.h>
#include <c10/util/ScopeExit.h>
#include <torch/csrc/jit/mobile/observer.h>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
+#include <torch/csrc/jit/serialization/import_export_functions.h>
#include <torch/csrc/jit/serialization/import_read.h>
#include <torch/custom_class.h>
OpCode parseOpCode(const char* str);
-IValue expect_field(
- std::vector<IValue>& elements,
- const std::string& expected_name,
- size_t entry) {
- auto row = std::move(elements.at(entry)).toTuple();
- TORCH_INTERNAL_ASSERT(
- row->elements().at(0).toStringRef() == expected_name,
- "Expected ",
- expected_name,
- " found ",
- row->elements().at(0).toStringRef());
- return std::move(row->elements().at(1));
-}
-
std::string operator_str(
const std::string& name,
const std::string& overloadname) {
c10::IValue readArchive(
const std::string& archive_name,
std::shared_ptr<mobile::CompilationUnit> mcu);
+ void parseFunctionSchema(
+ const std::string& function_name,
+ IValue* schemaTable,
+ const int64_t& model_version,
+ mobile::Function* function);
/**
* Loads operators by looking them up in the Dispatcher and returns
* the set of operator names (with overload) that are not supported
* even if the key is the same. You need to call has_same_arg_num()
* on the value to ensure that the number of arguments are the same.
*/
- std::unordered_set<std::string> load_and_find_unsupported_operator_names(
- const std::vector<IValue>& ops_list,
- mobile::Function* function,
- int64_t model_version,
- mobile::Function::OperatorCacheType& operator_cache) const;
std::shared_ptr<CompilationUnit> compilation_unit_;
std::unordered_set<std::string> imported_libs_;
std::unique_ptr<PyTorchStreamReader> reader_{};
reader_(std::move(reader)),
module_load_options_(module_load_options) {}
-std::unordered_set<std::string> BytecodeDeserializer::
- load_and_find_unsupported_operator_names(
- const std::vector<IValue>& ops_list,
- mobile::Function* function,
- int64_t model_version,
- mobile::Function::OperatorCacheType& operator_cache) const {
+std::unordered_set<std::string> load_and_find_unsupported_operator_names(
+ const std::vector<IValue>& ops_list,
+ mobile::Function* function,
+ int64_t model_version,
+ mobile::Function::OperatorCacheType& operator_cache) {
std::unordered_set<std::string> unsupported_op_names;
// ops_list is the list of operator names that were read in from
// bytecode.plk for the method that is currently being processed.
return resolveTypeNameMobile(qn, compilation_unit_);
}
+// It requires compilation_unit_ when parsing function schema. Keep it in
+// BytecodeDeserializer. It may be refacotred later to make it independent
+// of the specific BytecodeDeserializer, like parsing other tables
+void BytecodeDeserializer::parseFunctionSchema(
+ const std::string& function_name,
+ IValue* schemaTable,
+ const int64_t& model_version,
+ mobile::Function* function) {
+ // function schema
+ if (schemaTable) { // (schema is optional for back compat)
+ auto parseArgList = [this](std::vector<IValue>&& argTables) {
+ std::vector<c10::Argument> args;
+ for (auto&& argTable : std::move(argTables)) {
+ auto argTableElements =
+ std::move(*std::move(argTable).toTuple()).elements();
+ auto name =
+ expect_field(argTableElements, "name", BYTECODE_INDEX_ARGUMENT_NAME)
+ .toStringRef();
+ c10::TypePtr type = resolveTypeName(
+ (expect_field(
+ argTableElements, "type", BYTECODE_INDEX_ARGUMENT_TYPE))
+ .toStringRef());
+ IValue default_value = expect_field(
+ argTableElements,
+ "default_value",
+ BYTECODE_INDEX_ARGUMENT_DEFAULT_VALUE);
+ args.emplace_back(
+ name,
+ std::move(type),
+ c10::nullopt /*N*/,
+ std::move(default_value));
+ }
+ return args;
+ };
+ auto schemaTableElements =
+ std::move(*std::move(*schemaTable).toTuple()).elements();
+ std::vector<IValue> arg_list =
+ std::move(*expect_field(
+ schemaTableElements,
+ "arguments",
+ BYTECODE_INDEX_SCHEMA_ARGUMENTS)
+ .toTuple())
+ .elements();
+ std::vector<IValue> ret_list =
+ std::move(
+ *expect_field(
+ schemaTableElements, "returns", BYTECODE_INDEX_SCHEMA_RETURNS)
+ .toTuple())
+ .elements();
+ c10::FunctionSchema schema(
+ function_name,
+ "" /*overload_name*/,
+ parseArgList(std::move(arg_list)),
+ parseArgList(std::move(ret_list)),
+ false /*is_varargs*/,
+ false /*is_varret*/);
+ function->setSchema(std::move(schema));
+ }
+}
+
+void parseOperators(
+ const std::vector<IValue>& ops_list,
+ const int64_t& model_version,
+ const uint64_t& module_load_options,
+ mobile::Function* function,
+ mobile::Function::OperatorCacheType& operator_cache) {
+ std::unordered_set<std::string> unsupported_op_names =
+ load_and_find_unsupported_operator_names(
+ ops_list, function, model_version, operator_cache);
+ if ((module_load_options & MobileModuleLoadOptions::OPERATOR_CHECK) &&
+ !unsupported_op_names.empty()) {
+ print_unsupported_ops_and_throw(unsupported_op_names);
+ }
+}
+
void BytecodeDeserializer::parseMethods(
std::vector<IValue>&& vals,
c10::optional<std::vector<IValue>>&& debug_handles,
codeTableElements, "register_size", BYTECODE_INDEX_REGISTER_SIZE)
.toInt();
- c10::List<int64_t> debug_handles_list;
+ std::vector<IValue> debug_handles_m_tuple;
if (debug_handles) {
- auto debug_handles_m_tuple =
+ debug_handles_m_tuple =
std::move(*std::move((*debug_handles)[i]).toTuple()).elements();
- const std::string& debug_info_function_name =
- debug_handles_m_tuple[0].toStringRef();
- TORCH_CHECK(
- debug_info_function_name == function_name,
- "The function names in the bytecode table and the debug info table do not match.");
- IValue& debug_handles_table = debug_handles_m_tuple[1];
- debug_handles_list =
- (expect_field(
- std::move(debug_handles_table).toTuple()->elements(),
- "function_debug_handles",
- BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
- .toTuple()
- ->elements())[0]
- .toIntList();
- TORCH_CHECK(
- debug_handles_list.size() == ins_list.size(),
- "The numbers of instructions and debug handles strings do not match.");
}
- for (const auto j : c10::irange(ins_list.size())) {
- std::vector<IValue> ins_item =
- std::move(*std::move(ins_list[j]).toTuple()).elements();
- TORCH_CHECK(
- ins_item.size() == 3,
- "There should be three parts in an instruction. The function name is ",
- function_name);
- OpCode op_code = parseOpCode(ins_item[0].toString()->string().c_str());
- int X = ins_item[1].toInt();
- int N = ins_item[2].toInt();
- if (debug_handles) {
- int64_t debug_handle = debug_handles_list[j];
- function->append_instruction(op_code, X, N, debug_handle);
- } else {
- function->append_instruction(op_code, X, N);
- }
- }
+ parseInstructions(
+ function_name, ins_list, debug_handles_m_tuple, function.get());
- std::unordered_set<std::string> unsupported_op_names =
- load_and_find_unsupported_operator_names(
- ops_list, function.get(), model_version, operator_cache);
- if ((module_load_options_ & MobileModuleLoadOptions::OPERATOR_CHECK) &&
- !unsupported_op_names.empty()) {
- print_unsupported_ops_and_throw(unsupported_op_names);
- }
+ parseOperators(
+ ops_list,
+ model_version,
+ module_load_options_,
+ function.get(),
+ operator_cache);
- for (const auto& constant : consts_list) {
- function->append_constant(constant);
- }
+ parseConstants(consts_list, function.get());
- static const c10::QualifiedName classPrefix = "__torch__.torch.classes";
- for (const auto& t : types_list) {
- c10::QualifiedName qn(t.toStringRef());
- if (classPrefix.isPrefixOf(qn)) {
- auto classType = getCustomClass(qn.qualifiedName());
- TORCH_CHECK(
- classType,
- "The implementation of class ",
- qn.qualifiedName(),
- " cannot be found.");
- function->append_type(classType);
- } else {
- function->append_type(c10::parseType(t.toStringRef()));
- }
- }
+ parseTypes(types_list, function.get());
function->set_register_size(register_size);
- // function schema
- if (schemaTable) { // (schema is optional for back compat)
- auto parseArgList = [this](std::vector<IValue>&& argTables) {
- std::vector<c10::Argument> args;
- for (auto&& argTable : std::move(argTables)) {
- auto argTableElements =
- std::move(*std::move(argTable).toTuple()).elements();
- auto name =
- expect_field(
- argTableElements, "name", BYTECODE_INDEX_ARGUMENT_NAME)
- .toStringRef();
- c10::TypePtr type = resolveTypeName(
- (expect_field(
- argTableElements, "type", BYTECODE_INDEX_ARGUMENT_TYPE))
- .toStringRef());
- IValue default_value = expect_field(
- argTableElements,
- "default_value",
- BYTECODE_INDEX_ARGUMENT_DEFAULT_VALUE);
- args.emplace_back(
- name,
- std::move(type),
- c10::nullopt /*N*/,
- std::move(default_value));
- }
- return args;
- };
- auto schemaTableElements =
- std::move(*std::move(*schemaTable).toTuple()).elements();
- std::vector<IValue> arg_list =
- std::move(*expect_field(
- schemaTableElements,
- "arguments",
- BYTECODE_INDEX_SCHEMA_ARGUMENTS)
- .toTuple())
- .elements();
- std::vector<IValue> ret_list =
- std::move(*expect_field(
- schemaTableElements,
- "returns",
- BYTECODE_INDEX_SCHEMA_RETURNS)
- .toTuple())
- .elements();
- c10::FunctionSchema schema(
- function_name,
- "" /*overload_name*/,
- parseArgList(std::move(arg_list)),
- parseArgList(std::move(ret_list)),
- false /*is_varargs*/,
- false /*is_varret*/);
- function->setSchema(std::move(schema));
- }
+ parseFunctionSchema(
+ function_name, schemaTable, model_version, function.get());
mcu.register_function(std::move(function));
}
--- /dev/null
+#include <ATen/core/ivalue.h>
+#include <torch/csrc/jit/mobile/parse_bytecode.h>
+#include <torch/csrc/jit/mobile/type_parser.h>
+#include <torch/csrc/jit/serialization/import_export_constants.h>
+#include <torch/csrc/jit/serialization/import_export_functions.h>
+#include <torch/custom_class_detail.h>
+
+namespace torch {
+namespace jit {
+OpCode parseOpCode(const char* str);
+using c10::IValue;
+
+IValue expect_field(
+ std::vector<IValue>& elements,
+ const std::string& expected_name,
+ size_t entry) {
+ auto row = std::move(elements.at(entry)).toTuple();
+ TORCH_INTERNAL_ASSERT(
+ row->elements().at(0).toStringRef() == expected_name,
+ "Expected ",
+ expected_name,
+ " found ",
+ row->elements().at(0).toStringRef());
+ return std::move(row->elements().at(1));
+}
+
+namespace mobile {
+
+namespace {} // namespace
+
+void parseInstructions(
+ const std::string& function_name,
+ const std::vector<IValue>& ins_list,
+ std::vector<IValue>& debug_handles_m_tuple,
+ mobile::Function* function) {
+ c10::List<int64_t> debug_handles_list;
+ if (!debug_handles_m_tuple.empty()) {
+ const std::string& debug_info_function_name =
+ debug_handles_m_tuple[0].toStringRef();
+ TORCH_CHECK(
+ debug_info_function_name == function_name,
+ "The function names in the bytecode table and the debug info table do not match.");
+ IValue& debug_handles_table = debug_handles_m_tuple[1];
+ debug_handles_list =
+ (expect_field(
+ std::move(debug_handles_table).toTuple()->elements(),
+ "function_debug_handles",
+ BYTECODE_INDEX_MODULE_DEBUG_HANDLES)
+ .toTuple()
+ ->elements())[0]
+ .toIntList();
+ TORCH_CHECK(
+ debug_handles_list.size() == ins_list.size(),
+ "The numbers of instructions and debug handles strings do not match.");
+ }
+
+ for (const auto j : c10::irange(ins_list.size())) {
+ std::vector<IValue> ins_item =
+ std::move(*std::move(ins_list[j]).toTuple()).elements();
+ TORCH_CHECK(
+ ins_item.size() == 3,
+ "There should be three parts in an instruction. The function name is ",
+ function_name);
+ OpCode op_code = parseOpCode(ins_item[0].toString()->string().c_str());
+ int X = ins_item[1].toInt();
+ int N = ins_item[2].toInt();
+ if (!debug_handles_list.empty()) {
+ int64_t debug_handle = debug_handles_list[j];
+ function->append_instruction(op_code, X, N, debug_handle);
+ } else {
+ function->append_instruction(op_code, X, N);
+ }
+ }
+}
+
+void parseConstants(
+ const std::vector<IValue>& consts_list,
+ mobile::Function* function) {
+ for (const auto& constant : consts_list) {
+ function->append_constant(constant);
+ }
+}
+
+void parseTypes(
+ const std::vector<IValue>& types_list,
+ mobile::Function* function) {
+ static const c10::QualifiedName classPrefix = "__torch__.torch.classes";
+ for (const auto& t : types_list) {
+ c10::QualifiedName qn(t.toStringRef());
+ if (classPrefix.isPrefixOf(qn)) {
+ auto classType = getCustomClass(qn.qualifiedName());
+ TORCH_CHECK(
+ classType,
+ "The implementation of class ",
+ qn.qualifiedName(),
+ " cannot be found.");
+ function->append_type(classType);
+ } else {
+ function->append_type(c10::parseType(t.toStringRef()));
+ }
+ }
+}
+
+void parseRegisterSize(size_t rsize, mobile::Function* function) {
+ function->set_register_size(rsize);
+}
+
+} // namespace mobile
+} // namespace jit
+} // namespace torch
--- /dev/null
+#pragma once
+#include <torch/csrc/jit/mobile/function.h>
+
+namespace torch {
+namespace jit {
+namespace mobile {
+using c10::IValue;
+TORCH_API void parseInstructions(
+ const std::string& function_name,
+ const std::vector<IValue>& ins_list,
+ std::vector<IValue>& debug_handles_m_tuple,
+ mobile::Function* function);
+TORCH_API void parseConstants(
+ const std::vector<IValue>& consts_list,
+ mobile::Function* function);
+TORCH_API void parseTypes(
+ const std::vector<IValue>& types_list,
+ mobile::Function* function);
+TORCH_API void parseRegisterSize(size_t rsize, mobile::Function* function);
+} // namespace mobile
+} // namespace jit
+} // namespace torch
bool prev_mode;
};
+TORCH_API IValue to_tuple(std::vector<IValue> ivalues);
+TORCH_API IValue
+Table(const std::vector<std::pair<std::string, IValue>>& entries);
} // namespace jit
} // namespace torch
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
+#include <torch/csrc/jit/serialization/import_export_functions.h>
#include <torch/csrc/jit/serialization/import_export_helpers.h>
#include <torch/csrc/jit/serialization/pickle.h>
#include <torch/csrc/jit/serialization/python_print.h>
char const* toString(OpCode op);
-namespace {
-
-ExportModuleExtraFilesHook& GetExtraFilesHook() {
- static ExportModuleExtraFilesHook func = nullptr;
- return func;
-}
-
-static IValue Tup(std::vector<IValue> ivalues) {
+IValue to_tuple(std::vector<IValue> ivalues) {
return c10::ivalue::Tuple::create(std::move(ivalues));
}
-static IValue Table(
- const std::vector<std::pair<std::string, IValue>>& entries) {
+IValue Table(const std::vector<std::pair<std::string, IValue>>& entries) {
std::vector<IValue> ivalue_entries;
ivalue_entries.reserve(entries.size());
for (const auto& e : entries) {
- ivalue_entries.push_back(Tup({e.first, e.second}));
+ ivalue_entries.push_back(to_tuple({e.first, e.second}));
}
- return Tup(std::move(ivalue_entries));
+ return to_tuple(std::move(ivalue_entries));
+}
+
+namespace {
+
+ExportModuleExtraFilesHook& GetExtraFilesHook() {
+ static ExportModuleExtraFilesHook func = nullptr;
+ return func;
}
std::pair<IValue, IValue> getFunctionTuple(
std::vector<IValue> instructions;
instructions.reserve(instructions_copy.size());
for (Instruction ins : instructions_copy) {
- instructions.emplace_back(Tup({toString(ins.op), ins.X, ins.N}));
+ instructions.emplace_back(to_tuple({toString(ins.op), ins.X, ins.N}));
}
// operators
num_args = it->second;
}
if (BytecodeEmitDefaultValueForUnspecifiedArgMode::is_enabled()) {
- operators.emplace_back(Tup({opname.name, opname.overload_name}));
+ operators.emplace_back(to_tuple({opname.name, opname.overload_name}));
} else {
operators.emplace_back(
- Tup({opname.name, opname.overload_name, num_args}));
+ to_tuple({opname.name, opname.overload_name, num_args}));
}
}
auto register_size = static_cast<int>(code->register_size());
auto codeTable = Table(
- {{"instructions", Tup(instructions)},
- {"operators", Tup(operators)},
- {"constants", Tup(constants)},
- {"types", Tup(types)},
+ {{"instructions", to_tuple(instructions)},
+ {"operators", to_tuple(operators)},
+ {"constants", to_tuple(constants)},
+ {"types", to_tuple(types)},
{"register_size", register_size}});
// schema
{"default_value", arg.default_value()},
}));
}
- return Tup(argTables);
+ return to_tuple(argTables);
};
auto schemaTable = Table({
{"arguments", makeArgTuple(schema.arguments())},
});
// function tuple
- auto bytecode_vals = Tup({qn, codeTable, schemaTable});
+ auto bytecode_vals = to_tuple({qn, codeTable, schemaTable});
c10::optional<IValue> debug_info_vals;
// module debug info
IValue module_debug_tuple = c10::ivalue::Tuple::create(op_debug_handles);
auto function_debug_info =
Table({{"function_debug_handles", module_debug_tuple}});
- debug_info_vals = Tup({qn, function_debug_info});
+ debug_info_vals = to_tuple({qn, function_debug_info});
return std::make_pair(bytecode_vals, debug_info_vals);
}
debug_info_elements,
debug_info_recorder,
type_name_uniquer_);
- auto telements = Tup(std::move(elements));
+ auto telements = to_tuple(std::move(elements));
writeArchive(
telements,
/*archive_name=*/"bytecode",
/*tensor_dir=*/"constants/",
/*use_storage_context=*/true);
- auto debug_info_telements = Tup(std::move(debug_info_elements));
+ auto debug_info_telements = to_tuple(std::move(debug_info_elements));
// At the moment keeping this feature experimental
// since we have not evaluated how this affect model size
#pragma once
+#include <ATen/core/ivalue.h>
// Functions that are used in both import and export processes
namespace torch {
namespace jit {
-void moduleMethodsTuple(
- const Module& module,
- std::vector<c10::IValue>& elements);
-IValue expect_field(IValue tup, const std::string& expected_name, size_t entry);
+using c10::IValue;
+IValue expect_field(
+ std::vector<IValue>& elements,
+ const std::string& expected_name,
+ size_t entry);
std::string operator_str(
const std::string& name,
const std::string& overloadname);
#include <torch/csrc/jit/frontend/parser.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/frontend/script_type_parser.h>
-#include <torch/csrc/jit/serialization/export.h>
#include <torch/custom_class.h>
#include <regex>