#include <c10/util/Exception.h>
namespace c10 {
+/**
+ * class AliasInfo
+ *
+ * Data structure to hold aliasing information for an `Argument`. They can be
+ * nested to represent aliasing information on contained types.
+ *
+ * There is a `beforeSet` which describes the aliasing information before the
+ * operator executes, and an `afterSet` that describes aliasing info
+ * after execution.
+ */
class AliasInfo {
public:
// Symbol for the set that can alias anything
}
static AliasInfo createWildcard() {
AliasInfo ret;
- ret.addSet(wildcardSet());
+ ret.addBeforeSet(wildcardSet());
return ret;
}
return isWrite_;
}
- void addSet(Symbol aliasSet) {
- sets_.insert(aliasSet);
+ void addBeforeSet(Symbol aliasSet) {
+ beforeSets_.insert(aliasSet);
}
- const std::unordered_set<Symbol>& sets() const {
- return sets_;
+ void addAfterSet(Symbol aliasSet) {
+ afterSets_.insert(aliasSet);
}
- Symbol set() const {
- AT_ASSERT(sets_.size() == 1);
- return *sets_.begin();
+ const std::unordered_set<Symbol>& beforeSets() const {
+ return beforeSets_;
}
- bool isWildcard() const {
- return sets_.count(wildcardSet()) != 0;
+ const std::unordered_set<Symbol>& afterSets() const {
+ return afterSets_;
}
- void unionWith(const AliasInfo& other) {
- for (const auto& alias : other.sets()) {
- sets_.insert(alias);
- }
+ Symbol beforeSet() const {
+ AT_ASSERT(beforeSets_.size() == 1);
+ return *beforeSets_.begin();
}
- // TODO this doesn't check any contained types yet
- // non-strict: returns true if self.sets() == other.sets()
- bool isSubsetOf(const AliasInfo& other) const {
- for (const auto& alias : this->sets()) {
- if (other.sets().count(alias) == 0) {
- return false;
- }
- }
- return true;
+ bool isWildcard() const {
+ return beforeSets_.count(wildcardSet()) != 0;
}
+
// the alias info for the contained types of the type
// e.g. if this is an annotation on List[T], `sets` refers to
// the alias sets that the list may be in
}
private:
- std::unordered_set<Symbol> sets_;
+ std::unordered_set<Symbol> beforeSets_;
+ std::unordered_set<Symbol> afterSets_;
std::vector<AliasInfo> containedTypes_;
bool isWrite_ = false;
};
inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
out << "(";
bool first = true;
- for (const auto& set : aliasInfo.sets()) {
+ for (const auto& set : aliasInfo.beforeSets()) {
if (first) {
first = false;
} else {
// The list itself is annotated with `a`
const auto& aliasInfo = *s.arguments().at(0).alias_info();
ASSERT_TRUE(
- aliasInfo.sets() ==
+ aliasInfo.beforeSets() ==
std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
ASSERT_TRUE(aliasInfo.isWrite());
Symbol::fromQualString("alias::b"),
Symbol::fromQualString("alias::c"),
};
- ASSERT_TRUE(containedAliasInfo.sets() == expected);
+ ASSERT_TRUE(containedAliasInfo.beforeSets() == expected);
+ ASSERT_TRUE(containedAliasInfo.afterSets() == expected);
+ ASSERT_FALSE(containedAliasInfo.isWrite());
+ }
+ {
+ const auto s = parseSchema(
+ "at::what(Tensor(b -> b|c)[](a!) list, Tensor(c) element)"
+ " -> (Tensor(b|c)[](a!))");
+
+ // The list itself is annotated with `a`
+ const auto& aliasInfo = *s.arguments().at(0).alias_info();
+ ASSERT_EQ(
+ aliasInfo.beforeSets(),
+ std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
+ ASSERT_EQ(
+ aliasInfo.afterSets(),
+ std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
+ ASSERT_TRUE(aliasInfo.isWrite());
+ ASSERT_EQ(aliasInfo.containedTypes().size(), 1);
+
+ // Check the contained types
+ ASSERT_TRUE(!aliasInfo.containedTypes().empty());
+ const auto& containedAliasInfo = aliasInfo.containedTypes()[0];
+ const auto expectedBefore = std::unordered_set<Symbol>{
+ Symbol::fromQualString("alias::b"),
+ };
+ const auto expectedAfter = std::unordered_set<Symbol>{
+ Symbol::fromQualString("alias::b"),
+ Symbol::fromQualString("alias::c")
+ };
+ ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore);
+ ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter);
ASSERT_FALSE(containedAliasInfo.isWrite());
}
}
// TODO neither unions nor wildcards make sense on an input. We should
// disallow them in function schema
AT_ASSERT(!formal->isWildcard())
- const auto& formalAlias = formal->set();
+ const auto& formalAlias = formal->beforeSet();
// skip if we've already bound this alias
if (formalToActual.count(formalAlias) != 0) {
continue;
}
- for (const auto& formalAlias : formal->sets()) {
+ for (const auto& formalAlias : formal->beforeSets()) {
// If we encounter an alias annotation that wasn't in the inputs:
if (!formalToActual.count(formalAlias)) {
// If this alias is not seen elsewhere and is the only annotation on
// the output, it's equivalent to being fresh:
// e.g. foo(Tensor(a) self) -> Tensor(b)
- if (formal->sets().size() == 1) {
+ if (formal->beforeSets().size() == 1) {
giveFreshAlias(actual);
}
// Or it is the form of a|fresh, which we can ignore, taking the
Stack deepCopy(const Stack& stack) {
Stack ret;
+ ret.reserve(stack.size());
for (const auto& v : stack) {
ret.push_back(deepCopy(v));
}
const auto inputSet = input.aliasInfo;
const auto outputSet = output.aliasInfo;
AT_ASSERT(inputSet && outputSet);
- AT_ASSERT(inputSet->isSubsetOf(*outputSet));
+ bool found = false;
+ for (const auto& set : inputSet->beforeSets()) {
+ if (outputSet->beforeSets().count(set)) {
+ found = true;
+ break;
+ }
+ }
+ AT_ASSERT(found);
}
}
}
// If we found a wildcard, ignore all subsequent annotations
} else if (!alias_info.isWildcard()) {
- alias_info.addSet(
+ alias_info.addBeforeSet(
Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
}
});
if (L.nextIf('!')) {
alias_info.setIsWrite(true);
}
+ if (L.nextIf(TK_ARROW)) {
+ // optional 'alias set annotation'
+ parseList(TK_NOTHING, '|', TK_NOTHING, [&] {
+ if (L.cur().kind == '*') {
+ L.reportError("Wildcard not allowed as part of the output set");
+ }
+ alias_info.addAfterSet(
+ Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
+ });
+ } else {
+ // We didn't encounter an ->, so assume the "after set" is identical
+ // to the "before set"
+ AT_ASSERT(alias_info.afterSets().empty());
+ for (const auto& set : alias_info.beforeSets()) {
+ alias_info.addAfterSet(set);
+ }
+ }
L.expect(')');
} else if (L.nextIf('!')) {
- alias_info.addSet(
+ alias_info.addBeforeSet(
Symbol::fromQualString("alias::$" + std::to_string(next_id++)));
alias_info.setIsWrite(true);
} else {