AliasDB interface cleanup (#15656)
authorMichael Suo <suo@fb.com>
Sat, 12 Jan 2019 04:04:14 +0000 (20:04 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 12 Jan 2019 04:06:53 +0000 (20:06 -0800)
Summary:
This is the first of several PRs to simplify AliasDb usage.
- Hide the concept wildcards from users. They are too hard to think about and too easy to forget about.
- Start moving "mutability-safe" graph mutation methods into AliasDb (right now, the various methods that deal with topological move).

Eventually I want to create a "mutability-aware" handle to the graph. If you only use that handle to transform the graph, you can be sure that all transformations are safe with respect to mutability.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15656

Differential Revision: D13615492

Pulled By: suo

fbshipit-source-id: 5c39a157b4ea76f1f976315d06a314a89cc4f22f

12 files changed:
test/cpp/jit/tests.h
torch/csrc/jit/ir.cpp
torch/csrc/jit/ir.h
torch/csrc/jit/passes/alias_analysis.cpp
torch/csrc/jit/passes/alias_analysis.h
torch/csrc/jit/passes/batch_mm.cpp
torch/csrc/jit/passes/common_subexpression_elimination.cpp
torch/csrc/jit/passes/constant_propagation.cpp
torch/csrc/jit/passes/create_autodiff_subgraphs.cpp
torch/csrc/jit/passes/dead_code_elimination.cpp
torch/csrc/jit/passes/graph_fuser.cpp
torch/csrc/jit/passes/shape_analysis.cpp

index 4ae3eb4..c751f55 100644 (file)
@@ -1840,7 +1840,7 @@ struct TopoMoveTestFixture {
       const std::string& insertPoint) {
     std::function<bool(Node*, Node*)> func =
         [this](Node* toInsert, Node* insertPoint) {
-          return toInsert->moveBeforeTopologicallyValid(insertPoint, *aliasDb);
+          return aliasDb->moveBeforeTopologicallyValid(toInsert, insertPoint);
         };
     return moveWithChecks(toInsert, insertPoint, func);
   }
@@ -1850,7 +1850,7 @@ struct TopoMoveTestFixture {
       const std::string& insertPoint) {
     std::function<bool(Node*, Node*)> func =
         [this](Node* toInsert, Node* insertPoint) {
-          return toInsert->moveAfterTopologicallyValid(insertPoint, *aliasDb);
+          return aliasDb->moveAfterTopologicallyValid(toInsert, insertPoint);
         };
     return moveWithChecks(toInsert, insertPoint, func);
   }
@@ -2035,15 +2035,15 @@ void testAliasAnalysis() {
 
     graph->lint();
 
-    const auto aliasDb = AliasAnalysis(graph);
+    auto aliasDb = AliasAnalysis(graph);
     // Can't move past a mutation of a used value
-    JIT_ASSERT(!c->node()->moveAfterTopologicallyValid(aMut->node(), aliasDb));
-    JIT_ASSERT(d->node()->moveAfterTopologicallyValid(c->node(), aliasDb));
+    JIT_ASSERT(!aliasDb.moveAfterTopologicallyValid(c->node(), aMut->node()));
+    JIT_ASSERT(aliasDb.moveAfterTopologicallyValid(d->node(), c->node()));
 
     // b should alias to a (since they are both inputs)
     JIT_ASSERT(
-        !addsB->node()->moveAfterTopologicallyValid(aMut->node(), aliasDb));
-    JIT_ASSERT(addsB->node()->moveAfterTopologicallyValid(c->node(), aliasDb));
+        !aliasDb.moveAfterTopologicallyValid(addsB->node(), aMut->node()));
+    JIT_ASSERT(aliasDb.moveAfterTopologicallyValid(addsB->node(), c->node()));
 
     graph->lint();
   }
@@ -2060,12 +2060,11 @@ void testAliasAnalysis() {
     auto c = graph->insert(aten::add, {fresh, aliasesB});
     graph->lint();
 
-    const auto aliasDb = AliasAnalysis(graph);
-
-    JIT_ASSERT(!aliasesB->node()->moveAfterTopologicallyValid(
-        mutatesAliasOfB->node(), aliasDb));
-    JIT_ASSERT(!usesB->node()->moveAfterTopologicallyValid(
-        mutatesAliasOfB->node(), aliasDb));
+    auto aliasDb = AliasAnalysis(graph);
+    JIT_ASSERT(!aliasDb.moveAfterTopologicallyValid(
+        aliasesB->node(), mutatesAliasOfB->node()));
+    JIT_ASSERT(!aliasDb.moveAfterTopologicallyValid(
+        usesB->node(), mutatesAliasOfB->node()));
   }
 }
 } // namespace
index 17d640d..28ee32d 100644 (file)
@@ -969,306 +969,6 @@ Node* Node::insertAfter(Node* n) {
   return this;
 }
 
-bool Node::moveAfterTopologicallyValid(Node* n, const AliasDb& aliasDb) {
-  return tryMove(n, MoveSide::AFTER, aliasDb, /*dryRun=*/false);
-}
-
-bool Node::couldMoveAfterTopologically(Node* n, const AliasDb& aliasDb) {
-  return tryMove(n, MoveSide::AFTER, aliasDb, /*dryRun=*/true);
-}
-
-bool Node::moveBeforeTopologicallyValid(Node* n, const AliasDb& aliasDb) {
-  // We have to distinguish the move side (instead of just moving after
-  // n->prev()). Consider the following example:
-  //   If the dependency graph looks like this -> n -> o then moveBefore(o) will
-  //   end up with [this, o, n], but moveAfter(n) will return false.
-  return tryMove(n, MoveSide::BEFORE, aliasDb, /*dryRun=*/false);
-}
-
-bool Node::couldMoveBeforeTopologically(Node* n, const AliasDb& aliasDb) {
-  return tryMove(n, MoveSide::BEFORE, aliasDb, /*dryRun=*/true);
-}
-
-// Helper for topologically-safe node moves. See `tryMove()` for details.
-namespace {
-struct WorkingSet {
- public:
-  explicit WorkingSet(Node* mover, const AliasDb& aliasDb) : aliasDb_(aliasDb) {
-    add(mover);
-  }
-
-  // Add `n` to the working set
-  void add(Node* n) {
-    nodes_.push_back(n);
-    for (const auto user : getUsersSameBlock(n)) {
-      users_[user]++;
-    }
-
-    for (const auto& writer : getWritersSameBlock(n)) {
-      writers_[writer]++;
-    }
-    if (aliasDb_.hasWildcard(n)) {
-      numWildcards_++;
-    }
-    if (aliasDb_.hasWrites(n)) {
-      numWriterNodes_++;
-    }
-  }
-
-  void eraseMover() {
-    auto mover = nodes_.front();
-    for (const auto user : getUsersSameBlock(mover)) {
-      // If this user node only uses the mover, we can remove it
-      if (users_[user] == 1) {
-        users_.erase(user);
-      }
-    }
-
-    for (const auto& writer : getWritersSameBlock(mover)) {
-      if (writers_[writer] == 1) {
-        writers_.erase(writer);
-      }
-    }
-    if (aliasDb_.hasWildcard(mover)) {
-      numWildcards_--;
-    }
-    if (aliasDb_.hasWrites(mover)) {
-      numWriterNodes_--;
-    }
-    nodes_.pop_front();
-  }
-
-  const std::list<Node*>& nodes() {
-    return nodes_;
-  }
-
-  // Does the working set depend on `n`?
-  bool dependsOn(Node* n) const {
-    if (nodes_.empty()) {
-      return false;
-    }
-
-    return hasDataDependency(n) || hasMutabilityDependency(n);
-  }
-
- private:
-  bool hasDataDependency(Node* n) const {
-    if (n->isAfter(nodes_.front())) {
-      return producesFor(n);
-    } else {
-      return consumesFrom(n);
-    }
-  }
-
-  bool hasMutabilityDependency(Node* n) const {
-    // 1. Handle wildcard dependencies:
-    // If the working set has a wildcard, `n` can't write to anything.
-    if (numWildcards_ > 0 && aliasDb_.hasWrites(n)) {
-      return true;
-    }
-
-    // If `n` has a wildcard, the working set can't write to anything.
-    if (aliasDb_.hasWildcard(n) && numWriterNodes_ > 0) {
-      return true;
-    }
-
-    // 2. Handle regular mutable dependencies
-    // Check that this node does not write to anything used by the working set
-    if (writers_.count(n) != 0) {
-      return true;
-    }
-
-    // Check that the working set does not write to anything used by this node
-    const auto writersToNode = getWritersSameBlock(n);
-    return std::any_of(nodes_.begin(), nodes_.end(), [&](Node* node) {
-      return writersToNode.count(node) != 0;
-    });
-  }
-
-  // Does the working set produce any values consumed by `n`?
-  bool producesFor(Node* n) const {
-    // This equivalent to asking: does the total use-set of all the nodes in the
-    // working set include `n`?
-    return users_.count(n) != 0;
-  }
-
-  // Does the working set consume any values produced by `n`?
-  bool consumesFrom(Node* n) const {
-    const auto users = getUsersSameBlock(n);
-    return std::any_of(nodes_.begin(), nodes_.end(), [&](Node* node) {
-      return users.count(node) != 0;
-    });
-  }
-
-  // Get all users of outputs of `n`, in the same block as `n`.
-  // This means if there is an `if` node that uses an output of `n` in some
-  // inner sub-block, we will consider the whole `if` node a user of `n`.
-  std::unordered_set<Node*> getUsersSameBlock(Node* n) const {
-    std::unordered_set<Node*> users;
-    for (const auto output : n->outputs()) {
-      for (const auto& use : output->uses()) {
-        if (auto sameBlock = findSameBlock(use.user, n)) {
-          users.insert(sameBlock);
-        }
-      }
-    }
-    return users;
-  }
-
-  std::unordered_set<Node*> getWritersSameBlock(Node* n) const {
-    std::unordered_set<Node*> writers;
-    for (const auto writer : aliasDb_.getWriters(n)) {
-      if (auto sameBlock = findSameBlock(writer, n)) {
-        writers.insert(sameBlock);
-      }
-    }
-    return writers;
-  }
-
-  // Traverse `target`'s blockchain upward until we find a node that shares a
-  // block with `n`.
-  //
-  // If one can't be found (say, because `n` is an inner block and target is
-  // outside), then return nullptr. Since we can only reorder nodes within a
-  // block, `target` would be irrelevant.
-  static Node* findSameBlock(Node* target, Node* n) {
-    if (target->owningBlock() == n->owningBlock()) {
-      return target;
-    } else {
-      // This user is in a sub-block. Traverse the blockchain upward until
-      // we arrive at a node that shares a block with `this`
-      auto curNode = target;
-      while (curNode->owningBlock() != n->owningBlock()) {
-        curNode = curNode->owningBlock()->owningNode();
-        if (curNode == nullptr) {
-          return curNode;
-        }
-      }
-      return curNode;
-    }
-  }
-
-  const AliasDb& aliasDb_;
-  std::list<Node*> nodes_;
-  // users => # of working set nodes it uses
-  std::unordered_map<Node*, size_t> users_;
-  std::unordered_map<Node*, size_t> writers_;
-  size_t numWildcards_ = 0;
-  size_t numWriterNodes_ = 0;
-};
-} // namespace
-
-// Try to move `this` before/after `movePoint` while preserving value
-// dependencies. Returns false iff such a move could not be made
-//
-// The basic approach is: have a "working set" that we are moving forward, one
-// node at a time. When we can't move past a node (because it depends on the
-// working set), then add it to the working set and keep moving until we hit
-// `moveAfter`.
-bool Node::tryMove(
-    Node* movePoint,
-    MoveSide moveSide,
-    const AliasDb& aliasDb,
-    bool dryRun) {
-  JIT_ASSERT(this->inBlockList() && movePoint->inBlockList());
-  JIT_ASSERT(this->owningBlock() == movePoint->owningBlock());
-  if (this == movePoint) {
-    return true;
-  }
-
-  // 1. Move from `this` toward movePoint, building up the working set of
-  // dependencies
-  WorkingSet workingSet(this, aliasDb);
-
-  int direction;
-  if (this->isAfter(movePoint)) {
-    direction = kPrevDirection;
-  } else {
-    direction = kNextDirection;
-  }
-
-  auto curNode = this->next_in_graph[direction];
-  // Move forward one node at a time
-  while (curNode != movePoint) {
-    if (workingSet.dependsOn(curNode)) {
-      // If we can't move past this node, add it to the working set
-      workingSet.add(curNode);
-    }
-    curNode = curNode->next_in_graph[direction];
-  }
-
-  // 2. Decide whether we can move it all to `movePoint`.
-
-  // Say we are moving directly before movePoint and `this` starts before
-  // movePoint in the graph. The move looks like
-  //
-  //  `this`              `this`           |
-  //  <dependencies>  ->  `movePoint`      | `this` and deps are split
-  //  `movePoint`         <dependencies>   |
-  //
-  // Contrast with the case where `this` starts AFTER movePoint:
-  //
-  //  `movePoint`         <dependencies>   |
-  //  <dependencies>  ->  `this`           | `this` and deps are together
-  //  `this`              `movePoint`      |
-  //
-  // In the first case, we need to split `this` off from its dependencies, so we
-  // can move the dependencies below `movePoint` and keep `this` above.
-  const bool splitThisAndDeps =
-      (moveSide == MoveSide::BEFORE && this->isBefore(movePoint)) ||
-      (moveSide == MoveSide::AFTER && this->isAfter(movePoint));
-
-  if (splitThisAndDeps) {
-    // remove `this` from dependencies to be moved past `movePoint`
-    workingSet.eraseMover();
-  }
-
-  // Check if we can move the working set past the move point
-  if (workingSet.dependsOn(movePoint)) {
-    // if we can't, then there are intermediate dependencies between the
-    // `this` and `movePoint`, so we can't do the move
-    return false;
-  }
-
-  if (dryRun) {
-    return true;
-  }
-
-  // 3. Execute the move
-  JIT_ASSERT(curNode == movePoint);
-  if (splitThisAndDeps) {
-    // Move `this`
-    this->move(movePoint, moveSide);
-
-    // Then move all of its dependencies on the other side of `movePoint`
-    const auto reversed =
-        moveSide == MoveSide::BEFORE ? MoveSide::AFTER : MoveSide::BEFORE;
-    for (auto toMove : workingSet.nodes()) {
-      toMove->move(curNode, reversed);
-      curNode = toMove;
-    }
-  } else {
-    // Just append/prepend everything to `movePoint`
-    for (auto toMove : workingSet.nodes()) {
-      toMove->move(curNode, moveSide);
-      curNode = toMove;
-    }
-  }
-  return true;
-}
-
-// Helper function so we can generalize `tryMove`
-void Node::move(Node* movePoint, MoveSide moveSide) {
-  switch (moveSide) {
-    case MoveSide::BEFORE:
-      this->moveBefore(movePoint);
-      break;
-    case MoveSide::AFTER:
-      this->moveAfter(movePoint);
-      break;
-  }
-}
-
 void Node::moveAfter(Node* n) {
   removeFromList();
   insertAfter(n);
index 0d4f777..969aa44 100644 (file)
@@ -206,22 +206,6 @@ struct Node : public Attributes<Node> {
   friend const_graph_node_list_iterator;
 
  private:
-  // each node but Return/Param
-  // is associated with exactly one place in the node list...
-  // of the graph_
-  // this circular is a doubly-linked list. The Return node is used as the
-  // sentinel for the beginning and end of the list such that the list never has
-  // null pointers.
-  // - next_in_graph[0] is next pointer
-  // - next_in_graph[1] is prev pointer
-  //
-  // Using an array to allow the same iterator class for forward and
-  // reverse node lists
-  //
-  // This list represents a topological sort
-
-  Node* next_in_graph[2] = {nullptr, nullptr};
-
   const NodeKind kind_;
   std::vector<Value*> inputs_;
   std::vector<Value*> outputs_;
@@ -241,6 +225,16 @@ struct Node : public Attributes<Node> {
  protected:
   TORCH_API Node(Graph* graph_, NodeKind kind_); // defined after graph
  public:
+  // each node but Return/Param
+  // is associated with exactly one place in the node list...
+  // of the graph_
+  // this circular is a doubly-linked list, the Return node is used as the
+  // sentinel for the beginning and end of the list such that the list never has
+  // null pointers next_in_graph[0] is next pointer next_in_graph[1] is prev
+  // pointer using an array to allow the same iterator class for forward and
+  // reverse node lists This list represents a topological sort
+  Node* next_in_graph[2] = {nullptr, nullptr};
+
   Node*& next() {
     return next_in_graph[kNextDirection];
   }
@@ -481,7 +475,7 @@ struct Node : public Attributes<Node> {
   // Move 'this' (already in the graph) after 'n' in the topological order.
   //
   // NOTE: Does not check that value dependencies are preserved, see
-  //   moveAfterTopologicallyValid
+  //   AliasDb::moveAfterTopologicallyValid
   //
   // Given: %2 = f(%1)
   //        %3 = g(%1)
@@ -491,26 +485,11 @@ struct Node : public Attributes<Node> {
   //
   TORCH_API void moveAfter(Node* n);
 
-  // Move 'this' (already in the graph) after 'n' in the topological order.
-  //
-  // Tries to preserve value dependencies, so other nodes might be moved. We
-  // make two gurantees about the postcondition of the node list:
-  //   - `this` is directly after `n`.
-  //   - only nodes between `this` and `n` have been moved
-  //
-  // Returns `false` if it's impossible to move `this` after `n` without
-  // violating dependencies, otherwise executes the move and returns `true`
-  TORCH_API bool moveAfterTopologicallyValid(Node* n, const AliasDb& aliasDb);
-
-  // Like moveAfterTopologicallyValid, but only returns if the move is
-  // possible, without actually performing it.
-  TORCH_API bool couldMoveAfterTopologically(Node* n, const AliasDb& aliasdb);
-
   // Move a node 'n' (already in the graph) before 'this' in the topological
   // order.
   //
   // NOTE: Does not check that value dependencies are preserved, see
-  //   moveBeforeTopologicallyValid
+  //   AliasDb::moveBeforeTopologicallyValid
   //
   // Given: %2 = f(%1)
   //        %3 = g(%1)
@@ -519,21 +498,6 @@ struct Node : public Attributes<Node> {
   //         %2 = f(%1)
   TORCH_API void moveBefore(Node* n);
 
-  // Move 'this' (already in the graph) before 'n' in the topological order.
-  //
-  // Tries to preserve value dependencies, so other nodes might be moved. We
-  // make two gurantees about the postcondition of the node list:
-  //   - `this` is directly before `n`.
-  //   - only nodes between `this` and `n` have been moved
-  //
-  // Returns `false` if it's impossible to move `this` after `n` without
-  // violating dependencies, otherwise executes the move and returns `true`
-  TORCH_API bool moveBeforeTopologicallyValid(Node* n, const AliasDb& aliasDb);
-
-  // Like moveBeforeTopologicallyValid, but only returns if the move is
-  // possible, without actually performing it.
-  TORCH_API bool couldMoveBeforeTopologically(Node* n, const AliasDb& aliasDb);
-
   // Remove the input at 'i' from this node.
   //
   // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling
@@ -617,12 +581,6 @@ struct Node : public Attributes<Node> {
 
  private:
   enum class MoveSide { BEFORE, AFTER };
-  bool tryMove(
-      Node* movePoint,
-      MoveSide moveSide,
-      const AliasDb& aliasDb,
-      bool dryRun);
-  void move(Node* movePoint, MoveSide moveSide);
   bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
 
   std::pair<Value*, const Argument&> findInput(Symbol name);
index 384fc2b..56d7fce 100644 (file)
@@ -95,6 +95,26 @@ bool AliasDb::writesTo(Node* n, const Value* v) const {
   return writers.count(n) != 0;
 }
 
+bool AliasDb::hasWriters(const Node* n) const {
+  if (hasWildcard(n)) {
+    // If `n` has a wildcard, any write in the graph may write to it.
+    // So the only way we know there are no writers is if there are no writes
+    // at all.
+    return !aliasToWrites_.empty();
+  }
+  return getWriters(n).size() != 0;
+}
+
+bool AliasDb::hasWritersBefore(const Node* n) const {
+  if (hasWildcard(n)) {
+    return true;
+  }
+  const auto writers = getWriters(n);
+  return std::any_of(writers.cbegin(), writers.cend(), [&](const Node* writer) {
+    return writer->isBefore(n);
+  });
+}
+
 bool AliasDb::hasWrites(Node* n) const {
   for (const auto input : n->inputs()) {
     if (writesTo(n, input)) {
@@ -603,5 +623,324 @@ void AliasDb::giveFreshAlias(const Value* value) {
   }
   addAlias(value, getFreshAlias());
 }
+
+bool AliasDb::moveAfterTopologicallyValid(Node* n, Node* movePoint) {
+  return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/false);
+}
+
+bool AliasDb::couldMoveAfterTopologically(Node* n, Node* movePoint) {
+  return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/true);
+}
+
+bool AliasDb::moveBeforeTopologicallyValid(Node* n, Node* movePoint) {
+  // We have to distinguish the move side (instead of just moving after
+  // n->prev()). Consider the following example:
+  // If the dependency graph looks like
+  //   n -> movePoint -> o
+  // then moveBefore(o) will end up with
+  //   n, o, movePoint
+  // but moveAfter(n) will return false.
+  return tryMove(n, movePoint, MoveSide::BEFORE, /*dryRun=*/false);
+}
+
+bool AliasDb::couldMoveBeforeTopologically(Node* n, Node* movePoint) {
+  return tryMove(n, movePoint, MoveSide::BEFORE, /*dryRun=*/true);
+}
+
+// Helper for topologically-safe node moves. See `tryMove()` for details.
+class AliasDb::WorkingSet {
+ public:
+  explicit WorkingSet(Node* mover, const AliasDb& aliasDb) : aliasDb_(aliasDb) {
+    add(mover);
+  }
+
+  // Add `n` to the working set
+  void add(Node* n) {
+    nodes_.push_back(n);
+    for (const auto user : getUsersSameBlock(n)) {
+      users_[user]++;
+    }
+
+    for (const auto& writer : getWritersSameBlock(n)) {
+      writers_[writer]++;
+    }
+    if (aliasDb_.hasWildcard(n)) {
+      numWildcards_++;
+    }
+    if (aliasDb_.hasWrites(n)) {
+      numWriterNodes_++;
+    }
+  }
+
+  void eraseMover() {
+    auto mover = nodes_.front();
+    for (const auto user : getUsersSameBlock(mover)) {
+      // If this user node only uses the mover, we can remove it
+      if (users_[user] == 1) {
+        users_.erase(user);
+      }
+    }
+
+    for (const auto& writer : getWritersSameBlock(mover)) {
+      if (writers_[writer] == 1) {
+        writers_.erase(writer);
+      }
+    }
+    if (aliasDb_.hasWildcard(mover)) {
+      numWildcards_--;
+    }
+    if (aliasDb_.hasWrites(mover)) {
+      numWriterNodes_--;
+    }
+    nodes_.pop_front();
+  }
+
+  const std::list<Node*>& nodes() {
+    return nodes_;
+  }
+
+  // Does the working set depend on `n`?
+  bool dependsOn(Node* n) const {
+    if (nodes_.empty()) {
+      return false;
+    }
+
+    return hasDataDependency(n) || hasMutabilityDependency(n);
+  }
+
+ private:
+  bool hasDataDependency(Node* n) const {
+    if (n->isAfter(nodes_.front())) {
+      return producesFor(n);
+    } else {
+      return consumesFrom(n);
+    }
+  }
+
+  bool hasMutabilityDependency(Node* n) const {
+    // 1. Handle wildcard dependencies:
+    // If the working set has a wildcard, `n` can't write to anything.
+    if (numWildcards_ > 0 && aliasDb_.hasWrites(n)) {
+      return true;
+    }
+
+    // If `n` has a wildcard, the working set can't write to anything.
+    if (aliasDb_.hasWildcard(n) && numWriterNodes_ > 0) {
+      return true;
+    }
+
+    // 2. Handle regular mutable dependencies
+    // Check that this node does not write to anything used by the working set
+    if (writers_.count(n) != 0) {
+      return true;
+    }
+
+    // Check that the working set does not write to anything used by this node
+    const auto writersToNode = getWritersSameBlock(n);
+    return std::any_of(nodes_.begin(), nodes_.end(), [&](Node* node) {
+      return writersToNode.count(node) != 0;
+    });
+  }
+
+  // Does the working set produce any values consumed by `n`?
+  bool producesFor(Node* n) const {
+    // This equivalent to asking: does the total use-set of all the nodes in the
+    // working set include `n`?
+    return users_.count(n) != 0;
+  }
+
+  // Does the working set consume any values produced by `n`?
+  bool consumesFrom(Node* n) const {
+    const auto users = getUsersSameBlock(n);
+    return std::any_of(nodes_.begin(), nodes_.end(), [&](Node* node) {
+      return users.count(node) != 0;
+    });
+  }
+
+  // Get all users of outputs of `n`, in the same block as `n`.
+  // This means if there is an `if` node that uses an output of `n` in some
+  // inner sub-block, we will consider the whole `if` node a user of `n`.
+  std::unordered_set<Node*> getUsersSameBlock(Node* n) const {
+    std::unordered_set<Node*> users;
+    for (const auto output : n->outputs()) {
+      for (const auto& use : output->uses()) {
+        if (auto sameBlock = findSameBlock(use.user, n)) {
+          users.insert(sameBlock);
+        }
+      }
+    }
+    return users;
+  }
+
+  std::unordered_set<Node*> getWritersSameBlock(Node* n) const {
+    std::unordered_set<Node*> writers;
+    for (const auto writer : aliasDb_.getWriters(n)) {
+      if (auto sameBlock = findSameBlock(writer, n)) {
+        writers.insert(sameBlock);
+      }
+    }
+    return writers;
+  }
+
+  // Traverse `target`'s blockchain upward until we find a node that shares a
+  // block with `n`.
+  //
+  // If one can't be found (say, because `n` is an inner block and target is
+  // outside), then return nullptr. Since we can only reorder nodes within a
+  // block, `target` would be irrelevant.
+  static Node* findSameBlock(Node* target, Node* n) {
+    if (target->owningBlock() == n->owningBlock()) {
+      return target;
+    } else {
+      // This user is in a sub-block. Traverse the blockchain upward until
+      // we arrive at a node that shares a block with `this`
+      auto curNode = target;
+      while (curNode->owningBlock() != n->owningBlock()) {
+        curNode = curNode->owningBlock()->owningNode();
+        if (curNode == nullptr) {
+          return curNode;
+        }
+      }
+      return curNode;
+    }
+  }
+
+  const AliasDb& aliasDb_;
+  std::list<Node*> nodes_;
+  // users => # of working set nodes it uses
+  std::unordered_map<Node*, size_t> users_;
+  std::unordered_map<Node*, size_t> writers_;
+  size_t numWildcards_ = 0;
+  size_t numWriterNodes_ = 0;
+};
+
+// Try to move `toMove` before/after `movePoint` while preserving value
+// dependencies. Returns false iff such a move could not be made.
+//
+// If `dryRun` is set, don't actually execute the move, just check if the move
+// is possible
+//
+// The basic approach is: have a "working set" that we are moving forward, one
+// node at a time. When we can't move past a node (because it depends on the
+// working set), then add it to the working set and keep moving until we hit
+// `moveAfter`.
+bool AliasDb::tryMove(
+    Node* toMove,
+    Node* movePoint,
+    MoveSide moveSide,
+    bool dryRun) {
+  JIT_ASSERT(toMove->owningBlock() == movePoint->owningBlock());
+  if (toMove == movePoint) {
+    return true;
+  }
+
+  // 1. Move from `this` toward movePoint, building up the working set of
+  // dependencies
+  WorkingSet workingSet(toMove, *this);
+
+  int direction;
+  if (toMove->isAfter(movePoint)) {
+    direction = kPrevDirection;
+  } else {
+    direction = kNextDirection;
+  }
+
+  auto curNode = toMove->next_in_graph[direction];
+  // Move forward one node at a time
+  while (curNode != movePoint) {
+    if (workingSet.dependsOn(curNode)) {
+      // If we can't move past this node, add it to the working set
+      workingSet.add(curNode);
+    }
+    curNode = curNode->next_in_graph[direction];
+  }
+
+  // 2. Decide whether we can move it all to `movePoint`.
+
+  // Say we are moving directly before movePoint and `toMove` starts before
+  // movePoint in the graph. The move looks like
+  //
+  //  `toMove`            `toMove`         |
+  //  <dependencies>  ->  `movePoint`      | `toMove` and deps are split
+  //  `movePoint`         <dependencies>   |
+  //
+  // Contrast with the case where `toMove` starts AFTER movePoint:
+  //
+  //  `movePoint`           <dependencies>   |
+  //  <dependencies>  ->    `toMove`         | `toMove` and deps are together
+  //  `toMove`              `movePoint`      |
+  //
+  // In the first case, we need to split `this` off from its dependencies, so we
+  // can move the dependencies below `movePoint` and keep `toMove` above.
+  const bool splitToMoveAndDeps =
+      (moveSide == MoveSide::BEFORE && toMove->isBefore(movePoint)) ||
+      (moveSide == MoveSide::AFTER && toMove->isAfter(movePoint));
+
+  if (splitToMoveAndDeps) {
+    // remove `this` from dependencies to be moved past `movePoint`
+    workingSet.eraseMover();
+  }
+
+  // Check if we can move the working set past the move point
+  if (workingSet.dependsOn(movePoint)) {
+    // if we can't, then there are intermediate dependencies between the
+    // `this` and `movePoint`, so we can't do the move
+    return false;
+  }
+
+  if (dryRun) {
+    return true;
+  }
+
+  // 3. Execute the move
+  JIT_ASSERT(curNode == movePoint);
+  if (splitToMoveAndDeps) {
+    // Move `toMove`
+    move(toMove, movePoint, moveSide);
+
+    // Then move all of its dependencies on the other side of `movePoint`
+    const auto reversed =
+        moveSide == MoveSide::BEFORE ? MoveSide::AFTER : MoveSide::BEFORE;
+    for (auto n : workingSet.nodes()) {
+      move(n, curNode, reversed);
+      curNode = n;
+    }
+  } else {
+    // Just append/prepend everything to `movePoint`
+    for (auto n : workingSet.nodes()) {
+      move(n, curNode, moveSide);
+      curNode = n;
+    }
+  }
+  return true;
+}
+
+// Helper function so we can generalize `tryMove`
+void AliasDb::move(Node* toMove, Node* movePoint, MoveSide moveSide) {
+  switch (moveSide) {
+    case MoveSide::BEFORE:
+      toMove->moveBefore(movePoint);
+      break;
+    case MoveSide::AFTER:
+      toMove->moveAfter(movePoint);
+      break;
+  }
+}
+
+bool AliasDb::hasUntrackedEffects(Node* node) const {
+  bool touchesWildcard = false;
+  if (!wildcardNodes_.empty()) {
+    auto lastWildcard = *wildcardNodes_.begin();
+    for (const auto wildcard : wildcardNodes_) {
+      if (wildcard->isAfter(lastWildcard)) {
+        lastWildcard = wildcard;
+      }
+    }
+    touchesWildcard = hasWrites(node) &&
+        (node->isBefore(lastWildcard) || node == lastWildcard);
+  }
+
+  return writesToInputAlias(node) || touchesWildcard;
+}
 } // namespace jit
 } // namespace torch
index dca904e..a399a02 100644 (file)
@@ -10,9 +10,9 @@ namespace jit {
  * Alias analysis pass.
  *
  * This pass produces an AliasDb that contains aliasing and mutation
- * information about the graph. Callers (right now moveAfterTopologicallyValid)
- * can use this information to determine whether mutations to the graph are
- * safe, in that they don't reorder/change nodes in a way that affects output.
+ * information about the graph. Users can use this information to determine
+ * whether mutations to the graph are safe, i.e. they don't reorder/change
+ * nodes in a way that affects output.
  *
  * Every value with a mutable type (Tensors, Lists, Tuples, etc.) will be
  * associated with one or more "alias sets". If two values share an alias set,
@@ -28,18 +28,18 @@ class AliasDb {
  public:
   explicit AliasDb(std::shared_ptr<Graph> graph);
 
-  // Does `n` use or write to any wildcard aliases?
-  bool hasWildcard(const Node* n) const;
-
-  const std::unordered_set<const Node*>& getWildcardNodes() const {
-    return wildcardNodes_;
-  }
-
   // Does `n` write to any alias sets?
   bool hasWrites(Node* n) const;
 
-  // Does `n` write to a value that may alias one of the graph inputs?
-  bool writesToInputAlias(Node* n) const;
+  // There are limitations to what effects the alias analysis can track. Two
+  // kinds of nodes may have untracked effects:
+  // 1. Nodes that write to a value that may alias the graph inputs (since
+  //    the inputs can be used outside the graph).
+  // 2. Nodes that write to something in the wildcard set.
+  //
+  // These nodes are considered not safe to eliminate or mutate under any
+  // circumstances.
+  bool hasUntrackedEffects(Node* n) const;
 
   // Get all nodes that write to any alias set inputed/outputed by `n`
   std::unordered_set<Node*> getWriters(const Node* n) const;
@@ -50,15 +50,44 @@ class AliasDb {
   // Get all values that may alias to `v`.
   std::unordered_set<const Value*> getAliases(const Value* v) const;
 
-  // Do any nodes  write to an alias set inputed/outputed by `n`?
-  bool hasWriters(const Node* n) const {
-    return getWriters(n).size() != 0;
-  }
+  // Do any nodes write to an alias set inputed/outputed by `n`?
+  bool hasWriters(const Node* n) const;
+
+  // Same as hasWriters() but ignores writes after `n`.
+  bool hasWritersBefore(const Node* n) const;
+
+  // Move 'n' (already in the graph) after 'movePoint' in the topological order.
+  //
+  // Tries to preserve value dependencies, so other nodes might be moved. We
+  // make two gurantees about the postcondition of the node list:
+  //   - `n` is directly after `movePoint`.
+  //   - only nodes between `n` and `movePoint` have been moved.
+  //
+  // Returns `false` if it's impossible to move `n` after `MovePoint` without
+  // violating dependencies, otherwise executes the move and returns `true`
+  bool moveAfterTopologicallyValid(Node* n, Node* movePoint);
+  bool moveBeforeTopologicallyValid(Node* n, Node* movePoint);
+
+  bool couldMoveAfterTopologically(Node* n, Node* movePoint);
+  bool couldMoveBeforeTopologically(Node* n, Node* movePoint);
 
   // For debugging: print alias db state to stdout
   void dump() const;
 
  private:
+  // Helper for topologically-safe node moves.
+  class WorkingSet;
+  enum class MoveSide { BEFORE, AFTER };
+  bool tryMove(Node* toMove, Node* movePoint, MoveSide moveSide, bool dryRun);
+  void move(Node* toMove, Node* movePoint, MoveSide moveSide);
+  bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
+
+  // Does `n` use or write to any wildcard aliases?
+  bool hasWildcard(const Node* n) const;
+
+  // Does `n` write to a value that may alias one of the graph inputs?
+  bool writesToInputAlias(Node* n) const;
+
   void analyze(const std::shared_ptr<Graph>& graph);
   void analyze(Block* block);
   void analyze(Node* node);
@@ -73,6 +102,7 @@ class AliasDb {
 
   Symbol getFreshAlias(bool isGraphInput = false);
   void addAlias(const Value* value, AliasInfo alias);
+
   void addAlias(const Value* value, Symbol alias);
   void addAlias(const Value* value, const Value* from);
   void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from);
index 2c99d89..3f1693b 100644 (file)
@@ -329,7 +329,7 @@ RegisterOperators mm_batch_side_reg(
 
 std::pair<std::vector<Node*>, std::vector<Node*>> gatherIndependentMMUses(
     Value* value,
-    const AliasDb& alias_db) {
+    AliasDb& alias_db) {
   const auto postprocess = [&](std::vector<Node*> mms) {
     if (mms.size() == 0) {
       return mms;
@@ -346,7 +346,7 @@ std::pair<std::vector<Node*>, std::vector<Node*>> gatherIndependentMMUses(
       for (size_t j = i + 1; j < mms.size(); ++j) {
         if (mms[j] == nullptr)
           continue;
-        if (!mms[j]->couldMoveBeforeTopologically(mms[i], alias_db)) {
+        if (!alias_db.couldMoveBeforeTopologically(mms[j], mms[i])) {
           mms[j] = nullptr;
         }
       }
@@ -370,13 +370,13 @@ std::pair<std::vector<Node*>, std::vector<Node*>> gatherIndependentMMUses(
   return std::make_pair(postprocess(lhses), postprocess(rhses));
 }
 
-void BatchMMSide(Block* block, const AliasDb& alias_db) {
+void BatchMMSide(Block* block, AliasDb& alias_db) {
   // NB: 8 is the current loop unrolling factor
   static constexpr size_t how_many_is_many = 8;
   const auto batch_side = [&](std::vector<Node*>& mms, Side side) {
     JIT_ASSERT(!mms.empty());
     for (int64_t i = static_cast<int64_t>(mms.size()) - 2; i >= 0; --i) {
-      bool move_ok = mms[i]->moveBeforeTopologicallyValid(mms[i + 1], alias_db);
+      bool move_ok = alias_db.moveBeforeTopologicallyValid(mms[i], mms[i + 1]);
       JIT_ASSERT(move_ok);
     }
     WithInsertPoint insert_guard{mms[0]};
@@ -435,7 +435,7 @@ void BatchMM(std::shared_ptr<Graph>& graph) {
     // TODO(suo): make BatchMM mutability-safe
     return;
   }
-  const auto alias_db = AliasAnalysis(graph);
+  auto alias_db = AliasAnalysis(graph);
   BatchMMTreeReduce(graph->block());
   BatchMMSide(graph->block(), alias_db);
   EliminateDeadCode(graph);
index 7a87c1f..f2be647 100644 (file)
@@ -24,7 +24,7 @@ void EliminateCommonSubexpression(
   for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
     auto node = *it;
     if (node->hasSideEffects() || node->isNondeterministic() ||
-        aliasDb.hasWriters(node) || aliasDb.hasWildcard(node)) {
+        aliasDb.hasWriters(node)) {
       // Do NOT have enough information to do CSE on these nodes.
       continue;
     }
index 1fa33f6..b8d6f2b 100644 (file)
@@ -123,7 +123,7 @@ void ConstantPropagation(Node* n, const AliasDb& aliasDb, bool recurse) {
       });
   bool supported_node = !n->kind().is_onnx() &&
       skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
-      !n->hasSideEffects() && !aliasDb.hasWriters(n) && !aliasDb.hasWildcard(n);
+      !n->hasSideEffects() && !aliasDb.hasWriters(n);
   auto run_blocks = [&]() {
     if (recurse) {
       for (Block* block : n->blocks()) {
index 198323a..292a491 100644 (file)
@@ -123,7 +123,7 @@ class SubgraphSlicer {
 
   std::pair<graph_node_list::iterator, bool> scanNode(
       Node* consumer,
-      const AliasDb& aliasDb) {
+      AliasDb& aliasDb) {
     if (shouldConsiderForMerge(consumer)) {
       if (consumer->kind() != prim::DifferentiableGraph) {
         consumer = SubgraphUtils::createSingletonSubgraph(
@@ -147,10 +147,10 @@ class SubgraphSlicer {
   c10::optional<Node*> tryMerge(
       Node* consumer,
       Node* producer,
-      const AliasDb& aliasDb) {
+      AliasDb& aliasDb) {
     JIT_ASSERT(consumer->kind() == prim::DifferentiableGraph);
     bool canMerge = shouldConsiderForMerge(producer) &&
-        producer->moveBeforeTopologicallyValid(consumer, aliasDb);
+        aliasDb.moveBeforeTopologicallyValid(producer, consumer);
 
     if (!canMerge) {
       return c10::nullopt;
index 3389dce..d43694b 100644 (file)
@@ -18,11 +18,6 @@ class DeadCodeEliminator {
   // we mark "live" nodes that are necessary for the output. Nodes that have
   // side effects are also marked.
   void run(Block* block, bool recurse) {
-    // Find the last wildcard in the block. We cannot eliminate any mutable ops
-    // that precede the last wildcard (since they may have written to the
-    // wildcard alias set)
-    setLastWildcard();
-
     // Initialize by marking the return node and all its consumed values as live
     mark(block->return_node());
 
@@ -40,24 +35,6 @@ class DeadCodeEliminator {
   }
 
  private:
-  void setLastWildcard() {
-    if (!aliasDb_) {
-      return;
-    }
-
-    const auto& wildcards = aliasDb_->getWildcardNodes();
-    if (wildcards.empty()) {
-      return;
-    }
-
-    lastWildcard_ = *wildcards.begin();
-    for (const auto wildcard : wildcards) {
-      if (wildcard->isAfter(*lastWildcard_)) {
-        lastWildcard_ = wildcard;
-      }
-    }
-  }
-
   // Special handling for block return nodes. Unlike other nodes, the block
   // return node doesn't really "use" its inputs. Consider:
   //
@@ -227,16 +204,7 @@ class DeadCodeEliminator {
       auto schema = node->maybeSchema();
       return schema && schema->is_mutable();
     } else {
-      // Otherwise, there are two kinds of nodes with untracked effects:
-      // 1. Nodes that write to a value that may alias the graph inputs (since
-      //    the inputs can be used outside the graph).
-      // 2. Anything that could clobber a wildcard value.
-      bool touchesWildcard = false;
-      if (lastWildcard_) {
-        touchesWildcard = aliasDb_->hasWrites(node) &&
-            (node->isBefore(*lastWildcard_) || node == *lastWildcard_);
-      }
-      return aliasDb_->writesToInputAlias(node) || touchesWildcard;
+      return aliasDb_->hasUntrackedEffects(node);
     }
   }
 
@@ -300,7 +268,6 @@ class DeadCodeEliminator {
   std::unordered_set<Node*> marked_;
   std::unordered_set<const Value*> liveValues_;
   std::unordered_set<const Value*> liveAliases_;
-  c10::optional<const Node*> lastWildcard_;
   std::function<void(const std::unordered_set<const Value*>&)> deleteCallback_ =
       [](const std::unordered_set<const Value*>&) {};
 };
index fd7e3e1..d59ce74 100644 (file)
@@ -439,8 +439,7 @@ struct GraphFuser {
         // consumer. Fusion will rewrite those later uses to use the version of
         // producer generated by the fused blob. In this case, producer becomes
         // an output of the fusion group.
-        producer->node()->moveBeforeTopologicallyValid(
-            consumer, aliasDb_.value());
+        aliasDb_->moveBeforeTopologicallyValid(producer->node(), consumer);
 
     if (!shouldFuse) {
       return at::nullopt;
@@ -1012,8 +1011,8 @@ struct GraphFuser {
     }
     // NB: it is important that this check happens after isFusable, which checks
     // that the blocks match, and it's not a special node like prim::Param
-    if (!producer->node()->couldMoveBeforeTopologically(
-            before_check, aliasDb_.value())) {
+    if (!aliasDb_->couldMoveBeforeTopologically(
+            producer->node(), before_check)) {
       return false;
     }
     // Fusion groups can be merged with concat's group if and only if
index 69ed1ef..4a0fa77 100644 (file)
@@ -204,12 +204,7 @@ class ShapePropagator {
       return dependsOnMutationMemo_[node];
     }
 
-    const auto writers = aliasDb_.getWriters(node);
-    const auto hasWritersBefore =
-        std::any_of(writers.cbegin(), writers.cend(), [&](const Node* writer) {
-          return writer->isBefore(node);
-        });
-    if (hasWritersBefore || aliasDb_.hasWildcard(node)) {
+    if (aliasDb_.hasWritersBefore(node)) {
       // If something could have written to a value used by this node, we can't
       // guarantee the result is the same when running it in isolation.
       dependsOnMutationMemo_[node] = true;