Detect out argument in the schema (#62755)
authorTugsbayasgalan (Tugsuu) Manlaibaatar <tmanlaibaatar@fb.com>
Fri, 27 Aug 2021 18:18:52 +0000 (11:18 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 27 Aug 2021 18:20:33 +0000 (11:20 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62755

After this change, out argument can be checked by calling is_out()

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D30415256

Pulled By: tugsbayasgalan

fbshipit-source-id: b2e1fa46bab7c813aaede1f44149081ef2df566d

aten/src/ATen/core/function_schema.h
aten/src/ATen/core/function_schema_inl.h
test/cpp/jit/test_misc.cpp
test/test_function_schema.py

index 68e177a..a7b5149 100644 (file)
@@ -34,6 +34,9 @@ struct Argument {
         default_value_(std::move(default_value)),
         kwarg_only_(kwarg_only),
         alias_info_(std::move(alias_info)) {
+    // this is an softly-enforced invariant for out arguments.
+    bool is_alias = alias_info_.has_value() && alias_info_.value().isWrite();
+    is_out_ = kwarg_only_ && is_alias;
   }
   const std::string& name() const {
     return name_;
@@ -50,6 +53,11 @@ struct Argument {
   bool kwarg_only() const {
     return kwarg_only_;
   }
+
+  bool is_out() const {
+    return is_out_;
+  }
+
   const c10::optional<AliasInfo>& alias_info() const {
     return alias_info_;
   }
@@ -116,6 +124,8 @@ struct Argument {
   // is this only specifiable as a keyword argument?
   bool kwarg_only_;
   c10::optional<AliasInfo> alias_info_;
+  // marks if the argument is out variant of the schema
+  bool is_out_;
 };
 
 inline bool operator==(const Argument& lhs, const Argument& rhs) {
index 168ecb4..6e26e8c 100644 (file)
@@ -51,6 +51,16 @@ inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema)
   return out;
 }
 
+inline size_t findFirstOutArg(const std::vector<Argument>& args) {
+  // find the start of out args in the schema
+  for (size_t out_start_idx = 0; out_start_idx < args.size(); out_start_idx++) {
+    if (args.at(out_start_idx).is_out()) {
+      return out_start_idx;
+    }
+  }
+  return args.size();
+}
+
 inline bool Argument::isBackwardCompatibleWith(
       const Argument& old,
       std::ostream* why_not) const {
@@ -121,17 +131,20 @@ inline bool FunctionSchema::isBackwardCompatibleWith(
     }
   }
 
-  // Make sure that all the old arguments have their corresponding backward
-  // compatible arguments in this schema.
-  for (size_t i = 0; i < old.arguments().size(); ++i) {
+  // we want to test both out and default args seperately
+  size_t old_out_start_idx = findFirstOutArg(old.arguments());
+  size_t new_out_start_idx = findFirstOutArg(arguments());
+
+  // make sure among the default args, they are backward compatible
+  for (size_t i = 0; i < old_out_start_idx; i++) {
     if (!arguments().at(i).isBackwardCompatibleWith(
           old.arguments().at(i), why_not)) {
       return false;
     }
   }
 
-  // Validate that all new arguments provided a default value.
-  for (size_t i = old.arguments().size(); i < arguments().size(); ++i) {
+  // // Validate that all new arguments provided has a default value
+  for (size_t i = old_out_start_idx; i < new_out_start_idx; ++i) {
     if (!arguments().at(i).default_value()) {
       if (why_not) {
         *why_not
@@ -144,6 +157,15 @@ inline bool FunctionSchema::isBackwardCompatibleWith(
     }
   }
 
+  // now compare the out args
+  for (size_t i = old_out_start_idx; i < old.arguments().size(); i++) {
+    if (!arguments()
+             .at(i - old_out_start_idx + new_out_start_idx)
+             .isBackwardCompatibleWith(old.arguments().at(i), why_not)) {
+      return false;
+    }
+  }
+
   return true;
 }
 
index 82f70fe..9f8a732 100644 (file)
@@ -520,6 +520,28 @@ TEST(SchemaParserTest, NestedArrays) {
                                               .getElementType()));
 }
 
+TEST(SchemaParserTest, OutVariant) {
+  auto schema_with_out = parseSchema(
+      "at::foo(Tensor self, *, Tensor(a!) f, Tensor(b!) l) -> (Tensor(a!) f, Tensor(b!) l)");
+  ASSERT_TRUE(schema_with_out.arguments().at(1).is_out());
+  ASSERT_TRUE(schema_with_out.arguments().at(2).is_out());
+
+  auto schema_without_out =
+      parseSchema("at::foo(Tensor self, *, int scalar) -> (int)");
+
+  for (const auto& arg : schema_without_out.arguments()) {
+    ASSERT_TRUE(!arg.is_out());
+  }
+
+  auto schema_with_is_write = parseSchema(
+      "aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))");
+
+  for (const auto& arg : schema_with_is_write.arguments()) {
+    ASSERT_TRUE(!arg.is_out());
+  }
+}
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
 TEST(SchemaParserTest, NamedReturns) {
   // named returns
   parseSchema("at::what(Tensor! i_will_be_written_to) -> ()");
index 0451deb..7c7a0f7 100644 (file)
@@ -86,6 +86,27 @@ class TestFunctionSchema(TestCase):
         new_schema = parse_schema('any(Tensor self, *, Tensor b, int[] c) -> Tensor')
         self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
 
+    def test_backward_compatible_with_smart_serialization(self):
+        # cases where out arg is provided
+        old_schema = parse_schema('foo(Tensor self, *, int a, Tensor(a!) out) -> Tensor(a!)')
+        new_schema_same_out = parse_schema('foo(Tensor self, *, int a, int b=1, Tensor(a!) out) -> Tensor(a!)')
+        new_schema_wrong_default = parse_schema('foo(Tensor self, *, int b=1, int a, Tensor(a!) out) -> Tensor(a!)')
+        new_schema_more_out = parse_schema('foo(Tensor self, *, int a, int b=1, Tensor(a!) out, Tensor(b!) b) -> Tensor(a!)')
+        new_schema_wrong_pos = parse_schema('foo(Tensor self, *, int a, int b=1, Tensor(b!) b, Tensor(a!) out) -> Tensor(a!)')
+        self.assertTrue(new_schema_same_out.is_backward_compatible_with(old_schema))
+        self.assertTrue(new_schema_more_out.is_backward_compatible_with(old_schema))
+        self.assertFalse(new_schema_wrong_default.is_backward_compatible_with(old_schema))
+        self.assertFalse(new_schema_wrong_pos.is_backward_compatible_with(old_schema))
+
+        # cases where out arg is not provided
+        old_schema_without_arg = parse_schema('foo(Tensor self, int a, int b=1) -> int')
+        new_schema_without_arg = parse_schema('foo(Tensor self, int a, int b=1, int c=2) -> int')
+        new_schema_without_arg_multiple_default = parse_schema('foo(Tensor self, int a, int b=1, int c=2, int d=3) -> int')
+        new_schema_without_arg_wrong_pos = parse_schema('foo(Tensor self, int a, int c=2, int b=1) -> int')
+        self.assertTrue(new_schema_without_arg.is_backward_compatible_with(old_schema_without_arg))
+        self.assertTrue(new_schema_without_arg_multiple_default.is_backward_compatible_with(old_schema_without_arg))
+        self.assertFalse(new_schema_without_arg_wrong_pos.is_backward_compatible_with(old_schema_without_arg))
+
     def test_string_optional_parameter_default_value(self):
         schema_a = parse_schema("example::op(str? order=\"NCHW\") -> (Tensor)")
         schema_b = parse_schema(str(schema_a))