alias_analysis refactor (#17511)
authorMichael Suo <suo@fb.com>
Thu, 28 Feb 2019 19:28:16 +0000 (11:28 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 28 Feb 2019 20:00:36 +0000 (12:00 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17511

AliasTracker was doing bookkeeping for three concepts: the points-to graph,
writes, and wildcards.

This PR makes AliasTracker's job clearer: it keeps track of the points-to
graph. Thus it has been renamed MemoryDAG. Write and wildcard information were
pulled back into AliasDb as part of this—I may decide to pull them into their
own little modules since I don't want the alias analysis stuff to get too
bloated.

This refactor is necessary because we want to start tracking information for
aliasing elements that _aren't_ first-class IR Values (e.g. the "stuff" inside
a list). So MemoryDAG can't know too much about Values

Reviewed By: houseroad

Differential Revision: D14231251

fbshipit-source-id: 6cd98ae6fced8d6c1522c2454da77c3c1b2b0504

test/cpp/jit/gtest.cpp
test/cpp/jit/no-gtest.cpp
test/cpp/jit/test_alias_analysis.h
tools/build_variables.py
torch/CMakeLists.txt
torch/csrc/jit/passes/alias_analysis.cpp
torch/csrc/jit/passes/alias_analysis.h
torch/csrc/jit/passes/utils/alias_tracker.cpp [deleted file]
torch/csrc/jit/passes/utils/alias_tracker.h [deleted file]
torch/csrc/jit/passes/utils/memory_dag.cpp [new file with mode: 0644]
torch/csrc/jit/passes/utils/memory_dag.h [new file with mode: 0644]

index 65bc780..8f6c2a8 100644 (file)
@@ -34,7 +34,9 @@ JIT_TEST(TopologicalIndex)
 JIT_TEST(TopologicalMove)
 JIT_TEST(SubgraphUtils)
 JIT_TEST(AliasAnalysis)
-JIT_TEST(AliasTracker)
+JIT_TEST(WriteTracking)
+JIT_TEST(Wildcards)
+JIT_TEST(MemoryDAG)
 JIT_TEST(IRParser)
 
 JIT_TEST(NetDefConverter)
index 8072ef8..00b6892 100644 (file)
@@ -39,7 +39,9 @@ std::string runJITCPPTests() {
   testATenNativeBatchNorm();
   testRegisterFusionCachesKernel();
   testAliasAnalysis();
-  testAliasTracker();
+  testWriteTracking();
+  testWildcards();
+  testMemoryDAG();
   testNetDefConverter(out);
   testIRParser(out);
   return out.str();
index 840d3ac..c1d2b57 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
 #include "test/cpp/jit/test_base.h"
+#include "torch/csrc/jit/custom_operator.h"
 #include "torch/csrc/jit/passes/alias_analysis.h"
 #include "torch/csrc/jit/script/compiler.h"
 #include "torch/csrc/utils/memory.h"
@@ -452,41 +453,110 @@ void testAliasAnalysis() {
   }
 }
 
-void testAliasTracker() {
-  auto graph = std::make_shared<Graph>();
-  const Value* a = graph->addInput();
-  const Value* b = graph->addInput();
-  const Value* c = graph->addInput();
-  const Value* d = graph->addInput();
-  const Value* e = graph->addInput();
-  const Value* f = graph->addInput();
-  const Value* g = graph->addInput();
-  const Value* wc = graph->addInput();
-
+void testWriteTracking() {
+  RegisterOperators reg({createOperator(
+      "foo::creates_alias(Tensor(a) x) -> Tensor(a)",
+      [](at::Tensor a) { return a; })});
+  const auto creates_alias = Symbol::fromQualString("foo::creates_alias");
+  const auto returns_wildcard = Symbol::fromQualString("foo::returns_wildcard");
   {
-    // test contains()
-    AliasTracker t;
-    t.makeFreshValue(a);
-    ASSERT_TRUE(t.contains(a));
-    ASSERT_FALSE(t.contains(b));
+    auto graph = std::make_shared<Graph>();
+    auto a = graph->addInput();
+    auto b = graph->addInput();
+
+    // aten::add(%b, %b)
+    // aten::add_(%a, %b)
+    // foo::creates_alias(%a)
+    auto pureNode = graph->insert(aten::add, {b, b})->node();
+    auto writingNode = graph->insert(aten::add_, {a, b})->node();
+    auto node3 = graph->insert(creates_alias, {a})->node();
+    auto aAlias = node3->output();
+
+    graph->lint();
+
+    AliasDb aliasDb(graph);
+    ASSERT_TRUE(aliasDb.mayAlias(aAlias, a));
+    ASSERT_TRUE(aliasDb.mayAlias(a, b));
+    ASSERT_FALSE(
+        aliasDb.writesToAlias(pureNode, std::unordered_set<const Value*>{a}));
+    ASSERT_FALSE(
+        aliasDb.writesToAlias(pureNode, std::unordered_set<const Value*>{b}));
+    ASSERT_TRUE(aliasDb.writesToAlias(
+        writingNode, std::unordered_set<const Value*>{a}));
+    ASSERT_TRUE(aliasDb.writesToAlias(
+        writingNode, std::unordered_set<const Value*>{a, b}));
+    ASSERT_TRUE(aliasDb.writesToAlias(
+        writingNode, std::unordered_set<const Value*>{aAlias}));
   }
+}
+
+void testWildcards() {
+  RegisterOperators reg({createOperator(
+                             "foo::returns_wildcard(Tensor a) -> Tensor(*)",
+                             [](at::Tensor a) { return a; }),
+                         createOperator(
+                             "foo::writes(Tensor(z!) a) -> Tensor(a)",
+                             [](at::Tensor a) { return a; })});
+  const auto returns_wildcard = Symbol::fromQualString("foo::returns_wildcard");
+  const auto writes = Symbol::fromQualString("foo::writes");
+
+  auto graph = std::make_shared<Graph>();
+  const auto a = graph->addInput();
+
+  const auto constant = graph->insertConstant(1);
+  const auto fresh = graph->insert(aten::rand, {constant});
+  const auto fresh2 = graph->insert(aten::rand, {constant});
+  const auto wildcard = graph->insert(returns_wildcard, {fresh});
+  const auto wildcardWrite = graph->insert(writes, {wildcard})->node();
+
+  graph->lint();
+  AliasDb aliasDb(graph);
+
+  ASSERT_FALSE(aliasDb.mayAlias(a, fresh));
+  ASSERT_TRUE(aliasDb.mayAlias(wildcard, fresh));
+  ASSERT_TRUE(aliasDb.mayAlias(wildcard, a));
+  ASSERT_FALSE(aliasDb.mayAlias(
+      std::unordered_set<const Value*>({wildcard}),
+      std::unordered_set<const Value*>()));
+
+  // Test writes to wildcards
+  ASSERT_TRUE(aliasDb.writesToAlias(
+      wildcardWrite, std::unordered_set<const Value*>{fresh}));
+  ASSERT_TRUE(aliasDb.writesToAlias(
+      wildcardWrite, std::unordered_set<const Value*>{fresh2}));
+  ASSERT_TRUE(aliasDb.writesToAlias(
+      wildcardWrite, std::unordered_set<const Value*>{a}));
+}
+
+void testMemoryDAG() {
+  auto graph = std::make_shared<Graph>();
+  const Value* aValue = graph->addInput();
+  const Value* bValue = graph->addInput();
+  const Value* cValue = graph->addInput();
+  const Value* dValue = graph->addInput();
+  const Value* eValue = graph->addInput();
+  const Value* fValue = graph->addInput();
+  const Value* gValue = graph->addInput();
+
   {
     // a <- b <- c
     //      b <- d
     // a <- e
     // f <- e
     // g is by itself
-    // wc is a wildcard value
-    AliasTracker t;
-    t.makeFreshValue(a);
-    t.makeFreshValue(f);
-    t.makeFreshValue(g);
+    MemoryDAG t;
+    auto a = t.makeFreshValue(aValue);
+    auto b = t.makeFreshValue(bValue);
+    auto c = t.makeFreshValue(cValue);
+    auto d = t.makeFreshValue(dValue);
+    auto e = t.makeFreshValue(eValue);
+    auto f = t.makeFreshValue(fValue);
+    auto g = t.makeFreshValue(gValue);
     t.makePointerTo(b, a);
     t.makePointerTo(c, b);
     t.makePointerTo(d, b);
     t.makePointerTo(e, a);
     t.makePointerTo(e, f);
-    t.setWildcard(wc);
 
     /**
      * Test mayAlias()
@@ -506,59 +576,15 @@ void testAliasTracker() {
     // But a and f don't alias
     ASSERT_FALSE(t.mayAlias(a, f));
 
-    // Wildcards should alias everything
-    ASSERT_TRUE(t.mayAlias(wc, a));
-    ASSERT_TRUE(t.mayAlias(wc, b));
-    ASSERT_TRUE(t.mayAlias(wc, f));
-    ASSERT_TRUE(t.mayAlias(wc, g));
-
     /**
      * Test mayAlias() set interface
      */
-    std::multiset<const Value*> foo{c, c, d};
-    std::multiset<const Value*> bar{e, f};
-    std::unordered_set<const Value*> baz{f, g};
-    std::set<const Value*> containsWildcard{wc};
+    std::multiset<const Element*> foo{c, c, d};
+    std::multiset<const Element*> bar{e, f};
+    std::unordered_set<const Element*> baz{f, g};
     ASSERT_TRUE(t.mayAlias(foo, bar));
     ASSERT_TRUE(t.mayAlias(bar, baz));
     ASSERT_FALSE(t.mayAlias(foo, baz));
-    // wildcard stuff aliases everything
-    ASSERT_TRUE(t.mayAlias(containsWildcard, foo));
-    ASSERT_TRUE(t.mayAlias(containsWildcard, bar));
-    ASSERT_TRUE(t.mayAlias(containsWildcard, baz));
-
-    /**
-     * Test writer tracking
-     */
-    auto n1 = graph->appendNode(graph->create(prim::Undefined));
-    auto n2 = graph->appendNode(graph->create(prim::Undefined));
-    auto n3 = graph->appendNode(graph->create(prim::Undefined));
-    t.registerWrite(a, n1);
-    t.registerWrite(f, n2);
-    // We should report those writes accurately
-    ASSERT_TRUE(t.writesTo(n1, a));
-    ASSERT_TRUE(t.writesTo(n2, f));
-    ASSERT_FALSE(t.writesTo(n1, f));
-    ASSERT_FALSE(t.writesTo(n2, a));
-    // We should correctly report writes to aliases as well
-    ASSERT_TRUE(t.writesTo(n1, c));
-
-    // Check hasWriters()
-    ASSERT_TRUE(t.hasWriters(a));
-    // Aliases of written-to values should have writers
-    ASSERT_TRUE(t.hasWriters(b));
-    ASSERT_TRUE(t.hasWriters(d));
-    ASSERT_TRUE(t.hasWriters(e));
-    // Unique values not registered should be unaffected
-    ASSERT_FALSE(t.hasWriters(g));
-
-    // create a write to the wildcard set
-    t.registerWrite(wc, n3);
-    // Now everything may be written to
-    ASSERT_TRUE(t.hasWriters(g));
-    const auto& wildcardWriters = t.getWildcardWriters();
-    ASSERT_EQ(wildcardWriters.size(), 1);
-    ASSERT_EQ(*wildcardWriters.begin(), n3);
   }
 }
 } // namespace jit
index b73fac2..46e59cc 100644 (file)
@@ -87,7 +87,7 @@ libtorch_sources = [
     "torch/csrc/jit/passes/shape_analysis.cpp",
     "torch/csrc/jit/passes/specialize_undef.cpp",
     "torch/csrc/jit/passes/utils/subgraph_utils.cpp",
-    "torch/csrc/jit/passes/utils/alias_tracker.cpp",
+    "torch/csrc/jit/passes/utils/memory_dag.cpp",
     "torch/csrc/jit/register_prim_ops.cpp",
     "torch/csrc/jit/register_special_ops.cpp",
     "torch/csrc/jit/scope.cpp",
index 62be18c..8e80e60 100644 (file)
@@ -165,7 +165,7 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/passes/python_print.cpp
   ${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp
   ${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp
-  ${TORCH_SRC_DIR}/csrc/jit/passes/utils/alias_tracker.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/passes/utils/memory_dag.cpp
   ${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp
   ${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp
   ${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp
index 850862d..11c4539 100644 (file)
@@ -5,8 +5,8 @@
 
 namespace torch {
 namespace jit {
-namespace {
-bool shouldAnnotate(const TypePtr& type) {
+
+bool AliasDb::shouldAnnotate(const TypePtr& type) {
   return type->isSubtypeOf(TensorType::get()) ||
       type->kind() == TypeKind::ListType ||
       type->kind() == TypeKind::TupleType ||
@@ -19,56 +19,103 @@ bool shouldAnnotate(const TypePtr& type) {
 
 // We only need to annotate values that either are mutable or could contain
 // mutable types.
-bool shouldAnnotate(const Value* v) {
+bool AliasDb::shouldAnnotate(const Value* v) {
   return shouldAnnotate(v->type());
 }
-} // namespace
 
 AliasDb::~AliasDb() = default;
 
 AliasDb::AliasDb(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {
-  aliasTracker_ = torch::make_unique<AliasTracker>();
+  memoryDAG_ = torch::make_unique<MemoryDAG>();
   analyze(graph_);
 }
 
 // Does `n` use or write to any wildcard aliases?
 bool AliasDb::hasWildcard(const Node* n) const {
   for (const auto input : n->inputs()) {
-    if (aliasTracker_->isWildcard(input)) {
+    if (isWildcard(input)) {
       return true;
     }
   }
-
   for (const auto output : n->outputs()) {
-    if (aliasTracker_->isWildcard(output)) {
+    if (isWildcard(output)) {
       return true;
     }
   }
   return false;
 }
 
+bool AliasDb::isWildcard(const Value* v) const {
+  return wildcards_.count(v);
+}
+
 bool AliasDb::writesTo(Node* n, const Value* v) const {
   if (!shouldAnnotate(v)) {
     // This is a primitive type
     return false;
   }
-  return aliasTracker_->writesTo(n, v);
+  if (isWildcard(v)) {
+    return wildcardWriters_.count(n);
+  }
+
+  if (!elementMap_.count(v) || !writeIndex_.count(n)) {
+    return false;
+  }
+
+  // Can short-circuit if we know this node writes directly to `v`
+  if (writeIndex_.at(n).count(v)) {
+    return true;
+  }
+
+  // Otherwise, check if `v` may alias any of written-to values in `n`
+  const auto vSet = ValueSet{v};
+  return mayAlias(vSet, writeIndex_.at(n));
 }
 
 bool AliasDb::hasWriters(const Node* n) const {
   for (const auto input : n->inputs()) {
-    if (aliasTracker_->hasWriters(input)) {
+    if (hasWriters(input)) {
       return true;
     }
   }
   for (const auto output : n->outputs()) {
-    if (aliasTracker_->hasWriters(output)) {
+    if (hasWriters(output)) {
       return true;
     }
   }
   return false;
 }
 
+bool AliasDb::hasWriters(const Value* v) const {
+  if (isWildcard(v)) {
+    // If `n` has a wildcard, any write in the graph may write to it.
+    // So the only way we know there are no writers is if there are no writes
+    // at all.
+    return numWrites_ == 0;
+  }
+
+  if (!elementMap_.count(v)) {
+    return false;
+  }
+
+  if (wildcardWriters_.size() > 0) {
+    // A write to the wildcard may be a write to any value.
+    return true;
+  }
+
+  if (isWriteCacheStale_) {
+    rebuildWriteCache();
+  }
+
+  for (const auto loc : elementMap_.at(v)->getMemoryLocations()) {
+    if (writeCache_.count(loc)) {
+      return true;
+    }
+  }
+
+  return false;
+}
+
 bool AliasDb::hasWrites(Node* n) const {
   for (const auto input : n->inputs()) {
     if (writesTo(n, input)) {
@@ -102,16 +149,11 @@ bool AliasDb::writesToInputAlias(Node* n) const {
         graph_->inputs().cbegin(),
         graph_->inputs().cend(),
         [&](const Value* graphInput) {
-          return shouldAnnotate(graphInput) &&
-              aliasTracker_->mayAlias(graphInput, v);
+          return shouldAnnotate(graphInput) && mayAlias(graphInput, v);
         });
   });
 }
 
-bool AliasDb::mayAlias(const ValueSet& a, const ValueSet& b) const {
-  return aliasTracker_->mayAlias(a, b);
-}
-
 void AliasDb::getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks) const {
   for (auto node : b->nodes()) {
     getWritesImpl(node, ret, recurseBlocks);
@@ -144,7 +186,6 @@ ValueSet AliasDb::getWrites(Block* b) const {
   return writes;
 }
 
-
 // Does `n` write to an alias of one of the values in `vs`?
 bool AliasDb::writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks)
     const {
@@ -186,20 +227,47 @@ void AliasDb::dump() const {
   std::cout << "\n===1. GRAPH===\n";
   graph_->dump();
 
-  aliasTracker_->dump();
+  std::cout << "\n===2. ALIAS DB===\n";
+  for (const auto& ptrPair : elementMap_) {
+    const auto element = ptrPair.second;
+    if (element->pointsTo.size() > 0) {
+      std::cout << element->value->uniqueName() << " points to: ";
+      for (const auto pointedTo : element->pointsTo) {
+        std::cout << pointedTo->value->uniqueName() << ", ";
+      }
+      std::cout << "\n";
+    }
+  }
+
+  std::cout << "\n===3. WILDCARDS===\n";
+  for (const auto wildcard : wildcards_) {
+    std::cout << wildcard->uniqueName() << ", ";
+  }
+  std::cout << "\n";
+
+  std::cout << "\n===4. Writes===\n";
+  for (const auto& pr : writeIndex_) {
+    const auto node = pr.first;
+    const auto& values = pr.second;
+    std::cout << *node;
+    std::cout << "  ";
+    for (const auto value : values) {
+      std::cout << value->uniqueName() << ", ";
+    }
+    std::cout << "\n";
+  }
+  std::cout << "\n";
 }
 
-// TODO: need to create a dummy "graph input alias" value in setTracker for all
+// TODO: need to create a dummy "graph input alias" value in MemoryDAG for all
 // inputs of the same type to point to. Currently they all point to the first
 // element, which is technically wrong.
-static void makeAllAlias(
-    const std::vector<Value*> values,
-    AliasTracker& setTracker) {
+void AliasDb::makeAllAlias(const std::vector<Value*>& values) {
   if (values.size() > 0) {
-    setTracker.makeFreshValue(values[0]);
+    giveFreshAlias(values[0]);
   }
   for (const auto value : values) {
-    setTracker.makePointerTo(value, values[0]);
+    makePointerTo(value, values[0]);
   }
 }
 
@@ -247,18 +315,18 @@ void AliasDb::analyze(const std::shared_ptr<Graph>& graph) {
 
   // 2. Make all partitions alias each other
   for (const auto& pr : listTypes) {
-    makeAllAlias(pr.second, *aliasTracker_);
+    makeAllAlias(pr.second);
   }
   for (const auto& pr : tupleTypes) {
-    makeAllAlias(pr.second, *aliasTracker_);
+    makeAllAlias(pr.second);
   }
   for (const auto& pr : dictTypes) {
-    makeAllAlias(pr.second, *aliasTracker_);
+    makeAllAlias(pr.second);
   }
   for (const auto& pr : userTypes) {
-    makeAllAlias(pr.second, *aliasTracker_);
+    makeAllAlias(pr.second);
   }
-  makeAllAlias(tensors, *aliasTracker_);
+  makeAllAlias(tensors);
 
   analyze(graph->block());
 }
@@ -392,7 +460,7 @@ void AliasDb::analyzeImpl(Node* node) {
 
     // Record writes
     if (formal->isWrite()) {
-      aliasTracker_->registerWrite(actualValue, node);
+      registerWrite(actualValue, node);
     }
   }
 
@@ -415,7 +483,7 @@ void AliasDb::analyzeImpl(Node* node) {
     AT_ASSERT(formal->containedTypes().size() == 0);
 
     if (formal->isWildcard()) {
-      aliasTracker_->setWildcard(actual);
+      setWildcard(actual);
       continue;
     }
 
@@ -437,15 +505,27 @@ void AliasDb::analyzeImpl(Node* node) {
       }
 
       auto toAlias = formalToActual.at(formalAlias);
-      makeAliasOf(actual, toAlias);
+      makePointerTo(actual, toAlias);
     }
 
     // Record writes
     if (formal->isWrite()) {
-      aliasTracker_->registerWrite(actual, node);
+      registerWrite(actual, node);
     }
   }
 }
+// Register the fact that `n` writes to `v`.
+void AliasDb::registerWrite(const Value* v, Node* n) {
+  numWrites_++;
+
+  if (isWildcard(v)) {
+    wildcardWriters_.insert(n);
+    return;
+  }
+
+  AT_ASSERT(elementMap_.count(v));
+  writeIndex_[n].insert(v);
+}
 
 void AliasDb::analyzeIf(Node* node) {
   // For if statements, the alias set of an output is the union of the
@@ -461,8 +541,8 @@ void AliasDb::analyzeIf(Node* node) {
     const auto trueOutput = trueBlock->outputs().at(i);
     const auto falseOutput = falseBlock->outputs().at(i);
 
-    makeAliasOf(nodeOutput, trueOutput);
-    makeAliasOf(nodeOutput, falseOutput);
+    makePointerTo(nodeOutput, trueOutput);
+    makePointerTo(nodeOutput, falseOutput);
   }
 }
 
@@ -501,7 +581,7 @@ void AliasDb::analyzeSubgraph(Node* node) {
   // subgraph block.
   AT_ASSERT(subgraphBlock->outputs().size() >= node->outputs().size());
   for (size_t i = 0; i < node->outputs().size(); i++) {
-    makeAliasOf(node->outputs()[i], subgraphBlock->outputs()[i]);
+    makePointerTo(node->outputs()[i], subgraphBlock->outputs()[i]);
   }
 }
 
@@ -517,7 +597,7 @@ void AliasDb::analyzeCreator(Node* node) {
 void AliasDb::analyzeExtractor(Node* node) {
   for (const auto output : node->outputs()) {
     if (shouldAnnotate(output)) {
-      aliasTracker_->setWildcard(output);
+      setWildcard(output);
     }
   }
 }
@@ -525,7 +605,7 @@ void AliasDb::analyzeExtractor(Node* node) {
 // For torch.chunk(), all returned tensors may alias the input tensor
 void AliasDb::analyzeChunk(Node* node) {
   for (auto output : node->outputs()) {
-    makeAliasOf(output, node->input());
+    makePointerTo(output, node->input());
   }
 }
 
@@ -550,14 +630,14 @@ void AliasDb::analyzeWait(Node* node) {
   const auto fut = node->input();
   AT_ASSERT(fut->type()->kind() == TypeKind::FutureType);
 
-  if (aliasTracker_->isWildcard(fut)) {
+  if (isWildcard(fut)) {
     for (const auto output : node->outputs()) {
-      aliasTracker_->setWildcard(output);
+      setWildcard(output);
     }
     return;
   }
 
-  const auto originFuts = aliasTracker_->getMemoryLocations(fut);
+  const auto originFuts = getMemoryLocations(fut);
   for (const auto originFut : originFuts) {
     const auto subgraphNode = originFut->node();
 
@@ -590,7 +670,7 @@ void AliasDb::analyzeWait(Node* node) {
     // since the writes may or may not have been executed yet. But we'll let
     // users do that and shoot themselves in the foot for now.
     for (const auto write : subgraphWrites) {
-      aliasTracker_->registerWrite(write, node);
+      registerWrite(write, node);
     }
   }
 }
@@ -599,7 +679,7 @@ void AliasDb::analyzeWait(Node* node) {
 void AliasDb::analyzeSetAttr(Node* node) {
   const auto self = node->inputs().at(0);
   AT_ASSERT(self->type()->kind() == TypeKind::UserType);
-  aliasTracker_->registerWrite(self, node);
+  registerWrite(self, node);
 }
 
 // BroadcastingChunk: all inputs are broadcasted, and then individually chunked.
@@ -613,25 +693,57 @@ void AliasDb::analyzeBroadcastingChunk(Node* node) {
     // inputs[i] produces chunks outputs[i * nchunks + k] for k in [0..nchunks)
     auto output_begin = outputs.begin() + index * nchunks;
     for (auto it = output_begin; it != output_begin + nchunks; ++it) {
-      makeAliasOf(*it, inputs.at(index));
+      makePointerTo(*it, inputs.at(index));
     }
   }
 }
 
 // Register the fact that `value` is a pointer to `to`
-void AliasDb::makeAliasOf(const Value* value, const Value* to) {
-  if (!shouldAnnotate(value)) {
+void AliasDb::makePointerTo(const Value* from, const Value* to) {
+  if (!shouldAnnotate(from)) {
     AT_ASSERT(!shouldAnnotate(to));
     return;
   }
-  aliasTracker_->makePointerTo(value, to);
+
+  if (from == to) {
+    return;
+  }
+
+  // If either value is a wildcard, don't insert anything into the graph;
+  // wildcards are tracked separately since they have different aliasing rules.
+  if (isWildcard(to) || isWildcard(from)) {
+    setWildcard(from);
+    return;
+  }
+
+  if (!isTracked(from)) {
+    giveFreshAlias(from);
+  }
+  if (!isTracked(to)) {
+    giveFreshAlias(to);
+  }
+  auto fromEl = elementMap_.at(from);
+  auto toEl = elementMap_.at(to);
+  memoryDAG_->makePointerTo(fromEl, toEl);
+}
+
+bool AliasDb::mayAlias(const Value* a, const Value* b) const {
+  if (isWildcard(a) || isWildcard(b)) {
+    return true;
+  }
+
+  if (!elementMap_.count(a) || !elementMap_.count(b)) {
+    return false;
+  }
+
+  return memoryDAG_->mayAlias(elementMap_.at(a), elementMap_.at(b));
 }
 
 // Make each value in the `from` list point to its partner in the `to` list
 void AliasDb::mapAliases(at::ArrayRef<Value*> from, at::ArrayRef<Value*> to) {
   AT_ASSERT(to.size() == from.size());
   for (size_t i = 0; i < to.size(); i++) {
-    makeAliasOf(from[i], to[i]);
+    makePointerTo(from[i], to[i]);
   }
 }
 
@@ -640,13 +752,17 @@ void AliasDb::giveFreshAlias(const Value* value) {
     return;
   }
 
-  if (aliasTracker_->contains(value)) {
+  if (isTracked(value)) {
     // Inside a loop, we may have given a fresh alias to this value already, so
     // skip
     return;
   }
 
-  aliasTracker_->makeFreshValue(value);
+  elementMap_[value] = memoryDAG_->makeFreshValue(value);
+}
+
+bool AliasDb::isTracked(const Value* v) const {
+  return isWildcard(v) || elementMap_.count(v);
 }
 
 bool AliasDb::moveAfterTopologicallyValid(Node* n, Node* movePoint) {
@@ -762,13 +878,13 @@ class AliasDb::WorkingSet {
     // 2. Handle regular mutable dependencies
     // Check that `n` does not write to anything used by the working set
     const auto nWrites = aliasDb_.getWrites(n, /*recurseBlocks=*/true);
-    if (aliasDb_.aliasTracker_->mayAlias(nWrites, reads_)) {
+    if (aliasDb_.mayAlias(nWrites, reads_)) {
       return true;
     }
 
     // Check that the working set doesn't write to anything that `n` uses.
     const auto nReads = aliasDb_.getReads(n, /*recurseBlocks=*/true);
-    if (aliasDb_.aliasTracker_->mayAlias(writes_, nReads)) {
+    if (aliasDb_.mayAlias(writes_, nReads)) {
       return true;
     }
     return false;
@@ -1054,5 +1170,34 @@ TORCH_API bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
   return handled.count(symbol) || purposefully_not_handled.count(symbol);
 }
 
+// Register `v` as a wildcard value.
+void AliasDb::setWildcard(const Value* v) {
+  wildcards_.insert(v);
+}
+
+void AliasDb::rebuildWriteCache() const {
+  for (const auto& pr : writeIndex_) {
+    const auto& writtenValues = pr.second;
+
+    for (const auto value : writtenValues) {
+      for (const auto loc : elementMap_.at(value)->getMemoryLocations()) {
+        writeCache_.insert(loc);
+      }
+    }
+  }
+  isWriteCacheStale_ = false;
+}
+
+ValueSet AliasDb::getMemoryLocations(const Value* v) const {
+  ValueSet ret;
+  if (!elementMap_.count(v)) {
+    return ret;
+  }
+
+  for (const auto el : elementMap_.at(v)->getMemoryLocations()) {
+    ret.insert(el->value);
+  }
+  return ret;
+}
 } // namespace jit
 } // namespace torch
index 508de79..bed1d63 100644 (file)
@@ -2,7 +2,7 @@
 
 #include <torch/csrc/jit/alias_info.h>
 #include <torch/csrc/jit/ir.h>
-#include <torch/csrc/jit/passes/utils/alias_tracker.h>
+#include <torch/csrc/jit/passes/utils/memory_dag.h>
 
 namespace torch {
 namespace jit {
@@ -25,7 +25,6 @@ namespace jit {
  * we're not sure what this value may alias. To be conservative, we consider
  * the wildcard alias set as potentially aliasing any value.
  */
-
 class AliasDb {
  public:
   TORCH_API explicit AliasDb(std::shared_ptr<Graph> graph);
@@ -46,9 +45,53 @@ class AliasDb {
   bool writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks = false)
       const;
 
+  // Do `a` and `b` potentially share a memory location?
+  bool mayAlias(const Value* a, const Value* b) const;
   // Do any values in group `a` potentially share a memory location with any
-  // value in group `b`?
-  bool mayAlias(const ValueSet& a, const ValueSet& b) const;
+  // value in group `b`? i.e. may they overlap?
+  //
+  // NOTE: Bit of ugly templating, but this is just to make sure we can
+  // transform an arbitrary container of `Values` to the same container of
+  // `Elements`.
+  template <
+      typename... Other1,
+      template <typename, typename...> class T,
+      typename... Other2,
+      template <typename, typename...> class U>
+  bool mayAlias(
+      const T<const Value*, Other1...>& a,
+      const U<const Value*, Other2...>& b) const {
+    if (a.empty() || b.empty()) {
+      return false;
+    }
+    // Short-circuit for special case: if any value is a wildcard, the two sets
+    // may alias
+    if (std::any_of(
+            a.cbegin(),
+            a.cend(),
+            [this](const Value* v) { return isWildcard(v); }) ||
+        std::any_of(b.cbegin(), b.cend(), [this](const Value* v) {
+          return isWildcard(v);
+        })) {
+      return true;
+    }
+
+    T<Element*> aElements;
+    for (const Value* v : a) {
+      if (elementMap_.count(v)) {
+        aElements.insert(elementMap_.at(v));
+      }
+    }
+
+    U<Element*> bElements;
+    for (const Value* v : b) {
+      if (elementMap_.count(v)) {
+        bElements.insert(elementMap_.at(v));
+      }
+    }
+
+    return memoryDAG_->mayAlias(aElements, bElements);
+  }
 
   // Do any nodes write to an alias set inputed/outputed by `n`?
   bool hasWriters(const Node* n) const;
@@ -79,7 +122,11 @@ class AliasDb {
   void move(Node* toMove, Node* movePoint, MoveSide moveSide);
   bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
 
-
+  /**
+   * Write and read internal API
+   */
+  // Does `n` write to any alias sets?
+  bool hasWrites(Node* n) const;
   // Get all the values that `n` writes to.
   // NOTE: this only returns values directly written to, not aliases thereof
   //
@@ -88,28 +135,43 @@ class AliasDb {
   ValueSet getWrites(Block* b) const;
   void getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks = false) const;
   void getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const;
-
+  // Do any nodes write to `v`s memory location?
+  bool hasWriters(const Value* v) const;
+  // Register the fact that `n` writes to `v`.
+  void registerWrite(const Value* v, Node* n);
   // Get all the values that `n` reads from.
   // if `recurseBlocks` is true, gather reads on the nodes in `n`s sub-blocks
   ValueSet getReads(Node* n, bool recurseBlocks = false) const;
-
   void getReadsImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const;
-  // Does `n` write to any alias sets?
-  bool hasWrites(Node* n) const;
 
+  // Does `n` write to a value that may alias one of the graph inputs?
+  bool writesToInputAlias(Node* n) const;
+  // Does `n` write to `v` or any aliases of `v`?
+  bool writesTo(Node* n, const Value* v) const;
+
+  /**
+   * Wildcard methods
+   */
+  // is `v` a wildcard?
+  bool isWildcard(const Value* v) const;
+  // Register `v` as a wildcard value.
+  void setWildcard(const Value* v);
+  // Get all nodes that write to a wildcard value.
+  const std::unordered_set<Node*>& getWildcardWriters() const {
+    return wildcardWriters_;
+  }
   // Does `n` use or write to any wildcard aliases?
   bool hasWildcard(const Node* n) const;
   // Returns nullopt if there are no wildcard nodes
   c10::optional<const Node*> getLastWildcard() const;
 
-  // Does `n` write to a value that may alias one of the graph inputs?
-  bool writesToInputAlias(Node* n) const;
-
+  /**
+   * Special analysis methods
+   */
   void analyze(const std::shared_ptr<Graph>& graph);
   void analyze(Block* block);
   void analyze(Node* node);
   void analyzeImpl(Node* node);
-
   void analyzeIf(Node* node);
   void analyzeLoop(Node* node);
   void analyzeSubgraph(Node* node);
@@ -121,20 +183,51 @@ class AliasDb {
   void analyzeWait(Node* node);
   void analyzeSetAttr(Node* node);
 
-  void makeAliasOf(const Value* value, const Value* to);
+  /**
+   * Alias manipulation methods
+   */
+  void makeAllAlias(const std::vector<Value*>& values);
+  void makePointerTo(const Value* value, const Value* to);
   void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from);
   void giveFreshAlias(const Value* value);
 
+  static bool shouldAnnotate(const Value* v);
+  static bool shouldAnnotate(const TypePtr& type);
   bool hasUsesAfter(Symbol alias, const Node* n) const;
-  bool writesTo(Node* n, const Value* v) const;
   bool isBeforeSameGraph(const Node* lhs, const Node* rhs) const;
 
+  // Returns true iff `v` is part of the alias tracker/is a wildcard
+  bool isTracked(const Value* v) const;
+
+  // Get the values that represent the memory locations that `v` may point to.
+  // Return values are guaranteed to be "fresh" tensors--they do not point to
+  // anything else.
+  ValueSet getMemoryLocations(const Value* v) const;
+
   std::shared_ptr<Graph> graph_;
   std::unordered_map<const Graph*, const Node*> subgraphToOwner_;
+
+  // The points-to graph that stores aliasing relationships
+  std::unique_ptr<MemoryDAG> memoryDAG_;
+  // Mapping of values to MemoryDAG elements
+  std::unordered_map<const Value*, Element*> elementMap_;
+
+  // All values that may point to a wildcard value.
+  ValueSet wildcards_;
+  // All nodes that write to a wildcard
+  std::unordered_set<Node*> wildcardWriters_;
+  // All nodes that contain a wildcard
   std::unordered_set<const Node*> wildcardNodes_;
-  std::unique_ptr<AliasTracker> aliasTracker_;
+
+  // State for tracking write info
+  size_t numWrites_ = 0;
+  std::unordered_map<Node*, ValueSet> writeIndex_;
+  mutable std::unordered_set<const Element*> writeCache_;
+  mutable bool isWriteCacheStale_ = true;
+  void rebuildWriteCache() const;
 };
 
+// Used to assert that unschematized operators have an analysis method written
 TORCH_API bool aliasAnalysisHasSpecialCaseFor(c10::Symbol sym);
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/passes/utils/alias_tracker.cpp b/torch/csrc/jit/passes/utils/alias_tracker.cpp
deleted file mode 100644 (file)
index e6cd358..0000000
+++ /dev/null
@@ -1,266 +0,0 @@
-#include "alias_tracker.h"
-
-#include <torch/csrc/utils/memory.h>
-#include <queue>
-
-namespace torch {
-namespace jit {
-
-// Returns true iff `v` is present in the alias set tracker.
-bool AliasTracker::contains(const Value* v) const {
-  return isWildcard(v) || map_.count(v);
-}
-
-bool AliasTracker::mayAlias(const Value* a, const Value* b) const {
-  if (isWildcard(a) || isWildcard(b)) {
-    return true;
-  }
-
-  if (!map_.count(a) || !map_.count(b)) {
-    return false;
-  }
-
-  const auto aEl = map_.at(a);
-  const auto bEl = map_.at(b);
-
-  const auto aMemLoc = aEl->getMemoryLocations();
-  const auto bMemLoc = bEl->getMemoryLocations();
-
-  // XXX: This could be more efficiently done as a bitwise AND on two bitfields
-  // that represent memory location membership. If these comparisons end up
-  // being a bottleneck, consider implementing it that way.
-  for (const auto aLoc : aMemLoc) {
-    for (const auto bLoc : bMemLoc) {
-      if (aLoc == bLoc) {
-        return true;
-      }
-    }
-  }
-  return false;
-}
-
-bool AliasTracker::writesTo(Node* n, const Value* v) const {
-  if (isWildcard(v)) {
-    return wildcardWriters_.count(n);
-  }
-
-  if (!map_.count(v) || !writeIndex_.count(n)) {
-    return false;
-  }
-
-  // Can short-circuit if we know this node writes directly to `v`
-  if (writeIndex_.at(n).count(v)) {
-    return true;
-  }
-
-  // Otherwise, check if `v` may alias any of written-to values in `n`
-  const auto vSet = ValueSet{v};
-  return mayAlias(vSet, writeIndex_.at(n));
-}
-
-// Make `v` point at `to`.
-void AliasTracker::makePointerTo(const Value* v, const Value* to) {
-  if (v == to) {
-    return;
-  }
-
-  // If `to` is a wildcard, don't insert anything into the graph; wildcards
-  // are tracked separately since they have different aliasing rules.
-  if (isWildcard(to)) {
-    setWildcard(v);
-    return;
-  }
-
-  if (!map_.count(to)) {
-    makeFreshValue(to);
-  }
-
-  if (!map_.count(v)) {
-    makeFreshValue(v);
-  }
-
-  auto vEl = map_.at(v);
-  auto toEl = map_.at(to);
-
-  vEl->pointsTo.insert(toEl);
-  toEl->pointedFrom.insert(vEl);
-}
-
-// Give `v` a fresh alias (i.e. it does not point to any value)
-void AliasTracker::makeFreshValue(const Value* v) {
-  auto el = torch::make_unique<Element>();
-  el->value = v;
-
-  auto rawPtr = el.get();
-  elements_.emplace(rawPtr, std::move(el));
-  map_.emplace(v, rawPtr);
-}
-
-// Register `v` as a wildcard value.
-void AliasTracker::setWildcard(const Value* v) {
-  wildcards_.insert(v);
-}
-
-// is `v` a wildcard?
-bool AliasTracker::isWildcard(const Value* v) const {
-  return wildcards_.count(v);
-}
-
-// Register the fact that `n` writes to `v`.
-void AliasTracker::registerWrite(const Value* v, Node* n) {
-  numWrites_++;
-
-  if (isWildcard(v)) {
-    wildcardWriters_.insert(n);
-    return;
-  }
-
-  AT_ASSERT(map_.count(v));
-  writeIndex_[n].insert(v);
-}
-
-bool AliasTracker::hasWriters(const Value* v) const {
-  if (!map_.count(v)) {
-    return false;
-  }
-
-  if (isWildcard(v)) {
-    // If `n` has a wildcard, any write in the graph may write to it.
-    // So the only way we know there are no writers is if there are no writes
-    // at all.
-    return numWrites_ == 0;
-  }
-
-  if (wildcardWriters_.size() > 0) {
-    // A write to the wildcard may be a write to any value.
-    return true;
-  }
-
-  if (isWriteCacheStale_) {
-    rebuildWriteCache();
-  }
-
-  for (const auto loc : map_.at(v)->getMemoryLocations()) {
-    if (writeCache_.count(loc)) {
-      return true;
-    }
-  }
-
-  return false;
-}
-
-void AliasTracker::rebuildWriteCache() const {
-  for (const auto& pr : writeIndex_) {
-    const auto& writtenValues = pr.second;
-
-    for (const auto value : writtenValues) {
-      for (const auto loc : map_.at(value)->getMemoryLocations()) {
-        writeCache_.insert(loc);
-      }
-    }
-  }
-  isWriteCacheStale_ = false;
-}
-
-void AliasTracker::dump() const {
-  std::cout << "\n===2. ALIAS DB===\n";
-  for (const auto& ptrPair : elements_) {
-    const auto element = ptrPair.first;
-    if (element->pointsTo.size() > 0) {
-      std::cout << element->value->uniqueName() << " points to: ";
-      for (const auto pointedTo : element->pointsTo) {
-        std::cout << pointedTo->value->uniqueName() << ", ";
-      }
-      std::cout << "\n";
-    }
-  }
-
-  std::cout << "\n===3. WILDCARDS===\n";
-  for (const auto wildcard : wildcards_) {
-    std::cout << wildcard->uniqueName() << ", ";
-  }
-  std::cout << "\n";
-
-  std::cout << "\n===4. Writes===\n";
-  for (const auto& pr : writeIndex_) {
-    const auto node = pr.first;
-    const auto& values = pr.second;
-    std::cout << *node;
-    std::cout << "  ";
-    for (const auto value : values) {
-      std::cout << value->uniqueName() << ", ";
-    }
-    std::cout << "\n";
-  }
-  std::cout << "\n";
-}
-
-std::unordered_set<const AliasTracker::Element*> AliasTracker::Element::
-    getMemoryLocations() const {
-  if (!cachedMemoryLocations_.empty()) {
-    return cachedMemoryLocations_;
-  }
-
-  // Do a BFS in the `points-to` direction, collecting all memory locations
-  std::unordered_set<const Element*> ret;
-  this->bfs(
-      [&](const Element* el) {
-        if (el->pointsTo.empty()) {
-          ret.insert(el);
-        }
-      },
-      BfsDirection::POINTS_TO);
-
-  cachedMemoryLocations_ = ret;
-  return ret;
-}
-
-// Do a breadth-first search over the graph, starting at `this` and
-// traversing in the direction `dir`.`fn` will be run on each element.
-template <typename Fn>
-bool AliasTracker::Element::bfs(Fn fn, BfsDirection dir) const {
-  std::queue<const Element*> queue;
-  std::unordered_set<const Element*> seen;
-
-  queue.push(this);
-  while (!queue.empty()) {
-    const auto el = queue.front();
-    queue.pop();
-    seen.insert(el);
-
-    fn(el);
-
-    switch (dir) {
-      case BfsDirection::POINTS_TO: {
-        for (auto ptr : el->pointsTo) {
-          if (!seen.count(ptr)) {
-            queue.push(ptr);
-          }
-        }
-      } break;
-
-      case BfsDirection::POINTED_FROM: {
-        for (auto ptr : el->pointedFrom) {
-          if (!seen.count(ptr)) {
-            queue.push(ptr);
-          }
-        }
-      } break;
-    }
-  }
-  return false;
-}
-
-ValueSet AliasTracker::getMemoryLocations(const Value* v) const {
-  ValueSet ret;
-  if (!map_.count(v)) {
-    return ret;
-  }
-
-  for (const auto el : map_.at(v)->getMemoryLocations()) {
-    ret.insert(el->value);
-  }
-  return ret;
-}
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/passes/utils/alias_tracker.h b/torch/csrc/jit/passes/utils/alias_tracker.h
deleted file mode 100644 (file)
index afeae89..0000000
+++ /dev/null
@@ -1,159 +0,0 @@
-#pragma once
-
-#include <torch/csrc/jit/ir.h>
-
-namespace torch {
-namespace jit {
-
-// class AliasTracker
-//
-// This class tracks the "A points to B" graph for all values, as well as
-// wildcards and writes. It is used by AliasDb to provide a higher-level API.
-//
-// We maintain a DAG where:
-//   - Vertices (called "elements") represent values and
-//     other aliasing entities (e.g. like the stuff inside a list)
-//   - Edges represent a "points-to" relationship.
-//
-// Leaves in this DAG are entities that don't point to anything, and thus
-// correspond to unique "memory locations".
-//
-// So, by traversing the "points-to" graph to the leaves, you can determine
-// which memory locations an element may point to.
-class AliasTracker {
- public:
-  // Returns true iff `v` is present in the alias set tracker.
-  bool contains(const Value* v) const;
-
-  // Does `n` write to a memory location that `v` may point to?
-  bool writesTo(Node* n, const Value* v) const;
-
-  // Make `v` point at `to`.
-  void makePointerTo(const Value* v, const Value* to);
-
-  // Give `v` a fresh alias (i.e. it does not point to any value)
-  void makeFreshValue(const Value* v);
-
-  // Register `v` as a wildcard value.
-  void setWildcard(const Value* v);
-
-  // is `v` a wildcard?
-  bool isWildcard(const Value* v) const;
-
-  // Register the fact that `n` writes to `v`.
-  void registerWrite(const Value* v, Node* n);
-
-  // Does anything write to the memory locations that `v` may point to?
-  bool hasWriters(const Value* v) const;
-
-  // Get all nodes that write to a wildcard value.
-  const std::unordered_set<Node*>& getWildcardWriters() const {
-    return wildcardWriters_;
-  }
-
-  // Get the values that represent the memory locations that `v` may point to.
-  // Return values are guaranteed to be "fresh" tensors--they do not point to
-  // anything else.
-  ValueSet getMemoryLocations(const Value* v) const;
-
-  // Do `a` and `b` potentially share a memory location?
-  bool mayAlias(const Value* a, const Value* b) const;
-
-  // Do any values in group `a` potentially share a memory location with any
-  // value in group `b`?
-  //
-  // This is written so that either of the inputs could be a multiset
-  template <typename T, typename U>
-  bool mayAlias(const T& a, const U& b) const {
-    if (a.empty() || b.empty()) {
-      return false;
-    }
-
-    // Record all memory locations from group `a`
-    std::unordered_set<const Element*> memoryLocations;
-    for (auto it = a.cbegin(); it != a.cend();) {
-      const auto value = *it;
-      if (isWildcard(value)) {
-        return true;
-      }
-
-      if (map_.count(value)) {
-        for (const auto loc : map_.at(value)->getMemoryLocations()) {
-          memoryLocations.insert(loc);
-        }
-      }
-
-      const auto cnt = a.count(*it);
-      std::advance(it, cnt);
-    }
-
-    // If any of group `b`s memory locations overlap, return true.
-    for (auto it = b.cbegin(); it != b.cend();) {
-      const auto value = *it;
-      if (isWildcard(value)) {
-        return true;
-      }
-
-      if (map_.count(value)) {
-        for (const auto loc : map_.at(value)->getMemoryLocations()) {
-          if (memoryLocations.count(loc)) {
-            return true;
-          }
-        }
-      }
-
-      const auto cnt = b.count(*it);
-      std::advance(it, cnt);
-    }
-    // No overlap, so group `a` and `b` do not share a memory location
-    return false;
-  }
-
-  void dump() const;
-
- private:
-  enum class BfsDirection {
-    POINTS_TO,
-    POINTED_FROM,
-  };
-  // `Element` represents the vertex in the points-to graph. It has a 1:1
-  // relationship with IR `Value`s.
-  struct Element {
-    const Value* value = nullptr;
-    // All elements that this element *may* point to. It's possible to have
-    // multiple elements that you might point to due to control flow/complex ops
-    std::unordered_set<Element*> pointsTo;
-    // Backreference for points-to.
-    std::unordered_set<Element*> pointedFrom;
-
-    std::unordered_set<const Element*> getMemoryLocations() const;
-    // We do path compression to make repeated memory location queries faster.
-    // An empty cache means it is invalidated (it can never be empty otherwise,
-    // since every element must point to at least one memory location).
-    mutable std::unordered_set<const Element*> cachedMemoryLocations_;
-
-    // Do a breadth-first search over the graph, starting at `this` and
-    // traversing in the direction `dir`.`fn` will be run on each element.
-    template <typename Fn>
-    bool bfs(Fn fn, BfsDirection dir) const;
-  };
-
-  // Structure that owns all the element pointers. It's a map of
-  //  raw pointer -> unique_ptr to facilitate easy queries
-  std::unordered_map<Element*, std::unique_ptr<Element>> elements_;
-  // Index to look up whatever element corresponds to that value.
-  std::unordered_map<const Value*, Element*> map_;
-  // All values that may point to a wildcard value.
-  ValueSet wildcards_;
-  // All nodes that write to a wildcard
-  std::unordered_set<Node*> wildcardWriters_;
-  size_t numWrites_ = 0;
-
-  std::unordered_map<Node*, ValueSet> writeIndex_;
-  mutable std::unordered_set<const Element*> writeCache_;
-  mutable bool isWriteCacheStale_ = true;
-  void rebuildWriteCache() const;
-};
-
-} // namespace jit
-} // namespace torch
diff --git a/torch/csrc/jit/passes/utils/memory_dag.cpp b/torch/csrc/jit/passes/utils/memory_dag.cpp
new file mode 100644 (file)
index 0000000..4f5b75a
--- /dev/null
@@ -0,0 +1,105 @@
+#include "memory_dag.h"
+
+#include <torch/csrc/utils/memory.h>
+#include <queue>
+
+namespace torch {
+namespace jit {
+
+bool MemoryDAG::mayAlias(Element* a, Element* b) const {
+  return mayAliasImpl(a, b);
+}
+
+bool MemoryDAG::mayAlias(const Element* a, const Element* b) const {
+  return mayAliasImpl(a, b);
+}
+
+bool MemoryDAG::mayAliasImpl(const Element* a, const Element* b) const {
+  const auto aMemLoc = a->getMemoryLocations();
+  const auto bMemLoc = b->getMemoryLocations();
+
+  // XXX: This could be more efficiently done as a bitwise AND on two bitfields
+  // that represent memory location membership. If these comparisons end up
+  // being a bottleneck, consider implementing it that way.
+  for (const auto aLoc : aMemLoc) {
+    for (const auto bLoc : bMemLoc) {
+      if (aLoc == bLoc) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+// Make `v` point at `to`.
+void MemoryDAG::makePointerTo(Element* from, Element* to) {
+  from->pointsTo.insert(to);
+  to->pointedFrom.insert(from);
+}
+
+// Give `v` a fresh alias (i.e. it does not point to any value)
+Element* MemoryDAG::makeFreshValue(const Value* v) {
+  auto el = torch::make_unique<Element>();
+  el->value = v;
+
+  auto rawPtr = el.get();
+  elements_.emplace(rawPtr, std::move(el));
+  return rawPtr;
+}
+
+std::unordered_set<const Element*> Element::getMemoryLocations() const {
+  if (!cachedMemoryLocations_.empty()) {
+    return cachedMemoryLocations_;
+  }
+
+  // Do a BFS in the `points-to` direction, collecting all memory locations
+  std::unordered_set<const Element*> ret;
+  this->bfs(
+      [&](const Element* el) {
+        if (el->pointsTo.empty()) {
+          ret.insert(el);
+        }
+      },
+      BfsDirection::POINTS_TO);
+
+  cachedMemoryLocations_ = ret;
+  return ret;
+}
+
+// Do a breadth-first search over the graph, starting at `this` and
+// traversing in the direction `dir`.`fn` will be run on each element.
+template <typename Fn>
+bool Element::bfs(Fn fn, BfsDirection dir) const {
+  std::queue<const Element*> queue;
+  std::unordered_set<const Element*> seen;
+
+  queue.push(this);
+  while (!queue.empty()) {
+    const auto el = queue.front();
+    queue.pop();
+    seen.insert(el);
+
+    fn(el);
+
+    switch (dir) {
+      case BfsDirection::POINTS_TO: {
+        for (auto ptr : el->pointsTo) {
+          if (!seen.count(ptr)) {
+            queue.push(ptr);
+          }
+        }
+      } break;
+
+      case BfsDirection::POINTED_FROM: {
+        for (auto ptr : el->pointedFrom) {
+          if (!seen.count(ptr)) {
+            queue.push(ptr);
+          }
+        }
+      } break;
+    }
+  }
+  return false;
+}
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/utils/memory_dag.h b/torch/csrc/jit/passes/utils/memory_dag.h
new file mode 100644 (file)
index 0000000..6142481
--- /dev/null
@@ -0,0 +1,121 @@
+#pragma once
+
+#include <unordered_set>
+#include <unordered_map>
+#include <memory>
+
+namespace torch {
+namespace jit {
+
+struct Element;
+struct Value;
+
+// class MemoryDAG
+//
+// This class tracks the "A points to B" graph for all values. It is used by
+// AliasDb to provide a higher-level API.
+//
+// We maintain a DAG where:
+//   - Vertices (called "elements") represent values and
+//     other aliasing entities (e.g. like the stuff inside a list)
+//   - Edges represent a "points-to" relationship.
+//
+// Leaves in this DAG are entities that don't point to anything, and thus
+// correspond to unique "memory locations".
+//
+// So, by traversing the "points-to" graph to the leaves, you can determine
+// which memory locations an element may point to.
+class MemoryDAG {
+ public:
+  // Make `from` point at `to`.
+  void makePointerTo(Element* from, Element* to);
+
+  // Make a fresh element (i.e. an element that doesn't point to anything) and
+  // return it.
+  Element* makeFreshValue(const Value* v);
+
+  // Do `a` and `b` potentially share a memory location?
+  bool mayAlias(const Element* a, const Element* b) const;
+  bool mayAlias(Element* a, Element* b) const;
+
+  // Do any values in group `a` potentially share a memory location with any
+  // value in group `b`?
+  //
+  // This is written so that either of the inputs could be a multiset
+  template <typename T, typename U>
+  bool mayAlias(const T& a, const U& b) const {
+    if (a.empty() || b.empty()) {
+      return false;
+    }
+
+    // Record all memory locations from group `a`
+    std::unordered_set<const Element*> memoryLocations;
+    for (auto it = a.cbegin(); it != a.cend();) {
+      const auto element = *it;
+
+      for (const auto loc : element->getMemoryLocations()) {
+        memoryLocations.insert(loc);
+      }
+
+      const auto cnt = a.count(*it);
+      std::advance(it, cnt);
+    }
+
+    // If any of group `b`s memory locations overlap, return true.
+    for (auto it = b.cbegin(); it != b.cend();) {
+      const auto element = *it;
+
+      for (const auto loc : element->getMemoryLocations()) {
+        if (memoryLocations.count(loc)) {
+          return true;
+        }
+      }
+
+      const auto cnt = b.count(*it);
+      std::advance(it, cnt);
+    }
+    // No overlap, so group `a` and `b` do not share a memory location
+    return false;
+  }
+
+ private:
+   bool mayAliasImpl(const Element* a, const Element* b) const;
+  // Structure that owns all the element pointers. It's a map of
+  //  raw pointer -> unique_ptr to facilitate easy queries
+  std::unordered_map<Element*, std::unique_ptr<Element>> elements_;
+};
+
+enum class BfsDirection {
+  POINTS_TO,
+  POINTED_FROM,
+};
+
+// `Element` represents the vertex in the points-to graph. It represents
+// anything that could have an aliasing relationship, mostly IR `Value`s, but
+// also the "inside of a list", or wildcards.
+struct Element {
+  // The value that this element corresponds to. May be null if this element
+  // doesn't represent a first-class value.
+  const Value* value = nullptr;
+
+  // All elements that this element *may* point to. It's possible to have
+  // multiple elements that you might point to due to control flow/complex ops
+  std::unordered_set<Element*> pointsTo;
+  // Backreference for points-to.
+  std::unordered_set<Element*> pointedFrom;
+
+  // Return the unique memory locations that `Element` might represent.
+  std::unordered_set<const Element*> getMemoryLocations() const;
+  // We do path compression to make repeated memory location queries faster.
+  // An empty cache means it is invalidated (it can never be empty otherwise,
+  // since every element must point to at least one memory location).
+  mutable std::unordered_set<const Element*> cachedMemoryLocations_;
+
+  // Do a breadth-first search over the graph, starting at `this` and
+  // traversing in the direction `dir`.`fn` will be run on each element.
+  template <typename Fn>
+  bool bfs(Fn fn, BfsDirection dir) const;
+};
+
+} // namespace jit
+} // namespace torch