From 83c054de481d4f65a8a73a903edd6beaac18e8bc Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Fri, 11 Jan 2019 20:04:14 -0800 Subject: [PATCH] AliasDB interface cleanup (#15656) Summary: This is the first of several PRs to simplify AliasDb usage. - Hide the concept wildcards from users. They are too hard to think about and too easy to forget about. - Start moving "mutability-safe" graph mutation methods into AliasDb (right now, the various methods that deal with topological move). Eventually I want to create a "mutability-aware" handle to the graph. If you only use that handle to transform the graph, you can be sure that all transformations are safe with respect to mutability. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15656 Differential Revision: D13615492 Pulled By: suo fbshipit-source-id: 5c39a157b4ea76f1f976315d06a314a89cc4f22f --- test/cpp/jit/tests.h | 25 +- torch/csrc/jit/ir.cpp | 300 ------------------ torch/csrc/jit/ir.h | 66 +--- torch/csrc/jit/passes/alias_analysis.cpp | 339 +++++++++++++++++++++ torch/csrc/jit/passes/alias_analysis.h | 62 +++- torch/csrc/jit/passes/batch_mm.cpp | 10 +- .../passes/common_subexpression_elimination.cpp | 2 +- torch/csrc/jit/passes/constant_propagation.cpp | 2 +- .../csrc/jit/passes/create_autodiff_subgraphs.cpp | 6 +- torch/csrc/jit/passes/dead_code_elimination.cpp | 35 +-- torch/csrc/jit/passes/graph_fuser.cpp | 7 +- torch/csrc/jit/passes/shape_analysis.cpp | 7 +- 12 files changed, 424 insertions(+), 437 deletions(-) diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 4ae3eb4..c751f55 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -1840,7 +1840,7 @@ struct TopoMoveTestFixture { const std::string& insertPoint) { std::function func = [this](Node* toInsert, Node* insertPoint) { - return toInsert->moveBeforeTopologicallyValid(insertPoint, *aliasDb); + return aliasDb->moveBeforeTopologicallyValid(toInsert, insertPoint); }; return moveWithChecks(toInsert, insertPoint, func); } @@ -1850,7 +1850,7 @@ struct TopoMoveTestFixture { const std::string& insertPoint) { std::function func = [this](Node* toInsert, Node* insertPoint) { - return toInsert->moveAfterTopologicallyValid(insertPoint, *aliasDb); + return aliasDb->moveAfterTopologicallyValid(toInsert, insertPoint); }; return moveWithChecks(toInsert, insertPoint, func); } @@ -2035,15 +2035,15 @@ void testAliasAnalysis() { graph->lint(); - const auto aliasDb = AliasAnalysis(graph); + auto aliasDb = AliasAnalysis(graph); // Can't move past a mutation of a used value - JIT_ASSERT(!c->node()->moveAfterTopologicallyValid(aMut->node(), aliasDb)); - JIT_ASSERT(d->node()->moveAfterTopologicallyValid(c->node(), aliasDb)); + JIT_ASSERT(!aliasDb.moveAfterTopologicallyValid(c->node(), aMut->node())); + JIT_ASSERT(aliasDb.moveAfterTopologicallyValid(d->node(), c->node())); // b should alias to a (since they are both inputs) JIT_ASSERT( - !addsB->node()->moveAfterTopologicallyValid(aMut->node(), aliasDb)); - JIT_ASSERT(addsB->node()->moveAfterTopologicallyValid(c->node(), aliasDb)); + !aliasDb.moveAfterTopologicallyValid(addsB->node(), aMut->node())); + JIT_ASSERT(aliasDb.moveAfterTopologicallyValid(addsB->node(), c->node())); graph->lint(); } @@ -2060,12 +2060,11 @@ void testAliasAnalysis() { auto c = graph->insert(aten::add, {fresh, aliasesB}); graph->lint(); - const auto aliasDb = AliasAnalysis(graph); - - JIT_ASSERT(!aliasesB->node()->moveAfterTopologicallyValid( - mutatesAliasOfB->node(), aliasDb)); - JIT_ASSERT(!usesB->node()->moveAfterTopologicallyValid( - mutatesAliasOfB->node(), aliasDb)); + auto aliasDb = AliasAnalysis(graph); + JIT_ASSERT(!aliasDb.moveAfterTopologicallyValid( + aliasesB->node(), mutatesAliasOfB->node())); + JIT_ASSERT(!aliasDb.moveAfterTopologicallyValid( + usesB->node(), mutatesAliasOfB->node())); } } } // namespace diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index 17d640d..28ee32d 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -969,306 +969,6 @@ Node* Node::insertAfter(Node* n) { return this; } -bool Node::moveAfterTopologicallyValid(Node* n, const AliasDb& aliasDb) { - return tryMove(n, MoveSide::AFTER, aliasDb, /*dryRun=*/false); -} - -bool Node::couldMoveAfterTopologically(Node* n, const AliasDb& aliasDb) { - return tryMove(n, MoveSide::AFTER, aliasDb, /*dryRun=*/true); -} - -bool Node::moveBeforeTopologicallyValid(Node* n, const AliasDb& aliasDb) { - // We have to distinguish the move side (instead of just moving after - // n->prev()). Consider the following example: - // If the dependency graph looks like this -> n -> o then moveBefore(o) will - // end up with [this, o, n], but moveAfter(n) will return false. - return tryMove(n, MoveSide::BEFORE, aliasDb, /*dryRun=*/false); -} - -bool Node::couldMoveBeforeTopologically(Node* n, const AliasDb& aliasDb) { - return tryMove(n, MoveSide::BEFORE, aliasDb, /*dryRun=*/true); -} - -// Helper for topologically-safe node moves. See `tryMove()` for details. -namespace { -struct WorkingSet { - public: - explicit WorkingSet(Node* mover, const AliasDb& aliasDb) : aliasDb_(aliasDb) { - add(mover); - } - - // Add `n` to the working set - void add(Node* n) { - nodes_.push_back(n); - for (const auto user : getUsersSameBlock(n)) { - users_[user]++; - } - - for (const auto& writer : getWritersSameBlock(n)) { - writers_[writer]++; - } - if (aliasDb_.hasWildcard(n)) { - numWildcards_++; - } - if (aliasDb_.hasWrites(n)) { - numWriterNodes_++; - } - } - - void eraseMover() { - auto mover = nodes_.front(); - for (const auto user : getUsersSameBlock(mover)) { - // If this user node only uses the mover, we can remove it - if (users_[user] == 1) { - users_.erase(user); - } - } - - for (const auto& writer : getWritersSameBlock(mover)) { - if (writers_[writer] == 1) { - writers_.erase(writer); - } - } - if (aliasDb_.hasWildcard(mover)) { - numWildcards_--; - } - if (aliasDb_.hasWrites(mover)) { - numWriterNodes_--; - } - nodes_.pop_front(); - } - - const std::list& nodes() { - return nodes_; - } - - // Does the working set depend on `n`? - bool dependsOn(Node* n) const { - if (nodes_.empty()) { - return false; - } - - return hasDataDependency(n) || hasMutabilityDependency(n); - } - - private: - bool hasDataDependency(Node* n) const { - if (n->isAfter(nodes_.front())) { - return producesFor(n); - } else { - return consumesFrom(n); - } - } - - bool hasMutabilityDependency(Node* n) const { - // 1. Handle wildcard dependencies: - // If the working set has a wildcard, `n` can't write to anything. - if (numWildcards_ > 0 && aliasDb_.hasWrites(n)) { - return true; - } - - // If `n` has a wildcard, the working set can't write to anything. - if (aliasDb_.hasWildcard(n) && numWriterNodes_ > 0) { - return true; - } - - // 2. Handle regular mutable dependencies - // Check that this node does not write to anything used by the working set - if (writers_.count(n) != 0) { - return true; - } - - // Check that the working set does not write to anything used by this node - const auto writersToNode = getWritersSameBlock(n); - return std::any_of(nodes_.begin(), nodes_.end(), [&](Node* node) { - return writersToNode.count(node) != 0; - }); - } - - // Does the working set produce any values consumed by `n`? - bool producesFor(Node* n) const { - // This equivalent to asking: does the total use-set of all the nodes in the - // working set include `n`? - return users_.count(n) != 0; - } - - // Does the working set consume any values produced by `n`? - bool consumesFrom(Node* n) const { - const auto users = getUsersSameBlock(n); - return std::any_of(nodes_.begin(), nodes_.end(), [&](Node* node) { - return users.count(node) != 0; - }); - } - - // Get all users of outputs of `n`, in the same block as `n`. - // This means if there is an `if` node that uses an output of `n` in some - // inner sub-block, we will consider the whole `if` node a user of `n`. - std::unordered_set getUsersSameBlock(Node* n) const { - std::unordered_set users; - for (const auto output : n->outputs()) { - for (const auto& use : output->uses()) { - if (auto sameBlock = findSameBlock(use.user, n)) { - users.insert(sameBlock); - } - } - } - return users; - } - - std::unordered_set getWritersSameBlock(Node* n) const { - std::unordered_set writers; - for (const auto writer : aliasDb_.getWriters(n)) { - if (auto sameBlock = findSameBlock(writer, n)) { - writers.insert(sameBlock); - } - } - return writers; - } - - // Traverse `target`'s blockchain upward until we find a node that shares a - // block with `n`. - // - // If one can't be found (say, because `n` is an inner block and target is - // outside), then return nullptr. Since we can only reorder nodes within a - // block, `target` would be irrelevant. - static Node* findSameBlock(Node* target, Node* n) { - if (target->owningBlock() == n->owningBlock()) { - return target; - } else { - // This user is in a sub-block. Traverse the blockchain upward until - // we arrive at a node that shares a block with `this` - auto curNode = target; - while (curNode->owningBlock() != n->owningBlock()) { - curNode = curNode->owningBlock()->owningNode(); - if (curNode == nullptr) { - return curNode; - } - } - return curNode; - } - } - - const AliasDb& aliasDb_; - std::list nodes_; - // users => # of working set nodes it uses - std::unordered_map users_; - std::unordered_map writers_; - size_t numWildcards_ = 0; - size_t numWriterNodes_ = 0; -}; -} // namespace - -// Try to move `this` before/after `movePoint` while preserving value -// dependencies. Returns false iff such a move could not be made -// -// The basic approach is: have a "working set" that we are moving forward, one -// node at a time. When we can't move past a node (because it depends on the -// working set), then add it to the working set and keep moving until we hit -// `moveAfter`. -bool Node::tryMove( - Node* movePoint, - MoveSide moveSide, - const AliasDb& aliasDb, - bool dryRun) { - JIT_ASSERT(this->inBlockList() && movePoint->inBlockList()); - JIT_ASSERT(this->owningBlock() == movePoint->owningBlock()); - if (this == movePoint) { - return true; - } - - // 1. Move from `this` toward movePoint, building up the working set of - // dependencies - WorkingSet workingSet(this, aliasDb); - - int direction; - if (this->isAfter(movePoint)) { - direction = kPrevDirection; - } else { - direction = kNextDirection; - } - - auto curNode = this->next_in_graph[direction]; - // Move forward one node at a time - while (curNode != movePoint) { - if (workingSet.dependsOn(curNode)) { - // If we can't move past this node, add it to the working set - workingSet.add(curNode); - } - curNode = curNode->next_in_graph[direction]; - } - - // 2. Decide whether we can move it all to `movePoint`. - - // Say we are moving directly before movePoint and `this` starts before - // movePoint in the graph. The move looks like - // - // `this` `this` | - // -> `movePoint` | `this` and deps are split - // `movePoint` | - // - // Contrast with the case where `this` starts AFTER movePoint: - // - // `movePoint` | - // -> `this` | `this` and deps are together - // `this` `movePoint` | - // - // In the first case, we need to split `this` off from its dependencies, so we - // can move the dependencies below `movePoint` and keep `this` above. - const bool splitThisAndDeps = - (moveSide == MoveSide::BEFORE && this->isBefore(movePoint)) || - (moveSide == MoveSide::AFTER && this->isAfter(movePoint)); - - if (splitThisAndDeps) { - // remove `this` from dependencies to be moved past `movePoint` - workingSet.eraseMover(); - } - - // Check if we can move the working set past the move point - if (workingSet.dependsOn(movePoint)) { - // if we can't, then there are intermediate dependencies between the - // `this` and `movePoint`, so we can't do the move - return false; - } - - if (dryRun) { - return true; - } - - // 3. Execute the move - JIT_ASSERT(curNode == movePoint); - if (splitThisAndDeps) { - // Move `this` - this->move(movePoint, moveSide); - - // Then move all of its dependencies on the other side of `movePoint` - const auto reversed = - moveSide == MoveSide::BEFORE ? MoveSide::AFTER : MoveSide::BEFORE; - for (auto toMove : workingSet.nodes()) { - toMove->move(curNode, reversed); - curNode = toMove; - } - } else { - // Just append/prepend everything to `movePoint` - for (auto toMove : workingSet.nodes()) { - toMove->move(curNode, moveSide); - curNode = toMove; - } - } - return true; -} - -// Helper function so we can generalize `tryMove` -void Node::move(Node* movePoint, MoveSide moveSide) { - switch (moveSide) { - case MoveSide::BEFORE: - this->moveBefore(movePoint); - break; - case MoveSide::AFTER: - this->moveAfter(movePoint); - break; - } -} - void Node::moveAfter(Node* n) { removeFromList(); insertAfter(n); diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 0d4f777..969aa44 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -206,22 +206,6 @@ struct Node : public Attributes { friend const_graph_node_list_iterator; private: - // each node but Return/Param - // is associated with exactly one place in the node list... - // of the graph_ - // this circular is a doubly-linked list. The Return node is used as the - // sentinel for the beginning and end of the list such that the list never has - // null pointers. - // - next_in_graph[0] is next pointer - // - next_in_graph[1] is prev pointer - // - // Using an array to allow the same iterator class for forward and - // reverse node lists - // - // This list represents a topological sort - - Node* next_in_graph[2] = {nullptr, nullptr}; - const NodeKind kind_; std::vector inputs_; std::vector outputs_; @@ -241,6 +225,16 @@ struct Node : public Attributes { protected: TORCH_API Node(Graph* graph_, NodeKind kind_); // defined after graph public: + // each node but Return/Param + // is associated with exactly one place in the node list... + // of the graph_ + // this circular is a doubly-linked list, the Return node is used as the + // sentinel for the beginning and end of the list such that the list never has + // null pointers next_in_graph[0] is next pointer next_in_graph[1] is prev + // pointer using an array to allow the same iterator class for forward and + // reverse node lists This list represents a topological sort + Node* next_in_graph[2] = {nullptr, nullptr}; + Node*& next() { return next_in_graph[kNextDirection]; } @@ -481,7 +475,7 @@ struct Node : public Attributes { // Move 'this' (already in the graph) after 'n' in the topological order. // // NOTE: Does not check that value dependencies are preserved, see - // moveAfterTopologicallyValid + // AliasDb::moveAfterTopologicallyValid // // Given: %2 = f(%1) // %3 = g(%1) @@ -491,26 +485,11 @@ struct Node : public Attributes { // TORCH_API void moveAfter(Node* n); - // Move 'this' (already in the graph) after 'n' in the topological order. - // - // Tries to preserve value dependencies, so other nodes might be moved. We - // make two gurantees about the postcondition of the node list: - // - `this` is directly after `n`. - // - only nodes between `this` and `n` have been moved - // - // Returns `false` if it's impossible to move `this` after `n` without - // violating dependencies, otherwise executes the move and returns `true` - TORCH_API bool moveAfterTopologicallyValid(Node* n, const AliasDb& aliasDb); - - // Like moveAfterTopologicallyValid, but only returns if the move is - // possible, without actually performing it. - TORCH_API bool couldMoveAfterTopologically(Node* n, const AliasDb& aliasdb); - // Move a node 'n' (already in the graph) before 'this' in the topological // order. // // NOTE: Does not check that value dependencies are preserved, see - // moveBeforeTopologicallyValid + // AliasDb::moveBeforeTopologicallyValid // // Given: %2 = f(%1) // %3 = g(%1) @@ -519,21 +498,6 @@ struct Node : public Attributes { // %2 = f(%1) TORCH_API void moveBefore(Node* n); - // Move 'this' (already in the graph) before 'n' in the topological order. - // - // Tries to preserve value dependencies, so other nodes might be moved. We - // make two gurantees about the postcondition of the node list: - // - `this` is directly before `n`. - // - only nodes between `this` and `n` have been moved - // - // Returns `false` if it's impossible to move `this` after `n` without - // violating dependencies, otherwise executes the move and returns `true` - TORCH_API bool moveBeforeTopologicallyValid(Node* n, const AliasDb& aliasDb); - - // Like moveBeforeTopologicallyValid, but only returns if the move is - // possible, without actually performing it. - TORCH_API bool couldMoveBeforeTopologically(Node* n, const AliasDb& aliasDb); - // Remove the input at 'i' from this node. // // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling @@ -617,12 +581,6 @@ struct Node : public Attributes { private: enum class MoveSide { BEFORE, AFTER }; - bool tryMove( - Node* movePoint, - MoveSide moveSide, - const AliasDb& aliasDb, - bool dryRun); - void move(Node* movePoint, MoveSide moveSide); bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const; std::pair findInput(Symbol name); diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp index 384fc2b..56d7fce 100644 --- a/torch/csrc/jit/passes/alias_analysis.cpp +++ b/torch/csrc/jit/passes/alias_analysis.cpp @@ -95,6 +95,26 @@ bool AliasDb::writesTo(Node* n, const Value* v) const { return writers.count(n) != 0; } +bool AliasDb::hasWriters(const Node* n) const { + if (hasWildcard(n)) { + // If `n` has a wildcard, any write in the graph may write to it. + // So the only way we know there are no writers is if there are no writes + // at all. + return !aliasToWrites_.empty(); + } + return getWriters(n).size() != 0; +} + +bool AliasDb::hasWritersBefore(const Node* n) const { + if (hasWildcard(n)) { + return true; + } + const auto writers = getWriters(n); + return std::any_of(writers.cbegin(), writers.cend(), [&](const Node* writer) { + return writer->isBefore(n); + }); +} + bool AliasDb::hasWrites(Node* n) const { for (const auto input : n->inputs()) { if (writesTo(n, input)) { @@ -603,5 +623,324 @@ void AliasDb::giveFreshAlias(const Value* value) { } addAlias(value, getFreshAlias()); } + +bool AliasDb::moveAfterTopologicallyValid(Node* n, Node* movePoint) { + return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/false); +} + +bool AliasDb::couldMoveAfterTopologically(Node* n, Node* movePoint) { + return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/true); +} + +bool AliasDb::moveBeforeTopologicallyValid(Node* n, Node* movePoint) { + // We have to distinguish the move side (instead of just moving after + // n->prev()). Consider the following example: + // If the dependency graph looks like + // n -> movePoint -> o + // then moveBefore(o) will end up with + // n, o, movePoint + // but moveAfter(n) will return false. + return tryMove(n, movePoint, MoveSide::BEFORE, /*dryRun=*/false); +} + +bool AliasDb::couldMoveBeforeTopologically(Node* n, Node* movePoint) { + return tryMove(n, movePoint, MoveSide::BEFORE, /*dryRun=*/true); +} + +// Helper for topologically-safe node moves. See `tryMove()` for details. +class AliasDb::WorkingSet { + public: + explicit WorkingSet(Node* mover, const AliasDb& aliasDb) : aliasDb_(aliasDb) { + add(mover); + } + + // Add `n` to the working set + void add(Node* n) { + nodes_.push_back(n); + for (const auto user : getUsersSameBlock(n)) { + users_[user]++; + } + + for (const auto& writer : getWritersSameBlock(n)) { + writers_[writer]++; + } + if (aliasDb_.hasWildcard(n)) { + numWildcards_++; + } + if (aliasDb_.hasWrites(n)) { + numWriterNodes_++; + } + } + + void eraseMover() { + auto mover = nodes_.front(); + for (const auto user : getUsersSameBlock(mover)) { + // If this user node only uses the mover, we can remove it + if (users_[user] == 1) { + users_.erase(user); + } + } + + for (const auto& writer : getWritersSameBlock(mover)) { + if (writers_[writer] == 1) { + writers_.erase(writer); + } + } + if (aliasDb_.hasWildcard(mover)) { + numWildcards_--; + } + if (aliasDb_.hasWrites(mover)) { + numWriterNodes_--; + } + nodes_.pop_front(); + } + + const std::list& nodes() { + return nodes_; + } + + // Does the working set depend on `n`? + bool dependsOn(Node* n) const { + if (nodes_.empty()) { + return false; + } + + return hasDataDependency(n) || hasMutabilityDependency(n); + } + + private: + bool hasDataDependency(Node* n) const { + if (n->isAfter(nodes_.front())) { + return producesFor(n); + } else { + return consumesFrom(n); + } + } + + bool hasMutabilityDependency(Node* n) const { + // 1. Handle wildcard dependencies: + // If the working set has a wildcard, `n` can't write to anything. + if (numWildcards_ > 0 && aliasDb_.hasWrites(n)) { + return true; + } + + // If `n` has a wildcard, the working set can't write to anything. + if (aliasDb_.hasWildcard(n) && numWriterNodes_ > 0) { + return true; + } + + // 2. Handle regular mutable dependencies + // Check that this node does not write to anything used by the working set + if (writers_.count(n) != 0) { + return true; + } + + // Check that the working set does not write to anything used by this node + const auto writersToNode = getWritersSameBlock(n); + return std::any_of(nodes_.begin(), nodes_.end(), [&](Node* node) { + return writersToNode.count(node) != 0; + }); + } + + // Does the working set produce any values consumed by `n`? + bool producesFor(Node* n) const { + // This equivalent to asking: does the total use-set of all the nodes in the + // working set include `n`? + return users_.count(n) != 0; + } + + // Does the working set consume any values produced by `n`? + bool consumesFrom(Node* n) const { + const auto users = getUsersSameBlock(n); + return std::any_of(nodes_.begin(), nodes_.end(), [&](Node* node) { + return users.count(node) != 0; + }); + } + + // Get all users of outputs of `n`, in the same block as `n`. + // This means if there is an `if` node that uses an output of `n` in some + // inner sub-block, we will consider the whole `if` node a user of `n`. + std::unordered_set getUsersSameBlock(Node* n) const { + std::unordered_set users; + for (const auto output : n->outputs()) { + for (const auto& use : output->uses()) { + if (auto sameBlock = findSameBlock(use.user, n)) { + users.insert(sameBlock); + } + } + } + return users; + } + + std::unordered_set getWritersSameBlock(Node* n) const { + std::unordered_set writers; + for (const auto writer : aliasDb_.getWriters(n)) { + if (auto sameBlock = findSameBlock(writer, n)) { + writers.insert(sameBlock); + } + } + return writers; + } + + // Traverse `target`'s blockchain upward until we find a node that shares a + // block with `n`. + // + // If one can't be found (say, because `n` is an inner block and target is + // outside), then return nullptr. Since we can only reorder nodes within a + // block, `target` would be irrelevant. + static Node* findSameBlock(Node* target, Node* n) { + if (target->owningBlock() == n->owningBlock()) { + return target; + } else { + // This user is in a sub-block. Traverse the blockchain upward until + // we arrive at a node that shares a block with `this` + auto curNode = target; + while (curNode->owningBlock() != n->owningBlock()) { + curNode = curNode->owningBlock()->owningNode(); + if (curNode == nullptr) { + return curNode; + } + } + return curNode; + } + } + + const AliasDb& aliasDb_; + std::list nodes_; + // users => # of working set nodes it uses + std::unordered_map users_; + std::unordered_map writers_; + size_t numWildcards_ = 0; + size_t numWriterNodes_ = 0; +}; + +// Try to move `toMove` before/after `movePoint` while preserving value +// dependencies. Returns false iff such a move could not be made. +// +// If `dryRun` is set, don't actually execute the move, just check if the move +// is possible +// +// The basic approach is: have a "working set" that we are moving forward, one +// node at a time. When we can't move past a node (because it depends on the +// working set), then add it to the working set and keep moving until we hit +// `moveAfter`. +bool AliasDb::tryMove( + Node* toMove, + Node* movePoint, + MoveSide moveSide, + bool dryRun) { + JIT_ASSERT(toMove->owningBlock() == movePoint->owningBlock()); + if (toMove == movePoint) { + return true; + } + + // 1. Move from `this` toward movePoint, building up the working set of + // dependencies + WorkingSet workingSet(toMove, *this); + + int direction; + if (toMove->isAfter(movePoint)) { + direction = kPrevDirection; + } else { + direction = kNextDirection; + } + + auto curNode = toMove->next_in_graph[direction]; + // Move forward one node at a time + while (curNode != movePoint) { + if (workingSet.dependsOn(curNode)) { + // If we can't move past this node, add it to the working set + workingSet.add(curNode); + } + curNode = curNode->next_in_graph[direction]; + } + + // 2. Decide whether we can move it all to `movePoint`. + + // Say we are moving directly before movePoint and `toMove` starts before + // movePoint in the graph. The move looks like + // + // `toMove` `toMove` | + // -> `movePoint` | `toMove` and deps are split + // `movePoint` | + // + // Contrast with the case where `toMove` starts AFTER movePoint: + // + // `movePoint` | + // -> `toMove` | `toMove` and deps are together + // `toMove` `movePoint` | + // + // In the first case, we need to split `this` off from its dependencies, so we + // can move the dependencies below `movePoint` and keep `toMove` above. + const bool splitToMoveAndDeps = + (moveSide == MoveSide::BEFORE && toMove->isBefore(movePoint)) || + (moveSide == MoveSide::AFTER && toMove->isAfter(movePoint)); + + if (splitToMoveAndDeps) { + // remove `this` from dependencies to be moved past `movePoint` + workingSet.eraseMover(); + } + + // Check if we can move the working set past the move point + if (workingSet.dependsOn(movePoint)) { + // if we can't, then there are intermediate dependencies between the + // `this` and `movePoint`, so we can't do the move + return false; + } + + if (dryRun) { + return true; + } + + // 3. Execute the move + JIT_ASSERT(curNode == movePoint); + if (splitToMoveAndDeps) { + // Move `toMove` + move(toMove, movePoint, moveSide); + + // Then move all of its dependencies on the other side of `movePoint` + const auto reversed = + moveSide == MoveSide::BEFORE ? MoveSide::AFTER : MoveSide::BEFORE; + for (auto n : workingSet.nodes()) { + move(n, curNode, reversed); + curNode = n; + } + } else { + // Just append/prepend everything to `movePoint` + for (auto n : workingSet.nodes()) { + move(n, curNode, moveSide); + curNode = n; + } + } + return true; +} + +// Helper function so we can generalize `tryMove` +void AliasDb::move(Node* toMove, Node* movePoint, MoveSide moveSide) { + switch (moveSide) { + case MoveSide::BEFORE: + toMove->moveBefore(movePoint); + break; + case MoveSide::AFTER: + toMove->moveAfter(movePoint); + break; + } +} + +bool AliasDb::hasUntrackedEffects(Node* node) const { + bool touchesWildcard = false; + if (!wildcardNodes_.empty()) { + auto lastWildcard = *wildcardNodes_.begin(); + for (const auto wildcard : wildcardNodes_) { + if (wildcard->isAfter(lastWildcard)) { + lastWildcard = wildcard; + } + } + touchesWildcard = hasWrites(node) && + (node->isBefore(lastWildcard) || node == lastWildcard); + } + + return writesToInputAlias(node) || touchesWildcard; +} } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/alias_analysis.h b/torch/csrc/jit/passes/alias_analysis.h index dca904e..a399a02 100644 --- a/torch/csrc/jit/passes/alias_analysis.h +++ b/torch/csrc/jit/passes/alias_analysis.h @@ -10,9 +10,9 @@ namespace jit { * Alias analysis pass. * * This pass produces an AliasDb that contains aliasing and mutation - * information about the graph. Callers (right now moveAfterTopologicallyValid) - * can use this information to determine whether mutations to the graph are - * safe, in that they don't reorder/change nodes in a way that affects output. + * information about the graph. Users can use this information to determine + * whether mutations to the graph are safe, i.e. they don't reorder/change + * nodes in a way that affects output. * * Every value with a mutable type (Tensors, Lists, Tuples, etc.) will be * associated with one or more "alias sets". If two values share an alias set, @@ -28,18 +28,18 @@ class AliasDb { public: explicit AliasDb(std::shared_ptr graph); - // Does `n` use or write to any wildcard aliases? - bool hasWildcard(const Node* n) const; - - const std::unordered_set& getWildcardNodes() const { - return wildcardNodes_; - } - // Does `n` write to any alias sets? bool hasWrites(Node* n) const; - // Does `n` write to a value that may alias one of the graph inputs? - bool writesToInputAlias(Node* n) const; + // There are limitations to what effects the alias analysis can track. Two + // kinds of nodes may have untracked effects: + // 1. Nodes that write to a value that may alias the graph inputs (since + // the inputs can be used outside the graph). + // 2. Nodes that write to something in the wildcard set. + // + // These nodes are considered not safe to eliminate or mutate under any + // circumstances. + bool hasUntrackedEffects(Node* n) const; // Get all nodes that write to any alias set inputed/outputed by `n` std::unordered_set getWriters(const Node* n) const; @@ -50,15 +50,44 @@ class AliasDb { // Get all values that may alias to `v`. std::unordered_set getAliases(const Value* v) const; - // Do any nodes write to an alias set inputed/outputed by `n`? - bool hasWriters(const Node* n) const { - return getWriters(n).size() != 0; - } + // Do any nodes write to an alias set inputed/outputed by `n`? + bool hasWriters(const Node* n) const; + + // Same as hasWriters() but ignores writes after `n`. + bool hasWritersBefore(const Node* n) const; + + // Move 'n' (already in the graph) after 'movePoint' in the topological order. + // + // Tries to preserve value dependencies, so other nodes might be moved. We + // make two gurantees about the postcondition of the node list: + // - `n` is directly after `movePoint`. + // - only nodes between `n` and `movePoint` have been moved. + // + // Returns `false` if it's impossible to move `n` after `MovePoint` without + // violating dependencies, otherwise executes the move and returns `true` + bool moveAfterTopologicallyValid(Node* n, Node* movePoint); + bool moveBeforeTopologicallyValid(Node* n, Node* movePoint); + + bool couldMoveAfterTopologically(Node* n, Node* movePoint); + bool couldMoveBeforeTopologically(Node* n, Node* movePoint); // For debugging: print alias db state to stdout void dump() const; private: + // Helper for topologically-safe node moves. + class WorkingSet; + enum class MoveSide { BEFORE, AFTER }; + bool tryMove(Node* toMove, Node* movePoint, MoveSide moveSide, bool dryRun); + void move(Node* toMove, Node* movePoint, MoveSide moveSide); + bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const; + + // Does `n` use or write to any wildcard aliases? + bool hasWildcard(const Node* n) const; + + // Does `n` write to a value that may alias one of the graph inputs? + bool writesToInputAlias(Node* n) const; + void analyze(const std::shared_ptr& graph); void analyze(Block* block); void analyze(Node* node); @@ -73,6 +102,7 @@ class AliasDb { Symbol getFreshAlias(bool isGraphInput = false); void addAlias(const Value* value, AliasInfo alias); + void addAlias(const Value* value, Symbol alias); void addAlias(const Value* value, const Value* from); void mapAliases(at::ArrayRef to, at::ArrayRef from); diff --git a/torch/csrc/jit/passes/batch_mm.cpp b/torch/csrc/jit/passes/batch_mm.cpp index 2c99d89..3f1693b 100644 --- a/torch/csrc/jit/passes/batch_mm.cpp +++ b/torch/csrc/jit/passes/batch_mm.cpp @@ -329,7 +329,7 @@ RegisterOperators mm_batch_side_reg( std::pair, std::vector> gatherIndependentMMUses( Value* value, - const AliasDb& alias_db) { + AliasDb& alias_db) { const auto postprocess = [&](std::vector mms) { if (mms.size() == 0) { return mms; @@ -346,7 +346,7 @@ std::pair, std::vector> gatherIndependentMMUses( for (size_t j = i + 1; j < mms.size(); ++j) { if (mms[j] == nullptr) continue; - if (!mms[j]->couldMoveBeforeTopologically(mms[i], alias_db)) { + if (!alias_db.couldMoveBeforeTopologically(mms[j], mms[i])) { mms[j] = nullptr; } } @@ -370,13 +370,13 @@ std::pair, std::vector> gatherIndependentMMUses( return std::make_pair(postprocess(lhses), postprocess(rhses)); } -void BatchMMSide(Block* block, const AliasDb& alias_db) { +void BatchMMSide(Block* block, AliasDb& alias_db) { // NB: 8 is the current loop unrolling factor static constexpr size_t how_many_is_many = 8; const auto batch_side = [&](std::vector& mms, Side side) { JIT_ASSERT(!mms.empty()); for (int64_t i = static_cast(mms.size()) - 2; i >= 0; --i) { - bool move_ok = mms[i]->moveBeforeTopologicallyValid(mms[i + 1], alias_db); + bool move_ok = alias_db.moveBeforeTopologicallyValid(mms[i], mms[i + 1]); JIT_ASSERT(move_ok); } WithInsertPoint insert_guard{mms[0]}; @@ -435,7 +435,7 @@ void BatchMM(std::shared_ptr& graph) { // TODO(suo): make BatchMM mutability-safe return; } - const auto alias_db = AliasAnalysis(graph); + auto alias_db = AliasAnalysis(graph); BatchMMTreeReduce(graph->block()); BatchMMSide(graph->block(), alias_db); EliminateDeadCode(graph); diff --git a/torch/csrc/jit/passes/common_subexpression_elimination.cpp b/torch/csrc/jit/passes/common_subexpression_elimination.cpp index 7a87c1f..f2be647 100644 --- a/torch/csrc/jit/passes/common_subexpression_elimination.cpp +++ b/torch/csrc/jit/passes/common_subexpression_elimination.cpp @@ -24,7 +24,7 @@ void EliminateCommonSubexpression( for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { auto node = *it; if (node->hasSideEffects() || node->isNondeterministic() || - aliasDb.hasWriters(node) || aliasDb.hasWildcard(node)) { + aliasDb.hasWriters(node)) { // Do NOT have enough information to do CSE on these nodes. continue; } diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index 1fa33f6..b8d6f2b 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -123,7 +123,7 @@ void ConstantPropagation(Node* n, const AliasDb& aliasDb, bool recurse) { }); bool supported_node = !n->kind().is_onnx() && skip_list.count(n->kind()) == 0 && !n->isNondeterministic() && - !n->hasSideEffects() && !aliasDb.hasWriters(n) && !aliasDb.hasWildcard(n); + !n->hasSideEffects() && !aliasDb.hasWriters(n); auto run_blocks = [&]() { if (recurse) { for (Block* block : n->blocks()) { diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp index 198323a..292a491 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp @@ -123,7 +123,7 @@ class SubgraphSlicer { std::pair scanNode( Node* consumer, - const AliasDb& aliasDb) { + AliasDb& aliasDb) { if (shouldConsiderForMerge(consumer)) { if (consumer->kind() != prim::DifferentiableGraph) { consumer = SubgraphUtils::createSingletonSubgraph( @@ -147,10 +147,10 @@ class SubgraphSlicer { c10::optional tryMerge( Node* consumer, Node* producer, - const AliasDb& aliasDb) { + AliasDb& aliasDb) { JIT_ASSERT(consumer->kind() == prim::DifferentiableGraph); bool canMerge = shouldConsiderForMerge(producer) && - producer->moveBeforeTopologicallyValid(consumer, aliasDb); + aliasDb.moveBeforeTopologicallyValid(producer, consumer); if (!canMerge) { return c10::nullopt; diff --git a/torch/csrc/jit/passes/dead_code_elimination.cpp b/torch/csrc/jit/passes/dead_code_elimination.cpp index 3389dce..d43694b 100644 --- a/torch/csrc/jit/passes/dead_code_elimination.cpp +++ b/torch/csrc/jit/passes/dead_code_elimination.cpp @@ -18,11 +18,6 @@ class DeadCodeEliminator { // we mark "live" nodes that are necessary for the output. Nodes that have // side effects are also marked. void run(Block* block, bool recurse) { - // Find the last wildcard in the block. We cannot eliminate any mutable ops - // that precede the last wildcard (since they may have written to the - // wildcard alias set) - setLastWildcard(); - // Initialize by marking the return node and all its consumed values as live mark(block->return_node()); @@ -40,24 +35,6 @@ class DeadCodeEliminator { } private: - void setLastWildcard() { - if (!aliasDb_) { - return; - } - - const auto& wildcards = aliasDb_->getWildcardNodes(); - if (wildcards.empty()) { - return; - } - - lastWildcard_ = *wildcards.begin(); - for (const auto wildcard : wildcards) { - if (wildcard->isAfter(*lastWildcard_)) { - lastWildcard_ = wildcard; - } - } - } - // Special handling for block return nodes. Unlike other nodes, the block // return node doesn't really "use" its inputs. Consider: // @@ -227,16 +204,7 @@ class DeadCodeEliminator { auto schema = node->maybeSchema(); return schema && schema->is_mutable(); } else { - // Otherwise, there are two kinds of nodes with untracked effects: - // 1. Nodes that write to a value that may alias the graph inputs (since - // the inputs can be used outside the graph). - // 2. Anything that could clobber a wildcard value. - bool touchesWildcard = false; - if (lastWildcard_) { - touchesWildcard = aliasDb_->hasWrites(node) && - (node->isBefore(*lastWildcard_) || node == *lastWildcard_); - } - return aliasDb_->writesToInputAlias(node) || touchesWildcard; + return aliasDb_->hasUntrackedEffects(node); } } @@ -300,7 +268,6 @@ class DeadCodeEliminator { std::unordered_set marked_; std::unordered_set liveValues_; std::unordered_set liveAliases_; - c10::optional lastWildcard_; std::function&)> deleteCallback_ = [](const std::unordered_set&) {}; }; diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index fd7e3e1..d59ce74 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -439,8 +439,7 @@ struct GraphFuser { // consumer. Fusion will rewrite those later uses to use the version of // producer generated by the fused blob. In this case, producer becomes // an output of the fusion group. - producer->node()->moveBeforeTopologicallyValid( - consumer, aliasDb_.value()); + aliasDb_->moveBeforeTopologicallyValid(producer->node(), consumer); if (!shouldFuse) { return at::nullopt; @@ -1012,8 +1011,8 @@ struct GraphFuser { } // NB: it is important that this check happens after isFusable, which checks // that the blocks match, and it's not a special node like prim::Param - if (!producer->node()->couldMoveBeforeTopologically( - before_check, aliasDb_.value())) { + if (!aliasDb_->couldMoveBeforeTopologically( + producer->node(), before_check)) { return false; } // Fusion groups can be merged with concat's group if and only if diff --git a/torch/csrc/jit/passes/shape_analysis.cpp b/torch/csrc/jit/passes/shape_analysis.cpp index 69ed1ef..4a0fa77 100644 --- a/torch/csrc/jit/passes/shape_analysis.cpp +++ b/torch/csrc/jit/passes/shape_analysis.cpp @@ -204,12 +204,7 @@ class ShapePropagator { return dependsOnMutationMemo_[node]; } - const auto writers = aliasDb_.getWriters(node); - const auto hasWritersBefore = - std::any_of(writers.cbegin(), writers.cend(), [&](const Node* writer) { - return writer->isBefore(node); - }); - if (hasWritersBefore || aliasDb_.hasWildcard(node)) { + if (aliasDb_.hasWritersBefore(node)) { // If something could have written to a value used by this node, we can't // guarantee the result is the same when running it in isolation. dependsOnMutationMemo_[node] = true; -- 2.7.4