From f9d3f1dca528863d65301c99d260ad299be500e0 Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Thu, 28 Feb 2019 11:28:16 -0800 Subject: [PATCH] allow "before" and "after" alias annotations (#17480) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17480 This was always part of our "spec" but not implemented Reviewed By: houseroad Differential Revision: D14214301 fbshipit-source-id: 118db320b43ec099dc3e730c67d39487474c23ea --- aten/src/ATen/core/alias_info.h | 53 ++++++++++++---------- test/cpp/jit/test_misc.h | 35 +++++++++++++- torch/csrc/jit/passes/alias_analysis.cpp | 6 +-- .../jit/passes/utils/check_alias_annotation.cpp | 10 +++- torch/csrc/jit/script/schema_type_parser.cpp | 21 ++++++++- 5 files changed, 92 insertions(+), 33 deletions(-) diff --git a/aten/src/ATen/core/alias_info.h b/aten/src/ATen/core/alias_info.h index 70d1eed..61f482e 100644 --- a/aten/src/ATen/core/alias_info.h +++ b/aten/src/ATen/core/alias_info.h @@ -5,6 +5,16 @@ #include namespace c10 { +/** + * class AliasInfo + * + * Data structure to hold aliasing information for an `Argument`. They can be + * nested to represent aliasing information on contained types. + * + * There is a `beforeSet` which describes the aliasing information before the + * operator executes, and an `afterSet` that describes aliasing info + * after execution. + */ class AliasInfo { public: // Symbol for the set that can alias anything @@ -14,7 +24,7 @@ class AliasInfo { } static AliasInfo createWildcard() { AliasInfo ret; - ret.addSet(wildcardSet()); + ret.addBeforeSet(wildcardSet()); return ret; } @@ -26,39 +36,31 @@ class AliasInfo { return isWrite_; } - void addSet(Symbol aliasSet) { - sets_.insert(aliasSet); + void addBeforeSet(Symbol aliasSet) { + beforeSets_.insert(aliasSet); } - const std::unordered_set& sets() const { - return sets_; + void addAfterSet(Symbol aliasSet) { + afterSets_.insert(aliasSet); } - Symbol set() const { - AT_ASSERT(sets_.size() == 1); - return *sets_.begin(); + const std::unordered_set& beforeSets() const { + return beforeSets_; } - bool isWildcard() const { - return sets_.count(wildcardSet()) != 0; + const std::unordered_set& afterSets() const { + return afterSets_; } - void unionWith(const AliasInfo& other) { - for (const auto& alias : other.sets()) { - sets_.insert(alias); - } + Symbol beforeSet() const { + AT_ASSERT(beforeSets_.size() == 1); + return *beforeSets_.begin(); } - // TODO this doesn't check any contained types yet - // non-strict: returns true if self.sets() == other.sets() - bool isSubsetOf(const AliasInfo& other) const { - for (const auto& alias : this->sets()) { - if (other.sets().count(alias) == 0) { - return false; - } - } - return true; + bool isWildcard() const { + return beforeSets_.count(wildcardSet()) != 0; } + // the alias info for the contained types of the type // e.g. if this is an annotation on List[T], `sets` refers to // the alias sets that the list may be in @@ -72,7 +74,8 @@ class AliasInfo { } private: - std::unordered_set sets_; + std::unordered_set beforeSets_; + std::unordered_set afterSets_; std::vector containedTypes_; bool isWrite_ = false; }; @@ -81,7 +84,7 @@ class AliasInfo { inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) { out << "("; bool first = true; - for (const auto& set : aliasInfo.sets()) { + for (const auto& set : aliasInfo.beforeSets()) { if (first) { first = false; } else { diff --git a/test/cpp/jit/test_misc.h b/test/cpp/jit/test_misc.h index 6a064b9..36c22dc 100644 --- a/test/cpp/jit/test_misc.h +++ b/test/cpp/jit/test_misc.h @@ -1474,7 +1474,7 @@ void testSchemaParser() { // The list itself is annotated with `a` const auto& aliasInfo = *s.arguments().at(0).alias_info(); ASSERT_TRUE( - aliasInfo.sets() == + aliasInfo.beforeSets() == std::unordered_set{Symbol::fromQualString("alias::a")}); ASSERT_TRUE(aliasInfo.isWrite()); @@ -1485,7 +1485,38 @@ void testSchemaParser() { Symbol::fromQualString("alias::b"), Symbol::fromQualString("alias::c"), }; - ASSERT_TRUE(containedAliasInfo.sets() == expected); + ASSERT_TRUE(containedAliasInfo.beforeSets() == expected); + ASSERT_TRUE(containedAliasInfo.afterSets() == expected); + ASSERT_FALSE(containedAliasInfo.isWrite()); + } + { + const auto s = parseSchema( + "at::what(Tensor(b -> b|c)[](a!) list, Tensor(c) element)" + " -> (Tensor(b|c)[](a!))"); + + // The list itself is annotated with `a` + const auto& aliasInfo = *s.arguments().at(0).alias_info(); + ASSERT_EQ( + aliasInfo.beforeSets(), + std::unordered_set{Symbol::fromQualString("alias::a")}); + ASSERT_EQ( + aliasInfo.afterSets(), + std::unordered_set{Symbol::fromQualString("alias::a")}); + ASSERT_TRUE(aliasInfo.isWrite()); + ASSERT_EQ(aliasInfo.containedTypes().size(), 1); + + // Check the contained types + ASSERT_TRUE(!aliasInfo.containedTypes().empty()); + const auto& containedAliasInfo = aliasInfo.containedTypes()[0]; + const auto expectedBefore = std::unordered_set{ + Symbol::fromQualString("alias::b"), + }; + const auto expectedAfter = std::unordered_set{ + Symbol::fromQualString("alias::b"), + Symbol::fromQualString("alias::c") + }; + ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore); + ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter); ASSERT_FALSE(containedAliasInfo.isWrite()); } } diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp index 88680d2..850862d 100644 --- a/torch/csrc/jit/passes/alias_analysis.cpp +++ b/torch/csrc/jit/passes/alias_analysis.cpp @@ -380,7 +380,7 @@ void AliasDb::analyzeImpl(Node* node) { // TODO neither unions nor wildcards make sense on an input. We should // disallow them in function schema AT_ASSERT(!formal->isWildcard()) - const auto& formalAlias = formal->set(); + const auto& formalAlias = formal->beforeSet(); // skip if we've already bound this alias if (formalToActual.count(formalAlias) != 0) { @@ -419,13 +419,13 @@ void AliasDb::analyzeImpl(Node* node) { continue; } - for (const auto& formalAlias : formal->sets()) { + for (const auto& formalAlias : formal->beforeSets()) { // If we encounter an alias annotation that wasn't in the inputs: if (!formalToActual.count(formalAlias)) { // If this alias is not seen elsewhere and is the only annotation on // the output, it's equivalent to being fresh: // e.g. foo(Tensor(a) self) -> Tensor(b) - if (formal->sets().size() == 1) { + if (formal->beforeSets().size() == 1) { giveFreshAlias(actual); } // Or it is the form of a|fresh, which we can ignore, taking the diff --git a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp index 3c7accd..85d0fd1 100644 --- a/torch/csrc/jit/passes/utils/check_alias_annotation.cpp +++ b/torch/csrc/jit/passes/utils/check_alias_annotation.cpp @@ -49,6 +49,7 @@ IValue deepCopy(const IValue& self) { Stack deepCopy(const Stack& stack) { Stack ret; + ret.reserve(stack.size()); for (const auto& v : stack) { ret.push_back(deepCopy(v)); } @@ -104,7 +105,14 @@ void checkAliases( const auto inputSet = input.aliasInfo; const auto outputSet = output.aliasInfo; AT_ASSERT(inputSet && outputSet); - AT_ASSERT(inputSet->isSubsetOf(*outputSet)); + bool found = false; + for (const auto& set : inputSet->beforeSets()) { + if (outputSet->beforeSets().count(set)) { + found = true; + break; + } + } + AT_ASSERT(found); } } } diff --git a/torch/csrc/jit/script/schema_type_parser.cpp b/torch/csrc/jit/script/schema_type_parser.cpp index c71c0b6..19e9ccb 100644 --- a/torch/csrc/jit/script/schema_type_parser.cpp +++ b/torch/csrc/jit/script/schema_type_parser.cpp @@ -53,16 +53,33 @@ c10::optional SchemaTypeParser::parseAliasAnnotation() { // If we found a wildcard, ignore all subsequent annotations } else if (!alias_info.isWildcard()) { - alias_info.addSet( + alias_info.addBeforeSet( Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text())); } }); if (L.nextIf('!')) { alias_info.setIsWrite(true); } + if (L.nextIf(TK_ARROW)) { + // optional 'alias set annotation' + parseList(TK_NOTHING, '|', TK_NOTHING, [&] { + if (L.cur().kind == '*') { + L.reportError("Wildcard not allowed as part of the output set"); + } + alias_info.addAfterSet( + Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text())); + }); + } else { + // We didn't encounter an ->, so assume the "after set" is identical + // to the "before set" + AT_ASSERT(alias_info.afterSets().empty()); + for (const auto& set : alias_info.beforeSets()) { + alias_info.addAfterSet(set); + } + } L.expect(')'); } else if (L.nextIf('!')) { - alias_info.addSet( + alias_info.addBeforeSet( Symbol::fromQualString("alias::$" + std::to_string(next_id++))); alias_info.setIsWrite(true); } else { -- 2.7.4