aliasToValue_[aliasSet].insert(value);
}
}
- // - Set of all nodes with a wildcard
- buildWildcardIndex(graph_->block());
-}
-
-void AliasDb::buildWildcardIndex(const Block* b) {
- for (const auto node : b->nodes()) {
- for (const auto block : node->blocks()) {
- buildWildcardIndex(block);
- }
-
- if (hasWildcardImpl(node)) {
- wildcardNodes_.insert(node);
- }
- }
}
bool AliasDb::hasWildcard(const Node* n) const {
}
const auto writers = getWriters(n);
return std::any_of(writers.cbegin(), writers.cend(), [&](const Node* writer) {
- return writer->isBefore(n);
+ return isBeforeSameGraph(writer, n);
});
}
}
}
}
+
+ // A write to the wildcard set should be considered a write to `n`
+ if (aliasToWrites_.count(AliasInfo::wildcardSet())) {
+ const auto& wildcardWriters = aliasToWrites_.at(AliasInfo::wildcardSet());
+ for (auto writer : wildcardWriters) {
+ writers.insert(writer);
+ }
+ }
+
return writers;
}
addAlias(actual, outputAlias);
}
+ // Keep the wildcard index up to date.
+ if (hasWildcardImpl(node)) {
+ wildcardNodes_.insert(node);
+ }
}
void AliasDb::analyzeIf(Node* node) {
}
void AliasDb::analyzeSubgraph(Node* node) {
- const auto subgraphBlock = node->g(attr::Subgraph)->block();
+ const auto subgraph = node->g(attr::Subgraph).get();
+
+ subgraphToOwner_.insert({subgraph, node});
+
+ const auto subgraphBlock = subgraph->block();
mapAliases(subgraphBlock->inputs(), node->inputs());
analyze(subgraphBlock);
// outside), then return nullptr. Since we can only reorder nodes within a
// block, `target` would be irrelevant.
static Node* findSameBlock(Node* target, Node* n) {
+ JIT_ASSERT(target->owningGraph() == n->owningGraph());
if (target->owningBlock() == n->owningBlock()) {
return target;
} else {
}
}
+c10::optional<const Node*> AliasDb::getLastWildcard() const {
+ auto it = std::max_element(
+ wildcardNodes_.cbegin(),
+ wildcardNodes_.cend(),
+ [this](const Node* a, const Node* b) { return isBeforeSameGraph(a, b); });
+ if (it != wildcardNodes_.end()) {
+ return *it;
+ } else {
+ return c10::nullopt;
+ }
+}
+
bool AliasDb::hasUntrackedEffects(Node* node) const {
bool touchesWildcard = false;
- if (!wildcardNodes_.empty()) {
- auto lastWildcard = *wildcardNodes_.begin();
- for (const auto wildcard : wildcardNodes_) {
- if (wildcard->isAfter(lastWildcard)) {
- lastWildcard = wildcard;
- }
- }
+ if (const auto lastWildcard = getLastWildcard()) {
touchesWildcard = hasWrites(node) &&
- (node->isBefore(lastWildcard) || node == lastWildcard);
+ (isBeforeSameGraph(node, *lastWildcard) || node == *lastWildcard);
}
return writesToInputAlias(node) || touchesWildcard;
}
+
+// Nodes must be in the same graph in order to do `isBefore` or `isAfter`. This
+// traverses the subgraph "chain" upward until we find two nodes that share an
+// owning graph.
+//
+// NOTE: this is n^2 in subgraph depth. Right now the maximum depth is like 2,
+// but if we ever do huge nested subgraphs we'll need to reconsider this.
+bool AliasDb::isBeforeSameGraph(const Node* a, const Node* b) const {
+ auto lhs = a;
+ while (true) {
+ auto rhs = b;
+ while (true) {
+ if (lhs->owningGraph() == rhs->owningGraph()) {
+ return lhs->isBefore(rhs);
+ }
+ if (!subgraphToOwner_.count(rhs->owningGraph())) {
+ break;
+ }
+ rhs = subgraphToOwner_.at(rhs->owningGraph());
+ }
+ if (!subgraphToOwner_.count(lhs->owningGraph())) {
+ break;
+ }
+ lhs = subgraphToOwner_.at(lhs->owningGraph());
+ }
+ JIT_ASSERT(false);
+}
} // namespace jit
} // namespace torch
// Does `n` use or write to any wildcard aliases?
bool hasWildcard(const Node* n) const;
+ // Returns nullopt if there are no wildcard nodes
+ c10::optional<const Node*> getLastWildcard() const;
// Does `n` write to a value that may alias one of the graph inputs?
bool writesToInputAlias(Node* n) const;
bool hasWildcardImpl(const Node* n) const;
bool writesTo(Node* n, const Value* v) const;
+ bool isBeforeSameGraph(const Node* lhs, const Node* rhs) const;
+
std::shared_ptr<Graph> graph_;
Symbol latestSymbol_ = Symbol::fromQualString("alias::0");
std::unordered_map<const Value*, AliasInfo> valueToAlias_;
std::unordered_map<Symbol, std::unordered_set<Node*>> aliasToWrites_;
std::unordered_set<const Node*> wildcardNodes_;
std::unordered_set<Symbol> graphInputAliases_;
+ std::unordered_map<const Graph*, const Node*> subgraphToOwner_;
};
inline TORCH_API AliasDb AliasAnalysis(std::shared_ptr<Graph> graph) {