[PyTorch Edge] Support default args with out arg, flag off (#63540)
authorChen Lai <chenlai@fb.com>
Thu, 2 Sep 2021 07:50:40 +0000 (00:50 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 2 Sep 2021 08:36:16 +0000 (01:36 -0700)
Summary:
1. Allow consuming operators with defaults arguments and out arguments. Flag is off to keep the same behavior as v6, in pr 63651, turn on the flag.
2. Add two unittests to cover this type of operators.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63540

ghstack-source-id: 137211562

Test Plan:
```
caffe2/test/cpp/jit:jit - LiteInterpreterTest.DefaultArgsWithOutArg
caffe2/test/cpp/jit:jit - LiteInterpreterTest.DefaultArgsPinvWithOutArg
```

Reviewed By: raziel, iseeyuan, tugsbayasgalan

Differential Revision: D30414156

fbshipit-source-id: 0f3a219a22aee10ac53184cbd95940726c459d1f

caffe2/serialize/versions.h
test/cpp/jit/test_lite_interpreter.cpp
torch/csrc/jit/mobile/function.cpp
torch/csrc/jit/runtime/interpreter.cpp
torch/csrc/jit/runtime/interpreter.h
torch/csrc/jit/runtime/interpreter/code_impl.h

index 61c8c46..ed57958 100644 (file)
@@ -85,7 +85,7 @@ static_assert(kProducedBytecodeVersion >= kProducedFileFormatVersion,
 // we should support this model_version. For example, we provide a wrapper to
 // handle an updated operator.
 constexpr uint64_t kMinSupportedBytecodeVersion = 0x3L;
-constexpr uint64_t kMaxSupportedBytecodeVersion = 0x6L;
+constexpr uint64_t kMaxSupportedBytecodeVersion = 0x7L;
 
 } // namespace serialize
 } // namespace caffe2
index 26100b3..b362c8a 100644 (file)
@@ -1035,6 +1035,68 @@ TEST(LiteInterpreterTest, DefaultArgsPinvSpecifyDefault) {
   testLiteModuleCompareResultTensors(m, inputs);
 }
 
+void testDefaultArgsPinvWithOutArg(int num_args) {
+  Module m("m");
+  if (num_args == 1) {
+    m.define(R"(
+      def forward(self, input):
+        return torch.linalg_pinv(input, out=input)
+    )");
+  } else if (num_args == 2) {
+    m.define(R"(
+      def forward(self, input):
+        return torch.linalg_pinv(input, 1e-5, out=input)
+    )");
+  } else if (num_args == 3) {
+    m.define(R"(
+      def forward(self, input):
+        return torch.linalg_pinv(input, 1e-5, True, out=input)
+    )");
+  }
+
+  const int N = 28;
+  auto input = torch::range(1, N * N, 1);
+  input[0] = 10000; // a more stable matrix
+  input = input.view({N, N});
+  auto ref = m.run_method("forward", input);
+  TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
+  TORCH_CHECK(input.equal(ref.toTensor()));
+}
+
+TEST(LiteInterpreterTest, DefaultArgsPinvWithOutArg) {
+  // Test with different number of specified arguments + out arg.
+  // Arguments not specified take default value.
+  for (int num_args = 1; num_args <= 3; ++num_args) {
+    testDefaultArgsPinvWithOutArg(num_args);
+  }
+}
+
+TEST(LiteInterpreterTest, DefaultArgsWithOutArg) {
+  Module m("m");
+  m.define(R"(
+    def forward(self, x, h):
+      torch.add(x, h, out=x)
+  )");
+
+  std::vector<IValue> inputs;
+  auto input_x = 2 * torch::ones({});
+  auto input_h = torch::ones({});
+  auto ref = m.run_method("forward", input_x, input_h);
+
+  std::stringstream ss;
+
+  m._save_for_mobile(ss, {}, true);
+  mobile::Module bc = _load_for_mobile(ss);
+  bc.run_method("forward", input_x, input_h);
+  AT_ASSERT(input_x.equal(4 * torch::ones({})));
+
+  auto ops = _get_model_ops_and_info(ss);
+  auto op = ops.find("aten::add.out");
+  TORCH_CHECK(
+      op != ops.end() && op->second.num_schema_args.has_value() &&
+      op->second.num_schema_args.value() == 4);
+}
+
 TEST(LiteInterpreterTest, TestExceptionStackWithTwoLevelModuleHierarchy) {
   Module a("A");
   a.define(R"(
index 127bd5f..fad8c39 100644 (file)
@@ -99,21 +99,35 @@ bool Function::append_operator(
     // from model. We can use it to handle backward compatibility.
     if (num_specified_args &&
         num_specified_args.value() < static_cast<int64_t>(args.size())) {
-      // Sanity check at load time, to save perf at runtime
-      for (size_t i = num_specified_args.value(); i < args.size(); ++i) {
-        auto default_val = args[i].default_value();
-        TORCH_CHECK(
-            default_val.has_value(),
-            "Error happened at preparing for default values for the argument. The ",
-            i,
-            "th arguement of operator",
-            opname,
-            " does not have a specified value or default value. ");
-      }
       fn = [fn, num_specified_args, args](Stack& stack) {
-        for (size_t i = num_specified_args.value(); i < args.size(); ++i) {
+        std::vector<IValue> out_args;
+        // The following logic pops and temporarily stores all out arguments
+        // from the stack (which can be 0 or more, and always appended to the
+        // schema), in order to push the necessary default values. Finally, the
+        // out arguments are pushed back into the stack.
+        for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) {
+          out_args.push_back(stack.back());
+          stack.pop_back();
+        }
+        size_t start_index = num_specified_args.value() - out_args.size();
+        TORCH_CHECK(
+            start_index >= 0,
+            "The number of output arguments is: ",
+            out_args.size(),
+            ", which is more then the number of specified arguments: ",
+            num_specified_args.value());
+        for (size_t i = start_index; i < (args.size() - out_args.size()); ++i) {
+          TORCH_CHECK(
+              args[i].default_value().has_value(),
+              "Error happened at preparing for default values for the argument. The ",
+              i,
+              "th argument ",
+              args[i].name(),
+              " does not have a specified value or default value. ");
+
           stack.push_back(args[i].default_value());
         }
+        stack.insert(stack.end(), out_args.rbegin(), out_args.rend());
         fn(stack);
       };
     }
index 70c9c6c..b348271 100644 (file)
@@ -978,11 +978,13 @@ MobileCode::MobileCode(
     const std::shared_ptr<Graph>& graph,
     std::string function_name,
     bool emit_default_input_instructions,
+    bool support_default_args_before_out,
     size_t remaining_bailout_depth)
     : Code(new interpreter::MobileCodeImpl(
           graph,
           std::move(function_name),
           emit_default_input_instructions,
+          support_default_args_before_out,
           remaining_bailout_depth)) {}
 
 MobileCode::~MobileCode() = default;
index 80720ea..3471e55 100644 (file)
@@ -82,6 +82,7 @@ struct TORCH_API MobileCode : Code {
       const std::shared_ptr<Graph>& graph,
       std::string function_name,
       bool emit_default_input_instructions = true,
+      bool support_default_args_before_out = false,
       size_t remaining_bailout_depth = 0);
   ~MobileCode();
 };
index 682c695..15ba0ce 100644 (file)
@@ -721,9 +721,11 @@ struct MobileCodeImpl : CodeImpl {
       const std::shared_ptr<Graph>& graph,
       std::string function_name,
       bool emit_default_input_instructions,
+      bool support_default_args_before_out,
       size_t remaining_bailout_depth)
       : CodeImpl(graph, function_name, remaining_bailout_depth, false),
-        emit_default_input_instructions_(emit_default_input_instructions) {
+        emit_default_input_instructions_(emit_default_input_instructions),
+        support_default_args_before_out_(support_default_args_before_out) {
     // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
     run();
   }
@@ -746,11 +748,12 @@ struct MobileCodeImpl : CodeImpl {
         // skip if schema has vararg
         if (!op_schema.is_vararg()) {
           auto specifiedArgs = CalculateNecessaryArgs(
-              op_schema.arguments(), node->inputs(), false);
-          // preserving the old behavior
-          auto numInclude = specifiedArgs.first;
-          // TODO uncomment this
-          // auto numInclude = specifiedArgs.first + specifiedArgs.second;
+              op_schema.arguments(),
+              node->inputs(),
+              support_default_args_before_out_);
+
+          size_t numInclude = specifiedArgs.first +
+              (support_default_args_before_out_ ? specifiedArgs.second : 0);
           auto unique_name = op_schema.overload_name() != ""
               ? op_schema.name() + "." + op_schema.overload_name()
               : op_schema.name();
@@ -782,21 +785,27 @@ struct MobileCodeImpl : CodeImpl {
         if (it != op_to_num_specified_args_.end()) {
           num_include = it->second;
         }
-        emitLoadInputs(node->inputs(), num_include);
-        // TODO: uncomment this
-        // auto num_out = op_to_num_out_args_.find(unique_op_name)->second;
-        // auto num_specified_before_out = num_include - num_out;
-        // emitLoadInputs(node->inputs(), 0, num_specified_before_out);
-        // emitLoadInputs(node->inputs(), node->inputs().size() - num_out,
-        // node->inputs().size());
-
+        if (support_default_args_before_out_) {
+          auto num_out = op_to_num_out_args_.find(unique_op_name)->second;
+          auto num_specified_before_out = num_include - num_out;
+          emitLoadInputs(node->inputs(), 0, num_specified_before_out);
+          emitLoadInputs(
+              node->inputs(),
+              node->inputs().size() - num_out,
+              node->inputs().size());
+        } else {
+          emitLoadInputs(node->inputs(), num_include);
+        }
         insertInstruction(OP, operator_table_.size());
       }
       operator_table_.emplace_back(op.getOperation(node));
     }
   }
 
+  // To support forward compatibility for bytecode version bump from v5 to v6
   bool emit_default_input_instructions_;
+  // To support forward compatibility for bytecode version bump from v6 to v7
+  bool support_default_args_before_out_;
 };
 
 } // namespace interpreter