Allow DCE to clean up some mutable ops (#14601)
authorMichael Suo <suo@fb.com>
Mon, 3 Dec 2018 21:27:59 +0000 (13:27 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 3 Dec 2018 21:31:08 +0000 (13:31 -0800)
Summary:
This PR makes DCE a little smarter in the presence of mutable ops. Previously mutable ops could never be cleaned up, now they can be cleaned up if we can prove there are no live uses of any alias sets that the op writes to.

This behavior is optional; if you pass DCE a block instead of a graph, it will do the same thing as before. Also changed `InlineAutographSubgraph` to use the common subgraph utils.

Tested on traced ResNet, and it gets rid of the dead code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14601

Differential Revision: D13309118

Pulled By: suo

fbshipit-source-id: dac2791e7d2ecf219ae717a2759b83c1e927f254

20 files changed:
test/expect/TestScript.test_mutable_dce.expect [new file with mode: 0644]
test/expect/TestScript.test_mutable_dce_block.expect [new file with mode: 0644]
test/expect/TestScript.test_mutable_dce_graph_input.expect [new file with mode: 0644]
test/expect/TestScript.test_mutable_dce_list.expect [new file with mode: 0644]
test/expect/TestScript.test_mutable_dce_loop.expect [new file with mode: 0644]
test/test_jit.py
torch/csrc/jit/export.cpp
torch/csrc/jit/init.cpp
torch/csrc/jit/ir.cpp
torch/csrc/jit/ir.h
torch/csrc/jit/passes/alias_analysis.cpp
torch/csrc/jit/passes/alias_analysis.h
torch/csrc/jit/passes/common_subexpression_elimination.cpp
torch/csrc/jit/passes/constant_propagation.cpp
torch/csrc/jit/passes/dead_code_elimination.cpp
torch/csrc/jit/passes/dead_code_elimination.h
torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp
torch/csrc/jit/passes/lower_tuples.cpp
torch/csrc/jit/passes/shape_analysis.cpp
torch/csrc/jit/passes/utils/subgraph_utils.cpp

diff --git a/test/expect/TestScript.test_mutable_dce.expect b/test/expect/TestScript.test_mutable_dce.expect
new file mode 100644 (file)
index 0000000..97ad8f0
--- /dev/null
@@ -0,0 +1,14 @@
+graph() {
+  %0 : int = prim::Constant[value=1]()
+  %1 : int[] = prim::Constant[value=[0, -1]]()
+  %2 : int = prim::Constant[value=0]()
+  %3 : int = prim::Constant[value=6]()
+  %4 : int = prim::Constant[value=2]()
+  %5 : int = prim::Constant[value=3]()
+  %6 : int[] = prim::ListConstruct(%4, %5)
+  %a.1 : Tensor = aten::rand(%6, %3, %2, %1)
+  %8 : int[] = prim::ListConstruct(%4, %5)
+  %9 : Tensor = aten::rand(%8, %3, %2, %1)
+  %a : Tensor = aten::add_(%a.1, %9, %0)
+  return (%a);
+}
diff --git a/test/expect/TestScript.test_mutable_dce_block.expect b/test/expect/TestScript.test_mutable_dce_block.expect
new file mode 100644 (file)
index 0000000..5e6c69f
--- /dev/null
@@ -0,0 +1,30 @@
+graph() {
+  %0 : int = prim::Constant[value=1]()
+  %1 : int[] = prim::Constant[value=[0, -1]]()
+  %2 : int = prim::Constant[value=0]()
+  %3 : int = prim::Constant[value=6]()
+  %4 : int = prim::Constant[value=2]()
+  %5 : int = prim::Constant[value=3]()
+  %6 : int[] = prim::ListConstruct(%4, %5)
+  %a.1 : Tensor = aten::rand(%6, %3, %2, %1)
+  %8 : int[] = prim::ListConstruct(%4, %5)
+  %9 : Tensor = aten::rand(%8, %3, %2, %1)
+  %a.2 : Tensor = aten::add_(%a.1, %9, %0)
+  %11 : int[] = prim::ListConstruct(%4, %5)
+  %b.1 : Tensor = aten::rand(%11, %3, %2, %1)
+  %13 : int[] = prim::ListConstruct(%4, %5)
+  %14 : Tensor = aten::zeros(%13, %3, %2, %1)
+  %15 : Tensor = aten::gt(%a.2, %14)
+  %16 : bool = prim::TensorToBool(%15)
+  %b : Tensor = prim::If(%16)
+    block0() {
+      %18 : int[] = prim::ListConstruct(%4, %5)
+      %19 : Tensor = aten::rand(%18, %3, %2, %1)
+      %b.2 : Tensor = aten::add_(%b.1, %19, %0)
+      -> (%b.2)
+    }
+    block1() {
+      -> (%b.1)
+    }
+  return (%b);
+}
diff --git a/test/expect/TestScript.test_mutable_dce_graph_input.expect b/test/expect/TestScript.test_mutable_dce_graph_input.expect
new file mode 100644 (file)
index 0000000..fb4330a
--- /dev/null
@@ -0,0 +1,12 @@
+graph(%a.1 : Tensor) {
+  %1 : int = prim::Constant[value=1]()
+  %2 : int[] = prim::Constant[value=[0, -1]]()
+  %3 : int = prim::Constant[value=0]()
+  %4 : int = prim::Constant[value=6]()
+  %5 : int = prim::Constant[value=2]()
+  %6 : int = prim::Constant[value=3]()
+  %7 : int[] = prim::ListConstruct(%5, %6)
+  %8 : Tensor = aten::rand(%7, %4, %3, %2)
+  %a : Tensor = aten::add_(%a.1, %8, %1)
+  return ();
+}
diff --git a/test/expect/TestScript.test_mutable_dce_list.expect b/test/expect/TestScript.test_mutable_dce_list.expect
new file mode 100644 (file)
index 0000000..e348570
--- /dev/null
@@ -0,0 +1,17 @@
+graph(%a : Tensor) {
+  %1 : int = prim::Constant[value=1]()
+  %2 : int[] = prim::Constant[value=[0, -1]]()
+  %3 : int = prim::Constant[value=6]()
+  %4 : int = prim::Constant[value=0]()
+  %5 : int = prim::Constant[value=2]()
+  %6 : int = prim::Constant[value=3]()
+  %l : Tensor[] = prim::ListConstruct()
+  %8 : Tensor[] = aten::append(%l, %a)
+  %c.1 : Tensor = aten::select(%l, %4)
+  %10 : int[] = prim::ListConstruct(%5, %6)
+  %b : Tensor = aten::rand(%10, %3, %4, %2)
+  %12 : int[] = prim::ListConstruct(%5, %6)
+  %13 : Tensor = aten::rand(%12, %3, %4, %2)
+  %c : Tensor = aten::add_(%c.1, %13, %1)
+  return (%b);
+}
diff --git a/test/expect/TestScript.test_mutable_dce_loop.expect b/test/expect/TestScript.test_mutable_dce_loop.expect
new file mode 100644 (file)
index 0000000..eb2a997
--- /dev/null
@@ -0,0 +1,25 @@
+graph(%a : Tensor) {
+  %1 : int[] = prim::Constant[value=[0, -1]]()
+  %2 : int = prim::Constant[value=6]()
+  %i.1 : int = prim::Constant[value=0]()
+  %4 : int = prim::Constant[value=2]()
+  %5 : int = prim::Constant[value=3]()
+  %6 : int = prim::Constant[value=9223372036854775807]()
+  %7 : int = prim::Constant[value=1]()
+  %l : Tensor[] = prim::ListConstruct()
+  %9 : Tensor[] = aten::append(%l, %a)
+  %10 : int[] = prim::ListConstruct(%4, %5)
+  %b : Tensor = aten::rand(%10, %2, %i.1, %1)
+  %12 : bool = aten::lt(%i.1, %7)
+  %i : int = prim::Loop(%6, %12, %i.1)
+    block0(%14 : int, %15 : int) {
+      %c.1 : Tensor = aten::select(%l, %i.1)
+      %17 : int[] = prim::ListConstruct(%4, %5)
+      %18 : Tensor = aten::rand(%17, %2, %i.1, %1)
+      %c : Tensor = aten::add_(%c.1, %18, %7)
+      %i.2 : int = aten::add(%15, %7)
+      %21 : bool = aten::lt(%i.2, %7)
+      -> (%21, %i.2)
+    }
+  return (%b);
+}
index 208a649..6b3f1c8 100644 (file)
@@ -8881,6 +8881,68 @@ a")
 
         self.checkScript(fn, (torch.ones(2, 4, 2), torch.ones(2, 4, 2)))
 
+    def test_mutable_dce(self):
+        @torch.jit.script
+        def foo():
+            a = torch.rand(2, 3)
+            a += torch.rand(2, 3)
+            b = torch.rand(2, 3)
+            b += torch.rand(2, 3)
+            # b should be cleaned up but not a
+            return a
+
+        self.assertExpectedGraph(foo.graph)
+
+    def test_mutable_dce_block(self):
+        @torch.jit.script
+        def foo():
+            a = torch.rand(2, 3)
+            a += torch.rand(2, 3)
+            b = torch.rand(2, 3)
+            if bool(a > torch.zeros(2, 3)):
+                b += torch.rand(2, 3)
+                a += torch.rand(2, 3)
+            # a should be cleaned up but not b
+            return b
+
+        self.assertExpectedGraph(foo.graph)
+
+    def test_mutable_dce_graph_input(self):
+        @torch.jit.script
+        def foo(a):
+            a += torch.rand(2, 3)
+            # shouldn't clean up `a` even though it's not used in the output
+
+        self.assertExpectedGraph(foo.graph)
+
+    def test_mutable_dce_list(self):
+        @torch.jit.script
+        def foo(a):
+            l = []
+            l.append(a)
+            c = l[0]
+            b = torch.rand(2, 3)
+            c += torch.rand(2, 3)
+            return b
+
+        self.assertExpectedGraph(foo.graph)
+
+    def test_mutable_dce_loop(self):
+        @torch.jit.script
+        def foo(a):
+            l = []
+            l.append(a)
+            i = 0
+            b = torch.rand(2, 3)
+            while i < 1:
+                dead = torch.rand(2, 3)
+                c = l[0]
+                c += torch.rand(2, 3)
+                i += 1
+            return b
+
+        self.assertExpectedGraph(foo.graph)
+
 
 class MnistNet(nn.Module):
     def __init__(self):
index bb04473..6c8876c 100644 (file)
@@ -96,7 +96,7 @@ void validateBlock(Block *b, onnx_torch::OperatorExportTypes operator_export_typ
 
 void validateGraph(const std::shared_ptr<Graph>& graph, onnx_torch::OperatorExportTypes operator_export_type) {
   validateBlock(graph->block(), operator_export_type);
-  EliminateDeadCode(graph);
+  EliminateDeadCode(graph->block());
 }
 
 class EncoderBase {
index 676042a..3932a7e 100644 (file)
@@ -96,7 +96,7 @@ void initJITBindings(PyObject *module) {
    .def("_jit_pass_onnx_peephole", PeepholeOptimizeONNX)
    .def("_jit_pass_fuse", FuseGraph)
    .def("_jit_pass_dce", [](std::shared_ptr<Graph>& g) {
-     return EliminateDeadCode(g); // overload resolution
+     return EliminateDeadCode(g->block()); // overload resolution
    })
    .def("_jit_pass_cse", [](std::shared_ptr<Graph>& g) {
      return EliminateCommonSubexpression(g); // overload resolution
index 3a674c0..3cd27b3 100644 (file)
@@ -862,18 +862,18 @@ Value* Node::insertOutput(size_t i) {
   return outputs_.at(i);
 }
 
-bool Node::isBefore(const Node * n) const {
-  if (this == n) {
-    return false;
-  }
-  return !isAfter(n);
-}
+bool Node::isBeforeOrAfter(const Node* n, MoveSide moveSide) const {
+  if (this->owningBlock() == n->owningBlock()) {
+    if (moveSide == MoveSide::BEFORE) {
+      return this->topo_position_ < n->topo_position_;
+    }
 
-bool Node::isAfter(const Node * n) const {
-  JIT_ASSERT(this->owningGraph() == n->owningGraph());
+    if (moveSide == MoveSide::AFTER) {
+      return this->topo_position_ > n->topo_position_;
+    }
 
-  if (this->owningBlock() == n->owningBlock()) {
-    return this->topo_position_ > n->topo_position_;
+    JIT_ASSERT(this == n);
+    return false;
   }
 
   // These nodes don't share a common block. Traverse the blockchains upward
@@ -887,7 +887,7 @@ bool Node::isAfter(const Node * n) const {
       JIT_ASSERT(rhs->owningBlock());
 
       if (lhs->owningBlock() == rhs->owningBlock()) {
-        return lhs->isAfter(rhs);
+        return lhs->isBeforeOrAfter(rhs, moveSide);
       }
       rhs = rhs->owningBlock()->owningNode();
     }
@@ -896,6 +896,15 @@ bool Node::isAfter(const Node * n) const {
   }
   // should never reach here, since both nodes are ultimately in the same graph
   JIT_ASSERT(false);
+
+}
+
+bool Node::isBefore(const Node * n) const {
+  return isBeforeOrAfter(n, MoveSide::BEFORE);
+}
+
+bool Node::isAfter(const Node * n) const {
+  return isBeforeOrAfter(n, MoveSide::AFTER);
 }
 
 Node* Node::insertBefore(Node * n) {
index f4cf0d4..d1fb402 100644 (file)
@@ -599,6 +599,7 @@ public:
   enum class MoveSide { BEFORE, AFTER };
   bool tryMove(Node* movePoint, MoveSide moveSide, const AliasDb& aliasDb, bool dryRun);
   void move(Node* movePoint, MoveSide moveSide);
+  bool isBeforeOrAfter(const Node* n, MoveSide moveSide) const;
 
   std::pair<Value*, const Argument&> findInput(Symbol name);
   void findSchema() const;
@@ -808,6 +809,12 @@ public:
     const auto & block = *block_;
     return block.nodes();
   }
+  Node * param_node() {
+    return block_->param_node();
+  }
+  const Node * param_node() const {
+    return block_->param_node();
+  }
   Node * return_node() {
     return block_->return_node();
   }
index f30cc63..5953aac 100644 (file)
@@ -20,7 +20,43 @@ bool shouldAnnotate(const Value* v) {
 }
 } // namespace
 
+AliasDb::AliasDb(std::shared_ptr<Graph> graph) : graph_(graph) {
+  analyze(graph_);
+
+  // Build helper indices
+  // NOTE: that these assume that AliasDb is immutable once constructed.
+  // - Alias set -> value mapping
+  for (const auto& pr : valueToAlias_) {
+    const auto value = pr.first;
+    const auto& aliasInfo = pr.second;
+    // We don't support composite types yet
+    JIT_ASSERT(aliasInfo.containedTypes().size() == 0);
+    for (const auto aliasSet : aliasInfo.sets()) {
+      aliasToValue_[aliasSet].insert(value);
+    }
+  }
+  // - Set of all nodes with a wildcard
+  buildWildcardIndex(graph->block());
+}
+
+void AliasDb::buildWildcardIndex(const Block* b) {
+  for (const auto node : b->nodes()) {
+    for (const auto block : node->blocks()) {
+      buildWildcardIndex(block);
+    }
+
+    if (hasWildcardImpl(node)) {
+      wildcardNodes_.insert(node);
+    }
+  }
+}
+
 bool AliasDb::hasWildcard(const Node* n) const {
+  return wildcardNodes_.count(n) != 0;
+}
+
+// Does `n` use or write to any wildcard aliases?
+bool AliasDb::hasWildcardImpl(const Node* n) const {
   for (const auto input : n->inputs()) {
     if (valueToAlias_.count(input) != 0 &&
         valueToAlias_.at(input).isWildcard()) {
@@ -72,6 +108,32 @@ bool AliasDb::hasWrites(Node* n) const {
   return false;
 }
 
+bool AliasDb::writesToInputAlias(Node* n) const {
+  std::vector<const Value*> writes;
+  for (const auto input : n->inputs()) {
+    if (writesTo(n, input)) {
+      writes.push_back(input);
+    }
+  }
+  for (const auto output : n->outputs()) {
+    if (writesTo(n, output)) {
+      writes.push_back(output);
+    }
+  }
+
+  // For all writes, check if the written value may alias a graph input
+  return std::any_of(writes.cbegin(), writes.cend(), [&](const Value* v) {
+    const auto& aliasInfo = valueToAlias_.at(v);
+    const auto& aliasSets = aliasInfo.sets();
+
+    // Check every distinct alias set this value belongs to
+    return std::any_of(
+        aliasSets.cbegin(), aliasSets.cend(), [&](const Symbol aliasSet) {
+          return graphInputAliases_.count(aliasSet) != 0;
+        });
+  });
+}
+
 std::unordered_set<Node*> AliasDb::getWritersForNode(const Node* n) const {
   // Get all alias sets of this node
   // ... check the inputs
@@ -133,12 +195,17 @@ void AliasDb::dump() const {
     }
     std::cout << "\n";
   }
+
+  std::cout << "\n===3. WILDCARD INDEX===\n";
+  for (const auto node : wildcardNodes_) {
+    node->dump();
+  }
 }
 
 void AliasDb::analyze(std::shared_ptr<Graph> graph) {
   // Assign aliases to the graph's inputs, assuming that all inputs of a given
   // type may alias to each other.
-  const auto tensorAlias = getFreshAlias();
+  const auto tensorAlias = getFreshAlias(/*isGraphInput=*/true);
   // Create a separate alias set for each list type
   std::map<TypeKind, Symbol> listTypeAliases;
   // Create a separate alias set for each tuple type
@@ -162,14 +229,15 @@ void AliasDb::analyze(std::shared_ptr<Graph> graph) {
         containedType = DynamicType::get();
       }
       if (listTypeAliases.count(containedType->kind()) == 0) {
-        listTypeAliases[containedType->kind()] = getFreshAlias();
+        listTypeAliases[containedType->kind()] =
+            getFreshAlias(/*isGraphInput=*/true);
       }
 
       addAlias(input, listTypeAliases.at(containedType->kind()));
     } else if (inputType->kind() == TypeKind::TupleType) {
       auto tupleType = inputType->cast<TupleType>();
       if (tupleTypeAliases.count(tupleType) == 0) {
-        tupleTypeAliases[tupleType] = getFreshAlias();
+        tupleTypeAliases[tupleType] = getFreshAlias(/*isGraphInput=*/true);
       }
       addAlias(input, tupleTypeAliases.at(tupleType));
     } else {
@@ -380,7 +448,13 @@ void AliasDb::analyzeSubgraph(Node* node) {
 
   analyze(subgraphBlock);
 
-  mapAliases(node->outputs(), subgraphBlock->outputs());
+  // TODO(suo): the subgraph outputs and node outputs are NOT NECESSARILY the
+  // same length. Autodifferentiation maybe capture additional outputs in the
+  // subgraph block.
+  JIT_ASSERT(subgraphBlock->outputs().size() >= node->outputs().size());
+  for (size_t i = 0; i < node->outputs().size(); i++) {
+    addAlias(node->outputs()[i], subgraphBlock->outputs()[i]);
+  }
 }
 
 // For nodes that generate a fresh value from nothing
@@ -423,9 +497,12 @@ void AliasDb::analyzeBroadcastingChunk(Node* node) {
   }
 }
 
-Symbol AliasDb::getFreshAlias() const {
+Symbol AliasDb::getFreshAlias(bool isGraphInput) {
   auto num = std::stoll(latestSymbol_.toUnqualString());
   latestSymbol_ = Symbol::fromQualString("alias::" + std::to_string(++num));
+  if (isGraphInput) {
+    graphInputAliases_.insert(latestSymbol_);
+  }
   return latestSymbol_;
 }
 
index 1bfd40e..47a60a1 100644 (file)
@@ -26,16 +26,21 @@ namespace jit {
  */
 class AliasDb {
  public:
-  AliasDb(std::shared_ptr<Graph> graph) : graph_(graph) {
-    analyze(graph_);
-  }
+  explicit AliasDb(std::shared_ptr<Graph> graph);
 
-  // Does `n` contain any wildcard aliases?
+  // Does `n` use or write to any wildcard aliases?
   bool hasWildcard(const Node* n) const;
 
+  const std::unordered_set<const Node*>& getWildcardNodes() const {
+    return wildcardNodes_;
+  }
+
   // Does `n` write to any alias sets?
   bool hasWrites(Node* n) const;
 
+  // Does `n` write to a value that may alias one of the graph inputs?
+  bool writesToInputAlias(Node* n) const;
+
   // Get all nodes that write to any alias set inputed/outputed by `n`
   std::unordered_set<Node*> getWritersForNode(const Node* n) const;
 
@@ -60,19 +65,25 @@ class AliasDb {
   void analyzeChunk(Node* node);
   void analyzeBroadcastingChunk(Node* node);
 
-  Symbol getFreshAlias() const;
+  Symbol getFreshAlias(bool isGraphInput = false);
   void addAlias(const Value* value, AliasInfo alias);
   void addAlias(const Value* value, Symbol alias);
   void addAlias(const Value* value, const Value* from);
   void mapAliases(at::ArrayRef<Value*> to, at::ArrayRef<Value*> from);
   void giveFreshAlias(const Value* value);
 
+  bool hasUsesAfter(Symbol alias, const Node* n) const;
+  void buildWildcardIndex(const Block* b);
+  bool hasWildcardImpl(const Node* n) const;
   bool writesTo(Node* n, const Value* v) const;
 
   std::shared_ptr<Graph> graph_;
-  mutable Symbol latestSymbol_ = Symbol::fromQualString("alias::0");
+  Symbol latestSymbol_ = Symbol::fromQualString("alias::0");
   std::unordered_map<const Value*, AliasInfo> valueToAlias_;
+  std::unordered_map<Symbol, std::unordered_set<const Value*>> aliasToValue_;
   std::unordered_map<Symbol, std::unordered_set<Node*>> aliasToWrites_;
+  std::unordered_set<const Node*> wildcardNodes_;
+  std::unordered_set<Symbol> graphInputAliases_;
 };
 
 inline TORCH_API AliasDb AliasAnalysis(std::shared_ptr<Graph> graph) {
index 17b8e77..a826dc3 100644 (file)
@@ -24,7 +24,7 @@ void EliminateCommonSubexpression(
   for (auto it = block->nodes().begin(); it != block->nodes().end(); ++ it) {
     auto node = *it;
     if (node->kind() == prim::PythonOp || node->kind() == prim::Print ||
-        aliasDb.hasWriters(node)) {
+        aliasDb.hasWriters(node) || aliasDb.hasWildcard(node)) {
       // Do NOT have enough information to do CSE on these nodes.
       continue;
     }
index 95f4d3c..5c25f2d 100644 (file)
@@ -125,7 +125,7 @@ void ConstantPropagation(Node* n, const AliasDb& aliasDb, bool recurse) {
       });
   bool supported_node = !n->kind().is_onnx() &&
       skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
-      !aliasDb.hasWriters(n);
+      !aliasDb.hasWriters(n) && !aliasDb.hasWildcard(n);
   auto run_blocks = [&]() {
     if (recurse) {
       for (Block * block : n->blocks()) {
index 4d6b22b..ddcdd10 100644 (file)
-#include "torch/csrc/jit/passes/dead_code_elimination.h"
+#include "dead_code_elimination.h"
+
+#include "torch/csrc/jit/passes/alias_analysis.h"
 
 #include <unordered_map>
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
-using bool_memo_type = std::unordered_map<Node*, bool>;
+class DeadCodeEliminator {
+ public:
+  explicit DeadCodeEliminator(std::shared_ptr<Graph> graph)
+      : aliasDb_(AliasAnalysis(graph)) {}
+  DeadCodeEliminator(){};
 
-bool isMutable(Node* node) {
-  if(!node->kind().is_aten())
-    return false;
-  // onnx export calls EliminateDeadCode but sometimes passes invalid
-  // aten operators. So we call maybeSchema so we handle the cases when
-  // there is no valid schema for a node
-  auto schema = node->maybeSchema();
-  return schema && schema->is_mutable();
-}
+  // The algorithm is an inverse mark-and-sweep. Starting from the return node,
+  // we mark "live" nodes that are necessary for the output. Nodes that have
+  // side effects are also marked.
+  void run(Block* block, bool recurse) {
+    // Find the last wildcard in the block. We cannot eliminate any mutable ops
+    // that precede the last wildcard.
+    setLastWildcard();
 
-bool hasSideEffects(Node * node, bool_memo_type& memo) {
-  // FIXME: PythonOp should be treated as having side effects as well!
-  //        Unfortunately ONNX depends on it getting removed in this pass, so it's not
-  //        a simple change.
-  auto it = memo.find(node);
-  if (it != memo.end())
-    return it->second;
-  bool has_side_effects =
-      node->kind() == prim::Print ||
-        node->kind() == prim::RaiseException ||
-      std::any_of(node->blocks().begin(), node->blocks().end(), [&](Block* b) {
-        return std::any_of(b->nodes().begin(), b->nodes().end(), [&](Node* n) {
-          return hasSideEffects(n, memo);
-        });
-      }) || isMutable(node);
-
-  memo.emplace(node, has_side_effects);
-  return has_side_effects;
-}
+    // Initialize by adding the return node to work list
+    markAndEnqueue(block->return_node());
 
-void removeDeadIfOutputs(Node* node) {
-  if (node->kind() != prim::If)
-    return;
-  for(size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
-    size_t i = i_1 - 1;
-    if (!node->outputs().at(i)->hasUses()) {
-      node->eraseOutput(i);
-      for (Block* b : node->blocks()) {
-        b->eraseOutput(i);
+    mark(block);
+    sweep(block, recurse);
+  }
+
+ private:
+  void setLastWildcard() {
+    if (!aliasDb_) {
+      return;
+    }
+
+    const auto& wildcards = aliasDb_->getWildcardNodes();
+    if (wildcards.empty()) {
+      return;
+    }
+
+    lastWildcard_ = *wildcards.begin();
+    for (const auto wildcard : wildcards) {
+      if (wildcard->isAfter(*lastWildcard_)) {
+        lastWildcard_ = wildcard;
       }
     }
   }
-}
 
-void removeDeadLoopOutputs(Node* node) {
-  if (node->kind() != prim::Loop)
-    return;
-  auto loop_body = node->blocks().at(0);
-  auto loop_input_offset = 2; // offset of loop carried deps in input list
-  auto loop_body_offset = 1; // offset to the loop carried dependencies in block inputs/outputs
-
-  for(size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
-    size_t i = i_1 - 1;
-    if (!node->outputs().at(i)->hasUses() &&
-        !loop_body->inputs().at(loop_body_offset + i)->hasUses()) {
-      node->eraseOutput(i);
-      node->removeInput(loop_input_offset + i);
-      loop_body->eraseInput(loop_body_offset + i);
-      loop_body->eraseOutput(loop_body_offset + i);
+  void mark(Block* block) {
+    // Mark all nodes with side effects.
+    for (auto node : block->nodes()) {
+      if (hasSideEffects(node)) {
+        markAndEnqueue(node);
+      }
+    }
+
+    while (!workQueue_.empty()) {
+      auto node = workQueue_.front();
+      workQueue_.pop_front();
+
+      for (auto subBlock : node->blocks()) {
+        mark(subBlock);
+      }
+
+      // Mark all nodes in this node's blockchain (since owning nodes are
+      // considered live if they contain a live node)
+      if (node->owningBlock() != block) {
+        auto curNode = node;
+        while (curNode) {
+          if (!curNode->owningBlock()) {
+            break;
+          }
+
+          markAndEnqueue(curNode);
+          curNode = curNode->owningBlock()->owningNode();
+        }
+      }
+
+      // Find preceding writers for node, add to work list
+      if (aliasDb_) {
+        for (auto writer : aliasDb_->getWritersForNode(node)) {
+          if (writer->isBefore(node)) {
+            markAndEnqueue(writer);
+          }
+        }
+      }
+
+      // Find producers for all inputs, add to work list
+      for (auto input : node->inputs()) {
+        markAndEnqueue(input->node());
+      }
     }
   }
-}
 
-void EliminateDeadCode(Block *block, bool recurse, bool_memo_type& memo) {
-  auto nodes = block->nodes().reverse();
-  for (auto it = nodes.begin(); it != nodes.end(); it++) {
-    auto node = *it;
-    // note these occur before the recursion because we want to uncover
-    // dead code in the blocks used to calculate the output
-    removeDeadIfOutputs(node);
-    removeDeadLoopOutputs(node);
-    if (recurse) {
-      for (Block * block : node->blocks())
-        EliminateDeadCode(block, true, memo);
+  // Delete all unmarked nodes.
+  void sweep(Block* block, bool recurse) {
+    auto nodes = block->nodes().reverse();
+    for (auto it = nodes.begin(); it != nodes.end(); it++) {
+      auto node = *it;
+      // note these occur before the recursion because we want to uncover
+      // dead code in the blocks used to calculate the output
+      removeDeadIfOutputs(node);
+      removeDeadLoopOutputs(node);
+      if (recurse) {
+        for (Block* block : node->blocks()) {
+          sweep(block, true);
+        }
+      }
+      // TODO(suo): We shouldn't really have to check whether a node has uses,
+      // since the mark algorithm should do that. But currently, the marking
+      // doesn't reach loop counters in certain cases (see TestScript.test_pass)
+      if (!marked_.count(node) && !node->hasUses()) {
+        it.destroyCurrent();
+      }
     }
-    if (!node->hasUses() && !hasSideEffects(node, memo))
-      it.destroyCurrent();
   }
-}
+
+  void markAndEnqueue(Node* n) {
+    if (!marked_.count(n)) {
+      marked_.insert(n);
+      workQueue_.push_back(n);
+    }
+  }
+
+  bool hasUntrackedMutation(Node* node) {
+    if (!aliasDb_) {
+      // If we don't have alias information, all mutable ops have unknown
+      // effects and can't be considered for elimination.
+      if (!node->kind().is_aten()) {
+        return false;
+      }
+      // onnx export calls EliminateDeadCode but sometimes passes invalid
+      // aten operators. So we call maybeSchema so we handle the cases when
+      // there is no valid schema for a node
+      auto schema = node->maybeSchema();
+      return schema && schema->is_mutable();
+    } else {
+      // Otherwise, there are two kinds of nodes with untracked effects:
+      // 1. Nodes that write to a value that may alias the graph inputs (since
+      //    the inputs can be used outside the graph).
+      // 2. Anything that could clobber a wildcard value.
+      bool touchesWildcard = false;
+      if (lastWildcard_) {
+        touchesWildcard = aliasDb_->hasWrites(node) &&
+            (node->isBefore(*lastWildcard_) || node == *lastWildcard_);
+      }
+      return aliasDb_->writesToInputAlias(node) || touchesWildcard;
+    }
+  }
+
+  bool hasSideEffects(Node* node) {
+    // FIXME: PythonOp should be treated as having side effects as well!
+    //        Unfortunately ONNX depends on it getting removed in this pass, so
+    //        it's not a simple change.
+    auto it = memo_.find(node);
+    if (it != memo_.end())
+      return it->second;
+    bool has_side_effects = node->kind() == prim::Print ||
+        node->kind() == prim::RaiseException ||
+        std::any_of(node->blocks().begin(),
+                    node->blocks().end(),
+                    [&](Block* b) {
+                      return std::any_of(
+                          b->nodes().begin(), b->nodes().end(), [&](Node* n) {
+                            return hasSideEffects(n);
+                          });
+                    }) ||
+        hasUntrackedMutation(node);
+
+    memo_.emplace(node, has_side_effects);
+    return has_side_effects;
+  }
+
+  void removeDeadIfOutputs(Node* node) {
+    if (node->kind() != prim::If)
+      return;
+
+    for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
+      size_t i = i_1 - 1;
+      if (!node->outputs().at(i)->hasUses()) {
+        node->eraseOutput(i);
+        for (Block* b : node->blocks()) {
+          b->eraseOutput(i);
+        }
+      }
+    }
+  }
+
+  void removeDeadLoopOutputs(Node* node) {
+    if (node->kind() != prim::Loop)
+      return;
+    auto loop_body = node->blocks().at(0);
+    auto loop_input_offset = 2; // offset of loop carried deps in input list
+    auto loop_body_offset =
+        1; // offset to the loop carried dependencies in block inputs/outputs
+
+    for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
+      size_t i = i_1 - 1;
+      if (!node->outputs().at(i)->hasUses() &&
+          !loop_body->inputs().at(loop_body_offset + i)->hasUses()) {
+        node->eraseOutput(i);
+        node->removeInput(loop_input_offset + i);
+        loop_body->eraseInput(loop_body_offset + i);
+        loop_body->eraseOutput(loop_body_offset + i);
+      }
+    }
+  }
+
+  c10::optional<AliasDb> aliasDb_;
+  std::unordered_map<Node*, bool> memo_;
+  std::unordered_set<Node*> marked_;
+  std::list<Node*> workQueue_;
+  c10::optional<const Node*> lastWildcard_;
+
+};
 
 void EliminateDeadCode(const std::shared_ptr<Graph>& graph) {
-  bool_memo_type side_effect_memo;
-  EliminateDeadCode(graph->block(), true, side_effect_memo);
+  DeadCodeEliminator(graph).run(graph->block(), true);
 }
 
-void EliminateDeadCode(Block *block, bool recurse) {
-  bool_memo_type side_effect_memo;
-  EliminateDeadCode(block, recurse, side_effect_memo);
+void EliminateDeadCode(Block* block, bool recurse) {
+  DeadCodeEliminator().run(block, recurse);
 }
 
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
index 9ae89f9..39aec2f 100644 (file)
@@ -4,6 +4,12 @@
 
 namespace torch { namespace jit {
 
+// If given a top-level graph, DCE will construct do alias analysis that allows
+// for "smarter" dead code elimination (we will eliminate mutable ops if we can
+// prove the mutated values are not used). Otherwise, we will not allow DCE to
+// eliminate mutable ops.
+//
+// So, prefer to use the graph version if you can.
 TORCH_API void EliminateDeadCode(const std::shared_ptr<Graph>& graph);
 TORCH_API void EliminateDeadCode(Block *block, bool recurse=true);
 
index eeb87ff..bddf1d9 100644 (file)
@@ -2,53 +2,51 @@
 
 #include "torch/csrc/jit/ir.h"
 #include "torch/csrc/jit/passes/dead_code_elimination.h"
+#include "torch/csrc/jit/passes/utils/subgraph_utils.h"
 
-namespace torch { namespace jit {
+namespace torch {
+namespace jit {
 
 namespace {
 
-bool canRunWithAutograd(Node *node) {
+bool canRunWithAutograd(Nodenode) {
   return node->kind() != prim::FusionGroup;
 }
 
-void inlineNode(Node *node) {
-  WithInsertPoint insert_guard { node };
-  Graph * graph = node->owningGraph();
-  auto subgraph = node->g(attr::Subgraph);
-  std::unordered_map<Value*, Value*> input_map;
+void InlineAutodiffSubgraphs(Block* block, size_t threshold);
+
+graph_node_list::iterator scanNode(Node* node, size_t threshold) {
+  auto next_node = ++node->iterator();
+
+  for (Block* block : node->blocks()) {
+    InlineAutodiffSubgraphs(block, threshold);
+  }
 
-  size_t num_inputs = node->inputs().size();
-  JIT_ASSERT(num_inputs == subgraph->inputs().size());
-  for (size_t i = 0; i < num_inputs; ++i) {
-    input_map[subgraph->inputs()[i]] = node->inputs()[i];
+  if (node->kind() != prim::DifferentiableGraph) {
+    return next_node;
   }
 
-  for (Node * subnode : subgraph->nodes()) {
-    Node * new_node = graph->insertNode(graph->createClone(subnode, [&](Value * v) { return input_map.at(v); }));
-    for (size_t i = 0; i < subnode->outputs().size(); ++i) {
-      input_map[subnode->output(i)] = new_node->output(i);
-    }
+  auto subgraph = node->g(attr::Subgraph);
+  int64_t subgraph_size =
+      std::distance(subgraph->nodes().begin(), subgraph->nodes().end());
+  if (subgraph_size >= static_cast<int64_t>(threshold)) {
+    return next_node;
   }
 
-  size_t num_outputs = node->outputs().size();
-  JIT_ASSERT(num_outputs <= subgraph->outputs().size() &&
-             num_outputs == static_cast<size_t>(node->i(attr::f_real_outputs)));
-  for (size_t i = 0; i < num_outputs; ++i) {
-    node->output(i)->replaceAllUsesWith(input_map.at(subgraph->outputs()[i]));
+  if (!std::all_of(
+          subgraph->nodes().begin(),
+          subgraph->nodes().end(),
+          canRunWithAutograd)) {
+    return next_node;
   }
+
+  SubgraphUtils::unmergeSubgraph(node);
+  return next_node;
 }
 
-void InlineAutodiffSubgraphs(Block *block, size_t threshold) {
-  for (Node * node : block->nodes()) {
-    for (Block * block : node->blocks()) {
-      InlineAutodiffSubgraphs(block, threshold);
-    }
-    if (node->kind() != prim::DifferentiableGraph) continue;
-    auto subgraph = node->g(attr::Subgraph);
-    int64_t subgraph_size = std::distance(subgraph->nodes().begin(), subgraph->nodes().end());
-    if (subgraph_size >= static_cast<int64_t>(threshold)) continue;
-    if (!std::all_of(subgraph->nodes().begin(), subgraph->nodes().end(), canRunWithAutograd)) continue;
-    inlineNode(node);
+void InlineAutodiffSubgraphs(Block* block, size_t threshold) {
+  for (auto it = block->nodes().begin(); it != block->nodes().end();) {
+    it = scanNode(*it, threshold);
   }
 }
 
@@ -59,4 +57,5 @@ void InlineAutodiffSubgraphs(std::shared_ptr<Graph>& graph, size_t threshold) {
   EliminateDeadCode(graph);
 }
 
-}} // namespace torch::jit
+} // namespace jit
+} // namespace torch
index 3e1eebc..c441f96 100644 (file)
@@ -157,7 +157,7 @@ static void EnsureNoTuples(Block* block) {
 
 void LowerAllTuples(std::shared_ptr<Graph>& graph) {
   LowerAllTuples(graph->block());
-  EliminateDeadCode(graph);
+  EliminateDeadCode(graph->block());
   EnsureNoTuples(graph->block());
 }
 
index 80b0641..e6a0d66 100644 (file)
@@ -206,7 +206,7 @@ class ShapePropagator {
         std::any_of(writers.cbegin(), writers.cend(), [&](const Node* writer) {
           return writer->isBefore(node);
         });
-    if (hasWritersBefore) {
+    if (hasWritersBefore || aliasDb_.hasWildcard(node)) {
       // If something could have written to a value used by this node, we can't
       // guarantee the result is the same when running it in isolation.
       dependsOnMutationMemo_[node] = true;
@@ -702,7 +702,6 @@ class ShapePropagator {
           return {};
         }};
 
     static const auto any_tensor_type = [](Node* node) -> TensorTypePtr {
       for (Value* input : node->inputs()) {
         if (auto type = input->type()->cast<TensorType>()) {
index 0a7885e..691829b 100644 (file)
@@ -57,7 +57,8 @@ std::vector<Node*> unmergeSubgraph(Node* subgraphNode) {
 
   // Replace uses of group outputs and destroy the group
   const auto subgraphOutputs = subgraph->outputs();
-  for (size_t i = 0; i < subgraphOutputs.size(); ++i) {
+  JIT_ASSERT(subgraphOutputs.size() >= subgraphNode->outputs().size());
+  for (size_t i = 0; i < subgraphNode->outputs().size(); ++i) {
     const auto outerOutput = innerToOuter.at(subgraphOutputs[i]);
     subgraphNode->outputs()[i]->replaceAllUsesWith(outerOutput);
   }