From d0c63e857d12f3ddc04a80defb2530694b4f263d Mon Sep 17 00:00:00 2001 From: "Tugsbayasgalan (Tugsuu) Manlaibaatar" Date: Sat, 28 Aug 2021 11:44:58 -0700 Subject: [PATCH] Enhancement for smart serialization for out schemas (#63096) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63096 Test Plan: Imported from OSS Reviewed By: gmagogsfm Differential Revision: D30415255 Pulled By: tugsbayasgalan fbshipit-source-id: eb40440a3b46258394d035479f5fc4a4baa12bcc --- test/cpp/jit/test_interpreter.cpp | 9 ++++ test/cpp/jit/test_utils.cpp | 15 +++++++ test/cpp/jit/test_utils.h | 1 + test/jit/test_ignorable_args.py | 7 ++++ torch/csrc/jit/runtime/calculate_necessary_args.h | 43 +++++++++++++++---- torch/csrc/jit/runtime/interpreter/code_impl.h | 25 +++++++++++- torch/csrc/jit/serialization/python_print.cpp | 50 +++++++++++++++++------ 7 files changed, 126 insertions(+), 24 deletions(-) diff --git a/test/cpp/jit/test_interpreter.cpp b/test/cpp/jit/test_interpreter.cpp index a241891..bfdc1f3 100644 --- a/test/cpp/jit/test_interpreter.cpp +++ b/test/cpp/jit/test_interpreter.cpp @@ -175,6 +175,15 @@ TEST(InterpreterTest, IgnorableArgsInSchema) { ASSERT_TRUE(op_to_specified_args_non_const["aten::conv2d"] == 6); } +TEST(InterpreterTest, IgnorableArgsInSchemaWithOut) { + auto graph = build_mobile_export_with_out(); + MobileCode function(graph, ""); + auto op_to_specified_args = function.op_to_num_specified_args(); + ASSERT_TRUE(op_to_specified_args.size() == 1); + // this should be 3 when the add_out flag is set to True + ASSERT_TRUE(op_to_specified_args["aten::add.out"] == 4); +} + TEST(InterpreterTest, runAsyncBasicTest) { /* TODO: there are some problem with C++ parsing script program involving diff --git a/test/cpp/jit/test_utils.cpp b/test/cpp/jit/test_utils.cpp index 27667f0..f2fb9e1 100644 --- a/test/cpp/jit/test_utils.cpp +++ b/test/cpp/jit/test_utils.cpp @@ -123,6 +123,21 @@ std::shared_ptr build_mobile_export_analysis_graph() { return g; } +std::shared_ptr build_mobile_export_with_out() { + const auto graph_string = R"IR( + graph(%x.1 : Tensor, + %y.1 : Tensor): + %8 : NoneType = prim::Constant() + %6 : int = prim::Constant[value=1]() + %7 : Tensor = aten::add(%x.1, %y.1, %6, %y.1) + return (%8))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph_string, g.get()); + g->lint(); + return g; +} + std::shared_ptr build_mobile_export_analysis_graph_nested() { // this is pretty much same test as build_mobile_export_analysis_graph(), // but some aten::slice operators are hidden under block statement to check diff --git a/test/cpp/jit/test_utils.h b/test/cpp/jit/test_utils.h index 5e640ae..1a1e1b8 100644 --- a/test/cpp/jit/test_utils.h +++ b/test/cpp/jit/test_utils.h @@ -74,6 +74,7 @@ std::pair runGradient( std::shared_ptr build_lstm(); std::shared_ptr build_mobile_export_analysis_graph(); +std::shared_ptr build_mobile_export_with_out(); std::shared_ptr build_mobile_export_analysis_graph_with_vararg(); std::shared_ptr build_mobile_export_analysis_graph_nested(); std::shared_ptr build_mobile_export_analysis_graph_non_const(); diff --git a/test/jit/test_ignorable_args.py b/test/jit/test_ignorable_args.py index b195e3c..fb63c19 100644 --- a/test/jit/test_ignorable_args.py +++ b/test/jit/test_ignorable_args.py @@ -1,5 +1,6 @@ import os import sys +import torch from torch._C import parse_ir from torch.testing import FileCheck @@ -43,3 +44,9 @@ class TestIgnorableArgs(JitTestCase): # because in %16, %15 and %0 are default values for the schema. FileCheck().check("torch.slice(torch.slice(torch.tensor(_0), 0, 2), 1, None, 1)").run(src) self.assertEqual(function(), function_copy()) + + def test_add_out_ignorable_args(self): + @torch.jit.script + def fn(x: torch.Tensor, y: torch.Tensor): + torch.add(x, y, out=y) + FileCheck().check("torch.add(x, y, out=y)").run(fn.code) diff --git a/torch/csrc/jit/runtime/calculate_necessary_args.h b/torch/csrc/jit/runtime/calculate_necessary_args.h index 5f37660..07df670 100644 --- a/torch/csrc/jit/runtime/calculate_necessary_args.h +++ b/torch/csrc/jit/runtime/calculate_necessary_args.h @@ -7,18 +7,42 @@ namespace torch { namespace jit { -inline size_t CalculateNecessaryArgs( +inline std::pair CalculateNecessaryArgs( const std::vector& schema_args, - at::ArrayRef actual_inputs) { + at::ArrayRef actual_inputs, + bool allow_trailing_out_args) { + if (schema_args.size() == 0) { + return std::make_pair(0, 0); + } + + // count number of out arguments + auto schema_idx = schema_args.size() - 1; + if (allow_trailing_out_args) { + // skip over out arguments in the end. + while (schema_idx >= 0) { + auto current_arg = schema_args.at(schema_idx); + if (!current_arg.is_out()) { + break; + } + schema_idx--; + } + } + + auto num_out = schema_args.size() - schema_idx - 1; + if (schema_args.size() < actual_inputs.size()) { - return actual_inputs.size(); + return std::make_pair(actual_inputs.size(), num_out); + } + + // if it is the default args, we reset the index to the last element + if (!allow_trailing_out_args) { + schema_idx = schema_args.size() - 1; } // keeps track of trailing unnecessary args - int schema_size = schema_args.size(); - for (int schema_idx = schema_size - 1; schema_idx > -1; schema_idx--) { + while (schema_idx >= 0) { // this means it is not default argument, so it is necessary if (!schema_args.at(schema_idx).default_value().has_value()) { - return schema_idx + 1; + return std::make_pair(schema_idx + 1, num_out); } else { auto schema_value = schema_args.at(schema_idx).default_value().value().toIValue(); @@ -27,16 +51,17 @@ inline size_t CalculateNecessaryArgs( // well. auto actual_value = toIValue(actual_inputs[schema_idx]); if (!actual_value.has_value()) { - return schema_idx + 1; + return std::make_pair(schema_idx + 1, num_out); } // if the IR has same value as default value of the schema, // it is not neccessary argument. if (schema_value != actual_value.value()) { - return schema_idx + 1; + return std::make_pair(schema_idx + 1, num_out); } } + schema_idx--; } - return 0; + return std::make_pair(0, num_out); } } // namespace jit diff --git a/torch/csrc/jit/runtime/interpreter/code_impl.h b/torch/csrc/jit/runtime/interpreter/code_impl.h index 00648de..682c695 100644 --- a/torch/csrc/jit/runtime/interpreter/code_impl.h +++ b/torch/csrc/jit/runtime/interpreter/code_impl.h @@ -105,6 +105,8 @@ struct CodeImpl { // This is because for all usages, at most 3 args are used. std::unordered_map op_to_num_specified_args_; + std::unordered_map op_to_num_out_args_; + // running count of uses as we emit. When we reach use_count_[v] = // v.uses().size() we know it is the final use and we can move rather than // load. @@ -292,6 +294,12 @@ struct CodeImpl { } } + void emitLoadInputs(at::ArrayRef inputs, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + emitUse(inputs[i], false); + } + } + virtual void emitOperator(Node* node) { emitLoadInputs(node->inputs()); const Operator& op = node->getOperator(); @@ -737,13 +745,19 @@ struct MobileCodeImpl : CodeImpl { auto op_schema = node->getOperator().schema(); // skip if schema has vararg if (!op_schema.is_vararg()) { - auto numInclude = - CalculateNecessaryArgs(op_schema.arguments(), node->inputs()); + auto specifiedArgs = CalculateNecessaryArgs( + op_schema.arguments(), node->inputs(), false); + // preserving the old behavior + auto numInclude = specifiedArgs.first; + // TODO uncomment this + // auto numInclude = specifiedArgs.first + specifiedArgs.second; auto unique_name = op_schema.overload_name() != "" ? op_schema.name() + "." + op_schema.overload_name() : op_schema.name(); auto it = op_to_num_specified_args_.insert( std::pair(unique_name, 0)); + op_to_num_out_args_.insert(std::pair( + unique_name, specifiedArgs.second)); auto prev_value = it.first->second; it.first->second = std::max(numInclude, prev_value); } @@ -769,6 +783,13 @@ struct MobileCodeImpl : CodeImpl { num_include = it->second; } emitLoadInputs(node->inputs(), num_include); + // TODO: uncomment this + // auto num_out = op_to_num_out_args_.find(unique_op_name)->second; + // auto num_specified_before_out = num_include - num_out; + // emitLoadInputs(node->inputs(), 0, num_specified_before_out); + // emitLoadInputs(node->inputs(), node->inputs().size() - num_out, + // node->inputs().size()); + insertInstruction(OP, operator_table_.size()); } operator_table_.emplace_back(op.getOperation(node)); diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index 1ab9689..80123c6 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -1162,23 +1162,47 @@ struct PythonPrintImpl { // calculate how many args are specified. // see (https://github.com/pytorch/pytorch/pull/56079) for more // details. - size_t necessary_args = - CalculateNecessaryArgs(schema.arguments(), node->inputs()); - for (const auto i : c10::irange(necessary_args)) { - if (i > 0) + size_t num_schema_args = schema.arguments().size(); + + // we only want to do this extra logic only when necessary. + if (num_schema_args > 0) { + // calculate how many args are specified. + // see (https://github.com/pytorch/pytorch/pull/56079) for more + // details. + auto specified_args = + CalculateNecessaryArgs(schema.arguments(), node->inputs(), true); + + auto num_necessary = specified_args.first; + auto num_out = specified_args.second; + + for (size_t i = 0; i < num_necessary; ++i) { + if (i > 0) + stmt << ", "; + auto v = useOf(node->inputs().at(i)); + // print the kwarg name if it is a kwarg only argument. + if (i < num_schema_args) { + auto arg = schema.arguments().at(i); + if (arg.kwarg_only()) { + stmt << arg.name() << "="; + } + } else { + // vararg functions like format can have extra arguments + AT_ASSERT(schema.is_vararg()); + } + stmt << *v; + } + + // print out args + for (size_t i = num_schema_args - num_out; i < num_schema_args; i++) { stmt << ", "; - auto v = useOf(node->inputs().at(i)); - // print the kwarg name if it is a kwarg only argument. - if (i < schema.arguments().size()) { auto arg = schema.arguments().at(i); - if (arg.kwarg_only()) { - stmt << arg.name() << "="; + TORCH_INTERNAL_ASSERT(arg.is_out()); + // figure out the corresponding input at this index + auto input_idx = node->inputs().size() - (num_schema_args - i); + if (input_idx < node->inputs().size()) { + stmt << arg.name() << "=" << *useOf(node->inputs().at(input_idx)); } - } else { - // vararg functions like format can have extra arguments - AT_ASSERT(schema.is_vararg()); } - stmt << *v; } stmt << ")"; } break; -- 2.7.4