simplify aliasdb interface (#17453)
authorMichael Suo <suo@fb.com>
Mon, 25 Feb 2019 21:27:43 +0000 (13:27 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 25 Feb 2019 21:34:51 +0000 (13:34 -0800)
Summary:
Stack:
&nbsp;&nbsp;&nbsp;&nbsp;:black_circle:&nbsp; **#17453 [jit] simplify aliasdb interface**&nbsp;&nbsp;[:yellow_heart:](https://our.intern.facebook.com/intern/diff/D14205209/)

The previous "getWrites" API relies on the user to do alias checking, which is confusing and inconsistent with the rest of the interface. So replace it with a higher-level call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17453

Differential Revision: D14209942

Pulled By: suo

fbshipit-source-id: d4aff2af6062ab8465ee006fc6dc603296bcb7ab

torch/csrc/jit/passes/alias_analysis.cpp
torch/csrc/jit/passes/alias_analysis.h
torch/csrc/jit/passes/dead_code_elimination.cpp

index 899bff7..07375c3 100644 (file)
@@ -143,6 +143,14 @@ ValueSet AliasDb::getWrites(Block* b) const {
   return writes;
 }
 
+
+// Does `n` write to an alias of one of the values in `vs`?
+bool AliasDb::writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks)
+    const {
+  const auto writtenTo = getWrites(n, recurseBlocks);
+  return mayAlias(vs, writtenTo);
+}
+
 std::unordered_set<const Value*> AliasDb::getWrites(Node* n, bool recurseBlocks)
     const {
   ValueSet writes;
index 24212b7..5228472 100644 (file)
@@ -41,11 +41,10 @@ class AliasDb {
   // circumstances.
   bool hasUntrackedEffects(Node* n) const;
 
-  // Get all the values that `n` writes to.
-  // NOTE: this only returns values directly written to, not aliases thereof
-  //
-  // if `recurseBlocks` is true, gather writes on the nodes in `n`s sub-blocks
-  ValueSet getWrites(Node* n, bool recurseBlocks = false) const;
+  // Does `n` write to an alias of one of the values in `vs`?
+  // if `recurseBlocks` is true, consider writes on the nodes in `n`s sub-blocks
+  bool writesToAlias(Node* n, const ValueSet& vs, bool recurseBlocks = false)
+      const;
 
   // Do any values in group `a` potentially share a memory location with any
   // value in group `b`?
@@ -80,6 +79,12 @@ class AliasDb {
   void move(Node* toMove, Node* movePoint, MoveSide moveSide);
   bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
 
+
+  // Get all the values that `n` writes to.
+  // NOTE: this only returns values directly written to, not aliases thereof
+  //
+  // if `recurseBlocks` is true, gather writes on the nodes in `n`s sub-blocks
+  ValueSet getWrites(Node* n, bool recurseBlocks = false) const;
   ValueSet getWrites(Block* b) const;
   void getWritesImpl(Block* b, ValueSet& ret, bool recurseBlocks = false) const;
   void getWritesImpl(Node* n, ValueSet& ret, bool recurseBlocks = false) const;
index fbc3408..e559b84 100644 (file)
@@ -130,8 +130,7 @@ class DeadCodeEliminator {
     }
 
     if (aliasDb_) {
-      const auto writes = aliasDb_->getWrites(node);
-      if (aliasDb_->mayAlias(writes, liveValues_)) {
+      if (aliasDb_->writesToAlias(node, liveValues_, /*recurseBlocks=*/false)) {
         return mark(node);
       }
     }