const std::string& insertPoint) {
std::function<bool(Node*, Node*)> func =
[this](Node* toInsert, Node* insertPoint) {
- return toInsert->moveBeforeTopologicallyValid(insertPoint, *aliasDb);
+ return aliasDb->moveBeforeTopologicallyValid(toInsert, insertPoint);
};
return moveWithChecks(toInsert, insertPoint, func);
}
const std::string& insertPoint) {
std::function<bool(Node*, Node*)> func =
[this](Node* toInsert, Node* insertPoint) {
- return toInsert->moveAfterTopologicallyValid(insertPoint, *aliasDb);
+ return aliasDb->moveAfterTopologicallyValid(toInsert, insertPoint);
};
return moveWithChecks(toInsert, insertPoint, func);
}
graph->lint();
- const auto aliasDb = AliasAnalysis(graph);
+ auto aliasDb = AliasAnalysis(graph);
// Can't move past a mutation of a used value
- JIT_ASSERT(!c->node()->moveAfterTopologicallyValid(aMut->node(), aliasDb));
- JIT_ASSERT(d->node()->moveAfterTopologicallyValid(c->node(), aliasDb));
+ JIT_ASSERT(!aliasDb.moveAfterTopologicallyValid(c->node(), aMut->node()));
+ JIT_ASSERT(aliasDb.moveAfterTopologicallyValid(d->node(), c->node()));
// b should alias to a (since they are both inputs)
JIT_ASSERT(
- !addsB->node()->moveAfterTopologicallyValid(aMut->node(), aliasDb));
- JIT_ASSERT(addsB->node()->moveAfterTopologicallyValid(c->node(), aliasDb));
+ !aliasDb.moveAfterTopologicallyValid(addsB->node(), aMut->node()));
+ JIT_ASSERT(aliasDb.moveAfterTopologicallyValid(addsB->node(), c->node()));
graph->lint();
}
auto c = graph->insert(aten::add, {fresh, aliasesB});
graph->lint();
- const auto aliasDb = AliasAnalysis(graph);
-
- JIT_ASSERT(!aliasesB->node()->moveAfterTopologicallyValid(
- mutatesAliasOfB->node(), aliasDb));
- JIT_ASSERT(!usesB->node()->moveAfterTopologicallyValid(
- mutatesAliasOfB->node(), aliasDb));
+ auto aliasDb = AliasAnalysis(graph);
+ JIT_ASSERT(!aliasDb.moveAfterTopologicallyValid(
+ aliasesB->node(), mutatesAliasOfB->node()));
+ JIT_ASSERT(!aliasDb.moveAfterTopologicallyValid(
+ usesB->node(), mutatesAliasOfB->node()));
}
}
} // namespace
return this;
}
-bool Node::moveAfterTopologicallyValid(Node* n, const AliasDb& aliasDb) {
- return tryMove(n, MoveSide::AFTER, aliasDb, /*dryRun=*/false);
-}
-
-bool Node::couldMoveAfterTopologically(Node* n, const AliasDb& aliasDb) {
- return tryMove(n, MoveSide::AFTER, aliasDb, /*dryRun=*/true);
-}
-
-bool Node::moveBeforeTopologicallyValid(Node* n, const AliasDb& aliasDb) {
- // We have to distinguish the move side (instead of just moving after
- // n->prev()). Consider the following example:
- // If the dependency graph looks like this -> n -> o then moveBefore(o) will
- // end up with [this, o, n], but moveAfter(n) will return false.
- return tryMove(n, MoveSide::BEFORE, aliasDb, /*dryRun=*/false);
-}
-
-bool Node::couldMoveBeforeTopologically(Node* n, const AliasDb& aliasDb) {
- return tryMove(n, MoveSide::BEFORE, aliasDb, /*dryRun=*/true);
-}
-
-// Helper for topologically-safe node moves. See `tryMove()` for details.
-namespace {
-struct WorkingSet {
- public:
- explicit WorkingSet(Node* mover, const AliasDb& aliasDb) : aliasDb_(aliasDb) {
- add(mover);
- }
-
- // Add `n` to the working set
- void add(Node* n) {
- nodes_.push_back(n);
- for (const auto user : getUsersSameBlock(n)) {
- users_[user]++;
- }
-
- for (const auto& writer : getWritersSameBlock(n)) {
- writers_[writer]++;
- }
- if (aliasDb_.hasWildcard(n)) {
- numWildcards_++;
- }
- if (aliasDb_.hasWrites(n)) {
- numWriterNodes_++;
- }
- }
-
- void eraseMover() {
- auto mover = nodes_.front();
- for (const auto user : getUsersSameBlock(mover)) {
- // If this user node only uses the mover, we can remove it
- if (users_[user] == 1) {
- users_.erase(user);
- }
- }
-
- for (const auto& writer : getWritersSameBlock(mover)) {
- if (writers_[writer] == 1) {
- writers_.erase(writer);
- }
- }
- if (aliasDb_.hasWildcard(mover)) {
- numWildcards_--;
- }
- if (aliasDb_.hasWrites(mover)) {
- numWriterNodes_--;
- }
- nodes_.pop_front();
- }
-
- const std::list<Node*>& nodes() {
- return nodes_;
- }
-
- // Does the working set depend on `n`?
- bool dependsOn(Node* n) const {
- if (nodes_.empty()) {
- return false;
- }
-
- return hasDataDependency(n) || hasMutabilityDependency(n);
- }
-
- private:
- bool hasDataDependency(Node* n) const {
- if (n->isAfter(nodes_.front())) {
- return producesFor(n);
- } else {
- return consumesFrom(n);
- }
- }
-
- bool hasMutabilityDependency(Node* n) const {
- // 1. Handle wildcard dependencies:
- // If the working set has a wildcard, `n` can't write to anything.
- if (numWildcards_ > 0 && aliasDb_.hasWrites(n)) {
- return true;
- }
-
- // If `n` has a wildcard, the working set can't write to anything.
- if (aliasDb_.hasWildcard(n) && numWriterNodes_ > 0) {
- return true;
- }
-
- // 2. Handle regular mutable dependencies
- // Check that this node does not write to anything used by the working set
- if (writers_.count(n) != 0) {
- return true;
- }
-
- // Check that the working set does not write to anything used by this node
- const auto writersToNode = getWritersSameBlock(n);
- return std::any_of(nodes_.begin(), nodes_.end(), [&](Node* node) {
- return writersToNode.count(node) != 0;
- });
- }
-
- // Does the working set produce any values consumed by `n`?
- bool producesFor(Node* n) const {
- // This equivalent to asking: does the total use-set of all the nodes in the
- // working set include `n`?
- return users_.count(n) != 0;
- }
-
- // Does the working set consume any values produced by `n`?
- bool consumesFrom(Node* n) const {
- const auto users = getUsersSameBlock(n);
- return std::any_of(nodes_.begin(), nodes_.end(), [&](Node* node) {
- return users.count(node) != 0;
- });
- }
-
- // Get all users of outputs of `n`, in the same block as `n`.
- // This means if there is an `if` node that uses an output of `n` in some
- // inner sub-block, we will consider the whole `if` node a user of `n`.
- std::unordered_set<Node*> getUsersSameBlock(Node* n) const {
- std::unordered_set<Node*> users;
- for (const auto output : n->outputs()) {
- for (const auto& use : output->uses()) {
- if (auto sameBlock = findSameBlock(use.user, n)) {
- users.insert(sameBlock);
- }
- }
- }
- return users;
- }
-
- std::unordered_set<Node*> getWritersSameBlock(Node* n) const {
- std::unordered_set<Node*> writers;
- for (const auto writer : aliasDb_.getWriters(n)) {
- if (auto sameBlock = findSameBlock(writer, n)) {
- writers.insert(sameBlock);
- }
- }
- return writers;
- }
-
- // Traverse `target`'s blockchain upward until we find a node that shares a
- // block with `n`.
- //
- // If one can't be found (say, because `n` is an inner block and target is
- // outside), then return nullptr. Since we can only reorder nodes within a
- // block, `target` would be irrelevant.
- static Node* findSameBlock(Node* target, Node* n) {
- if (target->owningBlock() == n->owningBlock()) {
- return target;
- } else {
- // This user is in a sub-block. Traverse the blockchain upward until
- // we arrive at a node that shares a block with `this`
- auto curNode = target;
- while (curNode->owningBlock() != n->owningBlock()) {
- curNode = curNode->owningBlock()->owningNode();
- if (curNode == nullptr) {
- return curNode;
- }
- }
- return curNode;
- }
- }
-
- const AliasDb& aliasDb_;
- std::list<Node*> nodes_;
- // users => # of working set nodes it uses
- std::unordered_map<Node*, size_t> users_;
- std::unordered_map<Node*, size_t> writers_;
- size_t numWildcards_ = 0;
- size_t numWriterNodes_ = 0;
-};
-} // namespace
-
-// Try to move `this` before/after `movePoint` while preserving value
-// dependencies. Returns false iff such a move could not be made
-//
-// The basic approach is: have a "working set" that we are moving forward, one
-// node at a time. When we can't move past a node (because it depends on the
-// working set), then add it to the working set and keep moving until we hit
-// `moveAfter`.
-bool Node::tryMove(
- Node* movePoint,
- MoveSide moveSide,
- const AliasDb& aliasDb,
- bool dryRun) {
- JIT_ASSERT(this->inBlockList() && movePoint->inBlockList());
- JIT_ASSERT(this->owningBlock() == movePoint->owningBlock());
- if (this == movePoint) {
- return true;
- }
-
- // 1. Move from `this` toward movePoint, building up the working set of
- // dependencies
- WorkingSet workingSet(this, aliasDb);
-
- int direction;
- if (this->isAfter(movePoint)) {
- direction = kPrevDirection;
- } else {
- direction = kNextDirection;
- }
-
- auto curNode = this->next_in_graph[direction];
- // Move forward one node at a time
- while (curNode != movePoint) {
- if (workingSet.dependsOn(curNode)) {
- // If we can't move past this node, add it to the working set
- workingSet.add(curNode);
- }
- curNode = curNode->next_in_graph[direction];
- }
-
- // 2. Decide whether we can move it all to `movePoint`.
-
- // Say we are moving directly before movePoint and `this` starts before
- // movePoint in the graph. The move looks like
- //
- // `this` `this` |
- // <dependencies> -> `movePoint` | `this` and deps are split
- // `movePoint` <dependencies> |
- //
- // Contrast with the case where `this` starts AFTER movePoint:
- //
- // `movePoint` <dependencies> |
- // <dependencies> -> `this` | `this` and deps are together
- // `this` `movePoint` |
- //
- // In the first case, we need to split `this` off from its dependencies, so we
- // can move the dependencies below `movePoint` and keep `this` above.
- const bool splitThisAndDeps =
- (moveSide == MoveSide::BEFORE && this->isBefore(movePoint)) ||
- (moveSide == MoveSide::AFTER && this->isAfter(movePoint));
-
- if (splitThisAndDeps) {
- // remove `this` from dependencies to be moved past `movePoint`
- workingSet.eraseMover();
- }
-
- // Check if we can move the working set past the move point
- if (workingSet.dependsOn(movePoint)) {
- // if we can't, then there are intermediate dependencies between the
- // `this` and `movePoint`, so we can't do the move
- return false;
- }
-
- if (dryRun) {
- return true;
- }
-
- // 3. Execute the move
- JIT_ASSERT(curNode == movePoint);
- if (splitThisAndDeps) {
- // Move `this`
- this->move(movePoint, moveSide);
-
- // Then move all of its dependencies on the other side of `movePoint`
- const auto reversed =
- moveSide == MoveSide::BEFORE ? MoveSide::AFTER : MoveSide::BEFORE;
- for (auto toMove : workingSet.nodes()) {
- toMove->move(curNode, reversed);
- curNode = toMove;
- }
- } else {
- // Just append/prepend everything to `movePoint`
- for (auto toMove : workingSet.nodes()) {
- toMove->move(curNode, moveSide);
- curNode = toMove;
- }
- }
- return true;
-}
-
-// Helper function so we can generalize `tryMove`
-void Node::move(Node* movePoint, MoveSide moveSide) {
- switch (moveSide) {
- case MoveSide::BEFORE:
- this->moveBefore(movePoint);
- break;
- case MoveSide::AFTER:
- this->moveAfter(movePoint);
- break;
- }
-}
-
void Node::moveAfter(Node* n) {
removeFromList();
insertAfter(n);
friend const_graph_node_list_iterator;
private:
- // each node but Return/Param
- // is associated with exactly one place in the node list...
- // of the graph_
- // this circular is a doubly-linked list. The Return node is used as the
- // sentinel for the beginning and end of the list such that the list never has
- // null pointers.
- // - next_in_graph[0] is next pointer
- // - next_in_graph[1] is prev pointer
- //
- // Using an array to allow the same iterator class for forward and
- // reverse node lists
- //
- // This list represents a topological sort
-
- Node* next_in_graph[2] = {nullptr, nullptr};
-
const NodeKind kind_;
std::vector<Value*> inputs_;
std::vector<Value*> outputs_;
protected:
TORCH_API Node(Graph* graph_, NodeKind kind_); // defined after graph
public:
+ // each node but Return/Param
+ // is associated with exactly one place in the node list...
+ // of the graph_
+ // this circular is a doubly-linked list, the Return node is used as the
+ // sentinel for the beginning and end of the list such that the list never has
+ // null pointers next_in_graph[0] is next pointer next_in_graph[1] is prev
+ // pointer using an array to allow the same iterator class for forward and
+ // reverse node lists This list represents a topological sort
+ Node* next_in_graph[2] = {nullptr, nullptr};
+
Node*& next() {
return next_in_graph[kNextDirection];
}
// Move 'this' (already in the graph) after 'n' in the topological order.
//
// NOTE: Does not check that value dependencies are preserved, see
- // moveAfterTopologicallyValid
+ // AliasDb::moveAfterTopologicallyValid
//
// Given: %2 = f(%1)
// %3 = g(%1)
//
TORCH_API void moveAfter(Node* n);
- // Move 'this' (already in the graph) after 'n' in the topological order.
- //
- // Tries to preserve value dependencies, so other nodes might be moved. We
- // make two gurantees about the postcondition of the node list:
- // - `this` is directly after `n`.
- // - only nodes between `this` and `n` have been moved
- //
- // Returns `false` if it's impossible to move `this` after `n` without
- // violating dependencies, otherwise executes the move and returns `true`
- TORCH_API bool moveAfterTopologicallyValid(Node* n, const AliasDb& aliasDb);
-
- // Like moveAfterTopologicallyValid, but only returns if the move is
- // possible, without actually performing it.
- TORCH_API bool couldMoveAfterTopologically(Node* n, const AliasDb& aliasdb);
-
// Move a node 'n' (already in the graph) before 'this' in the topological
// order.
//
// NOTE: Does not check that value dependencies are preserved, see
- // moveBeforeTopologicallyValid
+ // AliasDb::moveBeforeTopologicallyValid
//
// Given: %2 = f(%1)
// %3 = g(%1)
// %2 = f(%1)
TORCH_API void moveBefore(Node* n);
- // Move 'this' (already in the graph) before 'n' in the topological order.
- //
- // Tries to preserve value dependencies, so other nodes might be moved. We
- // make two gurantees about the postcondition of the node list:
- // - `this` is directly before `n`.
- // - only nodes between `this` and `n` have been moved
- //
- // Returns `false` if it's impossible to move `this` after `n` without
- // violating dependencies, otherwise executes the move and returns `true`
- TORCH_API bool moveBeforeTopologicallyValid(Node* n, const AliasDb& aliasDb);
-
- // Like moveBeforeTopologicallyValid, but only returns if the move is
- // possible, without actually performing it.
- TORCH_API bool couldMoveBeforeTopologically(Node* n, const AliasDb& aliasDb);
-
// Remove the input at 'i' from this node.
//
// WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling
private:
enum class MoveSide { BEFORE, AFTER };
- bool tryMove(
- Node* movePoint,
- MoveSide moveSide,
- const AliasDb& aliasDb,
- bool dryRun);
- void move(Node* movePoint, MoveSide moveSide);
bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
std::pair<Value*, const Argument&> findInput(Symbol name);
return writers.count(n) != 0;
}
+bool AliasDb::hasWriters(const Node* n) const {
+ if (hasWildcard(n)) {
+ // If `n` has a wildcard, any write in the graph may write to it.
+ // So the only way we know there are no writers is if there are no writes
+ // at all.
+ return !aliasToWrites_.empty();
+ }
+ return getWriters(n).size() != 0;
+}
+
+bool AliasDb::hasWritersBefore(const Node* n) const {
+ if (hasWildcard(n)) {
+ return true;
+ }
+ const auto writers = getWriters(n);
+ return std::any_of(writers.cbegin(), writers.cend(), [&](const Node* writer) {
+ return writer->isBefore(n);
+ });
+}
+
bool AliasDb::hasWrites(Node* n) const {
for (const auto input : n->inputs()) {
if (writesTo(n, input)) {
}
addAlias(value, getFreshAlias());
}
+
+bool AliasDb::moveAfterTopologicallyValid(Node* n, Node* movePoint) {
+ return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/false);
+}
+
+bool AliasDb::couldMoveAfterTopologically(Node* n, Node* movePoint) {
+ return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/true);
+}
+
+bool AliasDb::moveBeforeTopologicallyValid(Node* n, Node* movePoint) {
+ // We have to distinguish the move side (instead of just moving after
+ // n->prev()). Consider the following example:
+ // If the dependency graph looks like
+ // n -> movePoint -> o
+ // then moveBefore(o) will end up with
+ // n, o, movePoint
+ // but moveAfter(n) will return false.
+ return tryMove(n, movePoint, MoveSide::BEFORE, /*dryRun=*/false);
+}
+
+bool AliasDb::couldMoveBeforeTopologically(Node* n, Node* movePoint) {
+ return tryMove(n, movePoint, MoveSide::BEFORE, /*dryRun=*/true);
+}
+
+// Helper for topologically-safe node moves. See `tryMove()` for details.
+class AliasDb::WorkingSet {
+ public:
+ explicit WorkingSet(Node* mover, const AliasDb& aliasDb) : aliasDb_(aliasDb) {
+ add(mover);
+ }
+
+ // Add `n` to the working set
+ void add(Node* n) {
+ nodes_.push_back(n);
+ for (const auto user : getUsersSameBlock(n)) {
+ users_[user]++;
+ }
+
+ for (const auto& writer : getWritersSameBlock(n)) {
+ writers_[writer]++;
+ }
+ if (aliasDb_.hasWildcard(n)) {
+ numWildcards_++;
+ }
+ if (aliasDb_.hasWrites(n)) {
+ numWriterNodes_++;
+ }
+ }
+
+ void eraseMover() {
+ auto mover = nodes_.front();
+ for (const auto user : getUsersSameBlock(mover)) {
+ // If this user node only uses the mover, we can remove it
+ if (users_[user] == 1) {
+ users_.erase(user);
+ }
+ }
+
+ for (const auto& writer : getWritersSameBlock(mover)) {
+ if (writers_[writer] == 1) {
+ writers_.erase(writer);
+ }
+ }
+ if (aliasDb_.hasWildcard(mover)) {
+ numWildcards_--;
+ }
+ if (aliasDb_.hasWrites(mover)) {
+ numWriterNodes_--;
+ }
+ nodes_.pop_front();
+ }
+
+ const std::list<Node*>& nodes() {
+ return nodes_;
+ }
+
+ // Does the working set depend on `n`?
+ bool dependsOn(Node* n) const {
+ if (nodes_.empty()) {
+ return false;
+ }
+
+ return hasDataDependency(n) || hasMutabilityDependency(n);
+ }
+
+ private:
+ bool hasDataDependency(Node* n) const {
+ if (n->isAfter(nodes_.front())) {
+ return producesFor(n);
+ } else {
+ return consumesFrom(n);
+ }
+ }
+
+ bool hasMutabilityDependency(Node* n) const {
+ // 1. Handle wildcard dependencies:
+ // If the working set has a wildcard, `n` can't write to anything.
+ if (numWildcards_ > 0 && aliasDb_.hasWrites(n)) {
+ return true;
+ }
+
+ // If `n` has a wildcard, the working set can't write to anything.
+ if (aliasDb_.hasWildcard(n) && numWriterNodes_ > 0) {
+ return true;
+ }
+
+ // 2. Handle regular mutable dependencies
+ // Check that this node does not write to anything used by the working set
+ if (writers_.count(n) != 0) {
+ return true;
+ }
+
+ // Check that the working set does not write to anything used by this node
+ const auto writersToNode = getWritersSameBlock(n);
+ return std::any_of(nodes_.begin(), nodes_.end(), [&](Node* node) {
+ return writersToNode.count(node) != 0;
+ });
+ }
+
+ // Does the working set produce any values consumed by `n`?
+ bool producesFor(Node* n) const {
+ // This equivalent to asking: does the total use-set of all the nodes in the
+ // working set include `n`?
+ return users_.count(n) != 0;
+ }
+
+ // Does the working set consume any values produced by `n`?
+ bool consumesFrom(Node* n) const {
+ const auto users = getUsersSameBlock(n);
+ return std::any_of(nodes_.begin(), nodes_.end(), [&](Node* node) {
+ return users.count(node) != 0;
+ });
+ }
+
+ // Get all users of outputs of `n`, in the same block as `n`.
+ // This means if there is an `if` node that uses an output of `n` in some
+ // inner sub-block, we will consider the whole `if` node a user of `n`.
+ std::unordered_set<Node*> getUsersSameBlock(Node* n) const {
+ std::unordered_set<Node*> users;
+ for (const auto output : n->outputs()) {
+ for (const auto& use : output->uses()) {
+ if (auto sameBlock = findSameBlock(use.user, n)) {
+ users.insert(sameBlock);
+ }
+ }
+ }
+ return users;
+ }
+
+ std::unordered_set<Node*> getWritersSameBlock(Node* n) const {
+ std::unordered_set<Node*> writers;
+ for (const auto writer : aliasDb_.getWriters(n)) {
+ if (auto sameBlock = findSameBlock(writer, n)) {
+ writers.insert(sameBlock);
+ }
+ }
+ return writers;
+ }
+
+ // Traverse `target`'s blockchain upward until we find a node that shares a
+ // block with `n`.
+ //
+ // If one can't be found (say, because `n` is an inner block and target is
+ // outside), then return nullptr. Since we can only reorder nodes within a
+ // block, `target` would be irrelevant.
+ static Node* findSameBlock(Node* target, Node* n) {
+ if (target->owningBlock() == n->owningBlock()) {
+ return target;
+ } else {
+ // This user is in a sub-block. Traverse the blockchain upward until
+ // we arrive at a node that shares a block with `this`
+ auto curNode = target;
+ while (curNode->owningBlock() != n->owningBlock()) {
+ curNode = curNode->owningBlock()->owningNode();
+ if (curNode == nullptr) {
+ return curNode;
+ }
+ }
+ return curNode;
+ }
+ }
+
+ const AliasDb& aliasDb_;
+ std::list<Node*> nodes_;
+ // users => # of working set nodes it uses
+ std::unordered_map<Node*, size_t> users_;
+ std::unordered_map<Node*, size_t> writers_;
+ size_t numWildcards_ = 0;
+ size_t numWriterNodes_ = 0;
+};
+
+// Try to move `toMove` before/after `movePoint` while preserving value
+// dependencies. Returns false iff such a move could not be made.
+//
+// If `dryRun` is set, don't actually execute the move, just check if the move
+// is possible
+//
+// The basic approach is: have a "working set" that we are moving forward, one
+// node at a time. When we can't move past a node (because it depends on the
+// working set), then add it to the working set and keep moving until we hit
+// `moveAfter`.
+bool AliasDb::tryMove(
+ Node* toMove,
+ Node* movePoint,
+ MoveSide moveSide,
+ bool dryRun) {
+ JIT_ASSERT(toMove->owningBlock() == movePoint->owningBlock());
+ if (toMove == movePoint) {
+ return true;
+ }
+
+ // 1. Move from `this` toward movePoint, building up the working set of
+ // dependencies
+ WorkingSet workingSet(toMove, *this);
+
+ int direction;
+ if (toMove->isAfter(movePoint)) {
+ direction = kPrevDirection;
+ } else {
+ direction = kNextDirection;
+ }
+
+ auto curNode = toMove->next_in_graph[direction];
+ // Move forward one node at a time
+ while (curNode != movePoint) {
+ if (workingSet.dependsOn(curNode)) {
+ // If we can't move past this node, add it to the working set
+ workingSet.add(curNode);
+ }
+ curNode = curNode->next_in_graph[direction];
+ }
+
+ // 2. Decide whether we can move it all to `movePoint`.
+
+ // Say we are moving directly before movePoint and `toMove` starts before
+ // movePoint in the graph. The move looks like
+ //
+ // `toMove` `toMove` |
+ // <dependencies> -> `movePoint` | `toMove` and deps are split
+ // `movePoint` <dependencies> |
+ //
+ // Contrast with the case where `toMove` starts AFTER movePoint:
+ //
+ // `movePoint` <dependencies> |
+ // <dependencies> -> `toMove` | `toMove` and deps are together
+ // `toMove` `movePoint` |
+ //
+ // In the first case, we need to split `this` off from its dependencies, so we
+ // can move the dependencies below `movePoint` and keep `toMove` above.
+ const bool splitToMoveAndDeps =
+ (moveSide == MoveSide::BEFORE && toMove->isBefore(movePoint)) ||
+ (moveSide == MoveSide::AFTER && toMove->isAfter(movePoint));
+
+ if (splitToMoveAndDeps) {
+ // remove `this` from dependencies to be moved past `movePoint`
+ workingSet.eraseMover();
+ }
+
+ // Check if we can move the working set past the move point
+ if (workingSet.dependsOn(movePoint)) {
+ // if we can't, then there are intermediate dependencies between the
+ // `this` and `movePoint`, so we can't do the move
+ return false;
+ }
+
+ if (dryRun) {
+ return true;
+ }
+
+ // 3. Execute the move
+ JIT_ASSERT(curNode == movePoint);
+ if (splitToMoveAndDeps) {
+ // Move `toMove`
+ move(toMove, movePoint, moveSide);
+
+ // Then move all of its dependencies on the other side of `movePoint`
+ const auto reversed =
+ moveSide == MoveSide::BEFORE ? MoveSide::AFTER : MoveSide::BEFORE;
+ for (auto n : workingSet.nodes()) {
+ move(n, curNode, reversed);
+ curNode = n;
+ }
+ } else {
+ // Just append/prepend everything to `movePoint`
+ for (auto n : workingSet.nodes()) {
+ move(n, curNode, moveSide);
+ curNode = n;
+ }
+ }
+ return true;
+}
+
+// Helper function so we can generalize `tryMove`
+void AliasDb::move(Node* toMove, Node* movePoint, MoveSide moveSide) {
+ switch (moveSide) {
+ case MoveSide::BEFORE:
+ toMove->moveBefore(movePoint);
+ break;
+ case MoveSide::AFTER:
+ toMove->moveAfter(movePoint);
+ break;
+ }
+}
+
+bool AliasDb::hasUntrackedEffects(Node* node) const {
+ bool touchesWildcard = false;
+ if (!wildcardNodes_.empty()) {
+ auto lastWildcard = *wildcardNodes_.begin();
+ for (const auto wildcard : wildcardNodes_) {
+ if (wildcard->isAfter(lastWildcard)) {
+ lastWildcard = wildcard;
+ }
+ }
+ touchesWildcard = hasWrites(node) &&
+ (node->isBefore(lastWildcard) || node == lastWildcard);
+ }
+
+ return writesToInputAlias(node) || touchesWildcard;
+}
} // namespace jit
} // namespace torch
* Alias analysis pass.
*
* This pass produces an AliasDb that contains aliasing and mutation
- * information about the graph. Callers (right now moveAfterTopologicallyValid)
- * can use this information to determine whether mutations to the graph are
- * safe, in that they don't reorder/change nodes in a way that affects output.
+ * information about the graph. Users can use this information to determine
+ * whether mutations to the graph are safe, i.e. they don't reorder/change
+ * nodes in a way that affects output.
*
* Every value with a mutable type (Tensors, Lists, Tuples, etc.) will be
* associated with one or more "alias sets". If two values share an alias set,
public:
explicit AliasDb(std::shared_ptr<Graph> graph);
- // Does `n` use or write to any wildcard aliases?
- bool hasWildcard(const Node* n) const;
-
- const std::unordered_set<const Node*>& getWildcardNodes() const {
- return wildcardNodes_;
- }
-
// Does `n` write to any alias sets?
bool hasWrites(Node* n) const;
- // Does `n` write to a value that may alias one of the graph inputs?
- bool writesToInputAlias(Node* n) const;
+ // There are limitations to what effects the alias analysis can track. Two
+ // kinds of nodes may have untracked effects:
+ // 1. Nodes that write to a value that may alias the graph inputs (since
+ // the inputs can be used outside the graph).
+ // 2. Nodes that write to something in the wildcard set.
+ //
+ // These nodes are considered not safe to eliminate or mutate under any
+ // circumstances.
+ bool hasUntrackedEffects(Node* n) const;
// Get all nodes that write to any alias set inputed/outputed by `n`
std::unordered_set<Node*> getWriters(const Node* n) const;
// Get all values that may alias to `v`.
std::unordered_set<const Value*> getAliases(const Value* v) const;
- // Do any nodes write to an alias set inputed/outputed by `n`?
- bool hasWriters(const Node* n) const {
- return getWriters(n).size() != 0;
- }
+ // Do any nodes write to an alias set inputed/outputed by `n`?
+ bool hasWriters(const Node* n) const;
+
+ // Same as hasWriters() but ignores writes after `n`.
+ bool hasWritersBefore(const Node* n) const;
+
+ // Move 'n' (already in the graph) after 'movePoint' in the topological order.
+ //
+ // Tries to preserve value dependencies, so other nodes might be moved. We
+ // make two gurantees about the postcondition of the node list:
+ // - `n` is directly after `movePoint`.
+ // - only nodes between `n` and `movePoint` have been moved.
+ //
+ // Returns `false` if it's impossible to move `n` after `MovePoint` without
+ // violating dependencies, otherwise executes the move and returns `true`
+ bool moveAfterTopologicallyValid(Node* n, Node* movePoint);
+ bool moveBeforeTopologicallyValid(Node* n, Node* movePoint);
+
+ bool couldMoveAfterTopologically(Node* n, Node* movePoint);
+ bool couldMoveBeforeTopologically(Node* n, Node* movePoint);
// For debugging: print alias db state to stdout
void dump() const;
private:
+ // Helper for topologically-safe node moves.
+ class WorkingSet;
+ enum class MoveSide { BEFORE, AFTER };
+ bool tryMove(Node* toMove, Node* movePoint, MoveSide moveSide, bool dryRun);
+ void move(Node* toMove, Node* movePoint, MoveSide moveSide);
+ bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
+
+ // Does `n` use or write to any wildcard aliases?
+ bool hasWildcard(const Node* n) const;
+
+ // Does `n` write to a value that may alias one of the graph inputs?
+ bool writesToInputAlias(Node* n) const;
+
void analyze(const std::shared_ptr<Graph>& graph);
void analyze(Block* block);
void analyze(Node* node);
Symbol getFreshAlias(bool isGraphInput = false);
void addAlias(const Value* value, AliasInfo alias);
+
void addAlias(const Value* value, Symbol alias);
void addAlias(const Value* value, const Value* from);
void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from);
std::pair<std::vector<Node*>, std::vector<Node*>> gatherIndependentMMUses(
Value* value,
- const AliasDb& alias_db) {
+ AliasDb& alias_db) {
const auto postprocess = [&](std::vector<Node*> mms) {
if (mms.size() == 0) {
return mms;
for (size_t j = i + 1; j < mms.size(); ++j) {
if (mms[j] == nullptr)
continue;
- if (!mms[j]->couldMoveBeforeTopologically(mms[i], alias_db)) {
+ if (!alias_db.couldMoveBeforeTopologically(mms[j], mms[i])) {
mms[j] = nullptr;
}
}
return std::make_pair(postprocess(lhses), postprocess(rhses));
}
-void BatchMMSide(Block* block, const AliasDb& alias_db) {
+void BatchMMSide(Block* block, AliasDb& alias_db) {
// NB: 8 is the current loop unrolling factor
static constexpr size_t how_many_is_many = 8;
const auto batch_side = [&](std::vector<Node*>& mms, Side side) {
JIT_ASSERT(!mms.empty());
for (int64_t i = static_cast<int64_t>(mms.size()) - 2; i >= 0; --i) {
- bool move_ok = mms[i]->moveBeforeTopologicallyValid(mms[i + 1], alias_db);
+ bool move_ok = alias_db.moveBeforeTopologicallyValid(mms[i], mms[i + 1]);
JIT_ASSERT(move_ok);
}
WithInsertPoint insert_guard{mms[0]};
// TODO(suo): make BatchMM mutability-safe
return;
}
- const auto alias_db = AliasAnalysis(graph);
+ auto alias_db = AliasAnalysis(graph);
BatchMMTreeReduce(graph->block());
BatchMMSide(graph->block(), alias_db);
EliminateDeadCode(graph);
for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
auto node = *it;
if (node->hasSideEffects() || node->isNondeterministic() ||
- aliasDb.hasWriters(node) || aliasDb.hasWildcard(node)) {
+ aliasDb.hasWriters(node)) {
// Do NOT have enough information to do CSE on these nodes.
continue;
}
});
bool supported_node = !n->kind().is_onnx() &&
skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
- !n->hasSideEffects() && !aliasDb.hasWriters(n) && !aliasDb.hasWildcard(n);
+ !n->hasSideEffects() && !aliasDb.hasWriters(n);
auto run_blocks = [&]() {
if (recurse) {
for (Block* block : n->blocks()) {
std::pair<graph_node_list::iterator, bool> scanNode(
Node* consumer,
- const AliasDb& aliasDb) {
+ AliasDb& aliasDb) {
if (shouldConsiderForMerge(consumer)) {
if (consumer->kind() != prim::DifferentiableGraph) {
consumer = SubgraphUtils::createSingletonSubgraph(
c10::optional<Node*> tryMerge(
Node* consumer,
Node* producer,
- const AliasDb& aliasDb) {
+ AliasDb& aliasDb) {
JIT_ASSERT(consumer->kind() == prim::DifferentiableGraph);
bool canMerge = shouldConsiderForMerge(producer) &&
- producer->moveBeforeTopologicallyValid(consumer, aliasDb);
+ aliasDb.moveBeforeTopologicallyValid(producer, consumer);
if (!canMerge) {
return c10::nullopt;
// we mark "live" nodes that are necessary for the output. Nodes that have
// side effects are also marked.
void run(Block* block, bool recurse) {
- // Find the last wildcard in the block. We cannot eliminate any mutable ops
- // that precede the last wildcard (since they may have written to the
- // wildcard alias set)
- setLastWildcard();
-
// Initialize by marking the return node and all its consumed values as live
mark(block->return_node());
}
private:
- void setLastWildcard() {
- if (!aliasDb_) {
- return;
- }
-
- const auto& wildcards = aliasDb_->getWildcardNodes();
- if (wildcards.empty()) {
- return;
- }
-
- lastWildcard_ = *wildcards.begin();
- for (const auto wildcard : wildcards) {
- if (wildcard->isAfter(*lastWildcard_)) {
- lastWildcard_ = wildcard;
- }
- }
- }
-
// Special handling for block return nodes. Unlike other nodes, the block
// return node doesn't really "use" its inputs. Consider:
//
auto schema = node->maybeSchema();
return schema && schema->is_mutable();
} else {
- // Otherwise, there are two kinds of nodes with untracked effects:
- // 1. Nodes that write to a value that may alias the graph inputs (since
- // the inputs can be used outside the graph).
- // 2. Anything that could clobber a wildcard value.
- bool touchesWildcard = false;
- if (lastWildcard_) {
- touchesWildcard = aliasDb_->hasWrites(node) &&
- (node->isBefore(*lastWildcard_) || node == *lastWildcard_);
- }
- return aliasDb_->writesToInputAlias(node) || touchesWildcard;
+ return aliasDb_->hasUntrackedEffects(node);
}
}
std::unordered_set<Node*> marked_;
std::unordered_set<const Value*> liveValues_;
std::unordered_set<const Value*> liveAliases_;
- c10::optional<const Node*> lastWildcard_;
std::function<void(const std::unordered_set<const Value*>&)> deleteCallback_ =
[](const std::unordered_set<const Value*>&) {};
};
// consumer. Fusion will rewrite those later uses to use the version of
// producer generated by the fused blob. In this case, producer becomes
// an output of the fusion group.
- producer->node()->moveBeforeTopologicallyValid(
- consumer, aliasDb_.value());
+ aliasDb_->moveBeforeTopologicallyValid(producer->node(), consumer);
if (!shouldFuse) {
return at::nullopt;
}
// NB: it is important that this check happens after isFusable, which checks
// that the blocks match, and it's not a special node like prim::Param
- if (!producer->node()->couldMoveBeforeTopologically(
- before_check, aliasDb_.value())) {
+ if (!aliasDb_->couldMoveBeforeTopologically(
+ producer->node(), before_check)) {
return false;
}
// Fusion groups can be merged with concat's group if and only if
return dependsOnMutationMemo_[node];
}
- const auto writers = aliasDb_.getWriters(node);
- const auto hasWritersBefore =
- std::any_of(writers.cbegin(), writers.cend(), [&](const Node* writer) {
- return writer->isBefore(node);
- });
- if (hasWritersBefore || aliasDb_.hasWildcard(node)) {
+ if (aliasDb_.hasWritersBefore(node)) {
// If something could have written to a value used by this node, we can't
// guarantee the result is the same when running it in isolation.
dependsOnMutationMemo_[node] = true;