JIT_TEST(TopologicalMove)
JIT_TEST(SubgraphUtils)
JIT_TEST(AliasAnalysis)
-JIT_TEST(AliasTracker)
+JIT_TEST(WriteTracking)
+JIT_TEST(Wildcards)
+JIT_TEST(MemoryDAG)
JIT_TEST(IRParser)
JIT_TEST(NetDefConverter)
testATenNativeBatchNorm();
testRegisterFusionCachesKernel();
testAliasAnalysis();
- testAliasTracker();
+ testWriteTracking();
+ testWildcards();
+ testMemoryDAG();
testNetDefConverter(out);
testIRParser(out);
return out.str();
#pragma once
#include "test/cpp/jit/test_base.h"
+#include "torch/csrc/jit/custom_operator.h"
#include "torch/csrc/jit/passes/alias_analysis.h"
#include "torch/csrc/jit/script/compiler.h"
#include "torch/csrc/utils/memory.h"
}
}
-void testAliasTracker() {
- auto graph = std::make_shared<Graph>();
- const Value* a = graph->addInput();
- const Value* b = graph->addInput();
- const Value* c = graph->addInput();
- const Value* d = graph->addInput();
- const Value* e = graph->addInput();
- const Value* f = graph->addInput();
- const Value* g = graph->addInput();
- const Value* wc = graph->addInput();
-
+void testWriteTracking() {
+ RegisterOperators reg({createOperator(
+ "foo::creates_alias(Tensor(a) x) -> Tensor(a)",
+ [](at::Tensor a) { return a; })});
+ const auto creates_alias = Symbol::fromQualString("foo::creates_alias");
+ const auto returns_wildcard = Symbol::fromQualString("foo::returns_wildcard");
{
- // test contains()
- AliasTracker t;
- t.makeFreshValue(a);
- ASSERT_TRUE(t.contains(a));
- ASSERT_FALSE(t.contains(b));
+ auto graph = std::make_shared<Graph>();
+ auto a = graph->addInput();
+ auto b = graph->addInput();
+
+ // aten::add(%b, %b)
+ // aten::add_(%a, %b)
+ // foo::creates_alias(%a)
+ auto pureNode = graph->insert(aten::add, {b, b})->node();
+ auto writingNode = graph->insert(aten::add_, {a, b})->node();
+ auto node3 = graph->insert(creates_alias, {a})->node();
+ auto aAlias = node3->output();
+
+ graph->lint();
+
+ AliasDb aliasDb(graph);
+ ASSERT_TRUE(aliasDb.mayAlias(aAlias, a));
+ ASSERT_TRUE(aliasDb.mayAlias(a, b));
+ ASSERT_FALSE(
+ aliasDb.writesToAlias(pureNode, std::unordered_set<const Value*>{a}));
+ ASSERT_FALSE(
+ aliasDb.writesToAlias(pureNode, std::unordered_set<const Value*>{b}));
+ ASSERT_TRUE(aliasDb.writesToAlias(
+ writingNode, std::unordered_set<const Value*>{a}));
+ ASSERT_TRUE(aliasDb.writesToAlias(
+ writingNode, std::unordered_set<const Value*>{a, b}));
+ ASSERT_TRUE(aliasDb.writesToAlias(
+ writingNode, std::unordered_set<const Value*>{aAlias}));
}
+}
+
+void testWildcards() {
+ RegisterOperators reg({createOperator(
+ "foo::returns_wildcard(Tensor a) -> Tensor(*)",
+ [](at::Tensor a) { return a; }),
+ createOperator(
+ "foo::writes(Tensor(z!) a) -> Tensor(a)",
+ [](at::Tensor a) { return a; })});
+ const auto returns_wildcard = Symbol::fromQualString("foo::returns_wildcard");
+ const auto writes = Symbol::fromQualString("foo::writes");
+
+ auto graph = std::make_shared<Graph>();
+ const auto a = graph->addInput();
+
+ const auto constant = graph->insertConstant(1);
+ const auto fresh = graph->insert(aten::rand, {constant});
+ const auto fresh2 = graph->insert(aten::rand, {constant});
+ const auto wildcard = graph->insert(returns_wildcard, {fresh});
+ const auto wildcardWrite = graph->insert(writes, {wildcard})->node();
+
+ graph->lint();
+ AliasDb aliasDb(graph);
+
+ ASSERT_FALSE(aliasDb.mayAlias(a, fresh));
+ ASSERT_TRUE(aliasDb.mayAlias(wildcard, fresh));
+ ASSERT_TRUE(aliasDb.mayAlias(wildcard, a));
+ ASSERT_FALSE(aliasDb.mayAlias(
+ std::unordered_set<const Value*>({wildcard}),
+ std::unordered_set<const Value*>()));
+
+ // Test writes to wildcards
+ ASSERT_TRUE(aliasDb.writesToAlias(
+ wildcardWrite, std::unordered_set<const Value*>{fresh}));
+ ASSERT_TRUE(aliasDb.writesToAlias(
+ wildcardWrite, std::unordered_set<const Value*>{fresh2}));
+ ASSERT_TRUE(aliasDb.writesToAlias(
+ wildcardWrite, std::unordered_set<const Value*>{a}));
+}
+
+void testMemoryDAG() {
+ auto graph = std::make_shared<Graph>();
+ const Value* aValue = graph->addInput();
+ const Value* bValue = graph->addInput();
+ const Value* cValue = graph->addInput();
+ const Value* dValue = graph->addInput();
+ const Value* eValue = graph->addInput();
+ const Value* fValue = graph->addInput();
+ const Value* gValue = graph->addInput();
+
{
// a <- b <- c
// b <- d
// a <- e
// f <- e
// g is by itself
- // wc is a wildcard value
- AliasTracker t;
- t.makeFreshValue(a);
- t.makeFreshValue(f);
- t.makeFreshValue(g);
+ MemoryDAG t;
+ auto a = t.makeFreshValue(aValue);
+ auto b = t.makeFreshValue(bValue);
+ auto c = t.makeFreshValue(cValue);
+ auto d = t.makeFreshValue(dValue);
+ auto e = t.makeFreshValue(eValue);
+ auto f = t.makeFreshValue(fValue);
+ auto g = t.makeFreshValue(gValue);
t.makePointerTo(b, a);
t.makePointerTo(c, b);
t.makePointerTo(d, b);
t.makePointerTo(e, a);
t.makePointerTo(e, f);
- t.setWildcard(wc);
/**
* Test mayAlias()
// But a and f don't alias
ASSERT_FALSE(t.mayAlias(a, f));
- // Wildcards should alias everything
- ASSERT_TRUE(t.mayAlias(wc, a));
- ASSERT_TRUE(t.mayAlias(wc, b));
- ASSERT_TRUE(t.mayAlias(wc, f));
- ASSERT_TRUE(t.mayAlias(wc, g));
-
/**
* Test mayAlias() set interface
*/
- std::multiset<const Value*> foo{c, c, d};
- std::multiset<const Value*> bar{e, f};
- std::unordered_set<const Value*> baz{f, g};
- std::set<const Value*> containsWildcard{wc};
+ std::multiset<const Element*> foo{c, c, d};
+ std::multiset<const Element*> bar{e, f};
+ std::unordered_set<const Element*> baz{f, g};
ASSERT_TRUE(t.mayAlias(foo, bar));
ASSERT_TRUE(t.mayAlias(bar, baz));
ASSERT_FALSE(t.mayAlias(foo, baz));
- // wildcard stuff aliases everything
- ASSERT_TRUE(t.mayAlias(containsWildcard, foo));
- ASSERT_TRUE(t.mayAlias(containsWildcard, bar));
- ASSERT_TRUE(t.mayAlias(containsWildcard, baz));
-
- /**
- * Test writer tracking
- */
- auto n1 = graph->appendNode(graph->create(prim::Undefined));
- auto n2 = graph->appendNode(graph->create(prim::Undefined));
- auto n3 = graph->appendNode(graph->create(prim::Undefined));
- t.registerWrite(a, n1);
- t.registerWrite(f, n2);
- // We should report those writes accurately
- ASSERT_TRUE(t.writesTo(n1, a));
- ASSERT_TRUE(t.writesTo(n2, f));
- ASSERT_FALSE(t.writesTo(n1, f));
- ASSERT_FALSE(t.writesTo(n2, a));
- // We should correctly report writes to aliases as well
- ASSERT_TRUE(t.writesTo(n1, c));
-
- // Check hasWriters()
- ASSERT_TRUE(t.hasWriters(a));
- // Aliases of written-to values should have writers
- ASSERT_TRUE(t.hasWriters(b));
- ASSERT_TRUE(t.hasWriters(d));
- ASSERT_TRUE(t.hasWriters(e));
- // Unique values not registered should be unaffected
- ASSERT_FALSE(t.hasWriters(g));
-
- // create a write to the wildcard set
- t.registerWrite(wc, n3);
- // Now everything may be written to
- ASSERT_TRUE(t.hasWriters(g));
- const auto& wildcardWriters = t.getWildcardWriters();
- ASSERT_EQ(wildcardWriters.size(), 1);
- ASSERT_EQ(*wildcardWriters.begin(), n3);
}
}
} // namespace jit
"torch/csrc/jit/passes/shape_analysis.cpp",
"torch/csrc/jit/passes/specialize_undef.cpp",
"torch/csrc/jit/passes/utils/subgraph_utils.cpp",
- "torch/csrc/jit/passes/utils/alias_tracker.cpp",
+ "torch/csrc/jit/passes/utils/memory_dag.cpp",
"torch/csrc/jit/register_prim_ops.cpp",
"torch/csrc/jit/register_special_ops.cpp",
"torch/csrc/jit/scope.cpp",
${TORCH_SRC_DIR}/csrc/jit/passes/python_print.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp
- ${TORCH_SRC_DIR}/csrc/jit/passes/utils/alias_tracker.cpp
+ ${TORCH_SRC_DIR}/csrc/jit/passes/utils/memory_dag.cpp
${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp
${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp
${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp
namespace torch {
namespace jit {
-namespace {
-bool shouldAnnotate(const TypePtr& type) {
+
+bool AliasDb::shouldAnnotate(const TypePtr& type) {
return type->isSubtypeOf(TensorType::get()) ||
type->kind() == TypeKind::ListType ||
type->kind() == TypeKind::TupleType ||
// We only need to annotate values that either are mutable or could contain
// mutable types.
-bool shouldAnnotate(const Value* v) {
+bool AliasDb::shouldAnnotate(const Value* v) {
return shouldAnnotate(v->type());
}
-} // namespace
AliasDb::~AliasDb() = default;
AliasDb::AliasDb(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {
- aliasTracker_ = torch::make_unique<AliasTracker>();
+ memoryDAG_ = torch::make_unique<MemoryDAG>();
analyze(graph_);
}
// Does `n` use or write to any wildcard aliases?
bool AliasDb::hasWildcard(const Node* n) const {
for (const auto input : n->inputs()) {
- if (aliasTracker_->isWildcard(input)) {
+ if (isWildcard(input)) {
return true;
}
}
-
for (const auto output : n->outputs()) {
- if (aliasTracker_->isWildcard(output)) {
+ if (isWildcard(output)) {
return true;
}
}
return false;
}
+bool AliasDb::isWildcard(const Value* v) const {
+ return wildcards_.count(v);
+}
+
bool AliasDb::writesTo(Node* n, const Value* v) const {
if (!shouldAnnotate(v)) {
// This is a primitive type
return false;
}
- return aliasTracker_->writesTo(n, v);
+ if (isWildcard(v)) {
+ return wildcardWriters_.count(n);
+ }
+
+ if (!elementMap_.count(v) || !writeIndex_.count(n)) {
+ return false;
+ }
+
+ // Can short-circuit if we know this node writes directly to `v`
+ if (writeIndex_.at(n).count(v)) {
+ return true;
+ }
+
+ // Otherwise, check if `v` may alias any of written-to values in `n`
+ const auto vSet = ValueSet{v};
+ return mayAlias(vSet, writeIndex_.at(n));
}
bool AliasDb::hasWriters(const Node* n) const {
for (const auto input : n->inputs()) {
- if (aliasTracker_->hasWriters(input)) {
+ if (hasWriters(input)) {
return true;
}
}
for (const auto output : n->outputs()) {
- if (aliasTracker_->hasWriters(output)) {
+ if (hasWriters(output)) {
return true;
}
}
return false;
}
+bool AliasDb::hasWriters(const Value* v) const {
+ if (isWildcard(v)) {
+ // If `n` has a wildcard, any write in the graph may write to it.
+ // So the only way we know there are no writers is if there are no writes
+ // at all.
+ return numWrites_ == 0;
+ }
+
+ if (!elementMap_.count(v)) {
+ return false;
+ }
+
+ if (wildcardWriters_.size() > 0) {
+ // A write to the wildcard may be a write to any value.
+ return true;
+ }
+
+ if (isWriteCacheStale_) {
+ rebuildWriteCache();
+ }
+
+ for (const auto loc : elementMap_.at(v)->getMemoryLocations()) {
+ if (writeCache_.count(loc)) {
+ return true;
+ }
+ }
+
+ return false;
+}
+
bool AliasDb::hasWrites(Node* n) const {
for (const auto input : n->inputs()) {
if (writesTo(n, input)) {
graph_->inputs().cbegin(),
graph_->inputs().cend(),
[&](const Value* graphInput) {
- return shouldAnnotate(graphInput) &&
- aliasTracker_->mayAlias(graphInput, v);
+ return shouldAnnotate(graphInput) && mayAlias(graphInput, v);
});
});
}
-bool AliasDb::mayAlias(const ValueSet& a, const ValueSet& b) const {
- return aliasTracker_->mayAlias(a, b);
-}
-
void AliasDb::getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks) const {
for (auto node : b->nodes()) {
getWritesImpl(node, ret, recurseBlocks);
return writes;
}
-
// Does `n` write to an alias of one of the values in `vs`?
bool AliasDb::writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks)
const {
std::cout << "\n===1. GRAPH===\n";
graph_->dump();
- aliasTracker_->dump();
+ std::cout << "\n===2. ALIAS DB===\n";
+ for (const auto& ptrPair : elementMap_) {
+ const auto element = ptrPair.second;
+ if (element->pointsTo.size() > 0) {
+ std::cout << element->value->uniqueName() << " points to: ";
+ for (const auto pointedTo : element->pointsTo) {
+ std::cout << pointedTo->value->uniqueName() << ", ";
+ }
+ std::cout << "\n";
+ }
+ }
+
+ std::cout << "\n===3. WILDCARDS===\n";
+ for (const auto wildcard : wildcards_) {
+ std::cout << wildcard->uniqueName() << ", ";
+ }
+ std::cout << "\n";
+
+ std::cout << "\n===4. Writes===\n";
+ for (const auto& pr : writeIndex_) {
+ const auto node = pr.first;
+ const auto& values = pr.second;
+ std::cout << *node;
+ std::cout << " ";
+ for (const auto value : values) {
+ std::cout << value->uniqueName() << ", ";
+ }
+ std::cout << "\n";
+ }
+ std::cout << "\n";
}
-// TODO: need to create a dummy "graph input alias" value in setTracker for all
+// TODO: need to create a dummy "graph input alias" value in MemoryDAG for all
// inputs of the same type to point to. Currently they all point to the first
// element, which is technically wrong.
-static void makeAllAlias(
- const std::vector<Value*> values,
- AliasTracker& setTracker) {
+void AliasDb::makeAllAlias(const std::vector<Value*>& values) {
if (values.size() > 0) {
- setTracker.makeFreshValue(values[0]);
+ giveFreshAlias(values[0]);
}
for (const auto value : values) {
- setTracker.makePointerTo(value, values[0]);
+ makePointerTo(value, values[0]);
}
}
// 2. Make all partitions alias each other
for (const auto& pr : listTypes) {
- makeAllAlias(pr.second, *aliasTracker_);
+ makeAllAlias(pr.second);
}
for (const auto& pr : tupleTypes) {
- makeAllAlias(pr.second, *aliasTracker_);
+ makeAllAlias(pr.second);
}
for (const auto& pr : dictTypes) {
- makeAllAlias(pr.second, *aliasTracker_);
+ makeAllAlias(pr.second);
}
for (const auto& pr : userTypes) {
- makeAllAlias(pr.second, *aliasTracker_);
+ makeAllAlias(pr.second);
}
- makeAllAlias(tensors, *aliasTracker_);
+ makeAllAlias(tensors);
analyze(graph->block());
}
// Record writes
if (formal->isWrite()) {
- aliasTracker_->registerWrite(actualValue, node);
+ registerWrite(actualValue, node);
}
}
AT_ASSERT(formal->containedTypes().size() == 0);
if (formal->isWildcard()) {
- aliasTracker_->setWildcard(actual);
+ setWildcard(actual);
continue;
}
}
auto toAlias = formalToActual.at(formalAlias);
- makeAliasOf(actual, toAlias);
+ makePointerTo(actual, toAlias);
}
// Record writes
if (formal->isWrite()) {
- aliasTracker_->registerWrite(actual, node);
+ registerWrite(actual, node);
}
}
}
+// Register the fact that `n` writes to `v`.
+void AliasDb::registerWrite(const Value* v, Node* n) {
+ numWrites_++;
+
+ if (isWildcard(v)) {
+ wildcardWriters_.insert(n);
+ return;
+ }
+
+ AT_ASSERT(elementMap_.count(v));
+ writeIndex_[n].insert(v);
+}
void AliasDb::analyzeIf(Node* node) {
// For if statements, the alias set of an output is the union of the
const auto trueOutput = trueBlock->outputs().at(i);
const auto falseOutput = falseBlock->outputs().at(i);
- makeAliasOf(nodeOutput, trueOutput);
- makeAliasOf(nodeOutput, falseOutput);
+ makePointerTo(nodeOutput, trueOutput);
+ makePointerTo(nodeOutput, falseOutput);
}
}
// subgraph block.
AT_ASSERT(subgraphBlock->outputs().size() >= node->outputs().size());
for (size_t i = 0; i < node->outputs().size(); i++) {
- makeAliasOf(node->outputs()[i], subgraphBlock->outputs()[i]);
+ makePointerTo(node->outputs()[i], subgraphBlock->outputs()[i]);
}
}
void AliasDb::analyzeExtractor(Node* node) {
for (const auto output : node->outputs()) {
if (shouldAnnotate(output)) {
- aliasTracker_->setWildcard(output);
+ setWildcard(output);
}
}
}
// For torch.chunk(), all returned tensors may alias the input tensor
void AliasDb::analyzeChunk(Node* node) {
for (auto output : node->outputs()) {
- makeAliasOf(output, node->input());
+ makePointerTo(output, node->input());
}
}
const auto fut = node->input();
AT_ASSERT(fut->type()->kind() == TypeKind::FutureType);
- if (aliasTracker_->isWildcard(fut)) {
+ if (isWildcard(fut)) {
for (const auto output : node->outputs()) {
- aliasTracker_->setWildcard(output);
+ setWildcard(output);
}
return;
}
- const auto originFuts = aliasTracker_->getMemoryLocations(fut);
+ const auto originFuts = getMemoryLocations(fut);
for (const auto originFut : originFuts) {
const auto subgraphNode = originFut->node();
// since the writes may or may not have been executed yet. But we'll let
// users do that and shoot themselves in the foot for now.
for (const auto write : subgraphWrites) {
- aliasTracker_->registerWrite(write, node);
+ registerWrite(write, node);
}
}
}
void AliasDb::analyzeSetAttr(Node* node) {
const auto self = node->inputs().at(0);
AT_ASSERT(self->type()->kind() == TypeKind::UserType);
- aliasTracker_->registerWrite(self, node);
+ registerWrite(self, node);
}
// BroadcastingChunk: all inputs are broadcasted, and then individually chunked.
// inputs[i] produces chunks outputs[i * nchunks + k] for k in [0..nchunks)
auto output_begin = outputs.begin() + index * nchunks;
for (auto it = output_begin; it != output_begin + nchunks; ++it) {
- makeAliasOf(*it, inputs.at(index));
+ makePointerTo(*it, inputs.at(index));
}
}
}
// Register the fact that `value` is a pointer to `to`
-void AliasDb::makeAliasOf(const Value* value, const Value* to) {
- if (!shouldAnnotate(value)) {
+void AliasDb::makePointerTo(const Value* from, const Value* to) {
+ if (!shouldAnnotate(from)) {
AT_ASSERT(!shouldAnnotate(to));
return;
}
- aliasTracker_->makePointerTo(value, to);
+
+ if (from == to) {
+ return;
+ }
+
+ // If either value is a wildcard, don't insert anything into the graph;
+ // wildcards are tracked separately since they have different aliasing rules.
+ if (isWildcard(to) || isWildcard(from)) {
+ setWildcard(from);
+ return;
+ }
+
+ if (!isTracked(from)) {
+ giveFreshAlias(from);
+ }
+ if (!isTracked(to)) {
+ giveFreshAlias(to);
+ }
+ auto fromEl = elementMap_.at(from);
+ auto toEl = elementMap_.at(to);
+ memoryDAG_->makePointerTo(fromEl, toEl);
+}
+
+bool AliasDb::mayAlias(const Value* a, const Value* b) const {
+ if (isWildcard(a) || isWildcard(b)) {
+ return true;
+ }
+
+ if (!elementMap_.count(a) || !elementMap_.count(b)) {
+ return false;
+ }
+
+ return memoryDAG_->mayAlias(elementMap_.at(a), elementMap_.at(b));
}
// Make each value in the `from` list point to its partner in the `to` list
void AliasDb::mapAliases(at::ArrayRef<Value*> from, at::ArrayRef<Value*> to) {
AT_ASSERT(to.size() == from.size());
for (size_t i = 0; i < to.size(); i++) {
- makeAliasOf(from[i], to[i]);
+ makePointerTo(from[i], to[i]);
}
}
return;
}
- if (aliasTracker_->contains(value)) {
+ if (isTracked(value)) {
// Inside a loop, we may have given a fresh alias to this value already, so
// skip
return;
}
- aliasTracker_->makeFreshValue(value);
+ elementMap_[value] = memoryDAG_->makeFreshValue(value);
+}
+
+bool AliasDb::isTracked(const Value* v) const {
+ return isWildcard(v) || elementMap_.count(v);
}
bool AliasDb::moveAfterTopologicallyValid(Node* n, Node* movePoint) {
// 2. Handle regular mutable dependencies
// Check that `n` does not write to anything used by the working set
const auto nWrites = aliasDb_.getWrites(n, /*recurseBlocks=*/true);
- if (aliasDb_.aliasTracker_->mayAlias(nWrites, reads_)) {
+ if (aliasDb_.mayAlias(nWrites, reads_)) {
return true;
}
// Check that the working set doesn't write to anything that `n` uses.
const auto nReads = aliasDb_.getReads(n, /*recurseBlocks=*/true);
- if (aliasDb_.aliasTracker_->mayAlias(writes_, nReads)) {
+ if (aliasDb_.mayAlias(writes_, nReads)) {
return true;
}
return false;
return handled.count(symbol) || purposefully_not_handled.count(symbol);
}
+// Register `v` as a wildcard value.
+void AliasDb::setWildcard(const Value* v) {
+ wildcards_.insert(v);
+}
+
+void AliasDb::rebuildWriteCache() const {
+ for (const auto& pr : writeIndex_) {
+ const auto& writtenValues = pr.second;
+
+ for (const auto value : writtenValues) {
+ for (const auto loc : elementMap_.at(value)->getMemoryLocations()) {
+ writeCache_.insert(loc);
+ }
+ }
+ }
+ isWriteCacheStale_ = false;
+}
+
+ValueSet AliasDb::getMemoryLocations(const Value* v) const {
+ ValueSet ret;
+ if (!elementMap_.count(v)) {
+ return ret;
+ }
+
+ for (const auto el : elementMap_.at(v)->getMemoryLocations()) {
+ ret.insert(el->value);
+ }
+ return ret;
+}
} // namespace jit
} // namespace torch
#include <torch/csrc/jit/alias_info.h>
#include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/passes/utils/alias_tracker.h>
+#include <torch/csrc/jit/passes/utils/memory_dag.h>
namespace torch {
namespace jit {
* we're not sure what this value may alias. To be conservative, we consider
* the wildcard alias set as potentially aliasing any value.
*/
-
class AliasDb {
public:
TORCH_API explicit AliasDb(std::shared_ptr<Graph> graph);
bool writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks = false)
const;
+ // Do `a` and `b` potentially share a memory location?
+ bool mayAlias(const Value* a, const Value* b) const;
// Do any values in group `a` potentially share a memory location with any
- // value in group `b`?
- bool mayAlias(const ValueSet& a, const ValueSet& b) const;
+ // value in group `b`? i.e. may they overlap?
+ //
+ // NOTE: Bit of ugly templating, but this is just to make sure we can
+ // transform an arbitrary container of `Values` to the same container of
+ // `Elements`.
+ template <
+ typename... Other1,
+ template <typename, typename...> class T,
+ typename... Other2,
+ template <typename, typename...> class U>
+ bool mayAlias(
+ const T<const Value*, Other1...>& a,
+ const U<const Value*, Other2...>& b) const {
+ if (a.empty() || b.empty()) {
+ return false;
+ }
+ // Short-circuit for special case: if any value is a wildcard, the two sets
+ // may alias
+ if (std::any_of(
+ a.cbegin(),
+ a.cend(),
+ [this](const Value* v) { return isWildcard(v); }) ||
+ std::any_of(b.cbegin(), b.cend(), [this](const Value* v) {
+ return isWildcard(v);
+ })) {
+ return true;
+ }
+
+ T<Element*> aElements;
+ for (const Value* v : a) {
+ if (elementMap_.count(v)) {
+ aElements.insert(elementMap_.at(v));
+ }
+ }
+
+ U<Element*> bElements;
+ for (const Value* v : b) {
+ if (elementMap_.count(v)) {
+ bElements.insert(elementMap_.at(v));
+ }
+ }
+
+ return memoryDAG_->mayAlias(aElements, bElements);
+ }
// Do any nodes write to an alias set inputed/outputed by `n`?
bool hasWriters(const Node* n) const;
void move(Node* toMove, Node* movePoint, MoveSide moveSide);
bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
-
+ /**
+ * Write and read internal API
+ */
+ // Does `n` write to any alias sets?
+ bool hasWrites(Node* n) const;
// Get all the values that `n` writes to.
// NOTE: this only returns values directly written to, not aliases thereof
//
ValueSet getWrites(Block* b) const;
void getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks = false) const;
void getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const;
-
+ // Do any nodes write to `v`s memory location?
+ bool hasWriters(const Value* v) const;
+ // Register the fact that `n` writes to `v`.
+ void registerWrite(const Value* v, Node* n);
// Get all the values that `n` reads from.
// if `recurseBlocks` is true, gather reads on the nodes in `n`s sub-blocks
ValueSet getReads(Node* n, bool recurseBlocks = false) const;
-
void getReadsImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const;
- // Does `n` write to any alias sets?
- bool hasWrites(Node* n) const;
+ // Does `n` write to a value that may alias one of the graph inputs?
+ bool writesToInputAlias(Node* n) const;
+ // Does `n` write to `v` or any aliases of `v`?
+ bool writesTo(Node* n, const Value* v) const;
+
+ /**
+ * Wildcard methods
+ */
+ // is `v` a wildcard?
+ bool isWildcard(const Value* v) const;
+ // Register `v` as a wildcard value.
+ void setWildcard(const Value* v);
+ // Get all nodes that write to a wildcard value.
+ const std::unordered_set<Node*>& getWildcardWriters() const {
+ return wildcardWriters_;
+ }
// Does `n` use or write to any wildcard aliases?
bool hasWildcard(const Node* n) const;
// Returns nullopt if there are no wildcard nodes
c10::optional<const Node*> getLastWildcard() const;
- // Does `n` write to a value that may alias one of the graph inputs?
- bool writesToInputAlias(Node* n) const;
-
+ /**
+ * Special analysis methods
+ */
void analyze(const std::shared_ptr<Graph>& graph);
void analyze(Block* block);
void analyze(Node* node);
void analyzeImpl(Node* node);
-
void analyzeIf(Node* node);
void analyzeLoop(Node* node);
void analyzeSubgraph(Node* node);
void analyzeWait(Node* node);
void analyzeSetAttr(Node* node);
- void makeAliasOf(const Value* value, const Value* to);
+ /**
+ * Alias manipulation methods
+ */
+ void makeAllAlias(const std::vector<Value*>& values);
+ void makePointerTo(const Value* value, const Value* to);
void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from);
void giveFreshAlias(const Value* value);
+ static bool shouldAnnotate(const Value* v);
+ static bool shouldAnnotate(const TypePtr& type);
bool hasUsesAfter(Symbol alias, const Node* n) const;
- bool writesTo(Node* n, const Value* v) const;
bool isBeforeSameGraph(const Node* lhs, const Node* rhs) const;
+ // Returns true iff `v` is part of the alias tracker/is a wildcard
+ bool isTracked(const Value* v) const;
+
+ // Get the values that represent the memory locations that `v` may point to.
+ // Return values are guaranteed to be "fresh" tensors--they do not point to
+ // anything else.
+ ValueSet getMemoryLocations(const Value* v) const;
+
std::shared_ptr<Graph> graph_;
std::unordered_map<const Graph*, const Node*> subgraphToOwner_;
+
+ // The points-to graph that stores aliasing relationships
+ std::unique_ptr<MemoryDAG> memoryDAG_;
+ // Mapping of values to MemoryDAG elements
+ std::unordered_map<const Value*, Element*> elementMap_;
+
+ // All values that may point to a wildcard value.
+ ValueSet wildcards_;
+ // All nodes that write to a wildcard
+ std::unordered_set<Node*> wildcardWriters_;
+ // All nodes that contain a wildcard
std::unordered_set<const Node*> wildcardNodes_;
- std::unique_ptr<AliasTracker> aliasTracker_;
+
+ // State for tracking write info
+ size_t numWrites_ = 0;
+ std::unordered_map<Node*, ValueSet> writeIndex_;
+ mutable std::unordered_set<const Element*> writeCache_;
+ mutable bool isWriteCacheStale_ = true;
+ void rebuildWriteCache() const;
};
+// Used to assert that unschematized operators have an analysis method written
TORCH_API bool aliasAnalysisHasSpecialCaseFor(c10::Symbol sym);
} // namespace jit
} // namespace torch
+++ /dev/null
-#include "alias_tracker.h"
-
-#include <torch/csrc/utils/memory.h>
-#include <queue>
-
-namespace torch {
-namespace jit {
-
-// Returns true iff `v` is present in the alias set tracker.
-bool AliasTracker::contains(const Value* v) const {
- return isWildcard(v) || map_.count(v);
-}
-
-bool AliasTracker::mayAlias(const Value* a, const Value* b) const {
- if (isWildcard(a) || isWildcard(b)) {
- return true;
- }
-
- if (!map_.count(a) || !map_.count(b)) {
- return false;
- }
-
- const auto aEl = map_.at(a);
- const auto bEl = map_.at(b);
-
- const auto aMemLoc = aEl->getMemoryLocations();
- const auto bMemLoc = bEl->getMemoryLocations();
-
- // XXX: This could be more efficiently done as a bitwise AND on two bitfields
- // that represent memory location membership. If these comparisons end up
- // being a bottleneck, consider implementing it that way.
- for (const auto aLoc : aMemLoc) {
- for (const auto bLoc : bMemLoc) {
- if (aLoc == bLoc) {
- return true;
- }
- }
- }
- return false;
-}
-
-bool AliasTracker::writesTo(Node* n, const Value* v) const {
- if (isWildcard(v)) {
- return wildcardWriters_.count(n);
- }
-
- if (!map_.count(v) || !writeIndex_.count(n)) {
- return false;
- }
-
- // Can short-circuit if we know this node writes directly to `v`
- if (writeIndex_.at(n).count(v)) {
- return true;
- }
-
- // Otherwise, check if `v` may alias any of written-to values in `n`
- const auto vSet = ValueSet{v};
- return mayAlias(vSet, writeIndex_.at(n));
-}
-
-// Make `v` point at `to`.
-void AliasTracker::makePointerTo(const Value* v, const Value* to) {
- if (v == to) {
- return;
- }
-
- // If `to` is a wildcard, don't insert anything into the graph; wildcards
- // are tracked separately since they have different aliasing rules.
- if (isWildcard(to)) {
- setWildcard(v);
- return;
- }
-
- if (!map_.count(to)) {
- makeFreshValue(to);
- }
-
- if (!map_.count(v)) {
- makeFreshValue(v);
- }
-
- auto vEl = map_.at(v);
- auto toEl = map_.at(to);
-
- vEl->pointsTo.insert(toEl);
- toEl->pointedFrom.insert(vEl);
-}
-
-// Give `v` a fresh alias (i.e. it does not point to any value)
-void AliasTracker::makeFreshValue(const Value* v) {
- auto el = torch::make_unique<Element>();
- el->value = v;
-
- auto rawPtr = el.get();
- elements_.emplace(rawPtr, std::move(el));
- map_.emplace(v, rawPtr);
-}
-
-// Register `v` as a wildcard value.
-void AliasTracker::setWildcard(const Value* v) {
- wildcards_.insert(v);
-}
-
-// is `v` a wildcard?
-bool AliasTracker::isWildcard(const Value* v) const {
- return wildcards_.count(v);
-}
-
-// Register the fact that `n` writes to `v`.
-void AliasTracker::registerWrite(const Value* v, Node* n) {
- numWrites_++;
-
- if (isWildcard(v)) {
- wildcardWriters_.insert(n);
- return;
- }
-
- AT_ASSERT(map_.count(v));
- writeIndex_[n].insert(v);
-}
-
-bool AliasTracker::hasWriters(const Value* v) const {
- if (!map_.count(v)) {
- return false;
- }
-
- if (isWildcard(v)) {
- // If `n` has a wildcard, any write in the graph may write to it.
- // So the only way we know there are no writers is if there are no writes
- // at all.
- return numWrites_ == 0;
- }
-
- if (wildcardWriters_.size() > 0) {
- // A write to the wildcard may be a write to any value.
- return true;
- }
-
- if (isWriteCacheStale_) {
- rebuildWriteCache();
- }
-
- for (const auto loc : map_.at(v)->getMemoryLocations()) {
- if (writeCache_.count(loc)) {
- return true;
- }
- }
-
- return false;
-}
-
-void AliasTracker::rebuildWriteCache() const {
- for (const auto& pr : writeIndex_) {
- const auto& writtenValues = pr.second;
-
- for (const auto value : writtenValues) {
- for (const auto loc : map_.at(value)->getMemoryLocations()) {
- writeCache_.insert(loc);
- }
- }
- }
- isWriteCacheStale_ = false;
-}
-
-void AliasTracker::dump() const {
- std::cout << "\n===2. ALIAS DB===\n";
- for (const auto& ptrPair : elements_) {
- const auto element = ptrPair.first;
- if (element->pointsTo.size() > 0) {
- std::cout << element->value->uniqueName() << " points to: ";
- for (const auto pointedTo : element->pointsTo) {
- std::cout << pointedTo->value->uniqueName() << ", ";
- }
- std::cout << "\n";
- }
- }
-
- std::cout << "\n===3. WILDCARDS===\n";
- for (const auto wildcard : wildcards_) {
- std::cout << wildcard->uniqueName() << ", ";
- }
- std::cout << "\n";
-
- std::cout << "\n===4. Writes===\n";
- for (const auto& pr : writeIndex_) {
- const auto node = pr.first;
- const auto& values = pr.second;
- std::cout << *node;
- std::cout << " ";
- for (const auto value : values) {
- std::cout << value->uniqueName() << ", ";
- }
- std::cout << "\n";
- }
- std::cout << "\n";
-}
-
-std::unordered_set<const AliasTracker::Element*> AliasTracker::Element::
- getMemoryLocations() const {
- if (!cachedMemoryLocations_.empty()) {
- return cachedMemoryLocations_;
- }
-
- // Do a BFS in the `points-to` direction, collecting all memory locations
- std::unordered_set<const Element*> ret;
- this->bfs(
- [&](const Element* el) {
- if (el->pointsTo.empty()) {
- ret.insert(el);
- }
- },
- BfsDirection::POINTS_TO);
-
- cachedMemoryLocations_ = ret;
- return ret;
-}
-
-// Do a breadth-first search over the graph, starting at `this` and
-// traversing in the direction `dir`.`fn` will be run on each element.
-template <typename Fn>
-bool AliasTracker::Element::bfs(Fn fn, BfsDirection dir) const {
- std::queue<const Element*> queue;
- std::unordered_set<const Element*> seen;
-
- queue.push(this);
- while (!queue.empty()) {
- const auto el = queue.front();
- queue.pop();
- seen.insert(el);
-
- fn(el);
-
- switch (dir) {
- case BfsDirection::POINTS_TO: {
- for (auto ptr : el->pointsTo) {
- if (!seen.count(ptr)) {
- queue.push(ptr);
- }
- }
- } break;
-
- case BfsDirection::POINTED_FROM: {
- for (auto ptr : el->pointedFrom) {
- if (!seen.count(ptr)) {
- queue.push(ptr);
- }
- }
- } break;
- }
- }
- return false;
-}
-
-ValueSet AliasTracker::getMemoryLocations(const Value* v) const {
- ValueSet ret;
- if (!map_.count(v)) {
- return ret;
- }
-
- for (const auto el : map_.at(v)->getMemoryLocations()) {
- ret.insert(el->value);
- }
- return ret;
-}
-} // namespace jit
-} // namespace torch
+++ /dev/null
-#pragma once
-
-#include <torch/csrc/jit/ir.h>
-
-namespace torch {
-namespace jit {
-
-// class AliasTracker
-//
-// This class tracks the "A points to B" graph for all values, as well as
-// wildcards and writes. It is used by AliasDb to provide a higher-level API.
-//
-// We maintain a DAG where:
-// - Vertices (called "elements") represent values and
-// other aliasing entities (e.g. like the stuff inside a list)
-// - Edges represent a "points-to" relationship.
-//
-// Leaves in this DAG are entities that don't point to anything, and thus
-// correspond to unique "memory locations".
-//
-// So, by traversing the "points-to" graph to the leaves, you can determine
-// which memory locations an element may point to.
-class AliasTracker {
- public:
- // Returns true iff `v` is present in the alias set tracker.
- bool contains(const Value* v) const;
-
- // Does `n` write to a memory location that `v` may point to?
- bool writesTo(Node* n, const Value* v) const;
-
- // Make `v` point at `to`.
- void makePointerTo(const Value* v, const Value* to);
-
- // Give `v` a fresh alias (i.e. it does not point to any value)
- void makeFreshValue(const Value* v);
-
- // Register `v` as a wildcard value.
- void setWildcard(const Value* v);
-
- // is `v` a wildcard?
- bool isWildcard(const Value* v) const;
-
- // Register the fact that `n` writes to `v`.
- void registerWrite(const Value* v, Node* n);
-
- // Does anything write to the memory locations that `v` may point to?
- bool hasWriters(const Value* v) const;
-
- // Get all nodes that write to a wildcard value.
- const std::unordered_set<Node*>& getWildcardWriters() const {
- return wildcardWriters_;
- }
-
- // Get the values that represent the memory locations that `v` may point to.
- // Return values are guaranteed to be "fresh" tensors--they do not point to
- // anything else.
- ValueSet getMemoryLocations(const Value* v) const;
-
- // Do `a` and `b` potentially share a memory location?
- bool mayAlias(const Value* a, const Value* b) const;
-
- // Do any values in group `a` potentially share a memory location with any
- // value in group `b`?
- //
- // This is written so that either of the inputs could be a multiset
- template <typename T, typename U>
- bool mayAlias(const T& a, const U& b) const {
- if (a.empty() || b.empty()) {
- return false;
- }
-
- // Record all memory locations from group `a`
- std::unordered_set<const Element*> memoryLocations;
- for (auto it = a.cbegin(); it != a.cend();) {
- const auto value = *it;
- if (isWildcard(value)) {
- return true;
- }
-
- if (map_.count(value)) {
- for (const auto loc : map_.at(value)->getMemoryLocations()) {
- memoryLocations.insert(loc);
- }
- }
-
- const auto cnt = a.count(*it);
- std::advance(it, cnt);
- }
-
- // If any of group `b`s memory locations overlap, return true.
- for (auto it = b.cbegin(); it != b.cend();) {
- const auto value = *it;
- if (isWildcard(value)) {
- return true;
- }
-
- if (map_.count(value)) {
- for (const auto loc : map_.at(value)->getMemoryLocations()) {
- if (memoryLocations.count(loc)) {
- return true;
- }
- }
- }
-
- const auto cnt = b.count(*it);
- std::advance(it, cnt);
- }
- // No overlap, so group `a` and `b` do not share a memory location
- return false;
- }
-
- void dump() const;
-
- private:
- enum class BfsDirection {
- POINTS_TO,
- POINTED_FROM,
- };
- // `Element` represents the vertex in the points-to graph. It has a 1:1
- // relationship with IR `Value`s.
- struct Element {
- const Value* value = nullptr;
- // All elements that this element *may* point to. It's possible to have
- // multiple elements that you might point to due to control flow/complex ops
- std::unordered_set<Element*> pointsTo;
- // Backreference for points-to.
- std::unordered_set<Element*> pointedFrom;
-
- std::unordered_set<const Element*> getMemoryLocations() const;
- // We do path compression to make repeated memory location queries faster.
- // An empty cache means it is invalidated (it can never be empty otherwise,
- // since every element must point to at least one memory location).
- mutable std::unordered_set<const Element*> cachedMemoryLocations_;
-
- // Do a breadth-first search over the graph, starting at `this` and
- // traversing in the direction `dir`.`fn` will be run on each element.
- template <typename Fn>
- bool bfs(Fn fn, BfsDirection dir) const;
- };
-
- // Structure that owns all the element pointers. It's a map of
- // raw pointer -> unique_ptr to facilitate easy queries
- std::unordered_map<Element*, std::unique_ptr<Element>> elements_;
- // Index to look up whatever element corresponds to that value.
- std::unordered_map<const Value*, Element*> map_;
- // All values that may point to a wildcard value.
- ValueSet wildcards_;
- // All nodes that write to a wildcard
- std::unordered_set<Node*> wildcardWriters_;
- size_t numWrites_ = 0;
-
- std::unordered_map<Node*, ValueSet> writeIndex_;
- mutable std::unordered_set<const Element*> writeCache_;
- mutable bool isWriteCacheStale_ = true;
- void rebuildWriteCache() const;
-};
-
-} // namespace jit
-} // namespace torch
--- /dev/null
+#include "memory_dag.h"
+
+#include <torch/csrc/utils/memory.h>
+#include <queue>
+
+namespace torch {
+namespace jit {
+
+bool MemoryDAG::mayAlias(Element* a, Element* b) const {
+ return mayAliasImpl(a, b);
+}
+
+bool MemoryDAG::mayAlias(const Element* a, const Element* b) const {
+ return mayAliasImpl(a, b);
+}
+
+bool MemoryDAG::mayAliasImpl(const Element* a, const Element* b) const {
+ const auto aMemLoc = a->getMemoryLocations();
+ const auto bMemLoc = b->getMemoryLocations();
+
+ // XXX: This could be more efficiently done as a bitwise AND on two bitfields
+ // that represent memory location membership. If these comparisons end up
+ // being a bottleneck, consider implementing it that way.
+ for (const auto aLoc : aMemLoc) {
+ for (const auto bLoc : bMemLoc) {
+ if (aLoc == bLoc) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+// Make `v` point at `to`.
+void MemoryDAG::makePointerTo(Element* from, Element* to) {
+ from->pointsTo.insert(to);
+ to->pointedFrom.insert(from);
+}
+
+// Give `v` a fresh alias (i.e. it does not point to any value)
+Element* MemoryDAG::makeFreshValue(const Value* v) {
+ auto el = torch::make_unique<Element>();
+ el->value = v;
+
+ auto rawPtr = el.get();
+ elements_.emplace(rawPtr, std::move(el));
+ return rawPtr;
+}
+
+std::unordered_set<const Element*> Element::getMemoryLocations() const {
+ if (!cachedMemoryLocations_.empty()) {
+ return cachedMemoryLocations_;
+ }
+
+ // Do a BFS in the `points-to` direction, collecting all memory locations
+ std::unordered_set<const Element*> ret;
+ this->bfs(
+ [&](const Element* el) {
+ if (el->pointsTo.empty()) {
+ ret.insert(el);
+ }
+ },
+ BfsDirection::POINTS_TO);
+
+ cachedMemoryLocations_ = ret;
+ return ret;
+}
+
+// Do a breadth-first search over the graph, starting at `this` and
+// traversing in the direction `dir`.`fn` will be run on each element.
+template <typename Fn>
+bool Element::bfs(Fn fn, BfsDirection dir) const {
+ std::queue<const Element*> queue;
+ std::unordered_set<const Element*> seen;
+
+ queue.push(this);
+ while (!queue.empty()) {
+ const auto el = queue.front();
+ queue.pop();
+ seen.insert(el);
+
+ fn(el);
+
+ switch (dir) {
+ case BfsDirection::POINTS_TO: {
+ for (auto ptr : el->pointsTo) {
+ if (!seen.count(ptr)) {
+ queue.push(ptr);
+ }
+ }
+ } break;
+
+ case BfsDirection::POINTED_FROM: {
+ for (auto ptr : el->pointedFrom) {
+ if (!seen.count(ptr)) {
+ queue.push(ptr);
+ }
+ }
+ } break;
+ }
+ }
+ return false;
+}
+} // namespace jit
+} // namespace torch
--- /dev/null
+#pragma once
+
+#include <unordered_set>
+#include <unordered_map>
+#include <memory>
+
+namespace torch {
+namespace jit {
+
+struct Element;
+struct Value;
+
+// class MemoryDAG
+//
+// This class tracks the "A points to B" graph for all values. It is used by
+// AliasDb to provide a higher-level API.
+//
+// We maintain a DAG where:
+// - Vertices (called "elements") represent values and
+// other aliasing entities (e.g. like the stuff inside a list)
+// - Edges represent a "points-to" relationship.
+//
+// Leaves in this DAG are entities that don't point to anything, and thus
+// correspond to unique "memory locations".
+//
+// So, by traversing the "points-to" graph to the leaves, you can determine
+// which memory locations an element may point to.
+class MemoryDAG {
+ public:
+ // Make `from` point at `to`.
+ void makePointerTo(Element* from, Element* to);
+
+ // Make a fresh element (i.e. an element that doesn't point to anything) and
+ // return it.
+ Element* makeFreshValue(const Value* v);
+
+ // Do `a` and `b` potentially share a memory location?
+ bool mayAlias(const Element* a, const Element* b) const;
+ bool mayAlias(Element* a, Element* b) const;
+
+ // Do any values in group `a` potentially share a memory location with any
+ // value in group `b`?
+ //
+ // This is written so that either of the inputs could be a multiset
+ template <typename T, typename U>
+ bool mayAlias(const T& a, const U& b) const {
+ if (a.empty() || b.empty()) {
+ return false;
+ }
+
+ // Record all memory locations from group `a`
+ std::unordered_set<const Element*> memoryLocations;
+ for (auto it = a.cbegin(); it != a.cend();) {
+ const auto element = *it;
+
+ for (const auto loc : element->getMemoryLocations()) {
+ memoryLocations.insert(loc);
+ }
+
+ const auto cnt = a.count(*it);
+ std::advance(it, cnt);
+ }
+
+ // If any of group `b`s memory locations overlap, return true.
+ for (auto it = b.cbegin(); it != b.cend();) {
+ const auto element = *it;
+
+ for (const auto loc : element->getMemoryLocations()) {
+ if (memoryLocations.count(loc)) {
+ return true;
+ }
+ }
+
+ const auto cnt = b.count(*it);
+ std::advance(it, cnt);
+ }
+ // No overlap, so group `a` and `b` do not share a memory location
+ return false;
+ }
+
+ private:
+ bool mayAliasImpl(const Element* a, const Element* b) const;
+ // Structure that owns all the element pointers. It's a map of
+ // raw pointer -> unique_ptr to facilitate easy queries
+ std::unordered_map<Element*, std::unique_ptr<Element>> elements_;
+};
+
+enum class BfsDirection {
+ POINTS_TO,
+ POINTED_FROM,
+};
+
+// `Element` represents the vertex in the points-to graph. It represents
+// anything that could have an aliasing relationship, mostly IR `Value`s, but
+// also the "inside of a list", or wildcards.
+struct Element {
+ // The value that this element corresponds to. May be null if this element
+ // doesn't represent a first-class value.
+ const Value* value = nullptr;
+
+ // All elements that this element *may* point to. It's possible to have
+ // multiple elements that you might point to due to control flow/complex ops
+ std::unordered_set<Element*> pointsTo;
+ // Backreference for points-to.
+ std::unordered_set<Element*> pointedFrom;
+
+ // Return the unique memory locations that `Element` might represent.
+ std::unordered_set<const Element*> getMemoryLocations() const;
+ // We do path compression to make repeated memory location queries faster.
+ // An empty cache means it is invalidated (it can never be empty otherwise,
+ // since every element must point to at least one memory location).
+ mutable std::unordered_set<const Element*> cachedMemoryLocations_;
+
+ // Do a breadth-first search over the graph, starting at `this` and
+ // traversing in the direction `dir`.`fn` will be run on each element.
+ template <typename Fn>
+ bool bfs(Fn fn, BfsDirection dir) const;
+};
+
+} // namespace jit
+} // namespace torch