From 880098a7e34a20628f960daa8eab0eb1ad566c39 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Mon, 20 Sep 2021 22:22:17 -0700 Subject: [PATCH] [PyTorch Edge] Backport function for defaults args with out args, flag on (#63651) Summary: 1. Enable support for operators with default args and out args. For `torch.add(x, h, out=x)`, the number of specified arguments will be 3 instead of 4. 2. Bump bytecode version from 6 to 7 3. Implement backport_v7_to_v6 function. Also slightly refactor the local_thread to allow re-emit operators. 4. unittest to cover backport function 5. Update expect result from 4 to 3 in unit test DefaultArgsWithOutArg to cover the number of specified arguments. Pull Request resolved: https://github.com/pytorch/pytorch/pull/63651 ghstack-source-id: 138539912 Test Plan: ``` caffe2/test/cpp/jit:jit - LiteInterpreterTest.DefaultArgsWithOutArg caffe2/test/cpp/jit:jit - LiteInterpreterTest.DefaultArgsPinvWithOutArg caffe2/test/cpp/jit:jit - LiteInterpreterTest.BackPortByteCodeModelAllVersions ``` Reviewed By: raziel, tugsbayasgalan Differential Revision: D30454080 fbshipit-source-id: 357c50b96682430675142d20d688d1f64e1de307 --- caffe2/serialize/versions.h | 3 +- test/cpp/jit/test_interpreter.cpp | 2 +- test/cpp/jit/test_lite_interpreter.cpp | 12 +- torch/csrc/jit/mobile/backport_manager.cpp | 460 +++++++++++++++++-------- torch/csrc/jit/runtime/interpreter.h | 2 +- torch/csrc/jit/serialization/export.h | 41 ++- torch/csrc/jit/serialization/export_module.cpp | 21 +- 7 files changed, 373 insertions(+), 168 deletions(-) diff --git a/caffe2/serialize/versions.h b/caffe2/serialize/versions.h index ed57958..9fea8fc 100644 --- a/caffe2/serialize/versions.h +++ b/caffe2/serialize/versions.h @@ -74,7 +74,8 @@ constexpr uint64_t kProducedFileFormatVersion = 0x3L; // 0x6L: Implicit opereator versioning using number of specified argument. // Refer to the summary of https://github.com/pytorch/pytorch/pull/56845 // for details. -constexpr uint64_t kProducedBytecodeVersion = 0x6L; +// 0x7L: Enable support for operators with default arguments plus out arguments. +constexpr uint64_t kProducedBytecodeVersion = 0x7L; static_assert(kProducedBytecodeVersion >= kProducedFileFormatVersion, "kProducedBytecodeVersion must be higher or equal to kProducedFileFormatVersion."); diff --git a/test/cpp/jit/test_interpreter.cpp b/test/cpp/jit/test_interpreter.cpp index bfdc1f3..78174d8 100644 --- a/test/cpp/jit/test_interpreter.cpp +++ b/test/cpp/jit/test_interpreter.cpp @@ -181,7 +181,7 @@ TEST(InterpreterTest, IgnorableArgsInSchemaWithOut) { auto op_to_specified_args = function.op_to_num_specified_args(); ASSERT_TRUE(op_to_specified_args.size() == 1); // this should be 3 when the add_out flag is set to True - ASSERT_TRUE(op_to_specified_args["aten::add.out"] == 4); + ASSERT_TRUE(op_to_specified_args["aten::add.out"] == 3); } TEST(InterpreterTest, runAsyncBasicTest) { diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index 63c3146..41aa77a 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -494,6 +494,7 @@ void compareModelOutput( AT_ASSERT( actual_result_list[1].toTensor().dim() == expect_result_list[1].dim()); AT_ASSERT(actual_result_list[2].toTensor().equal(expect_result_list[2])); + AT_ASSERT(actual_result_list[3].toTensor().equal(expect_result_list[3])); } void runAndCheckTorchScriptModel( @@ -588,7 +589,12 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) { x1 = torch.zeros(2, 2) x2 = torch.empty_like(torch.empty(2, 2)) x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True) - return (x1, x2, x3) + # Add torch.add operator to cover bytecode version bump from 6 to 7 + # for bytecode version 7, the main change is to support defaults arguments with out arguments + x = 2 * torch.ones(1) + h = torch.ones(1) + torch.add(x, h, out=x) + return (x1, x2, x3, x) )"); torch::jit::Module module_freeze = freeze(module); @@ -602,6 +608,8 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) { expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float)); expect_result_list.emplace_back( at::ones({1, 20, 24, 24}, ScalarType::Float) * 26); + expect_result_list.emplace_back(3 * at::ones({1})); + backportAllVersionCheck( input_model_stream, input_data, @@ -1216,7 +1224,7 @@ TEST(LiteInterpreterTest, DefaultArgsWithOutArg) { auto op = ops.find("aten::add.out"); TORCH_CHECK( op != ops.end() && op->second.num_schema_args.has_value() && - op->second.num_schema_args.value() == 4); + op->second.num_schema_args.value() == 3); } TEST(LiteInterpreterTest, TestExceptionStackWithTwoLevelModuleHierarchy) { diff --git a/torch/csrc/jit/mobile/backport_manager.cpp b/torch/csrc/jit/mobile/backport_manager.cpp index 1871daf..a0e6544 100644 --- a/torch/csrc/jit/mobile/backport_manager.cpp +++ b/torch/csrc/jit/mobile/backport_manager.cpp @@ -26,21 +26,14 @@ namespace { constexpr int64_t kBytecodeVersionV4 = 0x4L; constexpr int64_t kBytecodeVersionV5 = 0x5L; constexpr int64_t kBytecodeVersionV6 = 0x6L; +constexpr int64_t kBytecodeVersionV7 = 0x7L; } // namespace +/********************** Utility Functions **********************/ + // Utility function that can be reused by backport_vn_to_vn-1(). If any utility // function can be reused by other backport function, move it here. namespace { -bool update_bytecode_version( - std::vector& bytecode_values, - const int64_t to_version) { - if (!bytecode_values.empty() && bytecode_values[0].isInt()) { - bytecode_values[0] = c10::IValue(to_version); - return true; - } - return false; -} - // Copy files from source to destination except the files and dirs void selective_copy( PyTorchStreamReader& reader, @@ -50,7 +43,7 @@ void selective_copy( auto records = reader.getAllRecords(); for (const auto& record : records) { // Don't copy archive in excluded_files, usually archive `version` and - // `bytecode`. Archvie `version` will be written when PyTorchStreamWriter is + // `bytecode`. Archive `version` will be written when PyTorchStreamWriter is // going to finalize and run writeEndOfFile() // records is the list of all files names in the zip file, and each record @@ -96,8 +89,157 @@ void get_model_stream(PyTorchStreamReader& reader, std::stringstream& out) { std::unordered_set()); } +// The write_archive_current function is used for bytecode from version v5 to +// v7 (the latest bytecode version). pre-v5 we serialized things differently. +// This write archive function may change in export_module.cpp, however we don't +// have a way to keep the old export function in the codebase. To be able to +// export the model in old format, we keep a record of the export function here. +void write_archive_current( + PyTorchStreamWriter& writer, + const IValue& value, + const std::string& archive_name, + const std::string& archive_dir, + const std::string& tensor_dir, + bool use_storage_context, + SerializationStorageContext& storage_context) { + std::vector data; + // Vector to capture the run-time class types during pickling the IValues + std::vector memoizedClassTypes; + std::vector tensor_names; + Pickler data_pickle( + [&](const char* buf, size_t size) { + data.insert(data.end(), buf, buf + size); + }, + nullptr, + nullptr, + &memoizedClassTypes, + [&](const at::Tensor& tensor) { + // returns a string to use in picker.cpp as storage obj key + if (use_storage_context) { + std::string string_id = + std::to_string(reinterpret_cast( + tensor.storage().unsafeGetStorageImpl())); + tensor_names.push_back(string_id + ".storage"); + storage_context.getOrAddStorage(tensor.storage()); + } else { + tensor_names.push_back(std::to_string(tensor_names.size())); + } + return tensor_names.back(); + }); + data_pickle.protocol(); + data_pickle.pushIValue(value); + data_pickle.stop(); + // write out tensor data + size_t i = 0; + std::string prefix = archive_name + "/"; + + TORCH_INTERNAL_ASSERT(tensor_names.size() == data_pickle.tensorData().size()); + const std::unordered_set& pre_serialized_files = + writer.getAllWrittenRecords(); + + for (const auto& td : data_pickle.tensorData()) { + WriteableTensorData writable_td = getWriteableTensorData(td); + std::string fname = tensor_dir + tensor_names[i++]; + if (use_storage_context && + pre_serialized_files.find(fname) != pre_serialized_files.end()) { + // storage has been serialzed already, skip + continue; + } + writer.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes()); + } + + std::string fname = archive_dir + archive_name + ".pkl"; + writer.writeRecord(fname, data.data(), data.size()); +} + +/* +inputs: 1) bytecode tuple from bytecode.pkl 2) the output bytecode version, +return: A boolean to indicate whether bytecode tuple is updated successfully +*/ +bool update_bytecode_version( + std::vector& bytecode_values, + const int64_t to_version) { + if (!bytecode_values.empty() && bytecode_values[0].isInt()) { + bytecode_values[0] = c10::IValue(to_version); + return true; + } + return false; +} + +/* +inputs: 1) input model stringstream 2) the output bytecode version, +return: model stringstream with updated bytecode version in bytecode.pkl + +Example bytecode.pkl: +(${bytecode_version}, + ('__torch__.m.forward', + (('instructions', + (('STOREN', 1, 2), + ('DROPR', 1, 0), + ('MOVE', 2, 0), + ('OP', 0, 0), + ('RET', 0, 0))), + ('operators', (('aten::Int', 'Tensor'),)), + ('constants', ()), + ('types', ()), + ('register_size', 2)))) +*/ +std::stringstream update_bytecode_version( + std::stringstream& input_model, + const int64_t to_version) { + PyTorchStreamReader reader_bytecode(&input_model); + std::vector constants_values = + readArchive(kArchiveNameConstants, reader_bytecode).toTuple()->elements(); + + std::vector bytecode_values = get_bytecode_ivalues(reader_bytecode); + std::unordered_set excluded_files{ + "constants.pkl", + "bytecode.pkl", + "version", + }; + + std::unordered_set excluded_dirs{ + "constants", + "bytecode", + }; + + std::stringstream ouput_model_stream; + auto writer_func = [&](const void* buf, size_t nbytes) -> size_t { + ouput_model_stream.write(static_cast(buf), nbytes); + return !ouput_model_stream ? 0 : nbytes; + }; + + PyTorchStreamWriter writer_bytecode(writer_func); + + selective_copy( + reader_bytecode, writer_bytecode, excluded_files, excluded_dirs); + + update_bytecode_version(bytecode_values, to_version); + auto bytecode_tuple = c10::ivalue::Tuple::create(std::move(bytecode_values)); + SerializationStorageContext storage_context; + write_archive_current( + writer_bytecode, + c10::ivalue::Tuple::create(constants_values), + /*archive_name=*/"constants", + /*archive_dir=*/"", + /*tensor_dir=*/"constants/", + /*use_storage_context=*/true, + storage_context); + write_archive_current( + writer_bytecode, + bytecode_tuple, + /*archive_name=*/"bytecode", + /*archive_dir=*/"", + /*tensor_dir=*/"constants/", + /*use_storage_context=*/true, + storage_context); + + return ouput_model_stream; +} } // namespace +/******************** backport_v{i}_to_v{i-1} Functions **********************/ + /* To add next backport function, for example, backport_vn_to_vn-1, create an anonymous namespace with a backport_vn_to_vn-1 function + other necessary @@ -127,39 +269,42 @@ void get_model_stream(PyTorchStreamReader& reader, std::stringstream& out) { */ -// The functions needed for backport model from v5 to v4. namespace { -void writeArchiveV4( - PyTorchStreamWriter& writer, - const std::string& archive_name, - const c10::IValue& value) { - std::vector data; - - // Vector to capture the run-time class types during pickling the IValues - std::vector memoizedClassTypes; - Pickler data_pickle( - [&](const char* buf, size_t size) { - data.insert(data.end(), buf, buf + size); - }, - nullptr, - nullptr, - &memoizedClassTypes); - data_pickle.protocol(); - data_pickle.pushIValue(value); - data_pickle.stop(); - size_t i = 0; - std::string prefix = archive_name + "/"; - - for (const auto& td : data_pickle.tensorData()) { - WriteableTensorData writable_td = getWriteableTensorData(td); - std::string fname = prefix + c10::to_string(i++); - writer.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes()); - } - std::string fname = archive_name + ".pkl"; - writer.writeRecord(fname, data.data(), data.size()); -} - +/* +The following functions needed for backport model from v5 to v4. +Backport function bytecode v5 that deduplicate constanst table. +Previously, in v4, constant table will be exported twice, in both archive +bytecode and archive constants, and majority (almost all) are duplicates. +Currently, in v5, JIT and mobile will share archive constants, and all +constant tensors will be exported in this archive. The bump was needed +because the v5 bytecode export the tensor storage path in the schema, since +the runtime code is now able to query which archive this tensor is stored at +and query the correct archive. +For example, Previously, in v4, we deserialize tensor as without archive +path, and mobile will always read tensor from bytecode archive: +(torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, +'0', 'cpu', 8),), + 0, + (2, 4), + (4, 1), + False, + collections.OrderedDict()), + 1)), + So, if the program defines: torch.add(x, h, out=x) +Currently, in v5, we deserialize the bytecode with the archive path, and +mobile can read tensor from the given path: +(torch._utils._rebuild_tensor_v2(pers.obj(('storage', torch.DoubleStorage, +'constants/0', 'cpu', 8),), + 0, + (2, 4), + (4, 1), + False, + collections.OrderedDict()), + 1)), +Thus, the backport is necessary such that the runtime can read tensor from +the correct archive. +*/ std::stringstream backport_v5_to_v4(std::stringstream& input_model_stream) { // 1) read from archive `bytecode` archive PyTorchStreamReader reader(&input_model_stream); @@ -195,6 +340,39 @@ std::stringstream backport_v5_to_v4(std::stringstream& input_model_stream) { update_bytecode_version(bytecode_values, kBytecodeVersionV4); // Construct the list of ivalues to a big tuple auto bytecode_tuple = c10::ivalue::Tuple::create(std::move(bytecode_values)); + + // The export function to generate bytecode.pkl for version 4. After bytecode + // version bump, the old export function doesn't exist anymore, so keep a copy + // here for backport pupose. + auto writeArchiveV4 = [](PyTorchStreamWriter& writer, + const std::string& archive_name, + const c10::IValue& value) { + std::vector data; + + // Vector to capture the run-time class types during pickling the IValues + std::vector memoizedClassTypes; + Pickler data_pickle( + [&](const char* buf, size_t size) { + data.insert(data.end(), buf, buf + size); + }, + nullptr, + nullptr, + &memoizedClassTypes); + data_pickle.protocol(); + data_pickle.pushIValue(value); + data_pickle.stop(); + size_t i = 0; + std::string prefix = archive_name + "/"; + + for (const auto& td : data_pickle.tensorData()) { + WriteableTensorData writable_td = getWriteableTensorData(td); + std::string fname = prefix + c10::to_string(i++); + writer.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes()); + } + std::string fname = archive_name + ".pkl"; + writer.writeRecord(fname, data.data(), data.size()); + }; + // write `bytecode` archive writeArchiveV4(writer, kArchiveNameBytecode, bytecode_tuple); // write `constants` archive @@ -204,72 +382,33 @@ std::stringstream backport_v5_to_v4(std::stringstream& input_model_stream) { return ouput_model_stream; } -void writeArchiveV5( - PyTorchStreamWriter& writer, - const IValue& value, - const std::string& archive_name, - const std::string& archive_dir, - const std::string& tensor_dir, - bool use_storage_context, - SerializationStorageContext& storage_context) { - std::vector data; - // Vector to capture the run-time class types during pickling the IValues - std::vector memoizedClassTypes; - std::vector tensor_names; - Pickler data_pickle( - [&](const char* buf, size_t size) { - data.insert(data.end(), buf, buf + size); - }, - nullptr, - nullptr, - &memoizedClassTypes, - [&](const at::Tensor& tensor) { - // returns a string to use in picker.cpp as storage obj key - if (use_storage_context) { - std::string string_id = - std::to_string(reinterpret_cast( - tensor.storage().unsafeGetStorageImpl())); - tensor_names.push_back(string_id + ".storage"); - storage_context.getOrAddStorage(tensor.storage()); - } else { - tensor_names.push_back(std::to_string(tensor_names.size())); - } - return tensor_names.back(); - }); - data_pickle.protocol(); - data_pickle.pushIValue(value); - data_pickle.stop(); - // write out tensor data - size_t i = 0; - std::string prefix = archive_name + "/"; - - TORCH_INTERNAL_ASSERT(tensor_names.size() == data_pickle.tensorData().size()); - const std::unordered_set& pre_serialized_files = - writer.getAllWrittenRecords(); - - for (const auto& td : data_pickle.tensorData()) { - WriteableTensorData writable_td = getWriteableTensorData(td); - std::string fname = tensor_dir + tensor_names[i++]; - if (use_storage_context && - std::find( - pre_serialized_files.begin(), pre_serialized_files.end(), fname) != - pre_serialized_files.end()) { - // storage has been serialzed already, skip - continue; - } - writer.writeRecord(fname, writable_td.data(), writable_td.sizeInBytes()); - } - - std::string fname = archive_dir + archive_name + ".pkl"; - writer.writeRecord(fname, data.data(), data.size()); -} - +/* +Backport function bytecode v6 that introduced support for operators with default +arguments in mobile. Previously, in v5, there is no number of specified +arguments for operators in bytecode operator table. In v6, operators are aware +of the number of specified arguments being present in the schema. + +The bump was needed because the v6 bytecode specifies number of specified +arguments for operators in the schema, since the runtime code is now able to +query the number of specified arguments and supports default arguments. + +For example, aten::foo's schema in v5 is +foo(Tensor a, Tensor b) -> Tensor +and in v6, it's +foo(Tensor a, Tensor b, int groups=1) -> Tensor + +Accordingly, the operator table in v5 is: +('operators', (('aten::foo', ''),)) +and in v6, it's +('operators', (('aten::foo', '', 2),)) + +Thus, the backport is necessary such that the bytecode operator table contains +number of specified arguments. +*/ std::stringstream backport_v6_to_v5(std::stringstream& input_model_stream) { std::shared_ptr rai = std::make_shared(&input_model_stream); auto reader = std::make_shared(rai); - std::vector constants_values = - readArchive(kArchiveNameConstants, *reader.get()).toTuple()->elements(); // If there are debug info files in the original model file, it should also // show up in the backported model @@ -296,61 +435,89 @@ std::stringstream backport_v6_to_v5(std::stringstream& input_model_stream) { // resolved at runtime init stage for better operator compatibility. std::stringstream intermediate_model_stream; { - BytecodeEmitDefaultInputsGuard argNumGuard(true); + BytecodeEmitModeGuard argNumGuard( + true /*emit_default_input_instructions*/, + false /*enable_defaults_args_with_out_args*/); torch_script._save_for_mobile( intermediate_model_stream, extra_files, hasBytecodeDebug); } // Update the bytecode version (from 6 to 5) + std::stringstream output_model_stream = + update_bytecode_version(intermediate_model_stream, kBytecodeVersionV5); + return output_model_stream; +} - PyTorchStreamReader reader_bytecode(&intermediate_model_stream); - std::vector bytecode_values = get_bytecode_ivalues(reader_bytecode); - std::unordered_set excluded_files{ - "constants.pkl", - "bytecode.pkl", - "version", - }; - - std::unordered_set excluded_dirs{ - "constants", - "bytecode", - }; - - std::stringstream ouput_model_stream; - auto writer_func = [&](const void* buf, size_t nbytes) -> size_t { - ouput_model_stream.write(static_cast(buf), nbytes); - return !ouput_model_stream ? 0 : nbytes; - }; +/* +Backport function bytecode v7 that introduced support for operators with out +arguments. Previously, in v6, operators with out arguments forced the +serialization of all arguments in the schema, even when optional arguments +were not provided (as they had default values). Currently, in v7, operators +are aware of out arguments being present in the schema (always appended), +allowing the serialization of only required arguments (as default values will +be provided by the runtime). + +The bump was needed because the v7 bytecode specifies less arguments for ops +with out arguments in the schema, since the runtime code is now able to query +whether an argument is of type "out" and insert the necessary default values in +the right order in the interpreter stack (i.e. before the out arguments). + +For example schema is: torch.add(x, h, alpha=1.0, out=x) So, if the program +defines: torch.add(x, h, out=x) Previously, in v6, we serialized the bytecode to +contain all 4 arguments. Currently, in v7, we serialize the bytecode with only 3 +arguments, since alpha is optional and has a default value that the runtime will +push in the stack. Thus, the backport is necessary such that the bytecode +contains all the arguments as before. +*/ +std::stringstream backport_v7_to_v6(std::stringstream& input_model_stream) { + std::shared_ptr rai = + std::make_shared(&input_model_stream); + auto reader = std::make_shared(rai); + std::vector constants_values = + readArchive(kArchiveNameConstants, *reader.get()).toTuple()->elements(); - PyTorchStreamWriter writer_bytecode(writer_func); + // If there are debug info files in the original model file, it should also + // show up in the backported model + bool hasBytecodeDebug = reader->hasRecord("mobile_debug_handles.pkl"); - selective_copy( - reader_bytecode, writer_bytecode, excluded_files, excluded_dirs); + // extra_files are kept + auto records = reader->getAllRecords(); + ExtraFilesMap extra_files; + for (const auto& record : records) { + std::size_t found = record.find_last_of("/\\"); + auto path = record.substr(0, found); + if ("extra" == path) { + extra_files.emplace(record.substr(found + 1), ""); + } + } + // Loading the TS module is required for this backport, because bytecode needs + // to be re-emitted (refer to the comments below) + Module torch_script = torch::jit::load(rai, c10::nullopt, extra_files); - update_bytecode_version(bytecode_values, kBytecodeVersionV5); - auto bytecode_tuple = c10::ivalue::Tuple::create(std::move(bytecode_values)); - SerializationStorageContext storage_context; - writeArchiveV5( - writer_bytecode, - c10::ivalue::Tuple::create(constants_values), - /*archive_name=*/"constants", - /*archive_dir=*/"", - /*tensor_dir=*/"constants/", - /*use_storage_context=*/true, - storage_context); - writeArchiveV5( - writer_bytecode, - bytecode_tuple, - /*archive_name=*/"bytecode", - /*archive_dir=*/"", - /*tensor_dir=*/"constants/", - /*use_storage_context=*/true, - storage_context); + // The RAII guard to change the flag, emit_default_input_instructions, to + // false to keep the same behavior in bytecode version 6. Change the flag, + // enable_defaults_args_with_out_args, to deserialized the number of specified + // operators which allowing both out arguments and default arguments to + // #all_args, instead of (#all_args - #default_args) + std::stringstream intermediate_model_stream; + { + BytecodeEmitModeGuard argNumGuard( + false /*emit_default_input_instructions*/, + false /*enable_defaults_args_with_out_args*/); + torch_script._save_for_mobile( + intermediate_model_stream, extra_files, hasBytecodeDebug); + } - return ouput_model_stream; + // Update the bytecode version (from 7 to 6) + std::stringstream output_model_stream = + update_bytecode_version(intermediate_model_stream, kBytecodeVersionV6); + return output_model_stream; } + } // namespace +/********************** BackportManager **********************/ + // A generic contract for backport logic to the previous bytecode version. // Args: // * PyTorchStreamReader has access to the input model from N bytecode version. @@ -362,6 +529,7 @@ using BytecodeBackportFunction = BackportManager::BackportManager() { registerBytecodeBackportFunction(kBytecodeVersionV5, backport_v5_to_v4); registerBytecodeBackportFunction(kBytecodeVersionV6, backport_v6_to_v5); + registerBytecodeBackportFunction(kBytecodeVersionV7, backport_v7_to_v6); } std::unordered_map< diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index 3471e55..790c328 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -82,7 +82,7 @@ struct TORCH_API MobileCode : Code { const std::shared_ptr& graph, std::string function_name, bool emit_default_input_instructions = true, - bool support_default_args_before_out = false, + bool support_default_args_before_out = true, size_t remaining_bailout_depth = 0); ~MobileCode(); }; diff --git a/torch/csrc/jit/serialization/export.h b/torch/csrc/jit/serialization/export.h index 5545f14..80f0884 100644 --- a/torch/csrc/jit/serialization/export.h +++ b/torch/csrc/jit/serialization/export.h @@ -188,24 +188,45 @@ TORCH_API void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook); */ TORCH_API std::vector export_opnames(const Module& m); -struct TORCH_API BytecodeEmitDefaultValueForUnspecifiedArgMode { - static bool is_enabled(); - static void set_enabled(bool enabled); +struct TORCH_API BytecodeEmitMode { + static bool is_default_value_for_unspecified_arg_enabled(); + static void set_default_value_for_unspecified_arg_enabled(bool enabled); + + static bool is_default_args_before_out_args_enabled(); + static void set_default_args_before_out_args_enabled(bool enabled); }; // RAII guard to switch the way JIT emits the bytecode for inputs. +// default_value_for_unspecified_arg: // true: instruction of default argument values (like LOADC) is emitted. // false: instruction of default argument values are not emitted. Instead // they are fetched from operator schema. -struct TORCH_API BytecodeEmitDefaultInputsGuard { - BytecodeEmitDefaultInputsGuard(bool enable) - : prev_mode(BytecodeEmitDefaultValueForUnspecifiedArgMode::is_enabled()) { - BytecodeEmitDefaultValueForUnspecifiedArgMode::set_enabled(enable); +// default_args_before_out_args (to forward compatibile support +// operators allowing out arguments and default arguments): +// true: the number of specified arguments will deserialized to (#all_args - +// #default_args). false: the number of specified arguments will deserialized to +// (#all_args). +struct TORCH_API BytecodeEmitModeGuard { + BytecodeEmitModeGuard( + bool enable_default_value_for_unspecified_arg, + bool enable_default_args_before_out_args) + : prev_default_value_for_unspecified_arg_mode( + BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()), + prev_default_args_before_out_args( + BytecodeEmitMode::is_default_args_before_out_args_enabled()) { + BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled( + enable_default_value_for_unspecified_arg); + BytecodeEmitMode::set_default_args_before_out_args_enabled( + enable_default_args_before_out_args); } - ~BytecodeEmitDefaultInputsGuard() { - BytecodeEmitDefaultValueForUnspecifiedArgMode::set_enabled(prev_mode); + ~BytecodeEmitModeGuard() { + BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled( + prev_default_value_for_unspecified_arg_mode); + BytecodeEmitMode::set_default_args_before_out_args_enabled( + prev_default_args_before_out_args); } - bool prev_mode; + bool prev_default_value_for_unspecified_arg_mode; + bool prev_default_args_before_out_args; }; TORCH_API IValue to_tuple(std::vector ivalues); diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp index d06be70..fbd2546 100644 --- a/torch/csrc/jit/serialization/export_module.cpp +++ b/torch/csrc/jit/serialization/export_module.cpp @@ -69,10 +69,7 @@ std::pair getFunctionTuple( std::shared_ptr code; code = std::make_shared( - graph, - func.name(), - BytecodeEmitDefaultValueForUnspecifiedArgMode:: - is_enabled() /* emit_default_input_instructions */); + graph, func.name(), BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled() /* emit_default_input_instructions */, BytecodeEmitMode::is_default_args_before_out_args_enabled() /* enable_defaults_args_with_out_args */); auto instructions_copy = code->instructions(); // operator names @@ -171,7 +168,7 @@ std::pair getFunctionTuple( if (it != op_to_specified_args.end()) { num_args = it->second; } - if (BytecodeEmitDefaultValueForUnspecifiedArgMode::is_enabled()) { + if (BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()) { operators.emplace_back(to_tuple({opname.name, opname.overload_name})); } else { operators.emplace_back( @@ -872,12 +869,22 @@ std::vector export_opnames(const script::Module& m) { // or not. It's the major difference between bytecode v5 and v6. thread_local bool emitBytecodeDefaultInputs = caffe2::serialize::kProducedBytecodeVersion <= 5 ? true : false; -bool BytecodeEmitDefaultValueForUnspecifiedArgMode::is_enabled() { +bool BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled() { return emitBytecodeDefaultInputs; } -void BytecodeEmitDefaultValueForUnspecifiedArgMode::set_enabled(bool enabled) { +void BytecodeEmitMode::set_default_value_for_unspecified_arg_enabled( + bool enabled) { emitBytecodeDefaultInputs = enabled; } +thread_local bool emitDefautlArgsWithOutArgs = + caffe2::serialize::kProducedBytecodeVersion <= 6 ? false : true; +bool BytecodeEmitMode::is_default_args_before_out_args_enabled() { + return emitDefautlArgsWithOutArgs; +} +void BytecodeEmitMode::set_default_args_before_out_args_enabled(bool enabled) { + emitDefautlArgsWithOutArgs = enabled; +} + } // namespace jit } // namespace torch -- 2.7.4