further wildcard cleanups (#16041)
authorMichael Suo <suo@fb.com>
Thu, 17 Jan 2019 22:38:42 +0000 (14:38 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 17 Jan 2019 22:54:34 +0000 (14:54 -0800)
Summary:
Some cleanup to wildcard handling, including one bugfix: previously, we were not considering writes to the wildcard set as part of the potential write set for nodes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16041

Differential Revision: D13705738

Pulled By: suo

fbshipit-source-id: acb8ccbaa70fe47445577ddf24a69f84630de411

test/test_jit.py
torch/csrc/jit/passes/alias_analysis.cpp
torch/csrc/jit/passes/alias_analysis.h

index 2afad2a..4f549d6 100644 (file)
@@ -9440,6 +9440,17 @@ a")
 
         self.checkScript(fn, ([torch.ones(2) + 2, torch.ones(2)],))
 
+    def test_view_write(self):
+        def fn(x, y):
+            l = []
+            l.append(x)
+            x_view = l[0]
+            a = x + x
+            x_view.add_(y)
+            b = x + x
+            return a == b
+        self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3)))
+
 
 class MnistNet(nn.Module):
     def __init__(self):
index 56d7fce..c98929c 100644 (file)
@@ -36,20 +36,6 @@ AliasDb::AliasDb(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {
       aliasToValue_[aliasSet].insert(value);
     }
   }
-  // - Set of all nodes with a wildcard
-  buildWildcardIndex(graph_->block());
-}
-
-void AliasDb::buildWildcardIndex(const Block* b) {
-  for (const auto node : b->nodes()) {
-    for (const auto block : node->blocks()) {
-      buildWildcardIndex(block);
-    }
-
-    if (hasWildcardImpl(node)) {
-      wildcardNodes_.insert(node);
-    }
-  }
 }
 
 bool AliasDb::hasWildcard(const Node* n) const {
@@ -111,7 +97,7 @@ bool AliasDb::hasWritersBefore(const Node* n) const {
   }
   const auto writers = getWriters(n);
   return std::any_of(writers.cbegin(), writers.cend(), [&](const Node* writer) {
-    return writer->isBefore(n);
+    return isBeforeSameGraph(writer, n);
   });
 }
 
@@ -185,6 +171,15 @@ std::unordered_set<Node*> AliasDb::getWriters(const Node* n) const {
       }
     }
   }
+
+  // A write to the wildcard set should be considered a write to `n`
+  if (aliasToWrites_.count(AliasInfo::wildcardSet())) {
+    const auto& wildcardWriters = aliasToWrites_.at(AliasInfo::wildcardSet());
+    for (auto writer : wildcardWriters) {
+      writers.insert(writer);
+    }
+  }
+
   return writers;
 }
 
@@ -447,6 +442,10 @@ void AliasDb::analyze(Node* node) {
 
     addAlias(actual, outputAlias);
   }
+  // Keep the wildcard index up to date.
+  if (hasWildcardImpl(node)) {
+    wildcardNodes_.insert(node);
+  }
 }
 
 void AliasDb::analyzeIf(Node* node) {
@@ -508,7 +507,11 @@ void AliasDb::analyzeLoop(Node* node) {
 }
 
 void AliasDb::analyzeSubgraph(Node* node) {
-  const auto subgraphBlock = node->g(attr::Subgraph)->block();
+  const auto subgraph = node->g(attr::Subgraph).get();
+
+  subgraphToOwner_.insert({subgraph, node});
+
+  const auto subgraphBlock = subgraph->block();
   mapAliases(subgraphBlock->inputs(), node->inputs());
 
   analyze(subgraphBlock);
@@ -789,6 +792,7 @@ class AliasDb::WorkingSet {
   // outside), then return nullptr. Since we can only reorder nodes within a
   // block, `target` would be irrelevant.
   static Node* findSameBlock(Node* target, Node* n) {
+    JIT_ASSERT(target->owningGraph() == n->owningGraph());
     if (target->owningBlock() == n->owningBlock()) {
       return target;
     } else {
@@ -927,20 +931,53 @@ void AliasDb::move(Node* toMove, Node* movePoint, MoveSide moveSide) {
   }
 }
 
+c10::optional<const Node*> AliasDb::getLastWildcard() const {
+  auto it = std::max_element(
+      wildcardNodes_.cbegin(),
+      wildcardNodes_.cend(),
+      [this](const Node* a, const Node* b) { return isBeforeSameGraph(a, b); });
+  if (it != wildcardNodes_.end()) {
+    return *it;
+  } else {
+    return c10::nullopt;
+  }
+}
+
 bool AliasDb::hasUntrackedEffects(Node* node) const {
   bool touchesWildcard = false;
-  if (!wildcardNodes_.empty()) {
-    auto lastWildcard = *wildcardNodes_.begin();
-    for (const auto wildcard : wildcardNodes_) {
-      if (wildcard->isAfter(lastWildcard)) {
-        lastWildcard = wildcard;
-      }
-    }
+  if (const auto lastWildcard = getLastWildcard()) {
     touchesWildcard = hasWrites(node) &&
-        (node->isBefore(lastWildcard) || node == lastWildcard);
+        (isBeforeSameGraph(node, *lastWildcard) || node == *lastWildcard);
   }
 
   return writesToInputAlias(node) || touchesWildcard;
 }
+
+// Nodes must be in the same graph in order to do `isBefore` or `isAfter`. This
+// traverses the subgraph "chain" upward until we find two nodes that share an
+// owning graph.
+//
+// NOTE: this is n^2 in subgraph depth. Right now the maximum depth is like 2,
+// but if we ever do huge nested subgraphs we'll need to reconsider this.
+bool AliasDb::isBeforeSameGraph(const Node* a, const Node* b) const {
+  auto lhs = a;
+  while (true) {
+    auto rhs = b;
+    while (true) {
+      if (lhs->owningGraph() == rhs->owningGraph()) {
+        return lhs->isBefore(rhs);
+      }
+      if (!subgraphToOwner_.count(rhs->owningGraph())) {
+        break;
+      }
+      rhs = subgraphToOwner_.at(rhs->owningGraph());
+    }
+    if (!subgraphToOwner_.count(lhs->owningGraph())) {
+      break;
+    }
+    lhs = subgraphToOwner_.at(lhs->owningGraph());
+  }
+  JIT_ASSERT(false);
+}
 } // namespace jit
 } // namespace torch
index a399a02..39357e1 100644 (file)
@@ -84,6 +84,8 @@ class AliasDb {
 
   // Does `n` use or write to any wildcard aliases?
   bool hasWildcard(const Node* n) const;
+  // Returns nullopt if there are no wildcard nodes
+  c10::optional<const Node*> getLastWildcard() const;
 
   // Does `n` write to a value that may alias one of the graph inputs?
   bool writesToInputAlias(Node* n) const;
@@ -113,6 +115,8 @@ class AliasDb {
   bool hasWildcardImpl(const Node* n) const;
   bool writesTo(Node* n, const Value* v) const;
 
+  bool isBeforeSameGraph(const Node* lhs, const Node* rhs) const;
+
   std::shared_ptr<Graph> graph_;
   Symbol latestSymbol_ = Symbol::fromQualString("alias::0");
   std::unordered_map<const Value*, AliasInfo> valueToAlias_;
@@ -120,6 +124,7 @@ class AliasDb {
   std::unordered_map<Symbol, std::unordered_set<Node*>> aliasToWrites_;
   std::unordered_set<const Node*> wildcardNodes_;
   std::unordered_set<Symbol> graphInputAliases_;
+  std::unordered_map<const Graph*, const Node*> subgraphToOwner_;
 };
 
 inline TORCH_API AliasDb AliasAnalysis(std::shared_ptr<Graph> graph) {