default_value_(std::move(default_value)),
kwarg_only_(kwarg_only),
alias_info_(std::move(alias_info)) {
+ // this is an softly-enforced invariant for out arguments.
+ bool is_alias = alias_info_.has_value() && alias_info_.value().isWrite();
+ is_out_ = kwarg_only_ && is_alias;
}
const std::string& name() const {
return name_;
bool kwarg_only() const {
return kwarg_only_;
}
+
+ bool is_out() const {
+ return is_out_;
+ }
+
const c10::optional<AliasInfo>& alias_info() const {
return alias_info_;
}
// is this only specifiable as a keyword argument?
bool kwarg_only_;
c10::optional<AliasInfo> alias_info_;
+ // marks if the argument is out variant of the schema
+ bool is_out_;
};
inline bool operator==(const Argument& lhs, const Argument& rhs) {
return out;
}
+inline size_t findFirstOutArg(const std::vector<Argument>& args) {
+ // find the start of out args in the schema
+ for (size_t out_start_idx = 0; out_start_idx < args.size(); out_start_idx++) {
+ if (args.at(out_start_idx).is_out()) {
+ return out_start_idx;
+ }
+ }
+ return args.size();
+}
+
inline bool Argument::isBackwardCompatibleWith(
const Argument& old,
std::ostream* why_not) const {
}
}
- // Make sure that all the old arguments have their corresponding backward
- // compatible arguments in this schema.
- for (size_t i = 0; i < old.arguments().size(); ++i) {
+ // we want to test both out and default args seperately
+ size_t old_out_start_idx = findFirstOutArg(old.arguments());
+ size_t new_out_start_idx = findFirstOutArg(arguments());
+
+ // make sure among the default args, they are backward compatible
+ for (size_t i = 0; i < old_out_start_idx; i++) {
if (!arguments().at(i).isBackwardCompatibleWith(
old.arguments().at(i), why_not)) {
return false;
}
}
- // Validate that all new arguments provided a default value.
- for (size_t i = old.arguments().size(); i < arguments().size(); ++i) {
+ // // Validate that all new arguments provided has a default value
+ for (size_t i = old_out_start_idx; i < new_out_start_idx; ++i) {
if (!arguments().at(i).default_value()) {
if (why_not) {
*why_not
}
}
+ // now compare the out args
+ for (size_t i = old_out_start_idx; i < old.arguments().size(); i++) {
+ if (!arguments()
+ .at(i - old_out_start_idx + new_out_start_idx)
+ .isBackwardCompatibleWith(old.arguments().at(i), why_not)) {
+ return false;
+ }
+ }
+
return true;
}
.getElementType()));
}
+TEST(SchemaParserTest, OutVariant) {
+ auto schema_with_out = parseSchema(
+ "at::foo(Tensor self, *, Tensor(a!) f, Tensor(b!) l) -> (Tensor(a!) f, Tensor(b!) l)");
+ ASSERT_TRUE(schema_with_out.arguments().at(1).is_out());
+ ASSERT_TRUE(schema_with_out.arguments().at(2).is_out());
+
+ auto schema_without_out =
+ parseSchema("at::foo(Tensor self, *, int scalar) -> (int)");
+
+ for (const auto& arg : schema_without_out.arguments()) {
+ ASSERT_TRUE(!arg.is_out());
+ }
+
+ auto schema_with_is_write = parseSchema(
+ "aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))");
+
+ for (const auto& arg : schema_with_is_write.arguments()) {
+ ASSERT_TRUE(!arg.is_out());
+ }
+}
+
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(SchemaParserTest, NamedReturns) {
// named returns
parseSchema("at::what(Tensor! i_will_be_written_to) -> ()");
new_schema = parse_schema('any(Tensor self, *, Tensor b, int[] c) -> Tensor')
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
+ def test_backward_compatible_with_smart_serialization(self):
+ # cases where out arg is provided
+ old_schema = parse_schema('foo(Tensor self, *, int a, Tensor(a!) out) -> Tensor(a!)')
+ new_schema_same_out = parse_schema('foo(Tensor self, *, int a, int b=1, Tensor(a!) out) -> Tensor(a!)')
+ new_schema_wrong_default = parse_schema('foo(Tensor self, *, int b=1, int a, Tensor(a!) out) -> Tensor(a!)')
+ new_schema_more_out = parse_schema('foo(Tensor self, *, int a, int b=1, Tensor(a!) out, Tensor(b!) b) -> Tensor(a!)')
+ new_schema_wrong_pos = parse_schema('foo(Tensor self, *, int a, int b=1, Tensor(b!) b, Tensor(a!) out) -> Tensor(a!)')
+ self.assertTrue(new_schema_same_out.is_backward_compatible_with(old_schema))
+ self.assertTrue(new_schema_more_out.is_backward_compatible_with(old_schema))
+ self.assertFalse(new_schema_wrong_default.is_backward_compatible_with(old_schema))
+ self.assertFalse(new_schema_wrong_pos.is_backward_compatible_with(old_schema))
+
+ # cases where out arg is not provided
+ old_schema_without_arg = parse_schema('foo(Tensor self, int a, int b=1) -> int')
+ new_schema_without_arg = parse_schema('foo(Tensor self, int a, int b=1, int c=2) -> int')
+ new_schema_without_arg_multiple_default = parse_schema('foo(Tensor self, int a, int b=1, int c=2, int d=3) -> int')
+ new_schema_without_arg_wrong_pos = parse_schema('foo(Tensor self, int a, int c=2, int b=1) -> int')
+ self.assertTrue(new_schema_without_arg.is_backward_compatible_with(old_schema_without_arg))
+ self.assertTrue(new_schema_without_arg_multiple_default.is_backward_compatible_with(old_schema_without_arg))
+ self.assertFalse(new_schema_without_arg_wrong_pos.is_backward_compatible_with(old_schema_without_arg))
+
def test_string_optional_parameter_default_value(self):
schema_a = parse_schema("example::op(str? order=\"NCHW\") -> (Tensor)")
schema_b = parse_schema(str(schema_a))