From: Michael Suo Date: Thu, 28 Feb 2019 19:28:16 +0000 (-0800) Subject: alias_analysis refactor (#17511) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1047 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=54c5b109345120bd78fb19e52f1527246fe8a790;p=platform%2Fupstream%2Fpytorch.git alias_analysis refactor (#17511) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17511 AliasTracker was doing bookkeeping for three concepts: the points-to graph, writes, and wildcards. This PR makes AliasTracker's job clearer: it keeps track of the points-to graph. Thus it has been renamed MemoryDAG. Write and wildcard information were pulled back into AliasDb as part of this—I may decide to pull them into their own little modules since I don't want the alias analysis stuff to get too bloated. This refactor is necessary because we want to start tracking information for aliasing elements that _aren't_ first-class IR Values (e.g. the "stuff" inside a list). So MemoryDAG can't know too much about Values Reviewed By: houseroad Differential Revision: D14231251 fbshipit-source-id: 6cd98ae6fced8d6c1522c2454da77c3c1b2b0504 --- diff --git a/test/cpp/jit/gtest.cpp b/test/cpp/jit/gtest.cpp index 65bc780..8f6c2a8 100644 --- a/test/cpp/jit/gtest.cpp +++ b/test/cpp/jit/gtest.cpp @@ -34,7 +34,9 @@ JIT_TEST(TopologicalIndex) JIT_TEST(TopologicalMove) JIT_TEST(SubgraphUtils) JIT_TEST(AliasAnalysis) -JIT_TEST(AliasTracker) +JIT_TEST(WriteTracking) +JIT_TEST(Wildcards) +JIT_TEST(MemoryDAG) JIT_TEST(IRParser) JIT_TEST(NetDefConverter) diff --git a/test/cpp/jit/no-gtest.cpp b/test/cpp/jit/no-gtest.cpp index 8072ef8..00b6892 100644 --- a/test/cpp/jit/no-gtest.cpp +++ b/test/cpp/jit/no-gtest.cpp @@ -39,7 +39,9 @@ std::string runJITCPPTests() { testATenNativeBatchNorm(); testRegisterFusionCachesKernel(); testAliasAnalysis(); - testAliasTracker(); + testWriteTracking(); + testWildcards(); + testMemoryDAG(); testNetDefConverter(out); testIRParser(out); return out.str(); diff --git a/test/cpp/jit/test_alias_analysis.h b/test/cpp/jit/test_alias_analysis.h index 840d3ac..c1d2b57 100644 --- a/test/cpp/jit/test_alias_analysis.h +++ b/test/cpp/jit/test_alias_analysis.h @@ -1,6 +1,7 @@ #pragma once #include "test/cpp/jit/test_base.h" +#include "torch/csrc/jit/custom_operator.h" #include "torch/csrc/jit/passes/alias_analysis.h" #include "torch/csrc/jit/script/compiler.h" #include "torch/csrc/utils/memory.h" @@ -452,41 +453,110 @@ void testAliasAnalysis() { } } -void testAliasTracker() { - auto graph = std::make_shared(); - const Value* a = graph->addInput(); - const Value* b = graph->addInput(); - const Value* c = graph->addInput(); - const Value* d = graph->addInput(); - const Value* e = graph->addInput(); - const Value* f = graph->addInput(); - const Value* g = graph->addInput(); - const Value* wc = graph->addInput(); - +void testWriteTracking() { + RegisterOperators reg({createOperator( + "foo::creates_alias(Tensor(a) x) -> Tensor(a)", + [](at::Tensor a) { return a; })}); + const auto creates_alias = Symbol::fromQualString("foo::creates_alias"); + const auto returns_wildcard = Symbol::fromQualString("foo::returns_wildcard"); { - // test contains() - AliasTracker t; - t.makeFreshValue(a); - ASSERT_TRUE(t.contains(a)); - ASSERT_FALSE(t.contains(b)); + auto graph = std::make_shared(); + auto a = graph->addInput(); + auto b = graph->addInput(); + + // aten::add(%b, %b) + // aten::add_(%a, %b) + // foo::creates_alias(%a) + auto pureNode = graph->insert(aten::add, {b, b})->node(); + auto writingNode = graph->insert(aten::add_, {a, b})->node(); + auto node3 = graph->insert(creates_alias, {a})->node(); + auto aAlias = node3->output(); + + graph->lint(); + + AliasDb aliasDb(graph); + ASSERT_TRUE(aliasDb.mayAlias(aAlias, a)); + ASSERT_TRUE(aliasDb.mayAlias(a, b)); + ASSERT_FALSE( + aliasDb.writesToAlias(pureNode, std::unordered_set{a})); + ASSERT_FALSE( + aliasDb.writesToAlias(pureNode, std::unordered_set{b})); + ASSERT_TRUE(aliasDb.writesToAlias( + writingNode, std::unordered_set{a})); + ASSERT_TRUE(aliasDb.writesToAlias( + writingNode, std::unordered_set{a, b})); + ASSERT_TRUE(aliasDb.writesToAlias( + writingNode, std::unordered_set{aAlias})); } +} + +void testWildcards() { + RegisterOperators reg({createOperator( + "foo::returns_wildcard(Tensor a) -> Tensor(*)", + [](at::Tensor a) { return a; }), + createOperator( + "foo::writes(Tensor(z!) a) -> Tensor(a)", + [](at::Tensor a) { return a; })}); + const auto returns_wildcard = Symbol::fromQualString("foo::returns_wildcard"); + const auto writes = Symbol::fromQualString("foo::writes"); + + auto graph = std::make_shared(); + const auto a = graph->addInput(); + + const auto constant = graph->insertConstant(1); + const auto fresh = graph->insert(aten::rand, {constant}); + const auto fresh2 = graph->insert(aten::rand, {constant}); + const auto wildcard = graph->insert(returns_wildcard, {fresh}); + const auto wildcardWrite = graph->insert(writes, {wildcard})->node(); + + graph->lint(); + AliasDb aliasDb(graph); + + ASSERT_FALSE(aliasDb.mayAlias(a, fresh)); + ASSERT_TRUE(aliasDb.mayAlias(wildcard, fresh)); + ASSERT_TRUE(aliasDb.mayAlias(wildcard, a)); + ASSERT_FALSE(aliasDb.mayAlias( + std::unordered_set({wildcard}), + std::unordered_set())); + + // Test writes to wildcards + ASSERT_TRUE(aliasDb.writesToAlias( + wildcardWrite, std::unordered_set{fresh})); + ASSERT_TRUE(aliasDb.writesToAlias( + wildcardWrite, std::unordered_set{fresh2})); + ASSERT_TRUE(aliasDb.writesToAlias( + wildcardWrite, std::unordered_set{a})); +} + +void testMemoryDAG() { + auto graph = std::make_shared(); + const Value* aValue = graph->addInput(); + const Value* bValue = graph->addInput(); + const Value* cValue = graph->addInput(); + const Value* dValue = graph->addInput(); + const Value* eValue = graph->addInput(); + const Value* fValue = graph->addInput(); + const Value* gValue = graph->addInput(); + { // a <- b <- c // b <- d // a <- e // f <- e // g is by itself - // wc is a wildcard value - AliasTracker t; - t.makeFreshValue(a); - t.makeFreshValue(f); - t.makeFreshValue(g); + MemoryDAG t; + auto a = t.makeFreshValue(aValue); + auto b = t.makeFreshValue(bValue); + auto c = t.makeFreshValue(cValue); + auto d = t.makeFreshValue(dValue); + auto e = t.makeFreshValue(eValue); + auto f = t.makeFreshValue(fValue); + auto g = t.makeFreshValue(gValue); t.makePointerTo(b, a); t.makePointerTo(c, b); t.makePointerTo(d, b); t.makePointerTo(e, a); t.makePointerTo(e, f); - t.setWildcard(wc); /** * Test mayAlias() @@ -506,59 +576,15 @@ void testAliasTracker() { // But a and f don't alias ASSERT_FALSE(t.mayAlias(a, f)); - // Wildcards should alias everything - ASSERT_TRUE(t.mayAlias(wc, a)); - ASSERT_TRUE(t.mayAlias(wc, b)); - ASSERT_TRUE(t.mayAlias(wc, f)); - ASSERT_TRUE(t.mayAlias(wc, g)); - /** * Test mayAlias() set interface */ - std::multiset foo{c, c, d}; - std::multiset bar{e, f}; - std::unordered_set baz{f, g}; - std::set containsWildcard{wc}; + std::multiset foo{c, c, d}; + std::multiset bar{e, f}; + std::unordered_set baz{f, g}; ASSERT_TRUE(t.mayAlias(foo, bar)); ASSERT_TRUE(t.mayAlias(bar, baz)); ASSERT_FALSE(t.mayAlias(foo, baz)); - // wildcard stuff aliases everything - ASSERT_TRUE(t.mayAlias(containsWildcard, foo)); - ASSERT_TRUE(t.mayAlias(containsWildcard, bar)); - ASSERT_TRUE(t.mayAlias(containsWildcard, baz)); - - /** - * Test writer tracking - */ - auto n1 = graph->appendNode(graph->create(prim::Undefined)); - auto n2 = graph->appendNode(graph->create(prim::Undefined)); - auto n3 = graph->appendNode(graph->create(prim::Undefined)); - t.registerWrite(a, n1); - t.registerWrite(f, n2); - // We should report those writes accurately - ASSERT_TRUE(t.writesTo(n1, a)); - ASSERT_TRUE(t.writesTo(n2, f)); - ASSERT_FALSE(t.writesTo(n1, f)); - ASSERT_FALSE(t.writesTo(n2, a)); - // We should correctly report writes to aliases as well - ASSERT_TRUE(t.writesTo(n1, c)); - - // Check hasWriters() - ASSERT_TRUE(t.hasWriters(a)); - // Aliases of written-to values should have writers - ASSERT_TRUE(t.hasWriters(b)); - ASSERT_TRUE(t.hasWriters(d)); - ASSERT_TRUE(t.hasWriters(e)); - // Unique values not registered should be unaffected - ASSERT_FALSE(t.hasWriters(g)); - - // create a write to the wildcard set - t.registerWrite(wc, n3); - // Now everything may be written to - ASSERT_TRUE(t.hasWriters(g)); - const auto& wildcardWriters = t.getWildcardWriters(); - ASSERT_EQ(wildcardWriters.size(), 1); - ASSERT_EQ(*wildcardWriters.begin(), n3); } } } // namespace jit diff --git a/tools/build_variables.py b/tools/build_variables.py index b73fac2..46e59cc 100644 --- a/tools/build_variables.py +++ b/tools/build_variables.py @@ -87,7 +87,7 @@ libtorch_sources = [ "torch/csrc/jit/passes/shape_analysis.cpp", "torch/csrc/jit/passes/specialize_undef.cpp", "torch/csrc/jit/passes/utils/subgraph_utils.cpp", - "torch/csrc/jit/passes/utils/alias_tracker.cpp", + "torch/csrc/jit/passes/utils/memory_dag.cpp", "torch/csrc/jit/register_prim_ops.cpp", "torch/csrc/jit/register_special_ops.cpp", "torch/csrc/jit/scope.cpp", diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 62be18c..8e80e60 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -165,7 +165,7 @@ set(TORCH_SRCS ${TORCH_SRC_DIR}/csrc/jit/passes/python_print.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp - ${TORCH_SRC_DIR}/csrc/jit/passes/utils/alias_tracker.cpp + ${TORCH_SRC_DIR}/csrc/jit/passes/utils/memory_dag.cpp ${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp ${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp ${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp index 850862d..11c4539 100644 --- a/torch/csrc/jit/passes/alias_analysis.cpp +++ b/torch/csrc/jit/passes/alias_analysis.cpp @@ -5,8 +5,8 @@ namespace torch { namespace jit { -namespace { -bool shouldAnnotate(const TypePtr& type) { + +bool AliasDb::shouldAnnotate(const TypePtr& type) { return type->isSubtypeOf(TensorType::get()) || type->kind() == TypeKind::ListType || type->kind() == TypeKind::TupleType || @@ -19,56 +19,103 @@ bool shouldAnnotate(const TypePtr& type) { // We only need to annotate values that either are mutable or could contain // mutable types. -bool shouldAnnotate(const Value* v) { +bool AliasDb::shouldAnnotate(const Value* v) { return shouldAnnotate(v->type()); } -} // namespace AliasDb::~AliasDb() = default; AliasDb::AliasDb(std::shared_ptr graph) : graph_(std::move(graph)) { - aliasTracker_ = torch::make_unique(); + memoryDAG_ = torch::make_unique(); analyze(graph_); } // Does `n` use or write to any wildcard aliases? bool AliasDb::hasWildcard(const Node* n) const { for (const auto input : n->inputs()) { - if (aliasTracker_->isWildcard(input)) { + if (isWildcard(input)) { return true; } } - for (const auto output : n->outputs()) { - if (aliasTracker_->isWildcard(output)) { + if (isWildcard(output)) { return true; } } return false; } +bool AliasDb::isWildcard(const Value* v) const { + return wildcards_.count(v); +} + bool AliasDb::writesTo(Node* n, const Value* v) const { if (!shouldAnnotate(v)) { // This is a primitive type return false; } - return aliasTracker_->writesTo(n, v); + if (isWildcard(v)) { + return wildcardWriters_.count(n); + } + + if (!elementMap_.count(v) || !writeIndex_.count(n)) { + return false; + } + + // Can short-circuit if we know this node writes directly to `v` + if (writeIndex_.at(n).count(v)) { + return true; + } + + // Otherwise, check if `v` may alias any of written-to values in `n` + const auto vSet = ValueSet{v}; + return mayAlias(vSet, writeIndex_.at(n)); } bool AliasDb::hasWriters(const Node* n) const { for (const auto input : n->inputs()) { - if (aliasTracker_->hasWriters(input)) { + if (hasWriters(input)) { return true; } } for (const auto output : n->outputs()) { - if (aliasTracker_->hasWriters(output)) { + if (hasWriters(output)) { return true; } } return false; } +bool AliasDb::hasWriters(const Value* v) const { + if (isWildcard(v)) { + // If `n` has a wildcard, any write in the graph may write to it. + // So the only way we know there are no writers is if there are no writes + // at all. + return numWrites_ == 0; + } + + if (!elementMap_.count(v)) { + return false; + } + + if (wildcardWriters_.size() > 0) { + // A write to the wildcard may be a write to any value. + return true; + } + + if (isWriteCacheStale_) { + rebuildWriteCache(); + } + + for (const auto loc : elementMap_.at(v)->getMemoryLocations()) { + if (writeCache_.count(loc)) { + return true; + } + } + + return false; +} + bool AliasDb::hasWrites(Node* n) const { for (const auto input : n->inputs()) { if (writesTo(n, input)) { @@ -102,16 +149,11 @@ bool AliasDb::writesToInputAlias(Node* n) const { graph_->inputs().cbegin(), graph_->inputs().cend(), [&](const Value* graphInput) { - return shouldAnnotate(graphInput) && - aliasTracker_->mayAlias(graphInput, v); + return shouldAnnotate(graphInput) && mayAlias(graphInput, v); }); }); } -bool AliasDb::mayAlias(const ValueSet& a, const ValueSet& b) const { - return aliasTracker_->mayAlias(a, b); -} - void AliasDb::getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks) const { for (auto node : b->nodes()) { getWritesImpl(node, ret, recurseBlocks); @@ -144,7 +186,6 @@ ValueSet AliasDb::getWrites(Block* b) const { return writes; } - // Does `n` write to an alias of one of the values in `vs`? bool AliasDb::writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks) const { @@ -186,20 +227,47 @@ void AliasDb::dump() const { std::cout << "\n===1. GRAPH===\n"; graph_->dump(); - aliasTracker_->dump(); + std::cout << "\n===2. ALIAS DB===\n"; + for (const auto& ptrPair : elementMap_) { + const auto element = ptrPair.second; + if (element->pointsTo.size() > 0) { + std::cout << element->value->uniqueName() << " points to: "; + for (const auto pointedTo : element->pointsTo) { + std::cout << pointedTo->value->uniqueName() << ", "; + } + std::cout << "\n"; + } + } + + std::cout << "\n===3. WILDCARDS===\n"; + for (const auto wildcard : wildcards_) { + std::cout << wildcard->uniqueName() << ", "; + } + std::cout << "\n"; + + std::cout << "\n===4. Writes===\n"; + for (const auto& pr : writeIndex_) { + const auto node = pr.first; + const auto& values = pr.second; + std::cout << *node; + std::cout << " "; + for (const auto value : values) { + std::cout << value->uniqueName() << ", "; + } + std::cout << "\n"; + } + std::cout << "\n"; } -// TODO: need to create a dummy "graph input alias" value in setTracker for all +// TODO: need to create a dummy "graph input alias" value in MemoryDAG for all // inputs of the same type to point to. Currently they all point to the first // element, which is technically wrong. -static void makeAllAlias( - const std::vector values, - AliasTracker& setTracker) { +void AliasDb::makeAllAlias(const std::vector& values) { if (values.size() > 0) { - setTracker.makeFreshValue(values[0]); + giveFreshAlias(values[0]); } for (const auto value : values) { - setTracker.makePointerTo(value, values[0]); + makePointerTo(value, values[0]); } } @@ -247,18 +315,18 @@ void AliasDb::analyze(const std::shared_ptr& graph) { // 2. Make all partitions alias each other for (const auto& pr : listTypes) { - makeAllAlias(pr.second, *aliasTracker_); + makeAllAlias(pr.second); } for (const auto& pr : tupleTypes) { - makeAllAlias(pr.second, *aliasTracker_); + makeAllAlias(pr.second); } for (const auto& pr : dictTypes) { - makeAllAlias(pr.second, *aliasTracker_); + makeAllAlias(pr.second); } for (const auto& pr : userTypes) { - makeAllAlias(pr.second, *aliasTracker_); + makeAllAlias(pr.second); } - makeAllAlias(tensors, *aliasTracker_); + makeAllAlias(tensors); analyze(graph->block()); } @@ -392,7 +460,7 @@ void AliasDb::analyzeImpl(Node* node) { // Record writes if (formal->isWrite()) { - aliasTracker_->registerWrite(actualValue, node); + registerWrite(actualValue, node); } } @@ -415,7 +483,7 @@ void AliasDb::analyzeImpl(Node* node) { AT_ASSERT(formal->containedTypes().size() == 0); if (formal->isWildcard()) { - aliasTracker_->setWildcard(actual); + setWildcard(actual); continue; } @@ -437,15 +505,27 @@ void AliasDb::analyzeImpl(Node* node) { } auto toAlias = formalToActual.at(formalAlias); - makeAliasOf(actual, toAlias); + makePointerTo(actual, toAlias); } // Record writes if (formal->isWrite()) { - aliasTracker_->registerWrite(actual, node); + registerWrite(actual, node); } } } +// Register the fact that `n` writes to `v`. +void AliasDb::registerWrite(const Value* v, Node* n) { + numWrites_++; + + if (isWildcard(v)) { + wildcardWriters_.insert(n); + return; + } + + AT_ASSERT(elementMap_.count(v)); + writeIndex_[n].insert(v); +} void AliasDb::analyzeIf(Node* node) { // For if statements, the alias set of an output is the union of the @@ -461,8 +541,8 @@ void AliasDb::analyzeIf(Node* node) { const auto trueOutput = trueBlock->outputs().at(i); const auto falseOutput = falseBlock->outputs().at(i); - makeAliasOf(nodeOutput, trueOutput); - makeAliasOf(nodeOutput, falseOutput); + makePointerTo(nodeOutput, trueOutput); + makePointerTo(nodeOutput, falseOutput); } } @@ -501,7 +581,7 @@ void AliasDb::analyzeSubgraph(Node* node) { // subgraph block. AT_ASSERT(subgraphBlock->outputs().size() >= node->outputs().size()); for (size_t i = 0; i < node->outputs().size(); i++) { - makeAliasOf(node->outputs()[i], subgraphBlock->outputs()[i]); + makePointerTo(node->outputs()[i], subgraphBlock->outputs()[i]); } } @@ -517,7 +597,7 @@ void AliasDb::analyzeCreator(Node* node) { void AliasDb::analyzeExtractor(Node* node) { for (const auto output : node->outputs()) { if (shouldAnnotate(output)) { - aliasTracker_->setWildcard(output); + setWildcard(output); } } } @@ -525,7 +605,7 @@ void AliasDb::analyzeExtractor(Node* node) { // For torch.chunk(), all returned tensors may alias the input tensor void AliasDb::analyzeChunk(Node* node) { for (auto output : node->outputs()) { - makeAliasOf(output, node->input()); + makePointerTo(output, node->input()); } } @@ -550,14 +630,14 @@ void AliasDb::analyzeWait(Node* node) { const auto fut = node->input(); AT_ASSERT(fut->type()->kind() == TypeKind::FutureType); - if (aliasTracker_->isWildcard(fut)) { + if (isWildcard(fut)) { for (const auto output : node->outputs()) { - aliasTracker_->setWildcard(output); + setWildcard(output); } return; } - const auto originFuts = aliasTracker_->getMemoryLocations(fut); + const auto originFuts = getMemoryLocations(fut); for (const auto originFut : originFuts) { const auto subgraphNode = originFut->node(); @@ -590,7 +670,7 @@ void AliasDb::analyzeWait(Node* node) { // since the writes may or may not have been executed yet. But we'll let // users do that and shoot themselves in the foot for now. for (const auto write : subgraphWrites) { - aliasTracker_->registerWrite(write, node); + registerWrite(write, node); } } } @@ -599,7 +679,7 @@ void AliasDb::analyzeWait(Node* node) { void AliasDb::analyzeSetAttr(Node* node) { const auto self = node->inputs().at(0); AT_ASSERT(self->type()->kind() == TypeKind::UserType); - aliasTracker_->registerWrite(self, node); + registerWrite(self, node); } // BroadcastingChunk: all inputs are broadcasted, and then individually chunked. @@ -613,25 +693,57 @@ void AliasDb::analyzeBroadcastingChunk(Node* node) { // inputs[i] produces chunks outputs[i * nchunks + k] for k in [0..nchunks) auto output_begin = outputs.begin() + index * nchunks; for (auto it = output_begin; it != output_begin + nchunks; ++it) { - makeAliasOf(*it, inputs.at(index)); + makePointerTo(*it, inputs.at(index)); } } } // Register the fact that `value` is a pointer to `to` -void AliasDb::makeAliasOf(const Value* value, const Value* to) { - if (!shouldAnnotate(value)) { +void AliasDb::makePointerTo(const Value* from, const Value* to) { + if (!shouldAnnotate(from)) { AT_ASSERT(!shouldAnnotate(to)); return; } - aliasTracker_->makePointerTo(value, to); + + if (from == to) { + return; + } + + // If either value is a wildcard, don't insert anything into the graph; + // wildcards are tracked separately since they have different aliasing rules. + if (isWildcard(to) || isWildcard(from)) { + setWildcard(from); + return; + } + + if (!isTracked(from)) { + giveFreshAlias(from); + } + if (!isTracked(to)) { + giveFreshAlias(to); + } + auto fromEl = elementMap_.at(from); + auto toEl = elementMap_.at(to); + memoryDAG_->makePointerTo(fromEl, toEl); +} + +bool AliasDb::mayAlias(const Value* a, const Value* b) const { + if (isWildcard(a) || isWildcard(b)) { + return true; + } + + if (!elementMap_.count(a) || !elementMap_.count(b)) { + return false; + } + + return memoryDAG_->mayAlias(elementMap_.at(a), elementMap_.at(b)); } // Make each value in the `from` list point to its partner in the `to` list void AliasDb::mapAliases(at::ArrayRef from, at::ArrayRef to) { AT_ASSERT(to.size() == from.size()); for (size_t i = 0; i < to.size(); i++) { - makeAliasOf(from[i], to[i]); + makePointerTo(from[i], to[i]); } } @@ -640,13 +752,17 @@ void AliasDb::giveFreshAlias(const Value* value) { return; } - if (aliasTracker_->contains(value)) { + if (isTracked(value)) { // Inside a loop, we may have given a fresh alias to this value already, so // skip return; } - aliasTracker_->makeFreshValue(value); + elementMap_[value] = memoryDAG_->makeFreshValue(value); +} + +bool AliasDb::isTracked(const Value* v) const { + return isWildcard(v) || elementMap_.count(v); } bool AliasDb::moveAfterTopologicallyValid(Node* n, Node* movePoint) { @@ -762,13 +878,13 @@ class AliasDb::WorkingSet { // 2. Handle regular mutable dependencies // Check that `n` does not write to anything used by the working set const auto nWrites = aliasDb_.getWrites(n, /*recurseBlocks=*/true); - if (aliasDb_.aliasTracker_->mayAlias(nWrites, reads_)) { + if (aliasDb_.mayAlias(nWrites, reads_)) { return true; } // Check that the working set doesn't write to anything that `n` uses. const auto nReads = aliasDb_.getReads(n, /*recurseBlocks=*/true); - if (aliasDb_.aliasTracker_->mayAlias(writes_, nReads)) { + if (aliasDb_.mayAlias(writes_, nReads)) { return true; } return false; @@ -1054,5 +1170,34 @@ TORCH_API bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { return handled.count(symbol) || purposefully_not_handled.count(symbol); } +// Register `v` as a wildcard value. +void AliasDb::setWildcard(const Value* v) { + wildcards_.insert(v); +} + +void AliasDb::rebuildWriteCache() const { + for (const auto& pr : writeIndex_) { + const auto& writtenValues = pr.second; + + for (const auto value : writtenValues) { + for (const auto loc : elementMap_.at(value)->getMemoryLocations()) { + writeCache_.insert(loc); + } + } + } + isWriteCacheStale_ = false; +} + +ValueSet AliasDb::getMemoryLocations(const Value* v) const { + ValueSet ret; + if (!elementMap_.count(v)) { + return ret; + } + + for (const auto el : elementMap_.at(v)->getMemoryLocations()) { + ret.insert(el->value); + } + return ret; +} } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/alias_analysis.h b/torch/csrc/jit/passes/alias_analysis.h index 508de79..bed1d63 100644 --- a/torch/csrc/jit/passes/alias_analysis.h +++ b/torch/csrc/jit/passes/alias_analysis.h @@ -2,7 +2,7 @@ #include #include -#include +#include namespace torch { namespace jit { @@ -25,7 +25,6 @@ namespace jit { * we're not sure what this value may alias. To be conservative, we consider * the wildcard alias set as potentially aliasing any value. */ - class AliasDb { public: TORCH_API explicit AliasDb(std::shared_ptr graph); @@ -46,9 +45,53 @@ class AliasDb { bool writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks = false) const; + // Do `a` and `b` potentially share a memory location? + bool mayAlias(const Value* a, const Value* b) const; // Do any values in group `a` potentially share a memory location with any - // value in group `b`? - bool mayAlias(const ValueSet& a, const ValueSet& b) const; + // value in group `b`? i.e. may they overlap? + // + // NOTE: Bit of ugly templating, but this is just to make sure we can + // transform an arbitrary container of `Values` to the same container of + // `Elements`. + template < + typename... Other1, + template class T, + typename... Other2, + template class U> + bool mayAlias( + const T& a, + const U& b) const { + if (a.empty() || b.empty()) { + return false; + } + // Short-circuit for special case: if any value is a wildcard, the two sets + // may alias + if (std::any_of( + a.cbegin(), + a.cend(), + [this](const Value* v) { return isWildcard(v); }) || + std::any_of(b.cbegin(), b.cend(), [this](const Value* v) { + return isWildcard(v); + })) { + return true; + } + + T aElements; + for (const Value* v : a) { + if (elementMap_.count(v)) { + aElements.insert(elementMap_.at(v)); + } + } + + U bElements; + for (const Value* v : b) { + if (elementMap_.count(v)) { + bElements.insert(elementMap_.at(v)); + } + } + + return memoryDAG_->mayAlias(aElements, bElements); + } // Do any nodes write to an alias set inputed/outputed by `n`? bool hasWriters(const Node* n) const; @@ -79,7 +122,11 @@ class AliasDb { void move(Node* toMove, Node* movePoint, MoveSide moveSide); bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const; - + /** + * Write and read internal API + */ + // Does `n` write to any alias sets? + bool hasWrites(Node* n) const; // Get all the values that `n` writes to. // NOTE: this only returns values directly written to, not aliases thereof // @@ -88,28 +135,43 @@ class AliasDb { ValueSet getWrites(Block* b) const; void getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks = false) const; void getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const; - + // Do any nodes write to `v`s memory location? + bool hasWriters(const Value* v) const; + // Register the fact that `n` writes to `v`. + void registerWrite(const Value* v, Node* n); // Get all the values that `n` reads from. // if `recurseBlocks` is true, gather reads on the nodes in `n`s sub-blocks ValueSet getReads(Node* n, bool recurseBlocks = false) const; - void getReadsImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const; - // Does `n` write to any alias sets? - bool hasWrites(Node* n) const; + // Does `n` write to a value that may alias one of the graph inputs? + bool writesToInputAlias(Node* n) const; + // Does `n` write to `v` or any aliases of `v`? + bool writesTo(Node* n, const Value* v) const; + + /** + * Wildcard methods + */ + // is `v` a wildcard? + bool isWildcard(const Value* v) const; + // Register `v` as a wildcard value. + void setWildcard(const Value* v); + // Get all nodes that write to a wildcard value. + const std::unordered_set& getWildcardWriters() const { + return wildcardWriters_; + } // Does `n` use or write to any wildcard aliases? bool hasWildcard(const Node* n) const; // Returns nullopt if there are no wildcard nodes c10::optional getLastWildcard() const; - // Does `n` write to a value that may alias one of the graph inputs? - bool writesToInputAlias(Node* n) const; - + /** + * Special analysis methods + */ void analyze(const std::shared_ptr& graph); void analyze(Block* block); void analyze(Node* node); void analyzeImpl(Node* node); - void analyzeIf(Node* node); void analyzeLoop(Node* node); void analyzeSubgraph(Node* node); @@ -121,20 +183,51 @@ class AliasDb { void analyzeWait(Node* node); void analyzeSetAttr(Node* node); - void makeAliasOf(const Value* value, const Value* to); + /** + * Alias manipulation methods + */ + void makeAllAlias(const std::vector& values); + void makePointerTo(const Value* value, const Value* to); void mapAliases(at::ArrayRef to, at::ArrayRef from); void giveFreshAlias(const Value* value); + static bool shouldAnnotate(const Value* v); + static bool shouldAnnotate(const TypePtr& type); bool hasUsesAfter(Symbol alias, const Node* n) const; - bool writesTo(Node* n, const Value* v) const; bool isBeforeSameGraph(const Node* lhs, const Node* rhs) const; + // Returns true iff `v` is part of the alias tracker/is a wildcard + bool isTracked(const Value* v) const; + + // Get the values that represent the memory locations that `v` may point to. + // Return values are guaranteed to be "fresh" tensors--they do not point to + // anything else. + ValueSet getMemoryLocations(const Value* v) const; + std::shared_ptr graph_; std::unordered_map subgraphToOwner_; + + // The points-to graph that stores aliasing relationships + std::unique_ptr memoryDAG_; + // Mapping of values to MemoryDAG elements + std::unordered_map elementMap_; + + // All values that may point to a wildcard value. + ValueSet wildcards_; + // All nodes that write to a wildcard + std::unordered_set wildcardWriters_; + // All nodes that contain a wildcard std::unordered_set wildcardNodes_; - std::unique_ptr aliasTracker_; + + // State for tracking write info + size_t numWrites_ = 0; + std::unordered_map writeIndex_; + mutable std::unordered_set writeCache_; + mutable bool isWriteCacheStale_ = true; + void rebuildWriteCache() const; }; +// Used to assert that unschematized operators have an analysis method written TORCH_API bool aliasAnalysisHasSpecialCaseFor(c10::Symbol sym); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/utils/alias_tracker.cpp b/torch/csrc/jit/passes/utils/alias_tracker.cpp deleted file mode 100644 index e6cd358..0000000 --- a/torch/csrc/jit/passes/utils/alias_tracker.cpp +++ /dev/null @@ -1,266 +0,0 @@ -#include "alias_tracker.h" - -#include -#include - -namespace torch { -namespace jit { - -// Returns true iff `v` is present in the alias set tracker. -bool AliasTracker::contains(const Value* v) const { - return isWildcard(v) || map_.count(v); -} - -bool AliasTracker::mayAlias(const Value* a, const Value* b) const { - if (isWildcard(a) || isWildcard(b)) { - return true; - } - - if (!map_.count(a) || !map_.count(b)) { - return false; - } - - const auto aEl = map_.at(a); - const auto bEl = map_.at(b); - - const auto aMemLoc = aEl->getMemoryLocations(); - const auto bMemLoc = bEl->getMemoryLocations(); - - // XXX: This could be more efficiently done as a bitwise AND on two bitfields - // that represent memory location membership. If these comparisons end up - // being a bottleneck, consider implementing it that way. - for (const auto aLoc : aMemLoc) { - for (const auto bLoc : bMemLoc) { - if (aLoc == bLoc) { - return true; - } - } - } - return false; -} - -bool AliasTracker::writesTo(Node* n, const Value* v) const { - if (isWildcard(v)) { - return wildcardWriters_.count(n); - } - - if (!map_.count(v) || !writeIndex_.count(n)) { - return false; - } - - // Can short-circuit if we know this node writes directly to `v` - if (writeIndex_.at(n).count(v)) { - return true; - } - - // Otherwise, check if `v` may alias any of written-to values in `n` - const auto vSet = ValueSet{v}; - return mayAlias(vSet, writeIndex_.at(n)); -} - -// Make `v` point at `to`. -void AliasTracker::makePointerTo(const Value* v, const Value* to) { - if (v == to) { - return; - } - - // If `to` is a wildcard, don't insert anything into the graph; wildcards - // are tracked separately since they have different aliasing rules. - if (isWildcard(to)) { - setWildcard(v); - return; - } - - if (!map_.count(to)) { - makeFreshValue(to); - } - - if (!map_.count(v)) { - makeFreshValue(v); - } - - auto vEl = map_.at(v); - auto toEl = map_.at(to); - - vEl->pointsTo.insert(toEl); - toEl->pointedFrom.insert(vEl); -} - -// Give `v` a fresh alias (i.e. it does not point to any value) -void AliasTracker::makeFreshValue(const Value* v) { - auto el = torch::make_unique(); - el->value = v; - - auto rawPtr = el.get(); - elements_.emplace(rawPtr, std::move(el)); - map_.emplace(v, rawPtr); -} - -// Register `v` as a wildcard value. -void AliasTracker::setWildcard(const Value* v) { - wildcards_.insert(v); -} - -// is `v` a wildcard? -bool AliasTracker::isWildcard(const Value* v) const { - return wildcards_.count(v); -} - -// Register the fact that `n` writes to `v`. -void AliasTracker::registerWrite(const Value* v, Node* n) { - numWrites_++; - - if (isWildcard(v)) { - wildcardWriters_.insert(n); - return; - } - - AT_ASSERT(map_.count(v)); - writeIndex_[n].insert(v); -} - -bool AliasTracker::hasWriters(const Value* v) const { - if (!map_.count(v)) { - return false; - } - - if (isWildcard(v)) { - // If `n` has a wildcard, any write in the graph may write to it. - // So the only way we know there are no writers is if there are no writes - // at all. - return numWrites_ == 0; - } - - if (wildcardWriters_.size() > 0) { - // A write to the wildcard may be a write to any value. - return true; - } - - if (isWriteCacheStale_) { - rebuildWriteCache(); - } - - for (const auto loc : map_.at(v)->getMemoryLocations()) { - if (writeCache_.count(loc)) { - return true; - } - } - - return false; -} - -void AliasTracker::rebuildWriteCache() const { - for (const auto& pr : writeIndex_) { - const auto& writtenValues = pr.second; - - for (const auto value : writtenValues) { - for (const auto loc : map_.at(value)->getMemoryLocations()) { - writeCache_.insert(loc); - } - } - } - isWriteCacheStale_ = false; -} - -void AliasTracker::dump() const { - std::cout << "\n===2. ALIAS DB===\n"; - for (const auto& ptrPair : elements_) { - const auto element = ptrPair.first; - if (element->pointsTo.size() > 0) { - std::cout << element->value->uniqueName() << " points to: "; - for (const auto pointedTo : element->pointsTo) { - std::cout << pointedTo->value->uniqueName() << ", "; - } - std::cout << "\n"; - } - } - - std::cout << "\n===3. WILDCARDS===\n"; - for (const auto wildcard : wildcards_) { - std::cout << wildcard->uniqueName() << ", "; - } - std::cout << "\n"; - - std::cout << "\n===4. Writes===\n"; - for (const auto& pr : writeIndex_) { - const auto node = pr.first; - const auto& values = pr.second; - std::cout << *node; - std::cout << " "; - for (const auto value : values) { - std::cout << value->uniqueName() << ", "; - } - std::cout << "\n"; - } - std::cout << "\n"; -} - -std::unordered_set AliasTracker::Element:: - getMemoryLocations() const { - if (!cachedMemoryLocations_.empty()) { - return cachedMemoryLocations_; - } - - // Do a BFS in the `points-to` direction, collecting all memory locations - std::unordered_set ret; - this->bfs( - [&](const Element* el) { - if (el->pointsTo.empty()) { - ret.insert(el); - } - }, - BfsDirection::POINTS_TO); - - cachedMemoryLocations_ = ret; - return ret; -} - -// Do a breadth-first search over the graph, starting at `this` and -// traversing in the direction `dir`.`fn` will be run on each element. -template -bool AliasTracker::Element::bfs(Fn fn, BfsDirection dir) const { - std::queue queue; - std::unordered_set seen; - - queue.push(this); - while (!queue.empty()) { - const auto el = queue.front(); - queue.pop(); - seen.insert(el); - - fn(el); - - switch (dir) { - case BfsDirection::POINTS_TO: { - for (auto ptr : el->pointsTo) { - if (!seen.count(ptr)) { - queue.push(ptr); - } - } - } break; - - case BfsDirection::POINTED_FROM: { - for (auto ptr : el->pointedFrom) { - if (!seen.count(ptr)) { - queue.push(ptr); - } - } - } break; - } - } - return false; -} - -ValueSet AliasTracker::getMemoryLocations(const Value* v) const { - ValueSet ret; - if (!map_.count(v)) { - return ret; - } - - for (const auto el : map_.at(v)->getMemoryLocations()) { - ret.insert(el->value); - } - return ret; -} -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/passes/utils/alias_tracker.h b/torch/csrc/jit/passes/utils/alias_tracker.h deleted file mode 100644 index afeae89..0000000 --- a/torch/csrc/jit/passes/utils/alias_tracker.h +++ /dev/null @@ -1,159 +0,0 @@ -#pragma once - -#include - -namespace torch { -namespace jit { - -// class AliasTracker -// -// This class tracks the "A points to B" graph for all values, as well as -// wildcards and writes. It is used by AliasDb to provide a higher-level API. -// -// We maintain a DAG where: -// - Vertices (called "elements") represent values and -// other aliasing entities (e.g. like the stuff inside a list) -// - Edges represent a "points-to" relationship. -// -// Leaves in this DAG are entities that don't point to anything, and thus -// correspond to unique "memory locations". -// -// So, by traversing the "points-to" graph to the leaves, you can determine -// which memory locations an element may point to. -class AliasTracker { - public: - // Returns true iff `v` is present in the alias set tracker. - bool contains(const Value* v) const; - - // Does `n` write to a memory location that `v` may point to? - bool writesTo(Node* n, const Value* v) const; - - // Make `v` point at `to`. - void makePointerTo(const Value* v, const Value* to); - - // Give `v` a fresh alias (i.e. it does not point to any value) - void makeFreshValue(const Value* v); - - // Register `v` as a wildcard value. - void setWildcard(const Value* v); - - // is `v` a wildcard? - bool isWildcard(const Value* v) const; - - // Register the fact that `n` writes to `v`. - void registerWrite(const Value* v, Node* n); - - // Does anything write to the memory locations that `v` may point to? - bool hasWriters(const Value* v) const; - - // Get all nodes that write to a wildcard value. - const std::unordered_set& getWildcardWriters() const { - return wildcardWriters_; - } - - // Get the values that represent the memory locations that `v` may point to. - // Return values are guaranteed to be "fresh" tensors--they do not point to - // anything else. - ValueSet getMemoryLocations(const Value* v) const; - - // Do `a` and `b` potentially share a memory location? - bool mayAlias(const Value* a, const Value* b) const; - - // Do any values in group `a` potentially share a memory location with any - // value in group `b`? - // - // This is written so that either of the inputs could be a multiset - template - bool mayAlias(const T& a, const U& b) const { - if (a.empty() || b.empty()) { - return false; - } - - // Record all memory locations from group `a` - std::unordered_set memoryLocations; - for (auto it = a.cbegin(); it != a.cend();) { - const auto value = *it; - if (isWildcard(value)) { - return true; - } - - if (map_.count(value)) { - for (const auto loc : map_.at(value)->getMemoryLocations()) { - memoryLocations.insert(loc); - } - } - - const auto cnt = a.count(*it); - std::advance(it, cnt); - } - - // If any of group `b`s memory locations overlap, return true. - for (auto it = b.cbegin(); it != b.cend();) { - const auto value = *it; - if (isWildcard(value)) { - return true; - } - - if (map_.count(value)) { - for (const auto loc : map_.at(value)->getMemoryLocations()) { - if (memoryLocations.count(loc)) { - return true; - } - } - } - - const auto cnt = b.count(*it); - std::advance(it, cnt); - } - // No overlap, so group `a` and `b` do not share a memory location - return false; - } - - void dump() const; - - private: - enum class BfsDirection { - POINTS_TO, - POINTED_FROM, - }; - // `Element` represents the vertex in the points-to graph. It has a 1:1 - // relationship with IR `Value`s. - struct Element { - const Value* value = nullptr; - // All elements that this element *may* point to. It's possible to have - // multiple elements that you might point to due to control flow/complex ops - std::unordered_set pointsTo; - // Backreference for points-to. - std::unordered_set pointedFrom; - - std::unordered_set getMemoryLocations() const; - // We do path compression to make repeated memory location queries faster. - // An empty cache means it is invalidated (it can never be empty otherwise, - // since every element must point to at least one memory location). - mutable std::unordered_set cachedMemoryLocations_; - - // Do a breadth-first search over the graph, starting at `this` and - // traversing in the direction `dir`.`fn` will be run on each element. - template - bool bfs(Fn fn, BfsDirection dir) const; - }; - - // Structure that owns all the element pointers. It's a map of - // raw pointer -> unique_ptr to facilitate easy queries - std::unordered_map> elements_; - // Index to look up whatever element corresponds to that value. - std::unordered_map map_; - // All values that may point to a wildcard value. - ValueSet wildcards_; - // All nodes that write to a wildcard - std::unordered_set wildcardWriters_; - size_t numWrites_ = 0; - - std::unordered_map writeIndex_; - mutable std::unordered_set writeCache_; - mutable bool isWriteCacheStale_ = true; - void rebuildWriteCache() const; -}; - -} // namespace jit -} // namespace torch diff --git a/torch/csrc/jit/passes/utils/memory_dag.cpp b/torch/csrc/jit/passes/utils/memory_dag.cpp new file mode 100644 index 0000000..4f5b75a --- /dev/null +++ b/torch/csrc/jit/passes/utils/memory_dag.cpp @@ -0,0 +1,105 @@ +#include "memory_dag.h" + +#include +#include + +namespace torch { +namespace jit { + +bool MemoryDAG::mayAlias(Element* a, Element* b) const { + return mayAliasImpl(a, b); +} + +bool MemoryDAG::mayAlias(const Element* a, const Element* b) const { + return mayAliasImpl(a, b); +} + +bool MemoryDAG::mayAliasImpl(const Element* a, const Element* b) const { + const auto aMemLoc = a->getMemoryLocations(); + const auto bMemLoc = b->getMemoryLocations(); + + // XXX: This could be more efficiently done as a bitwise AND on two bitfields + // that represent memory location membership. If these comparisons end up + // being a bottleneck, consider implementing it that way. + for (const auto aLoc : aMemLoc) { + for (const auto bLoc : bMemLoc) { + if (aLoc == bLoc) { + return true; + } + } + } + return false; +} + +// Make `v` point at `to`. +void MemoryDAG::makePointerTo(Element* from, Element* to) { + from->pointsTo.insert(to); + to->pointedFrom.insert(from); +} + +// Give `v` a fresh alias (i.e. it does not point to any value) +Element* MemoryDAG::makeFreshValue(const Value* v) { + auto el = torch::make_unique(); + el->value = v; + + auto rawPtr = el.get(); + elements_.emplace(rawPtr, std::move(el)); + return rawPtr; +} + +std::unordered_set Element::getMemoryLocations() const { + if (!cachedMemoryLocations_.empty()) { + return cachedMemoryLocations_; + } + + // Do a BFS in the `points-to` direction, collecting all memory locations + std::unordered_set ret; + this->bfs( + [&](const Element* el) { + if (el->pointsTo.empty()) { + ret.insert(el); + } + }, + BfsDirection::POINTS_TO); + + cachedMemoryLocations_ = ret; + return ret; +} + +// Do a breadth-first search over the graph, starting at `this` and +// traversing in the direction `dir`.`fn` will be run on each element. +template +bool Element::bfs(Fn fn, BfsDirection dir) const { + std::queue queue; + std::unordered_set seen; + + queue.push(this); + while (!queue.empty()) { + const auto el = queue.front(); + queue.pop(); + seen.insert(el); + + fn(el); + + switch (dir) { + case BfsDirection::POINTS_TO: { + for (auto ptr : el->pointsTo) { + if (!seen.count(ptr)) { + queue.push(ptr); + } + } + } break; + + case BfsDirection::POINTED_FROM: { + for (auto ptr : el->pointedFrom) { + if (!seen.count(ptr)) { + queue.push(ptr); + } + } + } break; + } + } + return false; +} +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/passes/utils/memory_dag.h b/torch/csrc/jit/passes/utils/memory_dag.h new file mode 100644 index 0000000..6142481 --- /dev/null +++ b/torch/csrc/jit/passes/utils/memory_dag.h @@ -0,0 +1,121 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace jit { + +struct Element; +struct Value; + +// class MemoryDAG +// +// This class tracks the "A points to B" graph for all values. It is used by +// AliasDb to provide a higher-level API. +// +// We maintain a DAG where: +// - Vertices (called "elements") represent values and +// other aliasing entities (e.g. like the stuff inside a list) +// - Edges represent a "points-to" relationship. +// +// Leaves in this DAG are entities that don't point to anything, and thus +// correspond to unique "memory locations". +// +// So, by traversing the "points-to" graph to the leaves, you can determine +// which memory locations an element may point to. +class MemoryDAG { + public: + // Make `from` point at `to`. + void makePointerTo(Element* from, Element* to); + + // Make a fresh element (i.e. an element that doesn't point to anything) and + // return it. + Element* makeFreshValue(const Value* v); + + // Do `a` and `b` potentially share a memory location? + bool mayAlias(const Element* a, const Element* b) const; + bool mayAlias(Element* a, Element* b) const; + + // Do any values in group `a` potentially share a memory location with any + // value in group `b`? + // + // This is written so that either of the inputs could be a multiset + template + bool mayAlias(const T& a, const U& b) const { + if (a.empty() || b.empty()) { + return false; + } + + // Record all memory locations from group `a` + std::unordered_set memoryLocations; + for (auto it = a.cbegin(); it != a.cend();) { + const auto element = *it; + + for (const auto loc : element->getMemoryLocations()) { + memoryLocations.insert(loc); + } + + const auto cnt = a.count(*it); + std::advance(it, cnt); + } + + // If any of group `b`s memory locations overlap, return true. + for (auto it = b.cbegin(); it != b.cend();) { + const auto element = *it; + + for (const auto loc : element->getMemoryLocations()) { + if (memoryLocations.count(loc)) { + return true; + } + } + + const auto cnt = b.count(*it); + std::advance(it, cnt); + } + // No overlap, so group `a` and `b` do not share a memory location + return false; + } + + private: + bool mayAliasImpl(const Element* a, const Element* b) const; + // Structure that owns all the element pointers. It's a map of + // raw pointer -> unique_ptr to facilitate easy queries + std::unordered_map> elements_; +}; + +enum class BfsDirection { + POINTS_TO, + POINTED_FROM, +}; + +// `Element` represents the vertex in the points-to graph. It represents +// anything that could have an aliasing relationship, mostly IR `Value`s, but +// also the "inside of a list", or wildcards. +struct Element { + // The value that this element corresponds to. May be null if this element + // doesn't represent a first-class value. + const Value* value = nullptr; + + // All elements that this element *may* point to. It's possible to have + // multiple elements that you might point to due to control flow/complex ops + std::unordered_set pointsTo; + // Backreference for points-to. + std::unordered_set pointedFrom; + + // Return the unique memory locations that `Element` might represent. + std::unordered_set getMemoryLocations() const; + // We do path compression to make repeated memory location queries faster. + // An empty cache means it is invalidated (it can never be empty otherwise, + // since every element must point to at least one memory location). + mutable std::unordered_set cachedMemoryLocations_; + + // Do a breadth-first search over the graph, starting at `this` and + // traversing in the direction `dir`.`fn` will be run on each element. + template + bool bfs(Fn fn, BfsDirection dir) const; +}; + +} // namespace jit +} // namespace torch