From 19c1b45f25af32fc6f6d6da315f0055ab7e30222 Mon Sep 17 00:00:00 2001 From: "Tugsbayasgalan (Tugsuu) Manlaibaatar" Date: Fri, 27 Aug 2021 11:18:52 -0700 Subject: [PATCH] Detect out argument in the schema (#62755) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62755 After this change, out argument can be checked by calling is_out() Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D30415256 Pulled By: tugsbayasgalan fbshipit-source-id: b2e1fa46bab7c813aaede1f44149081ef2df566d --- aten/src/ATen/core/function_schema.h | 10 ++++++++++ aten/src/ATen/core/function_schema_inl.h | 32 +++++++++++++++++++++++++++----- test/cpp/jit/test_misc.cpp | 22 ++++++++++++++++++++++ test/test_function_schema.py | 21 +++++++++++++++++++++ 4 files changed, 80 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index 68e177a..a7b5149 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -34,6 +34,9 @@ struct Argument { default_value_(std::move(default_value)), kwarg_only_(kwarg_only), alias_info_(std::move(alias_info)) { + // this is an softly-enforced invariant for out arguments. + bool is_alias = alias_info_.has_value() && alias_info_.value().isWrite(); + is_out_ = kwarg_only_ && is_alias; } const std::string& name() const { return name_; @@ -50,6 +53,11 @@ struct Argument { bool kwarg_only() const { return kwarg_only_; } + + bool is_out() const { + return is_out_; + } + const c10::optional& alias_info() const { return alias_info_; } @@ -116,6 +124,8 @@ struct Argument { // is this only specifiable as a keyword argument? bool kwarg_only_; c10::optional alias_info_; + // marks if the argument is out variant of the schema + bool is_out_; }; inline bool operator==(const Argument& lhs, const Argument& rhs) { diff --git a/aten/src/ATen/core/function_schema_inl.h b/aten/src/ATen/core/function_schema_inl.h index 168ecb4..6e26e8c 100644 --- a/aten/src/ATen/core/function_schema_inl.h +++ b/aten/src/ATen/core/function_schema_inl.h @@ -51,6 +51,16 @@ inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) return out; } +inline size_t findFirstOutArg(const std::vector& args) { + // find the start of out args in the schema + for (size_t out_start_idx = 0; out_start_idx < args.size(); out_start_idx++) { + if (args.at(out_start_idx).is_out()) { + return out_start_idx; + } + } + return args.size(); +} + inline bool Argument::isBackwardCompatibleWith( const Argument& old, std::ostream* why_not) const { @@ -121,17 +131,20 @@ inline bool FunctionSchema::isBackwardCompatibleWith( } } - // Make sure that all the old arguments have their corresponding backward - // compatible arguments in this schema. - for (size_t i = 0; i < old.arguments().size(); ++i) { + // we want to test both out and default args seperately + size_t old_out_start_idx = findFirstOutArg(old.arguments()); + size_t new_out_start_idx = findFirstOutArg(arguments()); + + // make sure among the default args, they are backward compatible + for (size_t i = 0; i < old_out_start_idx; i++) { if (!arguments().at(i).isBackwardCompatibleWith( old.arguments().at(i), why_not)) { return false; } } - // Validate that all new arguments provided a default value. - for (size_t i = old.arguments().size(); i < arguments().size(); ++i) { + // // Validate that all new arguments provided has a default value + for (size_t i = old_out_start_idx; i < new_out_start_idx; ++i) { if (!arguments().at(i).default_value()) { if (why_not) { *why_not @@ -144,6 +157,15 @@ inline bool FunctionSchema::isBackwardCompatibleWith( } } + // now compare the out args + for (size_t i = old_out_start_idx; i < old.arguments().size(); i++) { + if (!arguments() + .at(i - old_out_start_idx + new_out_start_idx) + .isBackwardCompatibleWith(old.arguments().at(i), why_not)) { + return false; + } + } + return true; } diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index 82f70fe..9f8a732 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -520,6 +520,28 @@ TEST(SchemaParserTest, NestedArrays) { .getElementType())); } +TEST(SchemaParserTest, OutVariant) { + auto schema_with_out = parseSchema( + "at::foo(Tensor self, *, Tensor(a!) f, Tensor(b!) l) -> (Tensor(a!) f, Tensor(b!) l)"); + ASSERT_TRUE(schema_with_out.arguments().at(1).is_out()); + ASSERT_TRUE(schema_with_out.arguments().at(2).is_out()); + + auto schema_without_out = + parseSchema("at::foo(Tensor self, *, int scalar) -> (int)"); + + for (const auto& arg : schema_without_out.arguments()) { + ASSERT_TRUE(!arg.is_out()); + } + + auto schema_with_is_write = parseSchema( + "aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))"); + + for (const auto& arg : schema_with_is_write.arguments()) { + ASSERT_TRUE(!arg.is_out()); + } +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) TEST(SchemaParserTest, NamedReturns) { // named returns parseSchema("at::what(Tensor! i_will_be_written_to) -> ()"); diff --git a/test/test_function_schema.py b/test/test_function_schema.py index 0451deb..7c7a0f7 100644 --- a/test/test_function_schema.py +++ b/test/test_function_schema.py @@ -86,6 +86,27 @@ class TestFunctionSchema(TestCase): new_schema = parse_schema('any(Tensor self, *, Tensor b, int[] c) -> Tensor') self.assertFalse(new_schema.is_backward_compatible_with(old_schema)) + def test_backward_compatible_with_smart_serialization(self): + # cases where out arg is provided + old_schema = parse_schema('foo(Tensor self, *, int a, Tensor(a!) out) -> Tensor(a!)') + new_schema_same_out = parse_schema('foo(Tensor self, *, int a, int b=1, Tensor(a!) out) -> Tensor(a!)') + new_schema_wrong_default = parse_schema('foo(Tensor self, *, int b=1, int a, Tensor(a!) out) -> Tensor(a!)') + new_schema_more_out = parse_schema('foo(Tensor self, *, int a, int b=1, Tensor(a!) out, Tensor(b!) b) -> Tensor(a!)') + new_schema_wrong_pos = parse_schema('foo(Tensor self, *, int a, int b=1, Tensor(b!) b, Tensor(a!) out) -> Tensor(a!)') + self.assertTrue(new_schema_same_out.is_backward_compatible_with(old_schema)) + self.assertTrue(new_schema_more_out.is_backward_compatible_with(old_schema)) + self.assertFalse(new_schema_wrong_default.is_backward_compatible_with(old_schema)) + self.assertFalse(new_schema_wrong_pos.is_backward_compatible_with(old_schema)) + + # cases where out arg is not provided + old_schema_without_arg = parse_schema('foo(Tensor self, int a, int b=1) -> int') + new_schema_without_arg = parse_schema('foo(Tensor self, int a, int b=1, int c=2) -> int') + new_schema_without_arg_multiple_default = parse_schema('foo(Tensor self, int a, int b=1, int c=2, int d=3) -> int') + new_schema_without_arg_wrong_pos = parse_schema('foo(Tensor self, int a, int c=2, int b=1) -> int') + self.assertTrue(new_schema_without_arg.is_backward_compatible_with(old_schema_without_arg)) + self.assertTrue(new_schema_without_arg_multiple_default.is_backward_compatible_with(old_schema_without_arg)) + self.assertFalse(new_schema_without_arg_wrong_pos.is_backward_compatible_with(old_schema_without_arg)) + def test_string_optional_parameter_default_value(self): schema_a = parse_schema("example::op(str? order=\"NCHW\") -> (Tensor)") schema_b = parse_schema(str(schema_a)) -- 2.7.4