From 8d5b95019d69d43963b33a1b188ad1fec8079664 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Thu, 2 Sep 2021 00:50:40 -0700 Subject: [PATCH] [PyTorch Edge] Support default args with out arg, flag off (#63540) Summary: 1. Allow consuming operators with defaults arguments and out arguments. Flag is off to keep the same behavior as v6, in pr 63651, turn on the flag. 2. Add two unittests to cover this type of operators. Pull Request resolved: https://github.com/pytorch/pytorch/pull/63540 ghstack-source-id: 137211562 Test Plan: ``` caffe2/test/cpp/jit:jit - LiteInterpreterTest.DefaultArgsWithOutArg caffe2/test/cpp/jit:jit - LiteInterpreterTest.DefaultArgsPinvWithOutArg ``` Reviewed By: raziel, iseeyuan, tugsbayasgalan Differential Revision: D30414156 fbshipit-source-id: 0f3a219a22aee10ac53184cbd95940726c459d1f --- caffe2/serialize/versions.h | 2 +- test/cpp/jit/test_lite_interpreter.cpp | 62 ++++++++++++++++++++++++++ torch/csrc/jit/mobile/function.cpp | 38 +++++++++++----- torch/csrc/jit/runtime/interpreter.cpp | 2 + torch/csrc/jit/runtime/interpreter.h | 1 + torch/csrc/jit/runtime/interpreter/code_impl.h | 37 +++++++++------ 6 files changed, 115 insertions(+), 27 deletions(-) diff --git a/caffe2/serialize/versions.h b/caffe2/serialize/versions.h index 61c8c46..ed57958 100644 --- a/caffe2/serialize/versions.h +++ b/caffe2/serialize/versions.h @@ -85,7 +85,7 @@ static_assert(kProducedBytecodeVersion >= kProducedFileFormatVersion, // we should support this model_version. For example, we provide a wrapper to // handle an updated operator. constexpr uint64_t kMinSupportedBytecodeVersion = 0x3L; -constexpr uint64_t kMaxSupportedBytecodeVersion = 0x6L; +constexpr uint64_t kMaxSupportedBytecodeVersion = 0x7L; } // namespace serialize } // namespace caffe2 diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index 26100b3..b362c8a 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -1035,6 +1035,68 @@ TEST(LiteInterpreterTest, DefaultArgsPinvSpecifyDefault) { testLiteModuleCompareResultTensors(m, inputs); } +void testDefaultArgsPinvWithOutArg(int num_args) { + Module m("m"); + if (num_args == 1) { + m.define(R"( + def forward(self, input): + return torch.linalg_pinv(input, out=input) + )"); + } else if (num_args == 2) { + m.define(R"( + def forward(self, input): + return torch.linalg_pinv(input, 1e-5, out=input) + )"); + } else if (num_args == 3) { + m.define(R"( + def forward(self, input): + return torch.linalg_pinv(input, 1e-5, True, out=input) + )"); + } + + const int N = 28; + auto input = torch::range(1, N * N, 1); + input[0] = 10000; // a more stable matrix + input = input.view({N, N}); + auto ref = m.run_method("forward", input); + TORCH_CHECK(!input.equal(torch::range(1, N * N, 1))); + TORCH_CHECK(input.equal(ref.toTensor())); +} + +TEST(LiteInterpreterTest, DefaultArgsPinvWithOutArg) { + // Test with different number of specified arguments + out arg. + // Arguments not specified take default value. + for (int num_args = 1; num_args <= 3; ++num_args) { + testDefaultArgsPinvWithOutArg(num_args); + } +} + +TEST(LiteInterpreterTest, DefaultArgsWithOutArg) { + Module m("m"); + m.define(R"( + def forward(self, x, h): + torch.add(x, h, out=x) + )"); + + std::vector inputs; + auto input_x = 2 * torch::ones({}); + auto input_h = torch::ones({}); + auto ref = m.run_method("forward", input_x, input_h); + + std::stringstream ss; + + m._save_for_mobile(ss, {}, true); + mobile::Module bc = _load_for_mobile(ss); + bc.run_method("forward", input_x, input_h); + AT_ASSERT(input_x.equal(4 * torch::ones({}))); + + auto ops = _get_model_ops_and_info(ss); + auto op = ops.find("aten::add.out"); + TORCH_CHECK( + op != ops.end() && op->second.num_schema_args.has_value() && + op->second.num_schema_args.value() == 4); +} + TEST(LiteInterpreterTest, TestExceptionStackWithTwoLevelModuleHierarchy) { Module a("A"); a.define(R"( diff --git a/torch/csrc/jit/mobile/function.cpp b/torch/csrc/jit/mobile/function.cpp index 127bd5f..fad8c39 100644 --- a/torch/csrc/jit/mobile/function.cpp +++ b/torch/csrc/jit/mobile/function.cpp @@ -99,21 +99,35 @@ bool Function::append_operator( // from model. We can use it to handle backward compatibility. if (num_specified_args && num_specified_args.value() < static_cast(args.size())) { - // Sanity check at load time, to save perf at runtime - for (size_t i = num_specified_args.value(); i < args.size(); ++i) { - auto default_val = args[i].default_value(); - TORCH_CHECK( - default_val.has_value(), - "Error happened at preparing for default values for the argument. The ", - i, - "th arguement of operator", - opname, - " does not have a specified value or default value. "); - } fn = [fn, num_specified_args, args](Stack& stack) { - for (size_t i = num_specified_args.value(); i < args.size(); ++i) { + std::vector out_args; + // The following logic pops and temporarily stores all out arguments + // from the stack (which can be 0 or more, and always appended to the + // schema), in order to push the necessary default values. Finally, the + // out arguments are pushed back into the stack. + for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) { + out_args.push_back(stack.back()); + stack.pop_back(); + } + size_t start_index = num_specified_args.value() - out_args.size(); + TORCH_CHECK( + start_index >= 0, + "The number of output arguments is: ", + out_args.size(), + ", which is more then the number of specified arguments: ", + num_specified_args.value()); + for (size_t i = start_index; i < (args.size() - out_args.size()); ++i) { + TORCH_CHECK( + args[i].default_value().has_value(), + "Error happened at preparing for default values for the argument. The ", + i, + "th argument ", + args[i].name(), + " does not have a specified value or default value. "); + stack.push_back(args[i].default_value()); } + stack.insert(stack.end(), out_args.rbegin(), out_args.rend()); fn(stack); }; } diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index 70c9c6c..b348271 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -978,11 +978,13 @@ MobileCode::MobileCode( const std::shared_ptr& graph, std::string function_name, bool emit_default_input_instructions, + bool support_default_args_before_out, size_t remaining_bailout_depth) : Code(new interpreter::MobileCodeImpl( graph, std::move(function_name), emit_default_input_instructions, + support_default_args_before_out, remaining_bailout_depth)) {} MobileCode::~MobileCode() = default; diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index 80720ea..3471e55 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -82,6 +82,7 @@ struct TORCH_API MobileCode : Code { const std::shared_ptr& graph, std::string function_name, bool emit_default_input_instructions = true, + bool support_default_args_before_out = false, size_t remaining_bailout_depth = 0); ~MobileCode(); }; diff --git a/torch/csrc/jit/runtime/interpreter/code_impl.h b/torch/csrc/jit/runtime/interpreter/code_impl.h index 682c695..15ba0ce 100644 --- a/torch/csrc/jit/runtime/interpreter/code_impl.h +++ b/torch/csrc/jit/runtime/interpreter/code_impl.h @@ -721,9 +721,11 @@ struct MobileCodeImpl : CodeImpl { const std::shared_ptr& graph, std::string function_name, bool emit_default_input_instructions, + bool support_default_args_before_out, size_t remaining_bailout_depth) : CodeImpl(graph, function_name, remaining_bailout_depth, false), - emit_default_input_instructions_(emit_default_input_instructions) { + emit_default_input_instructions_(emit_default_input_instructions), + support_default_args_before_out_(support_default_args_before_out) { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) run(); } @@ -746,11 +748,12 @@ struct MobileCodeImpl : CodeImpl { // skip if schema has vararg if (!op_schema.is_vararg()) { auto specifiedArgs = CalculateNecessaryArgs( - op_schema.arguments(), node->inputs(), false); - // preserving the old behavior - auto numInclude = specifiedArgs.first; - // TODO uncomment this - // auto numInclude = specifiedArgs.first + specifiedArgs.second; + op_schema.arguments(), + node->inputs(), + support_default_args_before_out_); + + size_t numInclude = specifiedArgs.first + + (support_default_args_before_out_ ? specifiedArgs.second : 0); auto unique_name = op_schema.overload_name() != "" ? op_schema.name() + "." + op_schema.overload_name() : op_schema.name(); @@ -782,21 +785,27 @@ struct MobileCodeImpl : CodeImpl { if (it != op_to_num_specified_args_.end()) { num_include = it->second; } - emitLoadInputs(node->inputs(), num_include); - // TODO: uncomment this - // auto num_out = op_to_num_out_args_.find(unique_op_name)->second; - // auto num_specified_before_out = num_include - num_out; - // emitLoadInputs(node->inputs(), 0, num_specified_before_out); - // emitLoadInputs(node->inputs(), node->inputs().size() - num_out, - // node->inputs().size()); - + if (support_default_args_before_out_) { + auto num_out = op_to_num_out_args_.find(unique_op_name)->second; + auto num_specified_before_out = num_include - num_out; + emitLoadInputs(node->inputs(), 0, num_specified_before_out); + emitLoadInputs( + node->inputs(), + node->inputs().size() - num_out, + node->inputs().size()); + } else { + emitLoadInputs(node->inputs(), num_include); + } insertInstruction(OP, operator_table_.size()); } operator_table_.emplace_back(op.getOperation(node)); } } + // To support forward compatibility for bytecode version bump from v5 to v6 bool emit_default_input_instructions_; + // To support forward compatibility for bytecode version bump from v6 to v7 + bool support_default_args_before_out_; }; } // namespace interpreter -- 2.7.4