allow "before" and "after" alias annotations (#17480)
authorMichael Suo <suo@fb.com>
Thu, 28 Feb 2019 19:28:16 +0000 (11:28 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 28 Feb 2019 20:00:34 +0000 (12:00 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17480

This was always part of our "spec" but not implemented

Reviewed By: houseroad

Differential Revision: D14214301

fbshipit-source-id: 118db320b43ec099dc3e730c67d39487474c23ea

aten/src/ATen/core/alias_info.h
test/cpp/jit/test_misc.h
torch/csrc/jit/passes/alias_analysis.cpp
torch/csrc/jit/passes/utils/check_alias_annotation.cpp
torch/csrc/jit/script/schema_type_parser.cpp

index 70d1eed..61f482e 100644 (file)
@@ -5,6 +5,16 @@
 #include <c10/util/Exception.h>
 
 namespace c10 {
+/**
+ * class AliasInfo
+ *
+ * Data structure to hold aliasing information for an `Argument`. They can be
+ * nested to represent aliasing information on contained types.
+ *
+ * There is a `beforeSet` which describes the aliasing information before the
+ * operator executes, and an `afterSet` that describes aliasing info
+ * after execution.
+ */
 class AliasInfo {
  public:
   // Symbol for the set that can alias anything
@@ -14,7 +24,7 @@ class AliasInfo {
   }
   static AliasInfo createWildcard() {
     AliasInfo ret;
-    ret.addSet(wildcardSet());
+    ret.addBeforeSet(wildcardSet());
     return ret;
   }
 
@@ -26,39 +36,31 @@ class AliasInfo {
     return isWrite_;
   }
 
-  void addSet(Symbol aliasSet) {
-    sets_.insert(aliasSet);
+  void addBeforeSet(Symbol aliasSet) {
+    beforeSets_.insert(aliasSet);
   }
 
-  const std::unordered_set<Symbol>& sets() const {
-    return sets_;
+  void addAfterSet(Symbol aliasSet) {
+    afterSets_.insert(aliasSet);
   }
 
-  Symbol set() const {
-    AT_ASSERT(sets_.size() == 1);
-    return *sets_.begin();
+  const std::unordered_set<Symbol>& beforeSets() const {
+    return beforeSets_;
   }
 
-  bool isWildcard() const {
-    return sets_.count(wildcardSet()) != 0;
+  const std::unordered_set<Symbol>& afterSets() const {
+    return afterSets_;
   }
 
-  void unionWith(const AliasInfo& other) {
-    for (const auto& alias : other.sets()) {
-      sets_.insert(alias);
-    }
+  Symbol beforeSet() const {
+    AT_ASSERT(beforeSets_.size() == 1);
+    return *beforeSets_.begin();
   }
 
-  // TODO this doesn't check any contained types yet
-  // non-strict: returns true if self.sets() == other.sets()
-  bool isSubsetOf(const AliasInfo& other) const {
-    for (const auto& alias : this->sets()) {
-      if (other.sets().count(alias) == 0) {
-        return false;
-      }
-    }
-    return true;
+  bool isWildcard() const {
+    return beforeSets_.count(wildcardSet()) != 0;
   }
+
   // the alias info for the contained types of the type
   // e.g. if this is an annotation on List[T], `sets` refers to
   // the alias sets that the list may be in
@@ -72,7 +74,8 @@ class AliasInfo {
   }
 
  private:
-  std::unordered_set<Symbol> sets_;
+  std::unordered_set<Symbol> beforeSets_;
+  std::unordered_set<Symbol> afterSets_;
   std::vector<AliasInfo> containedTypes_;
   bool isWrite_ = false;
 };
@@ -81,7 +84,7 @@ class AliasInfo {
 inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) {
   out << "(";
   bool first = true;
-  for (const auto& set : aliasInfo.sets()) {
+  for (const auto& set : aliasInfo.beforeSets()) {
     if (first) {
       first = false;
     } else {
index 6a064b9..36c22dc 100644 (file)
@@ -1474,7 +1474,7 @@ void testSchemaParser() {
     // The list itself is annotated with `a`
     const auto& aliasInfo = *s.arguments().at(0).alias_info();
     ASSERT_TRUE(
-        aliasInfo.sets() ==
+        aliasInfo.beforeSets() ==
         std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
     ASSERT_TRUE(aliasInfo.isWrite());
 
@@ -1485,7 +1485,38 @@ void testSchemaParser() {
         Symbol::fromQualString("alias::b"),
         Symbol::fromQualString("alias::c"),
     };
-    ASSERT_TRUE(containedAliasInfo.sets() == expected);
+    ASSERT_TRUE(containedAliasInfo.beforeSets() == expected);
+    ASSERT_TRUE(containedAliasInfo.afterSets() == expected);
+    ASSERT_FALSE(containedAliasInfo.isWrite());
+  }
+  {
+    const auto s = parseSchema(
+        "at::what(Tensor(b -> b|c)[](a!) list, Tensor(c) element)"
+        " -> (Tensor(b|c)[](a!))");
+
+    // The list itself is annotated with `a`
+    const auto& aliasInfo = *s.arguments().at(0).alias_info();
+    ASSERT_EQ(
+        aliasInfo.beforeSets(),
+        std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
+    ASSERT_EQ(
+        aliasInfo.afterSets(),
+        std::unordered_set<Symbol>{Symbol::fromQualString("alias::a")});
+    ASSERT_TRUE(aliasInfo.isWrite());
+    ASSERT_EQ(aliasInfo.containedTypes().size(), 1);
+
+    // Check the contained types
+    ASSERT_TRUE(!aliasInfo.containedTypes().empty());
+    const auto& containedAliasInfo = aliasInfo.containedTypes()[0];
+    const auto expectedBefore = std::unordered_set<Symbol>{
+        Symbol::fromQualString("alias::b"),
+    };
+    const auto expectedAfter = std::unordered_set<Symbol>{
+        Symbol::fromQualString("alias::b"),
+        Symbol::fromQualString("alias::c")
+    };
+    ASSERT_TRUE(containedAliasInfo.beforeSets() == expectedBefore);
+    ASSERT_TRUE(containedAliasInfo.afterSets() == expectedAfter);
     ASSERT_FALSE(containedAliasInfo.isWrite());
   }
 }
index 88680d2..850862d 100644 (file)
@@ -380,7 +380,7 @@ void AliasDb::analyzeImpl(Node* node) {
     // TODO neither unions nor wildcards make sense on an input. We should
     // disallow them in function schema
     AT_ASSERT(!formal->isWildcard())
-    const auto& formalAlias = formal->set();
+    const auto& formalAlias = formal->beforeSet();
 
     // skip if we've already bound this alias
     if (formalToActual.count(formalAlias) != 0) {
@@ -419,13 +419,13 @@ void AliasDb::analyzeImpl(Node* node) {
       continue;
     }
 
-    for (const auto& formalAlias : formal->sets()) {
+    for (const auto& formalAlias : formal->beforeSets()) {
       // If we encounter an alias annotation that wasn't in the inputs:
       if (!formalToActual.count(formalAlias)) {
         // If this alias is not seen elsewhere and is the only annotation on
         // the output, it's equivalent to being fresh:
         //   e.g. foo(Tensor(a) self) -> Tensor(b)
-        if (formal->sets().size() == 1) {
+        if (formal->beforeSets().size() == 1) {
           giveFreshAlias(actual);
         }
         // Or it is the form of a|fresh, which we can ignore, taking the
index 3c7accd..85d0fd1 100644 (file)
@@ -49,6 +49,7 @@ IValue deepCopy(const IValue& self) {
 
 Stack deepCopy(const Stack& stack) {
   Stack ret;
+  ret.reserve(stack.size());
   for (const auto& v : stack) {
     ret.push_back(deepCopy(v));
   }
@@ -104,7 +105,14 @@ void checkAliases(
         const auto inputSet = input.aliasInfo;
         const auto outputSet = output.aliasInfo;
         AT_ASSERT(inputSet && outputSet);
-        AT_ASSERT(inputSet->isSubsetOf(*outputSet));
+        bool found = false;
+        for (const auto& set : inputSet->beforeSets()) {
+          if (outputSet->beforeSets().count(set)) {
+            found = true;
+            break;
+          }
+        }
+        AT_ASSERT(found);
       }
     }
   }
index c71c0b6..19e9ccb 100644 (file)
@@ -53,16 +53,33 @@ c10::optional<AliasInfo> SchemaTypeParser::parseAliasAnnotation() {
 
         // If we found a wildcard, ignore all subsequent annotations
       } else if (!alias_info.isWildcard()) {
-        alias_info.addSet(
+        alias_info.addBeforeSet(
             Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
       }
     });
     if (L.nextIf('!')) {
       alias_info.setIsWrite(true);
     }
+    if (L.nextIf(TK_ARROW)) {
+      // optional 'alias set annotation'
+      parseList(TK_NOTHING, '|', TK_NOTHING, [&] {
+        if (L.cur().kind == '*') {
+          L.reportError("Wildcard not allowed as part of the output set");
+        }
+        alias_info.addAfterSet(
+            Symbol::fromQualString("alias::" + L.expect(TK_IDENT).text()));
+      });
+    } else {
+      // We didn't encounter an ->, so assume the "after set" is identical
+      // to the "before set"
+      AT_ASSERT(alias_info.afterSets().empty());
+      for (const auto& set : alias_info.beforeSets()) {
+        alias_info.addAfterSet(set);
+      }
+    }
     L.expect(')');
   } else if (L.nextIf('!')) {
-    alias_info.addSet(
+    alias_info.addBeforeSet(
         Symbol::fromQualString("alias::$" + std::to_string(next_id++)));
     alias_info.setIsWrite(true);
   } else {